1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
28 static void grad(chainable* vi);
85 static inline void*
operator new(
size_t nbytes) {
147 throw std::logic_error(
"vari destruction handled automatically");
176 return os << v <<
" " << v->
val_ <<
" : " << v->
adj_;
191 o <<
"STACK, size=" <<
var_stack_.size() << std::endl;
192 for (
size_t i = 0; i <
var_stack_.size(); ++i)
239 return (
vi_ ==
static_cast<vari*
>(0U));
269 vi_(new
vari(static_cast<double>(b))) {
279 vi_(new
vari(static_cast<double>(c))) {
289 vi_(new
vari(static_cast<double>(n))) {
299 vi_(new
vari(static_cast<double>(n))) {
309 vi_(new
vari(static_cast<double>(n))) {
319 vi_(new
vari(static_cast<double>(n))) {
329 vi_(new
vari(static_cast<double>(n))) {
338 var(
unsigned long int n) :
339 vi_(new
vari(static_cast<double>(n))) {
348 var(
unsigned long long n) :
349 vi_(new
vari(static_cast<double>(n))) {
359 vi_(new
vari(static_cast<double>(n))) {
369 vi_(new
vari(static_cast<double>(x))) {
388 vi_(new
vari(static_cast<double>(x))) {
396 inline double val()
const {
408 inline double adj()
const {
425 std::vector<double>& g) {
428 for (
size_t i = 0; i < x.size(); ++i)
598 return os << v.
val() <<
':' << v.
adj();
604 class op_v_vari :
public vari {
608 op_v_vari(
double f, vari* avi) :
614 class op_vv_vari :
public vari {
619 op_vv_vari(
double f, vari* avi, vari* bvi):
626 class op_vd_vari :
public vari {
631 op_vd_vari(
double f, vari* avi,
double b) :
638 class op_dv_vari :
public vari {
643 op_dv_vari(
double f,
double a, vari* bvi) :
650 class op_vvv_vari :
public vari {
656 op_vvv_vari(
double f, vari* avi, vari* bvi, vari* cvi) :
664 class op_vvd_vari :
public vari {
670 op_vvd_vari(
double f, vari* avi, vari* bvi,
double c) :
678 class op_vdv_vari :
public vari {
684 op_vdv_vari(
double f, vari* avi,
double b, vari* cvi) :
692 class op_vdd_vari :
public vari {
698 op_vdd_vari(
double f, vari* avi,
double b,
double c) :
706 class op_dvv_vari :
public vari {
712 op_dvv_vari(
double f,
double a, vari* bvi, vari* cvi) :
720 class op_dvd_vari :
public vari {
726 op_dvd_vari(
double f,
double a, vari* bvi,
double c) :
734 class op_ddv_vari :
public vari {
740 op_ddv_vari(
double f,
double a,
double b, vari* cvi) :
749 class op_vector_vari :
public vari {
754 op_vector_vari(
double f,
const std::vector<stan::agrad::var>& vs) :
757 vis_ = (vari**)
operator new(
sizeof(vari*[vs.size()]));
758 for (
size_t i = 0; i < vs.size(); ++i)
761 vari* operator[](
size_t n)
const {
769 class neg_vari :
public op_v_vari {
771 neg_vari(vari* avi) :
772 op_v_vari(-(avi->val_), avi) {
780 class add_vv_vari :
public op_vv_vari {
782 add_vv_vari(vari* avi, vari* bvi) :
783 op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
791 class add_vd_vari :
public op_vd_vari {
793 add_vd_vari(vari* avi,
double b) :
794 op_vd_vari(avi->val_ + b, avi, b) {
801 class increment_vari :
public op_v_vari {
803 increment_vari(vari* avi) :
804 op_v_vari(avi->val_ + 1.0, avi) {
811 class decrement_vari :
public op_v_vari {
813 decrement_vari(vari* avi) :
814 op_v_vari(avi->val_ - 1.0, avi) {
821 class subtract_vv_vari :
public op_vv_vari {
823 subtract_vv_vari(vari* avi, vari* bvi) :
824 op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
832 class subtract_vd_vari :
public op_vd_vari {
834 subtract_vd_vari(vari* avi,
double b) :
835 op_vd_vari(avi->val_ - b, avi, b) {
842 class subtract_dv_vari :
public op_dv_vari {
844 subtract_dv_vari(
double a, vari* bvi) :
845 op_dv_vari(a - bvi->val_, a, bvi) {
852 class multiply_vv_vari :
public op_vv_vari {
854 multiply_vv_vari(vari* avi, vari* bvi) :
855 op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
863 class multiply_vd_vari :
public op_vd_vari {
865 multiply_vd_vari(vari* avi,
double b) :
866 op_vd_vari(avi->val_ * b, avi, b) {
874 class divide_vv_vari :
public op_vv_vari {
876 divide_vv_vari(vari* avi, vari* bvi) :
877 op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
885 class divide_vd_vari :
public op_vd_vari {
887 divide_vd_vari(vari* avi,
double b) :
888 op_vd_vari(avi->val_ / b, avi, b) {
895 class divide_dv_vari :
public op_dv_vari {
897 divide_dv_vari(
double a, vari* bvi) :
898 op_dv_vari(a / bvi->val_, a, bvi) {
905 class exp_vari :
public op_v_vari {
907 exp_vari(vari* avi) :
908 op_v_vari(
std::
exp(avi->val_),avi) {
911 avi_->adj_ += adj_ * val_;
915 class log_vari :
public op_v_vari {
917 log_vari(vari* avi) :
918 op_v_vari(
std::
log(avi->val_),avi) {
927 class log10_vari :
public op_v_vari {
930 log10_vari(vari* avi) :
931 op_v_vari(
std::
log10(avi->val_),avi),
939 class sqrt_vari :
public op_v_vari {
941 sqrt_vari(vari* avi) :
942 op_v_vari(
std::
sqrt(avi->val_),avi) {
945 avi_->adj_ += adj_ / (2.0 * val_);
949 class pow_vv_vari :
public op_vv_vari {
951 pow_vv_vari(vari* avi, vari* bvi) :
952 op_vv_vari(
std::
pow(avi->val_,bvi->val_),avi,bvi) {
955 if (
avi_->val_ == 0.0)
return;
961 class pow_vd_vari :
public op_vd_vari {
963 pow_vd_vari(vari* avi,
double b) :
964 op_vd_vari(
std::
pow(avi->val_,b),avi,b) {
967 if (
avi_->val_ == 0.0)
return;
972 class pow_dv_vari :
public op_dv_vari {
974 pow_dv_vari(
double a, vari* bvi) :
975 op_dv_vari(
std::
pow(a,bvi->val_),a,bvi) {
978 if (
ad_ == 0.0)
return;
983 class cos_vari :
public op_v_vari {
985 cos_vari(vari* avi) :
986 op_v_vari(
std::
cos(avi->val_),avi) {
993 class sin_vari :
public op_v_vari {
995 sin_vari(vari* avi) :
996 op_v_vari(
std::
sin(avi->val_),avi) {
1003 class tan_vari :
public op_v_vari {
1005 tan_vari(vari* avi) :
1006 op_v_vari(
std::
tan(avi->val_),avi) {
1009 avi_->adj_ += adj_ * (1.0 + val_ * val_);
1013 class acos_vari :
public op_v_vari {
1015 acos_vari(vari* avi) :
1016 op_v_vari(
std::
acos(avi->val_),avi) {
1023 class asin_vari :
public op_v_vari {
1025 asin_vari(vari* avi) :
1026 op_v_vari(
std::
asin(avi->val_),avi) {
1033 class atan_vari :
public op_v_vari {
1035 atan_vari(vari* avi) :
1036 op_v_vari(
std::
atan(avi->val_),avi) {
1039 avi_->adj_ += adj_ / (1.0 + (
avi_->val_ *
avi_->val_));
1043 class atan2_vv_vari :
public op_vv_vari {
1045 atan2_vv_vari(vari* avi, vari* bvi) :
1046 op_vv_vari(
std::
atan2(avi->val_,bvi->val_),avi,bvi) {
1049 double a_sq_plus_b_sq = (
avi_->val_ *
avi_->val_) + (
bvi_->val_ *
bvi_->val_);
1050 avi_->adj_ +=
bvi_->val_ / a_sq_plus_b_sq;
1051 bvi_->adj_ -=
avi_->val_ / a_sq_plus_b_sq;
1055 class atan2_vd_vari :
public op_vd_vari {
1057 atan2_vd_vari(vari* avi,
double b) :
1058 op_vd_vari(
std::
atan2(avi->val_,b),avi,b) {
1062 avi_->adj_ +=
bd_ / a_sq_plus_b_sq;
1066 class atan2_dv_vari :
public op_dv_vari {
1068 atan2_dv_vari(
double a, vari* bvi) :
1069 op_dv_vari(
std::
atan2(a,bvi->val_),a,bvi) {
1073 bvi_->adj_ -=
ad_ / a_sq_plus_b_sq;
1077 class cosh_vari :
public op_v_vari {
1079 cosh_vari(vari* avi) :
1080 op_v_vari(
std::
cosh(avi->val_),avi) {
1087 class sinh_vari :
public op_v_vari {
1089 sinh_vari(vari* avi) :
1090 op_v_vari(
std::
sinh(avi->val_),avi) {
1097 class tanh_vari :
public op_v_vari {
1099 tanh_vari(vari* avi) :
1100 op_v_vari(
std::
tanh(avi->val_),avi) {
1109 class floor_vari :
public vari {
1111 floor_vari(vari* avi) :
1116 class ceil_vari :
public vari {
1118 ceil_vari(vari* avi) :
1123 class fmod_vv_vari :
public op_vv_vari {
1125 fmod_vv_vari(vari* avi, vari* bvi) :
1126 op_vv_vari(
std::
fmod(avi->val_,bvi->val_),avi,bvi) {
1130 bvi_->adj_ -= adj_ *
static_cast<int>(
avi_->val_ /
bvi_->val_);
1134 class fmod_vd_vari :
public op_v_vari {
1136 fmod_vd_vari(vari* avi,
double b) :
1137 op_v_vari(
std::
fmod(avi->val_,b),avi) {
1144 class fmod_dv_vari :
public op_dv_vari {
1146 fmod_dv_vari(
double a, vari* bvi) :
1147 op_dv_vari(
std::
fmod(a,bvi->val_),a,bvi) {
1150 int d =
static_cast<int>(
ad_ /
bvi_->val_);
1151 bvi_->adj_ -= adj_ * d;
1168 return a.
val() == b.
val();
1181 return a.
val() == b;
1193 return a == b.
val();
1205 return a.
val() != b.
val();
1218 return a.
val() != b;
1231 return a != b.
val();
1242 return a.
val() < b.
val();
1277 return a.
val() > b.
val();
1314 return a.
val() <= b.
val();
1327 return a.
val() <= b;
1340 return a <= b.
val();
1353 return a.
val() >= b.
val();
1366 return a.
val() >= b;
1379 return a >= b.
val();
1432 return var(
new neg_vari(a.
vi_));
1449 return var(
new add_vv_vari(a.
vi_,b.
vi_));
1467 return var(
new add_vd_vari(a.
vi_,b));
1484 return var(
new add_vd_vari(b.
vi_,a));
1502 return var(
new subtract_vv_vari(a.
vi_,b.
vi_));
1519 return var(
new subtract_vd_vari(a.
vi_,b));
1534 return var(
new subtract_dv_vari(a,b.
vi_));
1551 return var(
new multiply_vv_vari(a.
vi_,b.
vi_));
1568 return var(
new multiply_vd_vari(a.
vi_,b));
1585 return var(
new multiply_vd_vari(b.
vi_,a));
1603 return var(
new divide_vv_vari(a.
vi_,b.
vi_));
1620 return var(
new divide_vd_vari(a.
vi_,b));
1635 return var(
new divide_dv_vari(a,b.
vi_));
1648 a.
vi_ =
new increment_vari(a.
vi_);
1667 a.
vi_ =
new increment_vari(a.
vi_);
1685 a.
vi_ =
new decrement_vari(a.
vi_);
1704 a.
vi_ =
new decrement_vari(a.
vi_);
1717 return var(
new exp_vari(a.
vi_));
1731 return var(
new log_vari(a.
vi_));
1745 return var(
new log10_vari(a.
vi_));
1762 return var(
new sqrt_vari(a.
vi_));
1779 return var(
new pow_vv_vari(base.
vi_,exponent.
vi_));
1795 if (exponent == 0.5)
1797 if (exponent == 1.0)
1799 if (exponent == 2.0)
1801 return var(
new pow_vd_vari(base.
vi_,exponent));
1817 return var(
new pow_dv_vari(base,exponent.
vi_));
1834 return var(
new cos_vari(a.
vi_));
1848 return var(
new sin_vari(a.
vi_));
1862 return var(
new tan_vari(a.
vi_));
1877 return var(
new acos_vari(a.
vi_));
1892 return var(
new asin_vari(a.
vi_));
1907 return var(
new atan_vari(a.
vi_));
1925 return var(
new atan2_vv_vari(a.
vi_,b.
vi_));
1941 return var(
new atan2_vd_vari(a.
vi_,b));
1957 return var(
new atan2_dv_vari(a,b.
vi_));
1973 return var(
new cosh_vari(a.
vi_));
1987 return var(
new sinh_vari(a.
vi_));
2001 return var(
new tanh_vari(a.
vi_));
2028 return var(
new neg_vari(a.
vi_));
2052 return var(
new floor_vari(a.
vi_));
2074 return var(
new ceil_vari(a.
vi_));
2095 return var(
new fmod_vv_vari(a.
vi_,b.
vi_));
2112 return var(
new fmod_vd_vari(a.
vi_,b));
2129 return var(
new fmod_dv_vari(a,b.
vi_));
2154 return var(
new neg_vari(a.
vi_));
2193 std::vector<chainable*>::iterator begin =
var_stack_.begin();
2194 std::vector<chainable*>::iterator it =
var_stack_.end();
2195 if (begin == it)
return;
2196 for (--it; (it >= begin) && (*it != vi); --it) ;
2200 for (; it >= begin; --it)
2208 for (
size_t i = 0; i <
var_stack_.size(); ++i)
2239 std::vector<var>& independents,
2240 std::vector<std::vector<double> >&
jacobian) {
2241 jacobian.resize(dependents.size());
2242 for (
size_t i = 0; i < dependents.size(); ++i) {
2243 jacobian[i].resize(independents.size());
2246 jacobian.push_back(std::vector<double>(0));
2247 grad(dependents[i].vi_);
2248 for (
size_t j = 0; j < independents.size(); ++j)
2249 jacobian[i][j] = independents[j].adj();
2262 vi_ =
new add_vd_vari(
vi_,b);
2274 vi_ =
new subtract_vd_vari(
vi_,b);
2286 vi_ =
new multiply_vd_vari(
vi_,b);
2298 vi_ =
new divide_vd_vari(
vi_,b);
2315 struct numeric_limits<
stan::agrad::var> {
2316 static const bool is_specialized =
true;
2319 static const int digits = numeric_limits<double>::digits;
2320 static const int digits10 = numeric_limits<double>::digits10;
2321 static const bool is_signed = numeric_limits<double>::is_signed;
2322 static const bool is_integer = numeric_limits<double>::is_integer;
2323 static const bool is_exact = numeric_limits<double>::is_exact;
2324 static const int radix = numeric_limits<double>::radix;
2328 static const int min_exponent = numeric_limits<double>::min_exponent;
2329 static const int min_exponent10 = numeric_limits<double>::min_exponent10;
2330 static const int max_exponent = numeric_limits<double>::max_exponent;
2331 static const int max_exponent10 = numeric_limits<double>::max_exponent10;
2333 static const bool has_infinity = numeric_limits<double>::has_infinity;
2334 static const bool has_quiet_NaN = numeric_limits<double>::has_quiet_NaN;
2335 static const bool has_signaling_NaN = numeric_limits<double>::has_signaling_NaN;
2336 static const float_denorm_style has_denorm = numeric_limits<double>::has_denorm;
2337 static const bool has_denorm_loss = numeric_limits<double>::has_denorm_loss;
2343 static const bool is_iec559 = numeric_limits<double>::is_iec559;
2344 static const bool is_bounded = numeric_limits<double>::is_bounded;
2345 static const bool is_modulo = numeric_limits<double>::is_modulo;
2347 static const bool traps = numeric_limits<double>::traps;
2348 static const bool tinyness_before = numeric_limits<double>::tinyness_before;
2349 static const float_round_style round_style = numeric_limits<double>::round_style;
Abstract base class for variable implementations that handles memory management and applying the chai...
~chainable()
Throws a logic exception.
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends.
virtual void set_zero_adjoint()
Set the value of the adjoint for this chainable to its initial value.
chainable()
Construct a chainable object.
virtual void init_dependent()
Initialize this chainable's adjoint value to make it the dependent variable in a gradient calculation...
Independent (input) and dependent (output) variables for gradients.
var(unsigned long long n)
Construct a variable by static casting the specified value to double.
var(long long n)
Construct a variable by static casting the specified value to double.
double adj() const
Return the derivative of the root expression with respect to this expression.
void grad()
Compute gradients of this dependent variable with respect to all variables on which it depends.
var(unsigned int n)
Construct a variable by static casting the specified value to double.
var()
Construct a variable for later assignment.
friend std::ostream & operator<<(std::ostream &os, const var &v)
Write the value of this auto-dif variable and its adjoint to the specified output stream.
var(char c)
Construct a variable by static casting the specified value to double.
var(float x)
Construct a variable by static casting the specified value to double.
var(long int n)
Construct a variable by static casting the specified value to double.
var(double x)
Construct a variable with the specified value.
var & operator+=(const var &b)
The compound add/assignment operator for variables (C++).
var(int n)
Construct a variable by static casting the specified value to double.
var & operator*=(const var &b)
The compound multiply/assignment operator for variables (C++).
var & operator/=(const var &b)
The compound divide/assignment operator for variables (C++).
var(long double x)
Construct a variable by static casting the specified value to double.
bool is_uninitialized()
Return true if this variable has been declared, but not been defined.
vari & operator*()
Return a reference to underlying implementation of this variable.
var(unsigned short n)
Construct a variable by static casting the specified value to double.
var(unsigned long int n)
Construct a variable by static casting the specified value to double.
void grad(std::vector< var > &x, std::vector< double > &g)
Compute the gradient of this (dependent) variable with respect to the specified vector of (independen...
double val() const
Return the value of this variable.
var & operator-=(const var &b)
The compound subtract/assignment operator for variables (C++).
var(short n)
Construct a variable by static casting the specified value to double.
var(vari *vi)
Construct a variable from a pointer to a variable implementation.
var(bool b)
Construct a variable by static casting the specified value to double.
vari * vi_
Pointer to the implementation of this variable.
vari * operator->()
Return a pointer to the underlying implementation of this variable.
The variable implementation base class.
const double val_
The value of this variable.
virtual void set_zero_adjoint()
Set the adjoint value of this variable to 0.
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
~vari()
Throw an illegal argument exception.
friend std::ostream & operator<<(std::ostream &os, const vari *v)
Insertion operator for vari.
virtual void init_dependent()
Initialize the adjoint for this (dependent) variable to 1.
vari(const double x)
Construct a variable implementation from a value.
void free_all()
Free all memory used by the stack allocator other than the initial block allocation back to the syste...
void recover_all()
Recover all the memory used by the stack allocator.
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator.
std::vector< chainable * > var_stack_
var cosh(const var &a)
Return the hyperbolic cosine of the specified variable (cmath).
var atan(const var &a)
Return the principal value of the arc tangent, in radians, of the specified variable (cmath).
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
var fabs(const var &a)
Return the absolute value of the variable (cmath).
bool operator<=(const var &a, const var &b)
Less than or equal operator comparing two variables' values (C++).
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
var abs(const var &a)
Return the absolute value of the variable (std).
bool operator>=(const var &a, const var &b)
Greater than or equal operator comparing two variables' values (C++).
var operator-(const var &a)
Unary negation operator for variables (C++).
static void recover_memory()
Recover memory used for all variables for reuse.
var asin(const var &a)
Return the principal value of the arc sine, in radians, of the specified variable (cmath).
var operator+(const var &a)
Unary plus operator for variables (C++).
var acos(const var &a)
Return the principal value of the arc cosine of a variable, in radians (cmath).
var ceil(const var &a)
Return the ceiling of the specified variable (cmath).
void jacobian(std::vector< var > &dependents, std::vector< var > &independents, std::vector< std::vector< double > > &jacobian)
Return the Jacobian of the function producing the specified dependent variables with respect to the s...
bool operator!=(const var &a, const var &b)
Inequality operator comparing two variables' values (C++).
var log10(const var &a)
Return the base 10 log of the specified variable (cmath).
var cos(const var &a)
Return the cosine of a radian-scaled variable (cmath).
var operator/(const var &a, const var &b)
Division operator for two variables (C++).
void print_stack(std::ostream &o)
Prints the auto-dif variable stack.
var sin(const var &a)
Return the sine of a radian-scaled variable (cmath).
bool operator<(const var &a, const var &b)
Less than operator comparing variables' values (C++).
bool operator==(const var &a, const var &b)
Equality operator comparing two variables' values (C++).
static void free_memory()
Return all memory used for gradients back to the system.
var tanh(const var &a)
Return the hyperbolic tangent of the specified variable (cmath).
var floor(const var &a)
Return the floor of the specified variable (cmath).
memory::stack_alloc memalloc_
static void set_zero_all_adjoints()
Reset all adjoint values in the stack to zero.
var pow(const var &base, const var &exponent)
Return the base raised to the power of the exponent (cmath).
var operator*(const var &a, const var &b)
Multiplication operator for two variables (C++).
var & operator++(var &a)
Prefix increment operator for variables (C++).
var log(const var &a)
Return the natural log of the specified variable (cmath).
var fmod(const var &a, const var &b)
Return the floating point remainder after dividing the first variable by the second (cmath).
bool operator>(const var &a, const var &b)
Greater than operator comparing variables' values (C++).
var tan(const var &a)
Return the tangent of a radian-scaled variable (cmath).
var atan2(const var &a, const var &b)
Return the principal value of the arc tangent, in radians, of the first variable divided by the secon...
bool operator!(const var &a)
Prefix logical negation for the value of variables (C++).
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
var sinh(const var &a)
Return the hyperbolic sine of the specified variable (cmath).
var & operator--(var &a)
Prefix decrement operator for variables (C++).
double epsilon()
Return minimum positive number representable.
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
const double LOG_10
The natural logarithm of 10, .
Probability, optimization and sampling library.
Template specification of functions in std for Stan.
int isnan(const stan::agrad::var &a)
Checks if the given number is NaN.
int isinf(const stan::agrad::var &a)
Checks if the given number is infinite.
static stan::agrad::var max()
static stan::agrad::var epsilon()
static stan::agrad::var min()
static stan::agrad::var signaling_NaN()
static stan::agrad::var round_error()
static stan::agrad::var quiet_NaN()
static stan::agrad::var infinity()
static stan::agrad::var denorm_min()