1 #ifndef __STAN__AGRAD__AGRAD_SPECIAL_FUNCTIONS_HPP__
2 #define __STAN__AGRAD__AGRAD_SPECIAL_FUNCTIONS_HPP__
7 #include <boost/math/special_functions/acosh.hpp>
8 #include <boost/math/special_functions/asinh.hpp>
9 #include <boost/math/special_functions/atanh.hpp>
10 #include <boost/math/special_functions/digamma.hpp>
11 #include <boost/math/special_functions/hypot.hpp>
19 class lgamma_vari :
public op_v_vari {
21 lgamma_vari(
double value, vari* avi) :
22 op_v_vari(value, avi) {
25 avi_->adj_ += adj_ * boost::math::digamma(
avi_->val_);
29 class tgamma_vari :
public op_v_vari {
31 tgamma_vari(vari* avi) :
32 op_v_vari(boost::math::
tgamma(avi->val_), avi) {
35 avi_->adj_ += adj_ * val_ * boost::math::digamma(
avi_->val_);
39 class log1p_vari :
public op_v_vari {
41 log1p_vari(vari* avi) :
42 op_v_vari(stan::math::
log1p(avi->val_),avi) {
45 avi_->adj_ += adj_ / (1 +
avi_->val_);
49 class log1m_vari :
public op_v_vari {
51 log1m_vari(vari* avi) :
52 op_v_vari(stan::math::
log1p(-avi->val_),avi) {
55 avi_->adj_ += adj_ / (
avi_->val_ - 1);
59 class binary_log_loss_1_vari :
public op_v_vari {
61 binary_log_loss_1_vari(vari* avi) :
62 op_v_vari(-std::
log(avi->val_),avi) {
69 class binary_log_loss_0_vari :
public op_v_vari {
71 binary_log_loss_0_vari(vari* avi) :
72 op_v_vari(-stan::math::
log1p(-avi->val_),avi) {
75 avi_->adj_ += adj_ / (1.0 -
avi_->val_);
79 class fdim_vv_vari :
public op_vv_vari {
81 fdim_vv_vari(vari* avi, vari* bvi) :
82 op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
90 class fdim_vd_vari :
public op_v_vari {
92 fdim_vd_vari(vari* avi,
double b) :
93 op_v_vari(avi->val_ - b, avi) {
100 class fdim_dv_vari :
public op_v_vari {
102 fdim_dv_vari(
double a, vari* bvi) :
103 op_v_vari(a - bvi->val_, bvi) {
111 class fma_vvv_vari :
public op_vvv_vari {
113 fma_vvv_vari(vari* avi, vari* bvi, vari* cvi) :
114 op_vvv_vari(avi->val_ * bvi->val_ + cvi->val_,
124 class fma_vvd_vari :
public op_vv_vari {
126 fma_vvd_vari(vari* avi, vari* bvi,
double c) :
127 op_vv_vari(avi->val_ * bvi->val_ + c,
136 class fma_vdv_vari :
public op_vdv_vari {
138 fma_vdv_vari(vari* avi,
double b, vari* cvi) :
139 op_vdv_vari(avi->val_ * b + cvi->val_,
148 class fma_vdd_vari :
public op_vd_vari {
150 fma_vdd_vari(vari* avi,
double b,
double c) :
151 op_vd_vari(avi->val_ * b + c,
159 class fma_ddv_vari :
public op_v_vari {
161 fma_ddv_vari(
double a,
double b, vari* cvi) :
162 op_v_vari(a * b + cvi->val_,
171 class inv_logit_vari :
public op_v_vari {
173 inv_logit_vari(vari* avi) :
174 op_v_vari(math::
inv_logit(avi->val_),avi) {
177 avi_->adj_ += adj_ * val_ * (1.0 - val_);
181 class acosh_vari :
public op_v_vari {
183 acosh_vari(vari* avi) :
184 op_v_vari(boost::math::
acosh(avi->val_),avi) {
191 class asinh_vari :
public op_v_vari {
193 asinh_vari(vari* avi) :
194 op_v_vari(boost::math::
asinh(avi->val_),avi) {
201 class atanh_vari :
public op_v_vari {
203 atanh_vari(vari* avi) :
204 op_v_vari(boost::math::
atanh(avi->val_),avi) {
211 const double TWO_OVER_SQRT_PI = 2.0 /
std::sqrt(boost::math::constants::pi<double>());
213 class erf_vari :
public op_v_vari {
215 erf_vari(vari* avi) :
216 op_v_vari(boost::math::
erf(avi->val_),avi) {
223 const double NEG_TWO_OVER_SQRT_PI = - TWO_OVER_SQRT_PI;
225 class erfc_vari :
public op_v_vari {
227 erfc_vari(vari* avi) :
228 op_v_vari(boost::math::
erfc(avi->val_),avi) {
237 class exp2_vari :
public op_v_vari {
239 exp2_vari(vari* avi) :
240 op_v_vari(std::
pow(2.0,avi->val_),avi) {
247 class expm1_vari :
public op_v_vari {
249 expm1_vari(vari* avi) :
250 op_v_vari(std::
exp(avi->val_) - 1.0,avi) {
253 avi_->adj_ += adj_ * val_;
257 class hypot_vv_vari :
public op_vv_vari {
259 hypot_vv_vari(vari* avi, vari* bvi) :
260 op_vv_vari(boost::math::
hypot(avi->val_,bvi->val_),
264 avi_->adj_ += adj_ *
avi_->val_ / val_;
265 bvi_->adj_ += adj_ *
bvi_->val_ / val_;
269 class hypot_vd_vari :
public op_v_vari {
271 hypot_vd_vari(vari* avi,
double b) :
272 op_v_vari(boost::math::
hypot(avi->val_,b),
276 avi_->adj_ += adj_ *
avi_->val_ / val_;
282 class log2_vari :
public op_v_vari {
284 log2_vari(vari* avi) :
285 op_v_vari(stan::math::
log2(avi->val_),avi) {
288 avi_->adj_ += adj_ / (LOG2 *
avi_->val_);
292 class cbrt_vari :
public op_v_vari {
294 cbrt_vari(vari* avi) :
295 op_v_vari(boost::math::
cbrt(avi->val_),avi) {
298 avi_->adj_ += adj_ / (3.0 * val_ * val_);
303 class round_vari :
public vari {
305 round_vari(vari* avi) :
306 vari(boost::math::
round(avi->val_)) {
311 class trunc_vari :
public vari {
313 trunc_vari(vari* avi) :
314 vari(boost::math::
trunc(avi->val_)) {
318 class inv_cloglog_vari :
public op_v_vari {
320 inv_cloglog_vari(vari* avi) :
321 op_v_vari(stan::math::
inv_cloglog(avi->val_), avi) {
328 class Phi_vari :
public op_v_vari {
330 Phi_vari(vari* avi) :
331 op_v_vari(stan::math::
Phi(avi->val_), avi) {
334 static const double NEG_HALF = -0.5;
335 static const double INV_SQRT_TWO_PI
336 = 1.0 /
std::sqrt(2.0 * boost::math::constants::pi<double>());
341 inline double calculate_chain(
const double& x,
const double& val) {
345 double log_sum_exp_as_double(
const std::vector<var>& x) {
346 using std::numeric_limits;
349 double max = -numeric_limits<double>::infinity();
350 for (
size_t i = 0; i < x.size(); ++i)
354 for (
size_t i = 0; i < x.size(); ++i)
355 if (x[i] != -numeric_limits<double>::infinity())
356 sum +=
exp(x[i].val() -
max);
357 return max +
log(sum);
360 class log1p_exp_v_vari :
public op_v_vari {
362 log1p_exp_v_vari(vari* avi) :
363 op_v_vari(stan::math::
log1p_exp(avi->val_),
367 avi_->adj_ += adj_ * calculate_chain(
avi_->val_, val_);
370 class log_sum_exp_vv_vari :
public op_vv_vari {
372 log_sum_exp_vv_vari(vari* avi, vari* bvi) :
373 op_vv_vari(stan::math::
log_sum_exp(avi->val_, bvi->val_),
377 avi_->adj_ += adj_ * calculate_chain(
avi_->val_, val_);
378 bvi_->adj_ += adj_ * calculate_chain(
bvi_->val_, val_);
381 class log_sum_exp_vd_vari :
public op_vd_vari {
383 log_sum_exp_vd_vari(vari* avi,
double b) :
388 avi_->adj_ += adj_ * calculate_chain(
avi_->val_, val_);
391 class log_sum_exp_dv_vari :
public op_dv_vari {
393 log_sum_exp_dv_vari(
double a, vari* bvi) :
398 bvi_->adj_ += adj_ * calculate_chain(
bvi_->val_, val_);
401 class log_sum_exp_vector_vari :
public op_vector_vari {
403 log_sum_exp_vector_vari(
const std::vector<var>& x) :
404 op_vector_vari(log_sum_exp_as_double(x), x) {
407 for (
size_t i = 0; i <
size_; ++i) {
408 vis_[i]->adj_ += adj_ * calculate_chain(
vis_[i]->val_, val_);
413 class square_vari :
public op_v_vari {
415 square_vari(vari* avi) :
416 op_v_vari(avi->val_ * avi->val_,avi) {
419 avi_->adj_ += adj_ * 2.0 *
avi_->val_;
423 class multiply_log_vv_vari :
public op_vv_vari {
425 multiply_log_vv_vari(vari* avi, vari* bvi) :
426 op_vv_vari(stan::math::
multiply_log(avi->val_,bvi->val_),avi,bvi) {
431 if (
bvi_->val_ == 0.0 &&
avi_->val_ == 0)
432 bvi_->adj_ += adj_ * std::numeric_limits<double>::infinity();
437 class multiply_log_vd_vari :
public op_vd_vari {
439 multiply_log_vd_vari(vari* avi,
double b) :
440 op_vd_vari(stan::math::
multiply_log(avi->val_,b),avi,b) {
447 class multiply_log_dv_vari :
public op_dv_vari {
449 multiply_log_dv_vari(
double a, vari* bvi) :
450 op_dv_vari(stan::math::
multiply_log(a,bvi->val_),a,bvi) {
453 if (
bvi_->val_ == 0.0 &&
ad_ == 0.0)
454 bvi_->adj_ += adj_ * std::numeric_limits<double>::infinity();
465 double ibeta_hypergeometric_helper(
double a,
double b,
double z,
double precision=1
e-8,
double max_steps=1000) {
479 class ibeta_vvv_vari :
public op_vvv_vari {
481 ibeta_vvv_vari(vari* avi, vari* bvi, vari* xvi) :
482 op_vvv_vari(stan::math::
ibeta(avi->val_,bvi->val_,xvi->val_),avi,bvi,xvi) {
485 double a =
avi_->val_;
486 double b =
bvi_->val_;
487 double c =
cvi_->val_;
494 using boost::math::digamma;
496 using stan::agrad::ibeta_hypergeometric_helper;
498 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
502 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
504 boost::math::ibeta_derivative(a,b,c);
507 class ibeta_vvd_vari :
public op_vvd_vari {
509 ibeta_vvd_vari(vari* avi, vari* bvi,
double x) :
510 op_vvd_vari(stan::math::
ibeta(avi->val_,bvi->val_,x),avi,bvi,x) {
513 double a =
avi_->val_;
514 double b =
bvi_->val_;
522 using boost::math::digamma;
524 using stan::agrad::ibeta_hypergeometric_helper;
526 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
530 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
533 class ibeta_vdv_vari :
public op_vdv_vari {
535 ibeta_vdv_vari(vari* avi,
double b, vari* xvi) :
536 op_vdv_vari(stan::math::
ibeta(avi->val_,b,xvi->val_),avi,b,xvi) {
539 double a =
avi_->val_;
541 double c =
cvi_->val_;
548 using boost::math::digamma;
550 using stan::agrad::ibeta_hypergeometric_helper;
552 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
555 boost::math::ibeta_derivative(a,b,c);
558 class ibeta_vdd_vari :
public op_vdd_vari {
560 ibeta_vdd_vari(vari* avi,
double b,
double x) :
561 op_vdd_vari(stan::math::
ibeta(avi->val_,b,x),avi,b,x) {
564 double a =
avi_->val_;
573 using boost::math::digamma;
575 using stan::agrad::ibeta_hypergeometric_helper;
577 (
log(c) - digamma(a) + digamma(a+b)) * val_ -
581 class ibeta_dvv_vari :
public op_dvv_vari {
583 ibeta_dvv_vari(
double a, vari* bvi, vari* xvi) :
584 op_dvv_vari(stan::math::
ibeta(a,bvi->val_,xvi->val_),a,bvi,xvi) {
588 double b =
bvi_->val_;
589 double c =
cvi_->val_;
596 using boost::math::digamma;
598 using stan::agrad::ibeta_hypergeometric_helper;
601 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
603 boost::math::ibeta_derivative(a,b,c);
606 class ibeta_dvd_vari :
public op_dvd_vari {
608 ibeta_dvd_vari(
double a, vari* bvi,
double x) :
609 op_dvd_vari(stan::math::
ibeta(a,bvi->val_,x),a,bvi,x) {
613 double b =
bvi_->val_;
621 using boost::math::digamma;
623 using stan::agrad::ibeta_hypergeometric_helper;
626 +
ibeta(b, a, 1-c) * (digamma(b) - digamma(a+b) -
log(1-c)));
629 class ibeta_ddv_vari :
public op_ddv_vari {
631 ibeta_ddv_vari(
double a,
double b, vari* xvi) :
632 op_ddv_vari(stan::math::
ibeta(a,b,xvi->val_),a,b,xvi) {
637 double c =
cvi_->val_;
640 boost::math::ibeta_derivative(a,b,c);
662 return var(
new acosh_vari(a.
vi_));
678 return var(
new asinh_vari(a.
vi_));
694 return var(
new atanh_vari(a.
vi_));
710 return var(
new erf_vari(a.
vi_));
726 return var(
new erfc_vari(a.
vi_));
742 return var(
new exp2_vari(a.
vi_));
758 return var(
new expm1_vari(a.
vi_));
773 return var(
new lgamma_vari(lgamma_a, a.
vi_));
787 return var(
new log1p_vari(a.
vi_));
801 return var(
new log1m_vari(a.
vi_));
851 return var(
new fma_vvd_vari(a.
vi_,b.
vi_,c));
875 return var(
new fma_vdv_vari(a.
vi_,b,c.
vi_));
897 return var(
new fma_vdd_vari(a.
vi_,b,c));
919 return var(
new fma_vdd_vari(b.
vi_,a,c));
941 return var(
new fma_ddv_vari(a,b,c.
vi_));
965 return var(
new fma_vdv_vari(b.
vi_,a,c.
vi_));
1109 return var(
new hypot_vv_vari(a.
vi_,b.
vi_));
1128 return var(
new hypot_vd_vari(a.
vi_,b));
1147 return var(
new hypot_vd_vari(b.
vi_,a));
1163 return var(
new log2_vari(a.
vi_));
1179 return var(
new cbrt_vari(a.
vi_));
1197 return var(
new round_vari(a.
vi_));
1214 return var(
new trunc_vari(a.
vi_));
1244 return var(
new fdim_vv_vari(a.
vi_,b.
vi_));
1269 ?
var(
new fdim_dv_vari(a,b.
vi_))
1292 ?
var(
new fdim_vd_vari(a.
vi_,b))
1293 : var(
new vari(0.0));
1313 return var(
new tgamma_vari(a.
vi_));
1351 return var(
new inv_cloglog_vari(a.
vi_));
1367 return var(
new Phi_vari(a.
vi_));
1383 return var(
new inv_logit_vari(a.
vi_));
1404 ?
var(
new binary_log_loss_0_vari(y_hat.
vi_))
1405 :
var(
new binary_log_loss_1_vari(y_hat.
vi_));
1413 return var(
new log1p_exp_v_vari(a.
vi_));
1421 return var(
new log_sum_exp_vv_vari(a.
vi_, b.
vi_));
1428 return var(
new log_sum_exp_vd_vari(a.
vi_, b));
1435 return var(
new log_sum_exp_dv_vari(a, b.
vi_));
1441 return var(
new log_sum_exp_vector_vari(x));
1454 return var(
new square_vari(x.
vi_));
1471 return var(
new multiply_log_vv_vari(a.
vi_,b.
vi_));
1484 return var(
new multiply_log_vd_vari(a.
vi_,b));
1500 return var(
new multiply_log_dv_vari(a,b.
vi_));
1513 return c ? y_true : y_false;
1543 return var(y_false);