1 #ifndef __STAN__AGRAD__AGRAD_SPECIAL_FUNCTIONS_HPP__
2 #define __STAN__AGRAD__AGRAD_SPECIAL_FUNCTIONS_HPP__
7 #include <boost/math/special_functions/acosh.hpp>
8 #include <boost/math/special_functions/asinh.hpp>
9 #include <boost/math/special_functions/atanh.hpp>
10 #include <boost/math/special_functions/digamma.hpp>
11 #include <boost/math/special_functions/hypot.hpp>
19 class lgamma_vari :
public op_v_vari {
21 lgamma_vari(
double value, vari* avi) :
22 op_v_vari(value, avi) {
25 avi_->adj_ += adj_ * boost::math::digamma(
avi_->val_);
29 class tgamma_vari :
public op_v_vari {
31 tgamma_vari(vari* avi) :
35 avi_->adj_ += adj_ * val_ * boost::math::digamma(
avi_->val_);
39 class log1p_vari :
public op_v_vari {
41 log1p_vari(vari* avi) :
42 op_v_vari(
stan::math::
log1p(avi->val_),avi) {
45 avi_->adj_ += adj_ / (1 +
avi_->val_);
49 class log1m_vari :
public op_v_vari {
51 log1m_vari(vari* avi) :
52 op_v_vari(
stan::math::
log1p(-avi->val_),avi) {
55 avi_->adj_ += adj_ / (
avi_->val_ - 1);
59 class binary_log_loss_1_vari :
public op_v_vari {
61 binary_log_loss_1_vari(vari* avi) :
62 op_v_vari(-
std::
log(avi->val_),avi) {
69 class binary_log_loss_0_vari :
public op_v_vari {
71 binary_log_loss_0_vari(vari* avi) :
72 op_v_vari(-
stan::math::
log1p(-avi->val_),avi) {
75 avi_->adj_ += adj_ / (1.0 -
avi_->val_);
79 class fdim_vv_vari :
public op_vv_vari {
81 fdim_vv_vari(vari* avi, vari* bvi) :
82 op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
90 class fdim_vd_vari :
public op_v_vari {
92 fdim_vd_vari(vari* avi,
double b) :
93 op_v_vari(avi->val_ - b, avi) {
100 class fdim_dv_vari :
public op_v_vari {
102 fdim_dv_vari(
double a, vari* bvi) :
103 op_v_vari(a - bvi->val_, bvi) {
111 class fma_vvv_vari :
public op_vvv_vari {
113 fma_vvv_vari(vari* avi, vari* bvi, vari* cvi) :
114 op_vvv_vari(avi->val_ * bvi->val_ + cvi->val_,
124 class fma_vvd_vari :
public op_vv_vari {
126 fma_vvd_vari(vari* avi, vari* bvi,
double c) :
127 op_vv_vari(avi->val_ * bvi->val_ + c,
136 class fma_vdv_vari :
public op_vdv_vari {
138 fma_vdv_vari(vari* avi,
double b, vari* cvi) :
139 op_vdv_vari(avi->val_ * b + cvi->val_,
148 class fma_vdd_vari :
public op_vd_vari {
150 fma_vdd_vari(vari* avi,
double b,
double c) :
151 op_vd_vari(avi->val_ * b + c,
159 class fma_ddv_vari :
public op_v_vari {
161 fma_ddv_vari(
double a,
double b, vari* cvi) :
162 op_v_vari(a * b + cvi->val_,
171 class inv_logit_vari :
public op_v_vari {
173 inv_logit_vari(vari* avi) :
174 op_v_vari(math::
inv_logit(avi->val_),avi) {
177 avi_->adj_ += adj_ * val_ * (1.0 - val_);
181 class acosh_vari :
public op_v_vari {
183 acosh_vari(vari* avi) :
184 op_v_vari(
boost::math::
acosh(avi->val_),avi) {
191 class asinh_vari :
public op_v_vari {
193 asinh_vari(vari* avi) :
194 op_v_vari(
boost::math::
asinh(avi->val_),avi) {
201 class atanh_vari :
public op_v_vari {
203 atanh_vari(vari* avi) :
204 op_v_vari(
boost::math::
atanh(avi->val_),avi) {
211 const double TWO_OVER_SQRT_PI = 2.0 /
std::sqrt(boost::math::constants::pi<double>());
213 class erf_vari :
public op_v_vari {
215 erf_vari(vari* avi) :
216 op_v_vari(
boost::math::
erf(avi->val_),avi) {
223 const double NEG_TWO_OVER_SQRT_PI = - TWO_OVER_SQRT_PI;
225 class erfc_vari :
public op_v_vari {
227 erfc_vari(vari* avi) :
228 op_v_vari(
boost::math::
erfc(avi->val_),avi) {
237 class exp2_vari :
public op_v_vari {
239 exp2_vari(vari* avi) :
240 op_v_vari(
std::
pow(2.0,avi->val_),avi) {
247 class expm1_vari :
public op_v_vari {
249 expm1_vari(vari* avi) :
250 op_v_vari(
std::
exp(avi->val_) - 1.0,avi) {
253 avi_->adj_ += adj_ * val_;
257 class hypot_vv_vari :
public op_vv_vari {
259 hypot_vv_vari(vari* avi, vari* bvi) :
260 op_vv_vari(
boost::math::
hypot(avi->val_,bvi->val_),
264 avi_->adj_ += adj_ *
avi_->val_ / val_;
265 bvi_->adj_ += adj_ *
bvi_->val_ / val_;
269 class hypot_vd_vari :
public op_v_vari {
271 hypot_vd_vari(vari* avi,
double b) :
276 avi_->adj_ += adj_ *
avi_->val_ / val_;
282 class log2_vari :
public op_v_vari {
284 log2_vari(vari* avi) :
285 op_v_vari(
stan::math::
log2(avi->val_),avi) {
288 avi_->adj_ += adj_ / (LOG2 *
avi_->val_);
292 class cbrt_vari :
public op_v_vari {
294 cbrt_vari(vari* avi) :
295 op_v_vari(
boost::math::
cbrt(avi->val_),avi) {
298 avi_->adj_ += adj_ / (3.0 * val_ * val_);
303 class round_vari :
public vari {
305 round_vari(vari* avi) :
311 class trunc_vari :
public vari {
313 trunc_vari(vari* avi) :
318 class inv_cloglog_vari :
public op_v_vari {
320 inv_cloglog_vari(vari* avi) :
328 class Phi_vari :
public op_v_vari {
330 Phi_vari(vari* avi) :
331 op_v_vari(
stan::math::
Phi(avi->val_), avi) {
334 static const double NEG_HALF = -0.5;
335 static const double INV_SQRT_TWO_PI
336 = 1.0 /
std::sqrt(2.0 * boost::math::constants::pi<double>());
341 inline double calculate_chain(
const double& x,
const double& val) {
345 double log_sum_exp_as_double(
const std::vector<var>& x) {
346 using std::numeric_limits;
349 double max = -numeric_limits<double>::infinity();
350 for (
size_t i = 0; i < x.size(); ++i)
354 for (
size_t i = 0; i < x.size(); ++i)
355 if (x[i] != -numeric_limits<double>::infinity())
360 class log1p_exp_v_vari :
public op_v_vari {
362 log1p_exp_v_vari(vari* avi) :
367 avi_->adj_ += adj_ * calculate_chain(
avi_->val_, val_);
370 class log_sum_exp_vv_vari :
public op_vv_vari {
372 log_sum_exp_vv_vari(vari* avi, vari* bvi) :
377 avi_->adj_ += adj_ * calculate_chain(
avi_->val_, val_);
378 bvi_->adj_ += adj_ * calculate_chain(
bvi_->val_, val_);
381 class log_sum_exp_vd_vari :
public op_vd_vari {
383 log_sum_exp_vd_vari(vari* avi,
double b) :
388 avi_->adj_ += adj_ * calculate_chain(
avi_->val_, val_);
391 class log_sum_exp_dv_vari :
public op_dv_vari {
393 log_sum_exp_dv_vari(
double a, vari* bvi) :
398 bvi_->adj_ += adj_ * calculate_chain(
bvi_->val_, val_);
401 class log_sum_exp_vector_vari :
public op_vector_vari {
403 log_sum_exp_vector_vari(
const std::vector<var>& x) :
404 op_vector_vari(log_sum_exp_as_double(x), x) {
407 for (
size_t i = 0; i <
size_; ++i) {
408 vis_[i]->adj_ += adj_ * calculate_chain(
vis_[i]->val_, val_);
413 class square_vari :
public op_v_vari {
415 square_vari(vari* avi) :
416 op_v_vari(avi->val_ * avi->val_,avi) {
419 avi_->adj_ += adj_ * 2.0 *
avi_->val_;
423 class multiply_log_vv_vari :
public op_vv_vari {
425 multiply_log_vv_vari(vari* avi, vari* bvi) :
431 if (
bvi_->val_ == 0.0 &&
avi_->val_ == 0)
432 bvi_->adj_ += adj_ * std::numeric_limits<double>::infinity();
437 class multiply_log_vd_vari :
public op_vd_vari {
439 multiply_log_vd_vari(vari* avi,
double b) :
447 class multiply_log_dv_vari :
public op_dv_vari {
449 multiply_log_dv_vari(
double a, vari* bvi) :
453 if (
bvi_->val_ == 0.0 &&
ad_ == 0.0)
454 bvi_->adj_ += adj_ * std::numeric_limits<double>::infinity();
465 double ibeta_hypergeometric_helper(
double a,
double b,
double z,
double precision=1
e-8,
double max_steps=1000) {
479 class ibeta_vvv_vari :
public op_vvv_vari {
481 ibeta_vvv_vari(vari* avi, vari* bvi, vari* xvi) :
482 op_vvv_vari(
stan::math::
ibeta(avi->val_,bvi->val_,xvi->val_),avi,bvi,xvi) {
485 double a =
avi_->val_;
486 double b =
bvi_->val_;
487 double c =
cvi_->val_;
494 using boost::math::digamma;
496 using stan::agrad::ibeta_hypergeometric_helper;
498 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
502 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
504 boost::math::ibeta_derivative(a,b,c);
507 class ibeta_vvd_vari :
public op_vvd_vari {
509 ibeta_vvd_vari(vari* avi, vari* bvi,
double x) :
510 op_vvd_vari(
stan::math::
ibeta(avi->val_,bvi->val_,x),avi,bvi,x) {
513 double a =
avi_->val_;
514 double b =
bvi_->val_;
522 using boost::math::digamma;
524 using stan::agrad::ibeta_hypergeometric_helper;
526 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
530 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
533 class ibeta_vdv_vari :
public op_vdv_vari {
535 ibeta_vdv_vari(vari* avi,
double b, vari* xvi) :
536 op_vdv_vari(
stan::math::
ibeta(avi->val_,b,xvi->val_),avi,b,xvi) {
539 double a =
avi_->val_;
541 double c =
cvi_->val_;
548 using boost::math::digamma;
550 using stan::agrad::ibeta_hypergeometric_helper;
552 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
555 boost::math::ibeta_derivative(a,b,c);
558 class ibeta_vdd_vari :
public op_vdd_vari {
560 ibeta_vdd_vari(vari* avi,
double b,
double x) :
561 op_vdd_vari(
stan::math::
ibeta(avi->val_,b,x),avi,b,x) {
564 double a =
avi_->val_;
573 using boost::math::digamma;
575 using stan::agrad::ibeta_hypergeometric_helper;
577 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
581 class ibeta_dvv_vari :
public op_dvv_vari {
583 ibeta_dvv_vari(
double a, vari* bvi, vari* xvi) :
584 op_dvv_vari(
stan::math::
ibeta(a,bvi->val_,xvi->val_),a,bvi,xvi) {
588 double b =
bvi_->val_;
589 double c =
cvi_->val_;
596 using boost::math::digamma;
598 using stan::agrad::ibeta_hypergeometric_helper;
601 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
603 boost::math::ibeta_derivative(a,b,c);
606 class ibeta_dvd_vari :
public op_dvd_vari {
608 ibeta_dvd_vari(
double a, vari* bvi,
double x) :
609 op_dvd_vari(
stan::math::
ibeta(a,bvi->val_,x),a,bvi,x) {
613 double b =
bvi_->val_;
621 using boost::math::digamma;
623 using stan::agrad::ibeta_hypergeometric_helper;
626 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
629 class ibeta_ddv_vari :
public op_ddv_vari {
631 ibeta_ddv_vari(
double a,
double b, vari* xvi) :
632 op_ddv_vari(
stan::math::
ibeta(a,b,xvi->val_),a,b,xvi) {
637 double c =
cvi_->val_;
640 boost::math::ibeta_derivative(a,b,c);
662 return var(
new acosh_vari(a.
vi_));
678 return var(
new asinh_vari(a.
vi_));
694 return var(
new atanh_vari(a.
vi_));
710 return var(
new erf_vari(a.
vi_));
726 return var(
new erfc_vari(a.
vi_));
742 return var(
new exp2_vari(a.
vi_));
758 return var(
new expm1_vari(a.
vi_));
773 return var(
new lgamma_vari(lgamma_a, a.
vi_));
787 return var(
new log1p_vari(a.
vi_));
801 return var(
new log1m_vari(a.
vi_));
851 return var(
new fma_vvd_vari(a.
vi_,b.
vi_,c));
875 return var(
new fma_vdv_vari(a.
vi_,b,c.
vi_));
897 return var(
new fma_vdd_vari(a.
vi_,b,c));
919 return var(
new fma_vdd_vari(b.
vi_,a,c));
941 return var(
new fma_ddv_vari(a,b,c.
vi_));
965 return var(
new fma_vdv_vari(b.
vi_,a,c.
vi_));
1109 return var(
new hypot_vv_vari(a.
vi_,b.
vi_));
1128 return var(
new hypot_vd_vari(a.
vi_,b));
1147 return var(
new hypot_vd_vari(b.
vi_,a));
1163 return var(
new log2_vari(a.
vi_));
1179 return var(
new cbrt_vari(a.
vi_));
1197 return var(
new round_vari(a.
vi_));
1214 return var(
new trunc_vari(a.
vi_));
1244 return var(
new fdim_vv_vari(a.
vi_,b.
vi_));
1269 ?
var(
new fdim_dv_vari(a,b.
vi_))
1292 ?
var(
new fdim_vd_vari(a.
vi_,b))
1313 return var(
new tgamma_vari(a.
vi_));
1351 return var(
new inv_cloglog_vari(a.
vi_));
1367 return var(
new Phi_vari(a.
vi_));
1383 return var(
new inv_logit_vari(a.
vi_));
1404 ?
var(
new binary_log_loss_0_vari(y_hat.
vi_))
1405 :
var(
new binary_log_loss_1_vari(y_hat.
vi_));
1413 return var(
new log1p_exp_v_vari(a.
vi_));
1421 return var(
new log_sum_exp_vv_vari(a.
vi_, b.
vi_));
1428 return var(
new log_sum_exp_vd_vari(a.
vi_, b));
1435 return var(
new log_sum_exp_dv_vari(a, b.
vi_));
1441 return var(
new log_sum_exp_vector_vari(x));
1454 return var(
new square_vari(x.
vi_));
1471 return var(
new multiply_log_vv_vari(a.
vi_,b.
vi_));
1484 return var(
new multiply_log_vd_vari(a.
vi_,b));
1500 return var(
new multiply_log_dv_vari(a,b.
vi_));
1513 return c ? y_true : y_false;
1543 return var(y_false);
Independent (input) and dependent (output) variables for gradients.
double val() const
Return the value of this variable.
vari * vi_
Pointer to the implementation of this variable.
The variable implementation base class.
const double val_
The value of this variable.
Reimplementing boost functionality.
var inv_logit(const stan::agrad::var &a)
The inverse logit function for variables (stan).
var expm1(const stan::agrad::var &a)
The exponentiation of the specified variable minus 1 (C99).
int as_bool(const agrad::var &v)
Return 1 if the argument is unequal to zero and 0 otherwise.
var multiply_log(const var &a, const var &b)
Return the value of a*log(b).
var asinh(const stan::agrad::var &a)
The inverse hyperbolic sine function for variables (C99).
var log1p_exp(const stan::agrad::var &a)
Return the log of 1 plus the exponential of the specified variable.
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
var square(const var &x)
Return the square of the input variable.
var fma(const stan::agrad::var &a, const stan::agrad::var &b, const stan::agrad::var &c)
The fused multiply-add function for three variables (C99).
var if_else(bool c, const var &y_true, const var &y_false)
If the specified condition is true, return the first variable, otherwise return the second variable.
var abs(const var &a)
Return the absolute value of the variable (std).
var exp2(const stan::agrad::var &a)
Exponentiation base 2 function for variables (C99).
var step(const stan::agrad::var &a)
Return the step, or heaviside, function applied to the specified variable (stan).
var log1m(const stan::agrad::var &a)
The log (1 - x) function for variables.
var ibeta(const var &a, const var &b, const var &x)
The normalized incomplete beta function of a, b, and x.
var acosh(const stan::agrad::var &a)
The inverse hyperbolic cosine function for variables (C99).
var cbrt(const stan::agrad::var &a)
Returns the cube root of the specified variable (C99).
var erfc(const stan::agrad::var &a)
The complementary error function for variables (C99).
var log1p(const stan::agrad::var &a)
The log (1 + x) function for variables (C99).
var fmax(const stan::agrad::var &a, const stan::agrad::var &b)
Returns the maximum of the two variable arguments (C99).
var log_loss(const int &y, const stan::agrad::var &y_hat)
The log loss function for variables (stan).
var sin(const var &a)
Return the sine of a radian-scaled variable (cmath).
var fdim(const stan::agrad::var &a, const stan::agrad::var &b)
Return the positive difference between the first variable's the value and the second's (C99).
double value_of(const agrad::var &v)
Return the value of the specified variable.
var lgamma(const stan::agrad::var &a)
The log gamma function for variables (C99).
var erf(const stan::agrad::var &a)
The error function for variables (C99).
var Phi(const stan::agrad::var &a)
The unit normal cumulative density function for variables (stan).
var pow(const var &base, const var &exponent)
Return the base raised to the power of the exponent (cmath).
var log_sum_exp(const stan::agrad::var &a, const stan::agrad::var &b)
Returns the log sum of exponentials.
var log(const var &a)
Return the natural log of the specified variable (cmath).
var log2(const stan::agrad::var &a)
Returns the base 2 logarithm of the specified variable (C99).
var hypot(const stan::agrad::var &a, const stan::agrad::var &b)
Returns the length of the hypoteneuse of a right triangle with sides of the specified lengths (C99).
var tgamma(const stan::agrad::var &a)
Return the Gamma function applied to the specified variable (C99).
var round(const stan::agrad::var &a)
Returns the rounded form of the specified variable (C99).
var sum(const Eigen::Matrix< var, R, C > &m)
Returns the sum of the coefficients of the specified matrix, column vector or row vector.
var trunc(const stan::agrad::var &a)
Returns the truncatation of the specified variable (C99).
var atanh(const stan::agrad::var &a)
The inverse hyperbolic tangent function for variables (C99).
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
var fmin(const stan::agrad::var &a, const stan::agrad::var &b)
Returns the minimum of the two variable arguments (C99).
var inv_cloglog(const stan::agrad::var &a)
Return the inverse complementary log-log function applied specified variable (stan).
double e()
Return the base of the natural logarithm.
double ibeta(const double &a, const double &b, const double &x)
The normalized incomplete beta function of a, b, and x.
const double LOG_2
The natural logarithm of 2, .
bool check_greater_or_equal(const char *function, const T_y &y, const T_low &low, const char *name, T_result *result, const Policy &)
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
bool check_not_nan(const char *function, const T_y &y, const char *name, T_result *result, const Policy &)
Checks if the variable y is nan.
double pi()
Return the value of pi.
Probability, optimization and sampling library.
Template specification of functions in std for Stan.
int isnan(const stan::agrad::var &a)
Checks if the given number is NaN.