1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
19 struct var_allocator {
22 inline void* alloc(
size_t nbytes) {
25 inline void recover() {
33 #ifdef AGRAD_THREAD_SAFE
36 var_allocator* allocator_;
84 allocator_->var_stack_.push_back(
this);
104 static inline void*
operator new(
size_t nbytes) {
106 allocator_ =
new var_allocator();
107 return allocator_->alloc(nbytes);
114 return allocator_->recover();
137 std::vector<vari*>::iterator it = allocator_->var_stack_.end();
138 std::vector<vari*>::iterator begin = allocator_->var_stack_.begin();
140 for (; (it >= begin) && (*it != vi); --it)
144 for (; it >= begin; --it)
152 class op_v_vari :
public vari {
156 op_v_vari(
double f, vari* avi) :
162 class op_vv_vari :
public vari {
167 op_vv_vari(
double f, vari* avi, vari* bvi):
174 class op_vd_vari :
public vari {
179 op_vd_vari(
double f, vari* avi,
double b) :
186 class op_dv_vari :
public vari {
191 op_dv_vari(
double f,
double a, vari* bvi) :
198 class op_vvv_vari :
public vari {
204 op_vvv_vari(
double f, vari* avi, vari* bvi, vari* cvi) :
212 class op_vvd_vari :
public vari {
218 op_vvd_vari(
double f, vari* avi, vari* bvi,
double c) :
226 class op_vdv_vari :
public vari {
232 op_vdv_vari(
double f, vari* avi,
double b, vari* cvi) :
240 class op_vdd_vari :
public vari {
246 op_vdd_vari(
double f, vari* avi,
double b,
double c) :
254 class op_dvv_vari :
public vari {
260 op_dvv_vari(
double f,
double a, vari* bvi, vari* cvi) :
268 class op_dvd_vari :
public vari {
274 op_dvd_vari(
double f,
double a, vari* bvi,
double c) :
282 class op_ddv_vari :
public vari {
288 op_ddv_vari(
double f,
double a,
double b, vari* cvi) :
296 class neg_vari :
public op_v_vari {
298 neg_vari(vari* avi) :
299 op_v_vari(-(avi->val_), avi) {
307 class add_vv_vari :
public op_vv_vari {
309 add_vv_vari(vari* avi, vari* bvi) :
310 op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
318 class add_vd_vari :
public op_vd_vari {
320 add_vd_vari(vari* avi,
double b) :
321 op_vd_vari(avi->val_ + b, avi, b) {
328 class increment_vari :
public op_v_vari {
330 increment_vari(vari* avi) :
331 op_v_vari(avi->val_ + 1.0, avi) {
338 class decrement_vari :
public op_v_vari {
340 decrement_vari(vari* avi) :
341 op_v_vari(avi->val_ - 1.0, avi) {
348 class subtract_vv_vari :
public op_vv_vari {
350 subtract_vv_vari(vari* avi, vari* bvi) :
351 op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
359 class subtract_vd_vari :
public op_vd_vari {
361 subtract_vd_vari(vari* avi,
double b) :
362 op_vd_vari(avi->val_ - b, avi, b) {
369 class subtract_dv_vari :
public op_dv_vari {
371 subtract_dv_vari(
double a, vari* bvi) :
372 op_dv_vari(a - bvi->val_, a, bvi) {
379 class multiply_vv_vari :
public op_vv_vari {
381 multiply_vv_vari(vari* avi, vari* bvi) :
382 op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
390 class multiply_vd_vari :
public op_vd_vari {
392 multiply_vd_vari(vari* avi,
double b) :
393 op_vd_vari(avi->val_ * b, avi, b) {
401 class divide_vv_vari :
public op_vv_vari {
403 divide_vv_vari(vari* avi, vari* bvi) :
404 op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
412 class divide_vd_vari :
public op_vd_vari {
414 divide_vd_vari(vari* avi,
double b) :
415 op_vd_vari(avi->val_ / b, avi, b) {
422 class divide_dv_vari :
public op_dv_vari {
424 divide_dv_vari(
double a, vari* bvi) :
425 op_dv_vari(a / bvi->val_, a, bvi) {
432 class exp_vari :
public op_v_vari {
434 exp_vari(vari* avi) :
435 op_v_vari(
std::
exp(avi->val_),avi) {
438 avi_->adj_ += adj_ * val_;
442 class log_vari :
public op_v_vari {
444 log_vari(vari* avi) :
445 op_v_vari(
std::
log(avi->val_),avi) {
454 class log10_vari :
public op_v_vari {
457 log10_vari(vari* avi) :
458 op_v_vari(
std::
log10(avi->val_),avi),
466 class sqrt_vari :
public op_v_vari {
468 sqrt_vari(vari* avi) :
469 op_v_vari(
std::
sqrt(avi->val_),avi) {
472 avi_->adj_ += adj_ / (2.0 * val_);
476 class pow_vv_vari :
public op_vv_vari {
478 pow_vv_vari(vari* avi, vari* bvi) :
479 op_vv_vari(
std::
pow(avi->val_,bvi->val_),avi,bvi) {
482 if (
avi_->val_ == 0.0)
return;
488 class pow_vd_vari :
public op_vd_vari {
490 pow_vd_vari(vari* avi,
double b) :
491 op_vd_vari(
std::
pow(avi->val_,b),avi,b) {
494 if (
avi_->val_ == 0.0)
return;
499 class pow_dv_vari :
public op_dv_vari {
501 pow_dv_vari(
double a, vari* bvi) :
502 op_dv_vari(
std::
pow(a,bvi->val_),a,bvi) {
505 if (
ad_ == 0.0)
return;
510 class cos_vari :
public op_v_vari {
512 cos_vari(vari* avi) :
513 op_v_vari(
std::
cos(avi->val_),avi) {
520 class sin_vari :
public op_v_vari {
522 sin_vari(vari* avi) :
523 op_v_vari(
std::
sin(avi->val_),avi) {
530 class tan_vari :
public op_v_vari {
532 tan_vari(vari* avi) :
533 op_v_vari(
std::
tan(avi->val_),avi) {
536 avi_->adj_ += adj_ * (1.0 + val_ * val_);
540 class acos_vari :
public op_v_vari {
542 acos_vari(vari* avi) :
543 op_v_vari(
std::
acos(avi->val_),avi) {
550 class asin_vari :
public op_v_vari {
552 asin_vari(vari* avi) :
553 op_v_vari(
std::
asin(avi->val_),avi) {
560 class atan_vari :
public op_v_vari {
562 atan_vari(vari* avi) :
563 op_v_vari(
std::
atan(avi->val_),avi) {
570 class atan2_vv_vari :
public op_vv_vari {
572 atan2_vv_vari(vari* avi, vari* bvi) :
573 op_vv_vari(
std::
atan2(avi->val_,bvi->val_),avi,bvi) {
576 double a_sq_plus_b_sq = (
avi_->val_ *
avi_->val_) + (
bvi_->val_ *
bvi_->val_);
577 avi_->adj_ +=
bvi_->val_ / a_sq_plus_b_sq;
578 bvi_->adj_ -=
avi_->val_ / a_sq_plus_b_sq;
582 class atan2_vd_vari :
public op_vd_vari {
584 atan2_vd_vari(vari* avi,
double b) :
585 op_vd_vari(
std::
atan2(avi->val_,b),avi,b) {
589 avi_->adj_ +=
bd_ / a_sq_plus_b_sq;
593 class atan2_dv_vari :
public op_dv_vari {
595 atan2_dv_vari(
double a, vari* bvi) :
596 op_dv_vari(
std::
atan2(a,bvi->val_),a,bvi) {
600 bvi_->adj_ -=
ad_ / a_sq_plus_b_sq;
604 class cosh_vari :
public op_v_vari {
606 cosh_vari(vari* avi) :
607 op_v_vari(
std::
cosh(avi->val_),avi) {
614 class sinh_vari :
public op_v_vari {
616 sinh_vari(vari* avi) :
617 op_v_vari(
std::
sinh(avi->val_),avi) {
624 class tanh_vari :
public op_v_vari {
626 tanh_vari(vari* avi) :
627 op_v_vari(
std::
tanh(avi->val_),avi) {
636 class floor_vari :
public vari {
638 floor_vari(vari* avi) :
643 class ceil_vari :
public vari {
645 ceil_vari(vari* avi) :
650 class fmod_vv_vari :
public op_vv_vari {
652 fmod_vv_vari(vari* avi, vari* bvi) :
653 op_vv_vari(
std::
fmod(avi->val_,bvi->val_),avi,bvi) {
657 bvi_->adj_ -= adj_ *
static_cast<int>(
avi_->val_ /
bvi_->val_);
661 class fmod_vd_vari :
public op_v_vari {
663 fmod_vd_vari(vari* avi,
double b) :
664 op_v_vari(
std::
fmod(avi->val_,b),avi) {
671 class fmod_dv_vari :
public op_dv_vari {
673 fmod_dv_vari(
double a, vari* bvi) :
674 op_dv_vari(
std::
fmod(a,bvi->val_),a,bvi) {
677 int d =
static_cast<int>(
ad_ /
bvi_->val_);
678 bvi_->adj_ -= adj_ * d;
741 vi_(new
vari(static_cast<double>(b))) {
751 vi_(new
vari(static_cast<double>(c))) {
761 vi_(new
vari(static_cast<double>(n))) {
771 vi_(new
vari(static_cast<double>(n))) {
781 vi_(new
vari(static_cast<double>(n))) {
791 vi_(new
vari(static_cast<double>(n))) {
801 vi_(new
vari(static_cast<double>(n))) {
810 var(
unsigned long int n) :
811 vi_(new
vari(static_cast<double>(n))) {
821 vi_(new
vari(static_cast<double>(x))) {
840 vi_(new
vari(static_cast<double>(x))) {
848 inline double val()
const {
865 std::vector<double>& g) {
868 for (
size_t i = 0U; i < x.size(); ++i)
907 vi_ =
new add_vv_vari(
vi_,b.vi_);
922 vi_ =
new add_vd_vari(
vi_,b);
938 vi_ =
new subtract_vv_vari(
vi_,b.vi_);
954 vi_ =
new subtract_vd_vari(
vi_,b);
970 vi_ =
new multiply_vv_vari(
vi_,b.vi_);
986 vi_ =
new multiply_vd_vari(
vi_,b);
1001 vi_ =
new divide_vv_vari(
vi_,b.vi_);
1017 vi_ =
new divide_vd_vari(
vi_,b);
1034 inline bool operator==(
const var& a,
const var& b) {
1035 return a.val() == b.val();
1048 return a.val() == b;
1060 return a == b.val();
1071 inline bool operator!=(
const var& a,
const var& b) {
1072 return a.val() != b.val();
1085 return a.val() != b;
1098 return a != b.val();
1108 inline bool operator<(
const var& a,
const var& b) {
1109 return a.val() < b.val();
1143 inline bool operator>(
const var& a,
const var& b) {
1144 return a.val() > b.val();
1180 inline bool operator<=(
const var& a,
const var& b) {
1181 return a.val() <= b.val();
1194 return a.val() <= b;
1207 return a <= b.val();
1219 inline bool operator>=(
const var& a,
const var& b) {
1220 return a.val() >= b.val();
1233 return a.val() >= b;
1246 return a >= b.val();
1298 return var(
new neg_vari(a.vi_));
1314 inline var
operator+(
const var& a,
const var& b) {
1315 return var(
new add_vv_vari(a.vi_,b.vi_));
1331 return var(
new add_vd_vari(a.vi_,b));
1346 return var(
new add_vd_vari(b.vi_,a));
1363 inline var
operator-(
const var& a,
const var& b) {
1364 return var(
new subtract_vv_vari(a.vi_,b.vi_));
1379 return var(
new subtract_vd_vari(a.vi_,b));
1394 return var(
new subtract_dv_vari(a,b.vi_));
1410 inline var
operator*(
const var& a,
const var& b) {
1411 return var(
new multiply_vv_vari(a.vi_,b.vi_));
1426 return var(
new multiply_vd_vari(a.vi_,b));
1441 return var(
new multiply_vd_vari(b.vi_,a));
1458 inline var
operator/(
const var& a,
const var& b) {
1459 return var(
new divide_vv_vari(a.vi_,b.vi_));
1474 return var(
new divide_vd_vari(a.vi_,b));
1489 return var(
new divide_dv_vari(a,b.vi_));
1502 a.
vi_ =
new increment_vari(a.vi_);
1521 a.vi_ =
new increment_vari(a.vi_);
1539 a.
vi_ =
new decrement_vari(a.vi_);
1558 a.vi_ =
new decrement_vari(a.vi_);
1570 inline var
exp(
const var& a) {
1571 return var(
new exp_vari(a.vi_));
1584 inline var
log(
const var& a) {
1585 return var(
new log_vari(a.vi_));
1598 inline var
log10(
const var& a) {
1599 return var(
new log10_vari(a.vi_));
1615 inline var
sqrt(
const var& a) {
1616 return var(
new sqrt_vari(a.vi_));
1632 inline var
pow(
const var& base,
const var& exponent) {
1633 return var(
new pow_vv_vari(base.vi_,exponent.vi_));
1649 return var(
new pow_vd_vari(base.vi_,exponent));
1665 return var(
new pow_dv_vari(base,exponent.vi_));
1681 inline var
cos(
const var& a) {
1682 return var(
new cos_vari(a.vi_));
1695 inline var
sin(
const var& a) {
1696 return var(
new sin_vari(a.vi_));
1709 inline var
tan(
const var& a) {
1710 return var(
new tan_vari(a.vi_));
1724 inline var
acos(
const var& a) {
1725 return var(
new acos_vari(a.vi_));
1739 inline var
asin(
const var& a) {
1740 return var(
new asin_vari(a.vi_));
1754 inline var
atan(
const var& a) {
1755 return var(
new atan_vari(a.vi_));
1773 inline var
atan2(
const var& a,
const var& b) {
1774 return var(
new atan2_vv_vari(a.vi_,b.vi_));
1790 return var(
new atan2_vd_vari(a.vi_,b));
1806 return var(
new atan2_dv_vari(a,b.vi_));
1821 inline var
cosh(
const var& a) {
1822 return var(
new cosh_vari(a.vi_));
1835 inline var
sinh(
const var& a) {
1836 return var(
new sinh_vari(a.vi_));
1849 inline var
tanh(
const var& a) {
1850 return var(
new tanh_vari(a.vi_));
1872 inline var
fabs(
const var& a) {
1877 return var(
new neg_vari(a.vi_));
1878 return var(
new vari(0.0));
1899 inline var
floor(
const var& a) {
1900 return var(
new floor_vari(a.vi_));
1921 inline var
ceil(
const var& a) {
1922 return var(
new ceil_vari(a.vi_));
1942 inline var
fmod(
const var& a,
const var& b) {
1943 return var(
new fmod_vv_vari(a.vi_,b.vi_));
1960 return var(
new fmod_vd_vari(a.vi_,b));
1977 return var(
new fmod_dv_vari(a,b.vi_));
1997 inline var
abs(
const var& a) {
2002 return var(
new neg_vari(a.vi_));
2003 return var(
new vari(0.0));
Independent (input) and dependent (output) variables for gradients.
void grad()
Compute gradients of this dependent variable with respect to all variables on which it depends.
var & operator*=(const double &b)
The compound multiply/assignment operator for scalars (C++).
var(unsigned int n)
Construct a variable by static casting the specified value to double.
var()
Construct a variable for later assignment.
var(char c)
Construct a variable by static casting the specified value to double.
var(float x)
Construct a variable by static casting the specified value to double.
var & operator*=(const var &b)
The compound multiply/assignment operator for variables (C++).
var(long int n)
Construct a variable by static casting the specified value to double.
var(double x)
Construct a variable with the specified value.
var & operator+=(const double &b)
The compound add/assignment operator for scalars (C++).
var & operator-=(const var &b)
The compound subtract/assignment operator for variables (C++).
var & operator+=(const var &b)
The compound add/assignment operator for variables (C++).
var(int n)
Construct a variable by static casting the specified value to double.
var(long double x)
Construct a variable by static casting the specified value to double.
var(unsigned short n)
Construct a variable by static casting the specified value to double.
var(unsigned long int n)
Construct a variable by static casting the specified value to double.
var & operator/=(const double &b)
The compound divide/assignment operator for scalars (C++).
void grad(std::vector< var > &x, std::vector< double > &g)
Compute the gradient of this dependent variable with respect to the specified vector of independent v...
double val() const
Return the value of this variable.
var(short n)
Construct a variable by static casting the specified value to double.
var(vari *vi)
Construct a variable from a pointer to a variable implementation.
var & operator/=(const var &b)
The compound divide/assignment operator for variables (C++).
var & operator-=(const double &b)
The compound subtract/assignment operator for scalars (C++).
var(bool b)
Construct a variable by static casting the specified value to double.
vari * vi_
Pointer to the implementation of this variable.
The variable implementation base class.
const double val_
The value of this variable.
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
static void recover_memory()
Recover memory used for all variables for reuse.
static void free_memory()
Return all memory used for gradients back to the system.
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends.
vari(const double x)
Construct a variable implementation from a value.
void free_all()
Free all memory used by the stack allocator other than the initial block allocation back to the syste...
void recover_all()
Recover all the memory used by the stack allocator.
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator.
std::vector< chainable * > var_stack_
var cosh(const var &a)
Return the hyperbolic cosine of the specified variable (cmath).
var atan(const var &a)
Return the principal value of the arc tangent, in radians, of the specified variable (cmath).
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
var fabs(const var &a)
Return the absolute value of the variable (cmath).
bool operator<=(const var &a, const var &b)
Less than or equal operator comparing two variables' values (C++).
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
var abs(const var &a)
Return the absolute value of the variable (std).
bool operator>=(const var &a, const var &b)
Greater than or equal operator comparing two variables' values (C++).
var operator-(const var &a)
Unary negation operator for variables (C++).
var asin(const var &a)
Return the principal value of the arc sine, in radians, of the specified variable (cmath).
var operator+(const var &a)
Unary plus operator for variables (C++).
var acos(const var &a)
Return the principal value of the arc cosine of a variable, in radians (cmath).
var ceil(const var &a)
Return the ceiling of the specified variable (cmath).
bool operator!=(const var &a, const var &b)
Inequality operator comparing two variables' values (C++).
var log10(const var &a)
Return the base 10 log of the specified variable (cmath).
var cos(const var &a)
Return the cosine of a radian-scaled variable (cmath).
var operator/(const var &a, const var &b)
Division operator for two variables (C++).
var sin(const var &a)
Return the sine of a radian-scaled variable (cmath).
bool operator<(const var &a, const var &b)
Less than operator comparing variables' values (C++).
bool operator==(const var &a, const var &b)
Equality operator comparing two variables' values (C++).
var tanh(const var &a)
Return the hyperbolic tangent of the specified variable (cmath).
var floor(const var &a)
Return the floor of the specified variable (cmath).
memory::stack_alloc memalloc_
var pow(const var &base, const var &exponent)
Return the base raised to the power of the exponent (cmath).
var operator*(const var &a, const var &b)
Multiplication operator for two variables (C++).
var & operator++(var &a)
Prefix increment operator for variables (C++).
var log(const var &a)
Return the natural log of the specified variable (cmath).
var fmod(const var &a, const var &b)
Return the floating point remainder after dividing the first variable by the second (cmath).
bool operator>(const var &a, const var &b)
Greater than operator comparing variables' values (C++).
var tan(const var &a)
Return the tangent of a radian-scaled variable (cmath).
var atan2(const var &a, const var &b)
Return the principal value of the arc tangent, in radians, of the first variable divided by the secon...
bool operator!(const var &a)
Prefix logical negation for the value of variables (C++).
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
var sinh(const var &a)
Return the hyperbolic sine of the specified variable (cmath).
var & operator--(var &a)
Prefix decrement operator for variables (C++).
const double LOG_10
The natural logarithm of 10, .
Probability, optimization and sampling library.
Template specification of functions in std for Stan.