1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
19 struct var_allocator {
22 inline void* alloc(
size_t nbytes) {
25 inline void recover() {
33 #ifdef AGRAD_THREAD_SAFE
36 var_allocator* allocator_;
84 allocator_->var_stack_.push_back(
this);
104 static inline void*
operator new(
size_t nbytes) {
106 allocator_ =
new var_allocator();
107 return allocator_->alloc(nbytes);
114 return allocator_->recover();
137 std::vector<vari*>::iterator it = allocator_->var_stack_.end();
138 std::vector<vari*>::iterator begin = allocator_->var_stack_.begin();
140 for (; (it >= begin) && (*it != vi); --it)
144 for (; it >= begin; --it)
152 class op_v_vari :
public vari {
156 op_v_vari(
double f, vari* avi) :
162 class op_vv_vari :
public vari {
167 op_vv_vari(
double f, vari* avi, vari* bvi):
174 class op_vd_vari :
public vari {
179 op_vd_vari(
double f, vari* avi,
double b) :
186 class op_dv_vari :
public vari {
191 op_dv_vari(
double f,
double a, vari* bvi) :
198 class op_vvv_vari :
public vari {
204 op_vvv_vari(
double f, vari* avi, vari* bvi, vari* cvi) :
212 class op_vvd_vari :
public vari {
218 op_vvd_vari(
double f, vari* avi, vari* bvi,
double c) :
226 class op_vdv_vari :
public vari {
232 op_vdv_vari(
double f, vari* avi,
double b, vari* cvi) :
240 class op_vdd_vari :
public vari {
246 op_vdd_vari(
double f, vari* avi,
double b,
double c) :
254 class op_dvv_vari :
public vari {
260 op_dvv_vari(
double f,
double a, vari* bvi, vari* cvi) :
268 class op_dvd_vari :
public vari {
274 op_dvd_vari(
double f,
double a, vari* bvi,
double c) :
282 class op_ddv_vari :
public vari {
288 op_ddv_vari(
double f,
double a,
double b, vari* cvi) :
296 class neg_vari :
public op_v_vari {
298 neg_vari(vari* avi) :
299 op_v_vari(-(avi->val_), avi) {
307 class add_vv_vari :
public op_vv_vari {
309 add_vv_vari(vari* avi, vari* bvi) :
310 op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
318 class add_vd_vari :
public op_vd_vari {
320 add_vd_vari(vari* avi,
double b) :
321 op_vd_vari(avi->val_ + b, avi, b) {
328 class increment_vari :
public op_v_vari {
330 increment_vari(vari* avi) :
331 op_v_vari(avi->val_ + 1.0, avi) {
338 class decrement_vari :
public op_v_vari {
340 decrement_vari(vari* avi) :
341 op_v_vari(avi->val_ - 1.0, avi) {
348 class subtract_vv_vari :
public op_vv_vari {
350 subtract_vv_vari(vari* avi, vari* bvi) :
351 op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
359 class subtract_vd_vari :
public op_vd_vari {
361 subtract_vd_vari(vari* avi,
double b) :
362 op_vd_vari(avi->val_ - b, avi, b) {
369 class subtract_dv_vari :
public op_dv_vari {
371 subtract_dv_vari(
double a, vari* bvi) :
372 op_dv_vari(a - bvi->val_, a, bvi) {
379 class multiply_vv_vari :
public op_vv_vari {
381 multiply_vv_vari(vari* avi, vari* bvi) :
382 op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
390 class multiply_vd_vari :
public op_vd_vari {
392 multiply_vd_vari(vari* avi,
double b) :
393 op_vd_vari(avi->val_ * b, avi, b) {
401 class divide_vv_vari :
public op_vv_vari {
403 divide_vv_vari(vari* avi, vari* bvi) :
404 op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
412 class divide_vd_vari :
public op_vd_vari {
414 divide_vd_vari(vari* avi,
double b) :
415 op_vd_vari(avi->val_ / b, avi, b) {
422 class divide_dv_vari :
public op_dv_vari {
424 divide_dv_vari(
double a, vari* bvi) :
425 op_dv_vari(a / bvi->val_, a, bvi) {
432 class exp_vari :
public op_v_vari {
434 exp_vari(vari* avi) :
435 op_v_vari(std::
exp(avi->val_),avi) {
438 avi_->adj_ += adj_ * val_;
442 class log_vari :
public op_v_vari {
444 log_vari(vari* avi) :
445 op_v_vari(std::
log(avi->val_),avi) {
454 class log10_vari :
public op_v_vari {
457 log10_vari(vari* avi) :
458 op_v_vari(std::
log10(avi->val_),avi),
466 class sqrt_vari :
public op_v_vari {
468 sqrt_vari(vari* avi) :
469 op_v_vari(std::
sqrt(avi->val_),avi) {
472 avi_->adj_ += adj_ / (2.0 * val_);
476 class pow_vv_vari :
public op_vv_vari {
478 pow_vv_vari(vari* avi, vari* bvi) :
479 op_vv_vari(std::
pow(avi->val_,bvi->val_),avi,bvi) {
482 if (
avi_->val_ == 0.0)
return;
488 class pow_vd_vari :
public op_vd_vari {
490 pow_vd_vari(vari* avi,
double b) :
491 op_vd_vari(std::
pow(avi->val_,b),avi,b) {
494 if (
avi_->val_ == 0.0)
return;
499 class pow_dv_vari :
public op_dv_vari {
501 pow_dv_vari(
double a, vari* bvi) :
502 op_dv_vari(std::
pow(a,bvi->val_),a,bvi) {
505 if (
ad_ == 0.0)
return;
510 class cos_vari :
public op_v_vari {
512 cos_vari(vari* avi) :
513 op_v_vari(std::
cos(avi->val_),avi) {
520 class sin_vari :
public op_v_vari {
522 sin_vari(vari* avi) :
523 op_v_vari(std::
sin(avi->val_),avi) {
530 class tan_vari :
public op_v_vari {
532 tan_vari(vari* avi) :
533 op_v_vari(std::
tan(avi->val_),avi) {
536 avi_->adj_ += adj_ * (1.0 + val_ * val_);
540 class acos_vari :
public op_v_vari {
542 acos_vari(vari* avi) :
543 op_v_vari(std::
acos(avi->val_),avi) {
550 class asin_vari :
public op_v_vari {
552 asin_vari(vari* avi) :
553 op_v_vari(std::
asin(avi->val_),avi) {
560 class atan_vari :
public op_v_vari {
562 atan_vari(vari* avi) :
563 op_v_vari(std::
atan(avi->val_),avi) {
570 class atan2_vv_vari :
public op_vv_vari {
572 atan2_vv_vari(vari* avi, vari* bvi) :
573 op_vv_vari(std::
atan2(avi->val_,bvi->val_),avi,bvi) {
576 double a_sq_plus_b_sq = (
avi_->val_ *
avi_->val_) + (
bvi_->val_ *
bvi_->val_);
577 avi_->adj_ +=
bvi_->val_ / a_sq_plus_b_sq;
578 bvi_->adj_ -=
avi_->val_ / a_sq_plus_b_sq;
582 class atan2_vd_vari :
public op_vd_vari {
584 atan2_vd_vari(vari* avi,
double b) :
585 op_vd_vari(std::
atan2(avi->val_,b),avi,b) {
589 avi_->adj_ += bd_ / a_sq_plus_b_sq;
593 class atan2_dv_vari :
public op_dv_vari {
595 atan2_dv_vari(
double a, vari* bvi) :
596 op_dv_vari(std::
atan2(a,bvi->val_),a,bvi) {
600 bvi_->adj_ -=
ad_ / a_sq_plus_b_sq;
604 class cosh_vari :
public op_v_vari {
606 cosh_vari(vari* avi) :
607 op_v_vari(std::
cosh(avi->val_),avi) {
614 class sinh_vari :
public op_v_vari {
616 sinh_vari(vari* avi) :
617 op_v_vari(std::
sinh(avi->val_),avi) {
624 class tanh_vari :
public op_v_vari {
626 tanh_vari(vari* avi) :
627 op_v_vari(std::
tanh(avi->val_),avi) {
636 class floor_vari :
public vari {
638 floor_vari(vari* avi) :
639 vari(std::
floor(avi->val_)) {
643 class ceil_vari :
public vari {
645 ceil_vari(vari* avi) :
646 vari(std::
ceil(avi->val_)) {
650 class fmod_vv_vari :
public op_vv_vari {
652 fmod_vv_vari(vari* avi, vari* bvi) :
653 op_vv_vari(std::
fmod(avi->val_,bvi->val_),avi,bvi) {
657 bvi_->adj_ -= adj_ *
static_cast<int>(
avi_->val_ /
bvi_->val_);
661 class fmod_vd_vari :
public op_v_vari {
663 fmod_vd_vari(vari* avi,
double b) :
664 op_v_vari(std::
fmod(avi->val_,b),avi) {
671 class fmod_dv_vari :
public op_dv_vari {
673 fmod_dv_vari(
double a, vari* bvi) :
674 op_dv_vari(std::
fmod(a,bvi->val_),a,bvi) {
677 int d =
static_cast<int>(
ad_ /
bvi_->val_);
678 bvi_->adj_ -= adj_ * d;
741 vi_(new
vari(static_cast<double>(b))) {
751 vi_(new
vari(static_cast<double>(c))) {
761 vi_(new
vari(static_cast<double>(n))) {
771 vi_(new
vari(static_cast<double>(n))) {
781 vi_(new
vari(static_cast<double>(n))) {
791 vi_(new
vari(static_cast<double>(n))) {
801 vi_(new
vari(static_cast<double>(n))) {
810 var(
unsigned long int n) :
811 vi_(new
vari(static_cast<double>(n))) {
821 vi_(new
vari(static_cast<double>(x))) {
840 vi_(new
vari(static_cast<double>(x))) {
848 inline double val()
const {
865 std::vector<double>& g) {
868 for (
size_t i = 0U; i < x.size(); ++i)
907 vi_ =
new add_vv_vari(
vi_,b.vi_);
922 vi_ =
new add_vd_vari(
vi_,b);
938 vi_ =
new subtract_vv_vari(
vi_,b.vi_);
954 vi_ =
new subtract_vd_vari(
vi_,b);
970 vi_ =
new multiply_vv_vari(
vi_,b.vi_);
986 vi_ =
new multiply_vd_vari(
vi_,b);
1001 vi_ =
new divide_vv_vari(
vi_,b.vi_);
1017 vi_ =
new divide_vd_vari(
vi_,b);
1034 inline bool operator==(
const var& a,
const var& b) {
1035 return a.val() == b.val();
1048 return a.val() == b;
1060 return a == b.val();
1071 inline bool operator!=(
const var& a,
const var& b) {
1072 return a.val() != b.val();
1085 return a.val() != b;
1098 return a != b.val();
1108 inline bool operator<(
const var& a,
const var& b) {
1109 return a.val() < b.val();
1143 inline bool operator>(
const var& a,
const var& b) {
1144 return a.val() > b.val();
1180 inline bool operator<=(
const var& a,
const var& b) {
1181 return a.val() <= b.val();
1194 return a.val() <= b;
1207 return a <= b.val();
1219 inline bool operator>=(
const var& a,
const var& b) {
1220 return a.val() >= b.val();
1233 return a.val() >= b;
1246 return a >= b.val();
1298 return var(
new neg_vari(a.vi_));
1314 inline var
operator+(
const var& a,
const var& b) {
1315 return var(
new add_vv_vari(a.vi_,b.vi_));
1331 return var(
new add_vd_vari(a.vi_,b));
1346 return var(
new add_vd_vari(b.vi_,a));
1363 inline var
operator-(
const var& a,
const var& b) {
1364 return var(
new subtract_vv_vari(a.vi_,b.vi_));
1379 return var(
new subtract_vd_vari(a.vi_,b));
1394 return var(
new subtract_dv_vari(a,b.vi_));
1410 inline var
operator*(
const var& a,
const var& b) {
1411 return var(
new multiply_vv_vari(a.vi_,b.vi_));
1426 return var(
new multiply_vd_vari(a.vi_,b));
1441 return var(
new multiply_vd_vari(b.vi_,a));
1458 inline var
operator/(
const var& a,
const var& b) {
1459 return var(
new divide_vv_vari(a.vi_,b.vi_));
1474 return var(
new divide_vd_vari(a.vi_,b));
1489 return var(
new divide_dv_vari(a,b.vi_));
1502 a.vi_ =
new increment_vari(a.vi_);
1521 a.vi_ =
new increment_vari(a.vi_);
1539 a.vi_ =
new decrement_vari(a.vi_);
1558 a.vi_ =
new decrement_vari(a.vi_);
1570 inline var
exp(
const var& a) {
1571 return var(
new exp_vari(a.vi_));
1584 inline var
log(
const var& a) {
1585 return var(
new log_vari(a.vi_));
1598 inline var
log10(
const var& a) {
1599 return var(
new log10_vari(a.vi_));
1615 inline var
sqrt(
const var& a) {
1616 return var(
new sqrt_vari(a.vi_));
1632 inline var
pow(
const var& base,
const var& exponent) {
1633 return var(
new pow_vv_vari(base.vi_,exponent.vi_));
1649 return var(
new pow_vd_vari(base.vi_,exponent));
1665 return var(
new pow_dv_vari(base,exponent.vi_));
1681 inline var
cos(
const var& a) {
1682 return var(
new cos_vari(a.vi_));
1695 inline var
sin(
const var& a) {
1696 return var(
new sin_vari(a.vi_));
1709 inline var
tan(
const var& a) {
1710 return var(
new tan_vari(a.vi_));
1724 inline var
acos(
const var& a) {
1725 return var(
new acos_vari(a.vi_));
1739 inline var
asin(
const var& a) {
1740 return var(
new asin_vari(a.vi_));
1754 inline var
atan(
const var& a) {
1755 return var(
new atan_vari(a.vi_));
1773 inline var
atan2(
const var& a,
const var& b) {
1774 return var(
new atan2_vv_vari(a.vi_,b.vi_));
1790 return var(
new atan2_vd_vari(a.vi_,b));
1806 return var(
new atan2_dv_vari(a,b.vi_));
1821 inline var
cosh(
const var& a) {
1822 return var(
new cosh_vari(a.vi_));
1835 inline var
sinh(
const var& a) {
1836 return var(
new sinh_vari(a.vi_));
1849 inline var
tanh(
const var& a) {
1850 return var(
new tanh_vari(a.vi_));
1872 inline var
fabs(
const var& a) {
1877 return var(
new neg_vari(a.vi_));
1878 return var(
new vari(0.0));
1899 inline var
floor(
const var& a) {
1900 return var(
new floor_vari(a.vi_));
1921 inline var
ceil(
const var& a) {
1922 return var(
new ceil_vari(a.vi_));
1942 inline var
fmod(
const var& a,
const var& b) {
1943 return var(
new fmod_vv_vari(a.vi_,b.vi_));
1960 return var(
new fmod_vd_vari(a.vi_,b));
1977 return var(
new fmod_dv_vari(a,b.vi_));
1997 inline var
abs(
const var& a) {
2002 return var(
new neg_vari(a.vi_));
2003 return var(
new vari(0.0));