Stan  1.0
probability, sampling & optimization
agrad_thread_safe.hpp
Go to the documentation of this file.
1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
3 
4 #include <stdio.h>
5 #include <stdlib.h>
6 #include <vector>
7 #include <cmath>
8 #include <cstddef>
10 
11 
12 namespace stan {
13 
14  namespace agrad {
15 
16  class vari;
17 
18  namespace {
19  struct var_allocator {
20  std::vector<vari*> var_stack_;
21  memory::stack_alloc memalloc_;
22  inline void* alloc(size_t nbytes) {
23  return memalloc_.alloc(nbytes);
24  }
25  inline void recover() {
26  var_stack_.resize(0);
28  }
29  inline void free() {
31  }
32  };
33 #ifdef AGRAD_THREAD_SAFE
34  __thread
35 #endif
36  var_allocator* allocator_;
37  }
38 
52  class vari {
53  private:
54  friend class var;
55 
56  public:
57 
61  const double val_;
62 
67  double adj_;
68 
81  vari(const double x):
82  val_(x),
83  adj_(0.0) {
84  allocator_->var_stack_.push_back(this);
85  }
86 
92  virtual void chain() {
93  }
94 
104  static inline void* operator new(size_t nbytes) {
105  if (allocator_ == 0)
106  allocator_ = new var_allocator();
107  return allocator_->alloc(nbytes);
108  }
109 
113  static void recover_memory() {
114  return allocator_->recover();
115  // allocator_.var_stack_.resize(0);
116  // allocator_.memalloc_.recover_all();
117  }
118 
122  static void free_memory() {
123  allocator_->free();
124  // allocator_.memalloc_.free_all();
125  }
126 
127  private:
136  static void grad(vari* vi) {
137  std::vector<vari*>::iterator it = allocator_->var_stack_.end();
138  std::vector<vari*>::iterator begin = allocator_->var_stack_.begin();
139  // skip to root variable
140  for (; (it >= begin) && (*it != vi); --it)
141  ;
142  vi->adj_ = 1.0; // droot/droot = 1
143  // propagate derivates for remaining vars
144  for (; it >= begin; --it)
145  (*it)->chain();
146  }
147 
148  };
149 
150  namespace {
151 
152  class op_v_vari : public vari {
153  protected:
154  vari* avi_;
155  public:
156  op_v_vari(double f, vari* avi) :
157  vari(f),
158  avi_(avi) {
159  }
160  };
161 
162  class op_vv_vari : public vari {
163  protected:
164  vari* avi_;
165  vari* bvi_;
166  public:
167  op_vv_vari(double f, vari* avi, vari* bvi):
168  vari(f),
169  avi_(avi),
170  bvi_(bvi) {
171  }
172  };
173 
174  class op_vd_vari : public vari {
175  protected:
176  vari* avi_;
177  double bd_;
178  public:
179  op_vd_vari(double f, vari* avi, double b) :
180  vari(f),
181  avi_(avi),
182  bd_(b) {
183  }
184  };
185 
186  class op_dv_vari : public vari {
187  protected:
188  double ad_;
189  vari* bvi_;
190  public:
191  op_dv_vari(double f, double a, vari* bvi) :
192  vari(f),
193  ad_(a),
194  bvi_(bvi) {
195  }
196  };
197 
198  class op_vvv_vari : public vari {
199  protected:
200  vari* avi_;
201  vari* bvi_;
202  vari* cvi_;
203  public:
204  op_vvv_vari(double f, vari* avi, vari* bvi, vari* cvi) :
205  vari(f),
206  avi_(avi),
207  bvi_(bvi),
208  cvi_(cvi) {
209  }
210  };
211 
212  class op_vvd_vari : public vari {
213  protected:
214  vari* avi_;
215  vari* bvi_;
216  double cd_;
217  public:
218  op_vvd_vari(double f, vari* avi, vari* bvi, double c) :
219  vari(f),
220  avi_(avi),
221  bvi_(bvi),
222  cd_(c) {
223  }
224  };
225 
226  class op_vdv_vari : public vari {
227  protected:
228  vari* avi_;
229  double bd_;
230  vari* cvi_;
231  public:
232  op_vdv_vari(double f, vari* avi, double b, vari* cvi) :
233  vari(f),
234  avi_(avi),
235  bd_(b),
236  cvi_(cvi) {
237  }
238  };
239 
240  class op_vdd_vari : public vari {
241  protected:
242  vari* avi_;
243  double bd_;
244  double cd_;
245  public:
246  op_vdd_vari(double f, vari* avi, double b, double c) :
247  vari(f),
248  avi_(avi),
249  bd_(b),
250  cd_(c) {
251  }
252  };
253 
254  class op_dvv_vari : public vari {
255  protected:
256  double ad_;
257  vari* bvi_;
258  vari* cvi_;
259  public:
260  op_dvv_vari(double f, double a, vari* bvi, vari* cvi) :
261  vari(f),
262  ad_(a),
263  bvi_(bvi),
264  cvi_(cvi) {
265  }
266  };
267 
268  class op_dvd_vari : public vari {
269  protected:
270  double ad_;
271  vari* bvi_;
272  double cd_;
273  public:
274  op_dvd_vari(double f, double a, vari* bvi, double c) :
275  vari(f),
276  ad_(a),
277  bvi_(bvi),
278  cd_(c) {
279  }
280  };
281 
282  class op_ddv_vari : public vari {
283  protected:
284  double ad_;
285  double bd_;
286  vari* cvi_;
287  public:
288  op_ddv_vari(double f, double a, double b, vari* cvi) :
289  vari(f),
290  ad_(a),
291  bd_(b),
292  cvi_(cvi) {
293  }
294  };
295 
296  class neg_vari : public op_v_vari {
297  public:
298  neg_vari(vari* avi) :
299  op_v_vari(-(avi->val_), avi) {
300  }
301  void chain() {
302  avi_->adj_ -= adj_;
303  }
304  };
305 
306 
307  class add_vv_vari : public op_vv_vari {
308  public:
309  add_vv_vari(vari* avi, vari* bvi) :
310  op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
311  }
312  void chain() {
313  avi_->adj_ += adj_;
314  bvi_->adj_ += adj_;
315  }
316  };
317 
318  class add_vd_vari : public op_vd_vari {
319  public:
320  add_vd_vari(vari* avi, double b) :
321  op_vd_vari(avi->val_ + b, avi, b) {
322  }
323  void chain() {
324  avi_->adj_ += adj_;
325  }
326  };
327 
328  class increment_vari : public op_v_vari {
329  public:
330  increment_vari(vari* avi) :
331  op_v_vari(avi->val_ + 1.0, avi) {
332  }
333  void chain() {
334  avi_->adj_ += adj_;
335  }
336  };
337 
338  class decrement_vari : public op_v_vari {
339  public:
340  decrement_vari(vari* avi) :
341  op_v_vari(avi->val_ - 1.0, avi) {
342  }
343  void chain() {
344  avi_->adj_ += adj_;
345  }
346  };
347 
348  class subtract_vv_vari : public op_vv_vari {
349  public:
350  subtract_vv_vari(vari* avi, vari* bvi) :
351  op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
352  }
353  void chain() {
354  avi_->adj_ += adj_;
355  bvi_->adj_ -= adj_;
356  }
357  };
358 
359  class subtract_vd_vari : public op_vd_vari {
360  public:
361  subtract_vd_vari(vari* avi, double b) :
362  op_vd_vari(avi->val_ - b, avi, b) {
363  }
364  void chain() {
365  avi_->adj_ += adj_;
366  }
367  };
368 
369  class subtract_dv_vari : public op_dv_vari {
370  public:
371  subtract_dv_vari(double a, vari* bvi) :
372  op_dv_vari(a - bvi->val_, a, bvi) {
373  }
374  void chain() {
375  bvi_->adj_ -= adj_;
376  }
377  };
378 
379  class multiply_vv_vari : public op_vv_vari {
380  public:
381  multiply_vv_vari(vari* avi, vari* bvi) :
382  op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
383  }
384  void chain() {
385  avi_->adj_ += bvi_->val_ * adj_;
386  bvi_->adj_ += avi_->val_ * adj_;
387  }
388  };
389 
390  class multiply_vd_vari : public op_vd_vari {
391  public:
392  multiply_vd_vari(vari* avi, double b) :
393  op_vd_vari(avi->val_ * b, avi, b) {
394  }
395  void chain() {
396  avi_->adj_ += adj_ * bd_;
397  }
398  };
399 
400  // (a/b)' = a' * (1 / b) - b' * (a / [b * b])
401  class divide_vv_vari : public op_vv_vari {
402  public:
403  divide_vv_vari(vari* avi, vari* bvi) :
404  op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
405  }
406  void chain() {
407  avi_->adj_ += adj_ / bvi_->val_;
408  bvi_->adj_ -= adj_ * avi_->val_ / (bvi_->val_ * bvi_->val_);
409  }
410  };
411 
412  class divide_vd_vari : public op_vd_vari {
413  public:
414  divide_vd_vari(vari* avi, double b) :
415  op_vd_vari(avi->val_ / b, avi, b) {
416  }
417  void chain() {
418  avi_->adj_ += adj_ / bd_;
419  }
420  };
421 
422  class divide_dv_vari : public op_dv_vari {
423  public:
424  divide_dv_vari(double a, vari* bvi) :
425  op_dv_vari(a / bvi->val_, a, bvi) {
426  }
427  void chain() {
428  bvi_->adj_ -= adj_ * ad_ / (bvi_->val_ * bvi_->val_);
429  }
430  };
431 
432  class exp_vari : public op_v_vari {
433  public:
434  exp_vari(vari* avi) :
435  op_v_vari(std::exp(avi->val_),avi) {
436  }
437  void chain() {
438  avi_->adj_ += adj_ * val_;
439  }
440  };
441 
442  class log_vari : public op_v_vari {
443  public:
444  log_vari(vari* avi) :
445  op_v_vari(std::log(avi->val_),avi) {
446  }
447  void chain() {
448  avi_->adj_ += adj_ / avi_->val_;
449  }
450  };
451 
452  double LOG_10 = std::log(10.0);
453 
454  class log10_vari : public op_v_vari {
455  public:
456  const double exp_val_;
457  log10_vari(vari* avi) :
458  op_v_vari(std::log10(avi->val_),avi),
459  exp_val_(avi->val_) {
460  }
461  void chain() {
462  avi_->adj_ += adj_ / (LOG_10 * exp_val_);
463  }
464  };
465 
466  class sqrt_vari : public op_v_vari {
467  public:
468  sqrt_vari(vari* avi) :
469  op_v_vari(std::sqrt(avi->val_),avi) {
470  }
471  void chain() {
472  avi_->adj_ += adj_ / (2.0 * val_);
473  }
474  };
475 
476  class pow_vv_vari : public op_vv_vari {
477  public:
478  pow_vv_vari(vari* avi, vari* bvi) :
479  op_vv_vari(std::pow(avi->val_,bvi->val_),avi,bvi) {
480  }
481  void chain() {
482  if (avi_->val_ == 0.0) return; // partials zero, avoids /0 & log(0)
483  avi_->adj_ += adj_ * bvi_->val_ * val_ / avi_->val_;
484  bvi_->adj_ += adj_ * std::log(avi_->val_) * val_;
485  }
486  };
487 
488  class pow_vd_vari : public op_vd_vari {
489  public:
490  pow_vd_vari(vari* avi, double b) :
491  op_vd_vari(std::pow(avi->val_,b),avi,b) {
492  }
493  void chain() {
494  if (avi_->val_ == 0.0) return; // partials zero, avoids /0 & log(0)
495  avi_->adj_ += adj_ * bd_ * val_ / avi_->val_;
496  }
497  };
498 
499  class pow_dv_vari : public op_dv_vari {
500  public:
501  pow_dv_vari(double a, vari* bvi) :
502  op_dv_vari(std::pow(a,bvi->val_),a,bvi) {
503  }
504  void chain() {
505  if (ad_ == 0.0) return; // partials zero, avoids /0 & log(0)
506  bvi_->adj_ += adj_ * std::log(ad_) * val_;
507  }
508  };
509 
510  class cos_vari : public op_v_vari {
511  public:
512  cos_vari(vari* avi) :
513  op_v_vari(std::cos(avi->val_),avi) {
514  }
515  void chain() {
516  avi_->adj_ -= adj_ * std::sin(avi_->val_);
517  }
518  };
519 
520  class sin_vari : public op_v_vari {
521  public:
522  sin_vari(vari* avi) :
523  op_v_vari(std::sin(avi->val_),avi) {
524  }
525  void chain() {
526  avi_->adj_ += adj_ * std::cos(avi_->val_);
527  }
528  };
529 
530  class tan_vari : public op_v_vari {
531  public:
532  tan_vari(vari* avi) :
533  op_v_vari(std::tan(avi->val_),avi) {
534  }
535  void chain() {
536  avi_->adj_ += adj_ * (1.0 + val_ * val_);
537  }
538  };
539 
540  class acos_vari : public op_v_vari {
541  public:
542  acos_vari(vari* avi) :
543  op_v_vari(std::acos(avi->val_),avi) {
544  }
545  void chain() {
546  avi_->adj_ -= adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
547  }
548  };
549 
550  class asin_vari : public op_v_vari {
551  public:
552  asin_vari(vari* avi) :
553  op_v_vari(std::asin(avi->val_),avi) {
554  }
555  void chain() {
556  avi_->adj_ += adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
557  }
558  };
559 
560  class atan_vari : public op_v_vari {
561  public:
562  atan_vari(vari* avi) :
563  op_v_vari(std::atan(avi->val_),avi) {
564  }
565  void chain() {
566  avi_->adj_ += adj_ / (1.0 + (avi_->val_ * avi_->val_));
567  }
568  };
569 
570  class atan2_vv_vari : public op_vv_vari {
571  public:
572  atan2_vv_vari(vari* avi, vari* bvi) :
573  op_vv_vari(std::atan2(avi->val_,bvi->val_),avi,bvi) {
574  }
575  void chain() {
576  double a_sq_plus_b_sq = (avi_->val_ * avi_->val_) + (bvi_->val_ * bvi_->val_);
577  avi_->adj_ += bvi_->val_ / a_sq_plus_b_sq;
578  bvi_->adj_ -= avi_->val_ / a_sq_plus_b_sq;
579  }
580  };
581 
582  class atan2_vd_vari : public op_vd_vari {
583  public:
584  atan2_vd_vari(vari* avi, double b) :
585  op_vd_vari(std::atan2(avi->val_,b),avi,b) {
586  }
587  void chain() {
588  double a_sq_plus_b_sq = (avi_->val_ * avi_->val_) + (bd_ * bd_);
589  avi_->adj_ += bd_ / a_sq_plus_b_sq;
590  }
591  };
592 
593  class atan2_dv_vari : public op_dv_vari {
594  public:
595  atan2_dv_vari(double a, vari* bvi) :
596  op_dv_vari(std::atan2(a,bvi->val_),a,bvi) {
597  }
598  void chain() {
599  double a_sq_plus_b_sq = (ad_ * ad_) + (bvi_->val_ * bvi_->val_);
600  bvi_->adj_ -= ad_ / a_sq_plus_b_sq;
601  }
602  };
603 
604  class cosh_vari : public op_v_vari {
605  public:
606  cosh_vari(vari* avi) :
607  op_v_vari(std::cosh(avi->val_),avi) {
608  }
609  void chain() {
610  avi_->adj_ += adj_ * std::sinh(avi_->val_);
611  }
612  };
613 
614  class sinh_vari : public op_v_vari {
615  public:
616  sinh_vari(vari* avi) :
617  op_v_vari(std::sinh(avi->val_),avi) {
618  }
619  void chain() {
620  avi_->adj_ += adj_ * std::cosh(avi_->val_);
621  }
622  };
623 
624  class tanh_vari : public op_v_vari {
625  public:
626  tanh_vari(vari* avi) :
627  op_v_vari(std::tanh(avi->val_),avi) {
628  }
629  void chain() {
630  double cosh = std::cosh(avi_->val_);
631  avi_->adj_ += adj_ / (cosh * cosh);
632  }
633  };
634 
635 
636  class floor_vari : public vari {
637  public:
638  floor_vari(vari* avi) :
639  vari(std::floor(avi->val_)) {
640  }
641  };
642 
643  class ceil_vari : public vari {
644  public:
645  ceil_vari(vari* avi) :
646  vari(std::ceil(avi->val_)) {
647  }
648  };
649 
650  class fmod_vv_vari : public op_vv_vari {
651  public:
652  fmod_vv_vari(vari* avi, vari* bvi) :
653  op_vv_vari(std::fmod(avi->val_,bvi->val_),avi,bvi) {
654  }
655  void chain() {
656  avi_->adj_ += adj_;
657  bvi_->adj_ -= adj_ * static_cast<int>(avi_->val_ / bvi_->val_);
658  }
659  };
660 
661  class fmod_vd_vari : public op_v_vari {
662  public:
663  fmod_vd_vari(vari* avi, double b) :
664  op_v_vari(std::fmod(avi->val_,b),avi) {
665  }
666  void chain() {
667  avi_->adj_ += adj_;
668  }
669  };
670 
671  class fmod_dv_vari : public op_dv_vari {
672  public:
673  fmod_dv_vari(double a, vari* bvi) :
674  op_dv_vari(std::fmod(a,bvi->val_),a,bvi) {
675  }
676  void chain() {
677  int d = static_cast<int>(ad_ / bvi_->val_);
678  bvi_->adj_ -= adj_ * d;
679  }
680  };
681 
682 
683 
684 
685 
686  }
687 
688  // ********************* vari UP, var DOWN ***********************************************************
689 
690 
700  class var {
701  public:
702 
703  typedef double Scalar;
704 
712  vari * vi_;
713 
719  explicit var(vari* vi) :
720  vi_(vi) {
721  }
722 
730  var() :
731  vi_(0) {
732  }
733 
740  var(bool b) :
741  vi_(new vari(static_cast<double>(b))) {
742  }
743 
750  var(char c) :
751  vi_(new vari(static_cast<double>(c))) {
752  }
753 
760  var(short n) :
761  vi_(new vari(static_cast<double>(n))) {
762  }
763 
770  var(unsigned short n) :
771  vi_(new vari(static_cast<double>(n))) {
772  }
773 
780  var(int n) :
781  vi_(new vari(static_cast<double>(n))) {
782  }
783 
790  var(unsigned int n) :
791  vi_(new vari(static_cast<double>(n))) {
792  }
793 
800  var(long int n) :
801  vi_(new vari(static_cast<double>(n))) {
802  }
803 
810  var(unsigned long int n) :
811  vi_(new vari(static_cast<double>(n))) {
812  }
813 
820  var(float x) :
821  vi_(new vari(static_cast<double>(x))) {
822  }
823 
829  var(double x) :
830  vi_(new vari(x)) {
831  }
832 
839  var(long double x) :
840  vi_(new vari(static_cast<double>(x))) {
841  }
842 
848  inline double val() const {
849  return vi_->val_;
850  }
851 
864  void grad(std::vector<var>& x,
865  std::vector<double>& g) {
866  vari::grad(vi_);
867  g.resize(x.size());
868  for (size_t i = 0U; i < x.size(); ++i)
869  g[i] = x[i].vi_->adj_;
871  }
872 
889  void grad() {
890  vari::grad(vi_);
892  }
893 
894  // COMPOUND ASSIGNMENT OPERATORS
895 
906  inline var& operator+=(const var& b) {
907  vi_ = new add_vv_vari(vi_,b.vi_);
908  return *this;
909  }
910 
921  inline var& operator+=(const double& b) {
922  vi_ = new add_vd_vari(vi_,b);
923  return *this;
924  }
925 
937  inline var& operator-=(const var& b) {
938  vi_ = new subtract_vv_vari(vi_,b.vi_);
939  return *this;
940  }
941 
953  inline var& operator-=(const double& b) {
954  vi_ = new subtract_vd_vari(vi_,b);
955  return *this;
956  }
957 
969  inline var& operator*=(const var& b) {
970  vi_ = new multiply_vv_vari(vi_,b.vi_);
971  return *this;
972  }
973 
985  inline var& operator*=(const double& b) {
986  vi_ = new multiply_vd_vari(vi_,b);
987  return *this;
988  }
989 
1000  inline var& operator/=(const var& b) {
1001  vi_ = new divide_vv_vari(vi_,b.vi_);
1002  return *this;
1003  }
1004 
1016  inline var& operator/=(const double& b) {
1017  vi_ = new divide_vd_vari(vi_,b);
1018  return *this;
1019  };
1020 
1021 
1022  };
1023 
1024  // COMPARISON OPERATORS
1025 
1034  inline bool operator==(const var& a, const var& b) {
1035  return a.val() == b.val();
1036  }
1037 
1047  inline bool operator==(const var& a, const double& b) {
1048  return a.val() == b;
1049  }
1050 
1059  inline bool operator==(const double& a, const var& b) {
1060  return a == b.val();
1061  }
1062 
1071  inline bool operator!=(const var& a, const var& b) {
1072  return a.val() != b.val();
1073  }
1074 
1084  inline bool operator!=(const var& a, const double& b) {
1085  return a.val() != b;
1086  }
1087 
1097  inline bool operator!=(const double& a, const var& b) {
1098  return a != b.val();
1099  }
1100 
1108  inline bool operator<(const var& a, const var& b) {
1109  return a.val() < b.val();
1110  }
1111 
1120  inline bool operator<(const var& a, const double& b) {
1121  return a.val() < b;
1122  }
1123 
1132  inline bool operator<(const double& a, const var& b) {
1133  return a < b.val();
1134  }
1135 
1143  inline bool operator>(const var& a, const var& b) {
1144  return a.val() > b.val();
1145  }
1146 
1155  inline bool operator>(const var& a, const double& b) {
1156  return a.val() > b;
1157  }
1158 
1167  inline bool operator>(const double& a, const var& b) {
1168  return a > b.val();
1169  }
1170 
1180  inline bool operator<=(const var& a, const var& b) {
1181  return a.val() <= b.val();
1182  }
1183 
1193  inline bool operator<=(const var& a, const double& b) {
1194  return a.val() <= b;
1195  }
1196 
1206  inline bool operator<=(const double& a, const var& b) {
1207  return a <= b.val();
1208  }
1209 
1219  inline bool operator>=(const var& a, const var& b) {
1220  return a.val() >= b.val();
1221  }
1222 
1232  inline bool operator>=(const var& a, const double& b) {
1233  return a.val() >= b;
1234  }
1235 
1245  inline bool operator>=(const double& a, const var& b) {
1246  return a >= b.val();
1247  }
1248 
1249  // LOGICAL OPERATORS
1250 
1264  inline bool operator!(const var& a) {
1265  return !a.val();
1266  }
1267 
1268  // ARITHMETIC OPERATORS
1269 
1285  inline var operator+(const var& a) {
1286  return a;
1287  }
1288 
1297  inline var operator-(const var& a) {
1298  return var(new neg_vari(a.vi_));
1299  }
1300 
1314  inline var operator+(const var& a, const var& b) {
1315  return var(new add_vv_vari(a.vi_,b.vi_));
1316  }
1317 
1318 
1330  inline var operator+(const var& a, const double& b) {
1331  return var(new add_vd_vari(a.vi_,b));
1332  }
1333 
1345  inline var operator+(const double& a, const var& b) {
1346  return var(new add_vd_vari(b.vi_,a)); // by symmetry
1347  }
1348 
1363  inline var operator-(const var& a, const var& b) {
1364  return var(new subtract_vv_vari(a.vi_,b.vi_));
1365  }
1366 
1378  inline var operator-(const var& a, const double& b) {
1379  return var(new subtract_vd_vari(a.vi_,b));
1380  }
1381 
1393  inline var operator-(const double& a, const var& b) {
1394  return var(new subtract_dv_vari(a,b.vi_));
1395  }
1396 
1410  inline var operator*(const var& a, const var& b) {
1411  return var(new multiply_vv_vari(a.vi_,b.vi_));
1412  }
1413 
1425  inline var operator*(const var& a, const double& b) {
1426  return var(new multiply_vd_vari(a.vi_,b));
1427  }
1428 
1440  inline var operator*(const double& a, const var& b) {
1441  return var(new multiply_vd_vari(b.vi_,a)); // by symmetry
1442  }
1443 
1458  inline var operator/(const var& a, const var& b) {
1459  return var(new divide_vv_vari(a.vi_,b.vi_));
1460  }
1461 
1473  inline var operator/(const var& a, const double& b) {
1474  return var(new divide_vd_vari(a.vi_,b));
1475  }
1476 
1488  inline var operator/(const double& a, const var& b) {
1489  return var(new divide_dv_vari(a,b.vi_));
1490  }
1491 
1501  inline var& operator++(var& a) {
1502  a.vi_ = new increment_vari(a.vi_);
1503  return a;
1504  }
1505 
1519  inline var operator++(var& a, int dummy) {
1520  var temp(a);
1521  a.vi_ = new increment_vari(a.vi_);
1522  return temp;
1523  }
1524 
1538  inline var& operator--(var& a) {
1539  a.vi_ = new decrement_vari(a.vi_);
1540  return a;
1541  }
1542 
1556  inline var operator--(var& a, int dummy) {
1557  var temp(a);
1558  a.vi_ = new decrement_vari(a.vi_);
1559  return temp;
1560  }
1561 
1562  // CMATH EXP AND LOG
1563 
1570  inline var exp(const var& a) {
1571  return var(new exp_vari(a.vi_));
1572  }
1573 
1584  inline var log(const var& a) {
1585  return var(new log_vari(a.vi_));
1586  }
1587 
1598  inline var log10(const var& a) {
1599  return var(new log10_vari(a.vi_));
1600  }
1601 
1602 
1603  // POWER FUNCTIONS
1604 
1615  inline var sqrt(const var& a) {
1616  return var(new sqrt_vari(a.vi_));
1617  }
1618 
1632  inline var pow(const var& base, const var& exponent) {
1633  return var(new pow_vv_vari(base.vi_,exponent.vi_));
1634  }
1635 
1648  inline var pow(const var& base, const double& exponent) {
1649  return var(new pow_vd_vari(base.vi_,exponent));
1650  }
1651 
1664  inline var pow(const double& base, const var& exponent) {
1665  return var(new pow_dv_vari(base,exponent.vi_));
1666  }
1667 
1668 
1669  // TRIG FUNCTIONS
1670 
1681  inline var cos(const var& a) {
1682  return var(new cos_vari(a.vi_));
1683  }
1684 
1695  inline var sin(const var& a) {
1696  return var(new sin_vari(a.vi_));
1697  }
1698 
1709  inline var tan(const var& a) {
1710  return var(new tan_vari(a.vi_));
1711  }
1712 
1724  inline var acos(const var& a) {
1725  return var(new acos_vari(a.vi_));
1726  }
1727 
1739  inline var asin(const var& a) {
1740  return var(new asin_vari(a.vi_));
1741  }
1742 
1754  inline var atan(const var& a) {
1755  return var(new atan_vari(a.vi_));
1756  }
1757 
1773  inline var atan2(const var& a, const var& b) {
1774  return var(new atan2_vv_vari(a.vi_,b.vi_));
1775  }
1776 
1789  inline var atan2(const var& a, const double& b) {
1790  return var(new atan2_vd_vari(a.vi_,b));
1791  }
1792 
1805  inline var atan2(const double& a, const var& b) {
1806  return var(new atan2_dv_vari(a,b.vi_));
1807  }
1808 
1809  // HYPERBOLIC FUNCTIONS
1810 
1821  inline var cosh(const var& a) {
1822  return var(new cosh_vari(a.vi_));
1823  }
1824 
1835  inline var sinh(const var& a) {
1836  return var(new sinh_vari(a.vi_));
1837  }
1838 
1849  inline var tanh(const var& a) {
1850  return var(new tanh_vari(a.vi_));
1851  }
1852 
1853 
1854  // ROUNDING FUNCTIONS
1855 
1872  inline var fabs(const var& a) {
1873  // cut-and-paste from abs()
1874  if (a.val() > 0.0)
1875  return a;
1876  if (a.val() < 0.0)
1877  return var(new neg_vari(a.vi_));
1878  return var(new vari(0.0));
1879  }
1880 
1899  inline var floor(const var& a) {
1900  return var(new floor_vari(a.vi_));
1901  }
1902 
1921  inline var ceil(const var& a) {
1922  return var(new ceil_vari(a.vi_));
1923  }
1924 
1942  inline var fmod(const var& a, const var& b) {
1943  return var(new fmod_vv_vari(a.vi_,b.vi_));
1944  }
1945 
1959  inline var fmod(const var& a, const double& b) {
1960  return var(new fmod_vd_vari(a.vi_,b));
1961  }
1962 
1976  inline var fmod(const double& a, const var& b) {
1977  return var(new fmod_dv_vari(a,b.vi_));
1978  }
1979 
1980 
1981  // STD LIB FUNCTIONS
1982 
1997  inline var abs(const var& a) {
1998  // cut-and-paste from fabs()
1999  if (a.val() > 0.0)
2000  return a;
2001  if (a.val() < 0.0)
2002  return var(new neg_vari(a.vi_));
2003  return var(new vari(0.0));
2004  }
2005 
2006  }
2007 
2008 }
2009 
2010 #endif
double cd_
double bd_
double ad_
vari * avi_
vari * cvi_
const double exp_val_
vari * bvi_
Independent (input) and dependent (output) variables for gradients.
Definition: agrad.hpp:214
void grad()
Compute gradients of this dependent variable with respect to all variables on which it depends.
var & operator*=(const double &b)
The compound multiply/assignment operator for scalars (C++).
var(unsigned int n)
Construct a variable by static casting the specified value to double.
var()
Construct a variable for later assignment.
var(char c)
Construct a variable by static casting the specified value to double.
var(float x)
Construct a variable by static casting the specified value to double.
var & operator*=(const var &b)
The compound multiply/assignment operator for variables (C++).
var(long int n)
Construct a variable by static casting the specified value to double.
var(double x)
Construct a variable with the specified value.
var & operator+=(const double &b)
The compound add/assignment operator for scalars (C++).
var & operator-=(const var &b)
The compound subtract/assignment operator for variables (C++).
var & operator+=(const var &b)
The compound add/assignment operator for variables (C++).
var(int n)
Construct a variable by static casting the specified value to double.
var(long double x)
Construct a variable by static casting the specified value to double.
var(unsigned short n)
Construct a variable by static casting the specified value to double.
var(unsigned long int n)
Construct a variable by static casting the specified value to double.
var & operator/=(const double &b)
The compound divide/assignment operator for scalars (C++).
void grad(std::vector< var > &x, std::vector< double > &g)
Compute the gradient of this dependent variable with respect to the specified vector of independent v...
double val() const
Return the value of this variable.
var(short n)
Construct a variable by static casting the specified value to double.
var(vari *vi)
Construct a variable from a pointer to a variable implementation.
var & operator/=(const var &b)
The compound divide/assignment operator for variables (C++).
var & operator-=(const double &b)
The compound subtract/assignment operator for scalars (C++).
var(bool b)
Construct a variable by static casting the specified value to double.
vari * vi_
Pointer to the implementation of this variable.
Definition: agrad.hpp:227
The variable implementation base class.
Definition: agrad.hpp:104
const double val_
The value of this variable.
Definition: agrad.hpp:113
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: agrad.hpp:119
friend class var
Definition: agrad.hpp:106
static void recover_memory()
Recover memory used for all variables for reuse.
static void free_memory()
Return all memory used for gradients back to the system.
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends.
vari(const double x)
Construct a variable implementation from a value.
void free_all()
Free all memory used by the stack allocator other than the initial block allocation back to the syste...
void recover_all()
Recover all the memory used by the stack allocator.
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator.
std::vector< chainable * > var_stack_
Definition: agrad.cpp:10
var cosh(const var &a)
Return the hyperbolic cosine of the specified variable (cmath).
Definition: agrad.hpp:1972
var atan(const var &a)
Return the principal value of the arc tangent, in radians, of the specified variable (cmath).
Definition: agrad.hpp:1906
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
Definition: agrad.hpp:1761
var fabs(const var &a)
Return the absolute value of the variable (cmath).
Definition: agrad.hpp:2023
bool operator<=(const var &a, const var &b)
Less than or equal operator comparing two variables' values (C++).
Definition: agrad.hpp:1313
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
Definition: agrad.hpp:2187
var abs(const var &a)
Return the absolute value of the variable (std).
Definition: agrad.hpp:2149
bool operator>=(const var &a, const var &b)
Greater than or equal operator comparing two variables' values (C++).
Definition: agrad.hpp:1352
var operator-(const var &a)
Unary negation operator for variables (C++).
Definition: agrad.hpp:1431
var asin(const var &a)
Return the principal value of the arc sine, in radians, of the specified variable (cmath).
Definition: agrad.hpp:1891
var operator+(const var &a)
Unary plus operator for variables (C++).
Definition: agrad.hpp:1419
var acos(const var &a)
Return the principal value of the arc cosine of a variable, in radians (cmath).
Definition: agrad.hpp:1876
var ceil(const var &a)
Return the ceiling of the specified variable (cmath).
Definition: agrad.hpp:2073
bool operator!=(const var &a, const var &b)
Inequality operator comparing two variables' values (C++).
Definition: agrad.hpp:1204
var log10(const var &a)
Return the base 10 log of the specified variable (cmath).
Definition: agrad.hpp:1744
var cos(const var &a)
Return the cosine of a radian-scaled variable (cmath).
Definition: agrad.hpp:1833
var operator/(const var &a, const var &b)
Division operator for two variables (C++).
Definition: agrad.hpp:1602
var sin(const var &a)
Return the sine of a radian-scaled variable (cmath).
Definition: agrad.hpp:1847
bool operator<(const var &a, const var &b)
Less than operator comparing variables' values (C++).
Definition: agrad.hpp:1241
bool operator==(const var &a, const var &b)
Equality operator comparing two variables' values (C++).
Definition: agrad.hpp:1167
var tanh(const var &a)
Return the hyperbolic tangent of the specified variable (cmath).
Definition: agrad.hpp:2000
var floor(const var &a)
Return the floor of the specified variable (cmath).
Definition: agrad.hpp:2051
memory::stack_alloc memalloc_
Definition: agrad.cpp:11
var pow(const var &base, const var &exponent)
Return the base raised to the power of the exponent (cmath).
Definition: agrad.hpp:1778
var operator*(const var &a, const var &b)
Multiplication operator for two variables (C++).
Definition: agrad.hpp:1550
var & operator++(var &a)
Prefix increment operator for variables (C++).
Definition: agrad.hpp:1647
var log(const var &a)
Return the natural log of the specified variable (cmath).
Definition: agrad.hpp:1730
var fmod(const var &a, const var &b)
Return the floating point remainder after dividing the first variable by the second (cmath).
Definition: agrad.hpp:2094
bool operator>(const var &a, const var &b)
Greater than operator comparing variables' values (C++).
Definition: agrad.hpp:1276
var tan(const var &a)
Return the tangent of a radian-scaled variable (cmath).
Definition: agrad.hpp:1861
var atan2(const var &a, const var &b)
Return the principal value of the arc tangent, in radians, of the first variable divided by the secon...
Definition: agrad.hpp:1924
bool operator!(const var &a)
Prefix logical negation for the value of variables (C++).
Definition: agrad.hpp:1397
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
Definition: agrad.hpp:1716
var sinh(const var &a)
Return the hyperbolic sine of the specified variable (cmath).
Definition: agrad.hpp:1986
var & operator--(var &a)
Prefix decrement operator for variables (C++).
Definition: agrad.hpp:1684
const double LOG_10
The natural logarithm of 10, .
Definition: constants.hpp:32
Probability, optimization and sampling library.
Definition: agrad.cpp:6
Template specification of functions in std for Stan.
Definition: agrad.hpp:2306

     [ Stan Home Page ] © 2011–2012, Stan Development Team.