1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
28 static void grad(chainable* vi);
85 static inline void*
operator new(
size_t nbytes) {
147 throw std::logic_error(
"vari destruction handled automatically");
176 return os << v <<
" " << v->
val_ <<
" : " << v->
adj_;
191 o <<
"STACK, size=" <<
var_stack_.size() << std::endl;
192 for (
size_t i = 0; i <
var_stack_.size(); ++i)
195 <<
" " << (static_cast<vari*>(
var_stack_[i]))->val_
196 <<
" : " << (static_cast<vari*>(
var_stack_[i]))->adj_
239 return (
vi_ == static_cast<vari*>(0U));
269 vi_(new
vari(static_cast<double>(b))) {
279 vi_(new
vari(static_cast<double>(c))) {
289 vi_(new
vari(static_cast<double>(n))) {
299 vi_(new
vari(static_cast<double>(n))) {
309 vi_(new
vari(static_cast<double>(n))) {
319 vi_(new
vari(static_cast<double>(n))) {
329 vi_(new
vari(static_cast<double>(n))) {
338 var(
unsigned long int n) :
339 vi_(new
vari(static_cast<double>(n))) {
348 var(
unsigned long long n) :
349 vi_(new
vari(static_cast<double>(n))) {
359 vi_(new
vari(static_cast<double>(n))) {
369 vi_(new
vari(static_cast<double>(x))) {
388 vi_(new
vari(static_cast<double>(x))) {
396 inline double val()
const {
408 inline double adj()
const {
425 std::vector<double>& g) {
428 for (
size_t i = 0; i < x.size(); ++i)
598 return os << v.
val() <<
':' << v.
adj();
604 class op_v_vari :
public vari {
608 op_v_vari(
double f, vari* avi) :
614 class op_vv_vari :
public vari {
619 op_vv_vari(
double f, vari* avi, vari* bvi):
626 class op_vd_vari :
public vari {
631 op_vd_vari(
double f, vari* avi,
double b) :
638 class op_dv_vari :
public vari {
643 op_dv_vari(
double f,
double a, vari* bvi) :
650 class op_vvv_vari :
public vari {
656 op_vvv_vari(
double f, vari* avi, vari* bvi, vari* cvi) :
664 class op_vvd_vari :
public vari {
670 op_vvd_vari(
double f, vari* avi, vari* bvi,
double c) :
678 class op_vdv_vari :
public vari {
684 op_vdv_vari(
double f, vari* avi,
double b, vari* cvi) :
692 class op_vdd_vari :
public vari {
698 op_vdd_vari(
double f, vari* avi,
double b,
double c) :
706 class op_dvv_vari :
public vari {
712 op_dvv_vari(
double f,
double a, vari* bvi, vari* cvi) :
720 class op_dvd_vari :
public vari {
726 op_dvd_vari(
double f,
double a, vari* bvi,
double c) :
734 class op_ddv_vari :
public vari {
740 op_ddv_vari(
double f,
double a,
double b, vari* cvi) :
749 class op_vector_vari :
public vari {
754 op_vector_vari(
double f,
const std::vector<stan::agrad::var>& vs) :
757 vis_ = (vari**)
operator new(
sizeof(vari*[vs.size()]));
758 for (
size_t i = 0; i < vs.size(); ++i)
761 vari* operator[](
size_t n)
const {
769 class neg_vari :
public op_v_vari {
771 neg_vari(vari* avi) :
772 op_v_vari(-(avi->val_), avi) {
780 class add_vv_vari :
public op_vv_vari {
782 add_vv_vari(vari* avi, vari* bvi) :
783 op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
791 class add_vd_vari :
public op_vd_vari {
793 add_vd_vari(vari* avi,
double b) :
794 op_vd_vari(avi->val_ + b, avi, b) {
801 class increment_vari :
public op_v_vari {
803 increment_vari(vari* avi) :
804 op_v_vari(avi->val_ + 1.0, avi) {
811 class decrement_vari :
public op_v_vari {
813 decrement_vari(vari* avi) :
814 op_v_vari(avi->val_ - 1.0, avi) {
821 class subtract_vv_vari :
public op_vv_vari {
823 subtract_vv_vari(vari* avi, vari* bvi) :
824 op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
832 class subtract_vd_vari :
public op_vd_vari {
834 subtract_vd_vari(vari* avi,
double b) :
835 op_vd_vari(avi->val_ - b, avi, b) {
842 class subtract_dv_vari :
public op_dv_vari {
844 subtract_dv_vari(
double a, vari* bvi) :
845 op_dv_vari(a - bvi->val_, a, bvi) {
852 class multiply_vv_vari :
public op_vv_vari {
854 multiply_vv_vari(vari* avi, vari* bvi) :
855 op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
863 class multiply_vd_vari :
public op_vd_vari {
865 multiply_vd_vari(vari* avi,
double b) :
866 op_vd_vari(avi->val_ * b, avi, b) {
874 class divide_vv_vari :
public op_vv_vari {
876 divide_vv_vari(vari* avi, vari* bvi) :
877 op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
885 class divide_vd_vari :
public op_vd_vari {
887 divide_vd_vari(vari* avi,
double b) :
888 op_vd_vari(avi->val_ / b, avi, b) {
895 class divide_dv_vari :
public op_dv_vari {
897 divide_dv_vari(
double a, vari* bvi) :
898 op_dv_vari(a / bvi->val_, a, bvi) {
905 class exp_vari :
public op_v_vari {
907 exp_vari(vari* avi) :
908 op_v_vari(std::
exp(avi->val_),avi) {
911 avi_->adj_ += adj_ * val_;
915 class log_vari :
public op_v_vari {
917 log_vari(vari* avi) :
918 op_v_vari(std::
log(avi->val_),avi) {
927 class log10_vari :
public op_v_vari {
930 log10_vari(vari* avi) :
931 op_v_vari(std::
log10(avi->val_),avi),
939 class sqrt_vari :
public op_v_vari {
941 sqrt_vari(vari* avi) :
942 op_v_vari(std::
sqrt(avi->val_),avi) {
945 avi_->adj_ += adj_ / (2.0 * val_);
949 class pow_vv_vari :
public op_vv_vari {
951 pow_vv_vari(vari* avi, vari* bvi) :
952 op_vv_vari(std::
pow(avi->val_,bvi->val_),avi,bvi) {
955 if (
avi_->val_ == 0.0)
return;
961 class pow_vd_vari :
public op_vd_vari {
963 pow_vd_vari(vari* avi,
double b) :
964 op_vd_vari(std::
pow(avi->val_,b),avi,b) {
967 if (
avi_->val_ == 0.0)
return;
972 class pow_dv_vari :
public op_dv_vari {
974 pow_dv_vari(
double a, vari* bvi) :
975 op_dv_vari(std::
pow(a,bvi->val_),a,bvi) {
978 if (
ad_ == 0.0)
return;
983 class cos_vari :
public op_v_vari {
985 cos_vari(vari* avi) :
986 op_v_vari(std::
cos(avi->val_),avi) {
993 class sin_vari :
public op_v_vari {
995 sin_vari(vari* avi) :
996 op_v_vari(std::
sin(avi->val_),avi) {
1003 class tan_vari :
public op_v_vari {
1005 tan_vari(vari* avi) :
1006 op_v_vari(std::
tan(avi->val_),avi) {
1009 avi_->adj_ += adj_ * (1.0 + val_ * val_);
1013 class acos_vari :
public op_v_vari {
1015 acos_vari(vari* avi) :
1016 op_v_vari(std::
acos(avi->val_),avi) {
1023 class asin_vari :
public op_v_vari {
1025 asin_vari(vari* avi) :
1026 op_v_vari(std::
asin(avi->val_),avi) {
1033 class atan_vari :
public op_v_vari {
1035 atan_vari(vari* avi) :
1036 op_v_vari(std::
atan(avi->val_),avi) {
1039 avi_->adj_ += adj_ / (1.0 + (
avi_->val_ *
avi_->val_));
1043 class atan2_vv_vari :
public op_vv_vari {
1045 atan2_vv_vari(vari* avi, vari* bvi) :
1046 op_vv_vari(std::
atan2(avi->val_,bvi->val_),avi,bvi) {
1049 double a_sq_plus_b_sq = (
avi_->val_ *
avi_->val_) + (
bvi_->val_ *
bvi_->val_);
1050 avi_->adj_ +=
bvi_->val_ / a_sq_plus_b_sq;
1051 bvi_->adj_ -=
avi_->val_ / a_sq_plus_b_sq;
1055 class atan2_vd_vari :
public op_vd_vari {
1057 atan2_vd_vari(vari* avi,
double b) :
1058 op_vd_vari(std::
atan2(avi->val_,b),avi,b) {
1062 avi_->adj_ += bd_ / a_sq_plus_b_sq;
1066 class atan2_dv_vari :
public op_dv_vari {
1068 atan2_dv_vari(
double a, vari* bvi) :
1069 op_dv_vari(std::
atan2(a,bvi->val_),a,bvi) {
1073 bvi_->adj_ -=
ad_ / a_sq_plus_b_sq;
1077 class cosh_vari :
public op_v_vari {
1079 cosh_vari(vari* avi) :
1080 op_v_vari(std::
cosh(avi->val_),avi) {
1087 class sinh_vari :
public op_v_vari {
1089 sinh_vari(vari* avi) :
1090 op_v_vari(std::
sinh(avi->val_),avi) {
1097 class tanh_vari :
public op_v_vari {
1099 tanh_vari(vari* avi) :
1100 op_v_vari(std::
tanh(avi->val_),avi) {
1104 avi_->adj_ += adj_ / (cosh *
cosh);
1109 class floor_vari :
public vari {
1111 floor_vari(vari* avi) :
1112 vari(std::
floor(avi->val_)) {
1116 class ceil_vari :
public vari {
1118 ceil_vari(vari* avi) :
1119 vari(std::
ceil(avi->val_)) {
1123 class fmod_vv_vari :
public op_vv_vari {
1125 fmod_vv_vari(vari* avi, vari* bvi) :
1126 op_vv_vari(std::
fmod(avi->val_,bvi->val_),avi,bvi) {
1130 bvi_->adj_ -= adj_ *
static_cast<int>(
avi_->val_ /
bvi_->val_);
1134 class fmod_vd_vari :
public op_v_vari {
1136 fmod_vd_vari(vari* avi,
double b) :
1137 op_v_vari(std::
fmod(avi->val_,b),avi) {
1144 class fmod_dv_vari :
public op_dv_vari {
1146 fmod_dv_vari(
double a, vari* bvi) :
1147 op_dv_vari(std::
fmod(a,bvi->val_),a,bvi) {
1150 int d =
static_cast<int>(
ad_ /
bvi_->val_);
1151 bvi_->adj_ -= adj_ * d;
1168 return a.
val() == b.
val();
1181 return a.
val() == b;
1193 return a == b.
val();
1205 return a.
val() != b.
val();
1218 return a.
val() != b;
1231 return a != b.
val();
1242 return a.
val() < b.
val();
1277 return a.
val() > b.
val();
1314 return a.
val() <= b.
val();
1327 return a.
val() <= b;
1340 return a <= b.
val();
1353 return a.
val() >= b.
val();
1366 return a.
val() >= b;
1379 return a >= b.
val();
1432 return var(
new neg_vari(a.
vi_));
1449 return var(
new add_vv_vari(a.
vi_,b.
vi_));
1467 return var(
new add_vd_vari(a.
vi_,b));
1484 return var(
new add_vd_vari(b.
vi_,a));
1502 return var(
new subtract_vv_vari(a.
vi_,b.
vi_));
1519 return var(
new subtract_vd_vari(a.
vi_,b));
1534 return var(
new subtract_dv_vari(a,b.
vi_));
1551 return var(
new multiply_vv_vari(a.
vi_,b.
vi_));
1568 return var(
new multiply_vd_vari(a.
vi_,b));
1585 return var(
new multiply_vd_vari(b.
vi_,a));
1603 return var(
new divide_vv_vari(a.
vi_,b.
vi_));
1620 return var(
new divide_vd_vari(a.
vi_,b));
1635 return var(
new divide_dv_vari(a,b.
vi_));
1648 a.
vi_ =
new increment_vari(a.
vi_);
1667 a.
vi_ =
new increment_vari(a.
vi_);
1685 a.
vi_ =
new decrement_vari(a.
vi_);
1704 a.
vi_ =
new decrement_vari(a.
vi_);
1717 return var(
new exp_vari(a.
vi_));
1731 return var(
new log_vari(a.
vi_));
1745 return var(
new log10_vari(a.
vi_));
1762 return var(
new sqrt_vari(a.
vi_));
1779 return var(
new pow_vv_vari(base.
vi_,exponent.
vi_));
1795 if (exponent == 0.5)
1797 if (exponent == 1.0)
1799 if (exponent == 2.0)
1801 return var(
new pow_vd_vari(base.vi_,exponent));
1817 return var(
new pow_dv_vari(base,exponent.
vi_));
1834 return var(
new cos_vari(a.
vi_));
1848 return var(
new sin_vari(a.
vi_));
1862 return var(
new tan_vari(a.
vi_));
1877 return var(
new acos_vari(a.
vi_));
1892 return var(
new asin_vari(a.
vi_));
1907 return var(
new atan_vari(a.
vi_));
1925 return var(
new atan2_vv_vari(a.
vi_,b.
vi_));
1941 return var(
new atan2_vd_vari(a.
vi_,b));
1957 return var(
new atan2_dv_vari(a,b.
vi_));
1973 return var(
new cosh_vari(a.
vi_));
1987 return var(
new sinh_vari(a.
vi_));
2001 return var(
new tanh_vari(a.
vi_));
2028 return var(
new neg_vari(a.
vi_));
2052 return var(
new floor_vari(a.
vi_));
2074 return var(
new ceil_vari(a.
vi_));
2095 return var(
new fmod_vv_vari(a.
vi_,b.
vi_));
2112 return var(
new fmod_vd_vari(a.
vi_,b));
2129 return var(
new fmod_dv_vari(a,b.
vi_));
2154 return var(
new neg_vari(a.
vi_));
2193 std::vector<chainable*>::iterator begin =
var_stack_.begin();
2194 std::vector<chainable*>::iterator it =
var_stack_.end();
2195 if (begin == it)
return;
2196 for (--it; (it >= begin) && (*it != vi); --it) ;
2200 for (; it >= begin; --it)
2208 for (
size_t i = 0; i <
var_stack_.size(); ++i)
2239 std::vector<var>& independents,
2240 std::vector<std::vector<double> >&
jacobian) {
2241 jacobian.resize(dependents.size());
2242 for (
size_t i = 0; i < dependents.size(); ++i) {
2243 jacobian[i].resize(independents.size());
2246 jacobian.push_back(std::vector<double>(0));
2247 grad(dependents[i].vi_);
2248 for (
size_t j = 0; j < independents.size(); ++j)
2249 jacobian[i][j] = independents[j].adj();
2262 vi_ =
new add_vd_vari(
vi_,b);
2274 vi_ =
new subtract_vd_vari(
vi_,b);
2286 vi_ =
new multiply_vd_vari(
vi_,b);
2298 vi_ =
new divide_vd_vari(
vi_,b);
2315 struct numeric_limits<stan::agrad::var> {
2316 static const bool is_specialized =
true;
2319 static const int digits = numeric_limits<double>::digits;
2320 static const int digits10 = numeric_limits<double>::digits10;
2321 static const bool is_signed = numeric_limits<double>::is_signed;
2322 static const bool is_integer = numeric_limits<double>::is_integer;
2323 static const bool is_exact = numeric_limits<double>::is_exact;
2324 static const int radix = numeric_limits<double>::radix;
2328 static const int min_exponent = numeric_limits<double>::min_exponent;
2329 static const int min_exponent10 = numeric_limits<double>::min_exponent10;
2330 static const int max_exponent = numeric_limits<double>::max_exponent;
2331 static const int max_exponent10 = numeric_limits<double>::max_exponent10;
2333 static const bool has_infinity = numeric_limits<double>::has_infinity;
2334 static const bool has_quiet_NaN = numeric_limits<double>::has_quiet_NaN;
2335 static const bool has_signaling_NaN = numeric_limits<double>::has_signaling_NaN;
2336 static const float_denorm_style has_denorm = numeric_limits<double>::has_denorm;
2337 static const bool has_denorm_loss = numeric_limits<double>::has_denorm_loss;
2343 static const bool is_iec559 = numeric_limits<double>::is_iec559;
2344 static const bool is_bounded = numeric_limits<double>::is_bounded;
2345 static const bool is_modulo = numeric_limits<double>::is_modulo;
2347 static const bool traps = numeric_limits<double>::traps;
2348 static const bool tinyness_before = numeric_limits<double>::tinyness_before;
2349 static const float_round_style round_style = numeric_limits<double>::round_style;