Stan  1.0
probability, sampling & optimization
agrad.hpp
Go to the documentation of this file.
1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
3 
4 #include <cmath>
5 #include <cstddef>
6 #include <limits>
7 #include <stdexcept>
8 #include <vector>
9 #include <ostream>
10 #include <iostream>
11 
13 
14 namespace stan {
15 
16  namespace agrad {
17 
18  class chainable;
19  class vari;
20  class var;
21 
22  // FIXME: manage all this as a single singleton (thread local)
23  extern std::vector<chainable*> var_stack_;
24  extern memory::stack_alloc memalloc_;
25 
26  static void recover_memory();
27 
28  static void grad(chainable* vi);
29 
34  class chainable {
35 
36  public:
37 
42  chainable() { }
43 
50  // handled automatically
51  }
52 
58  virtual void chain() {
59  }
60 
65  virtual void init_dependent() {
66  }
67 
72  virtual void set_zero_adjoint() {
73  }
74 
75 
85  static inline void* operator new(size_t nbytes) {
86  return memalloc_.alloc(nbytes);
87  }
88 
89  };
90 
104  class vari : public chainable {
105  private:
106  friend class var;
107 
108  public:
109 
113  const double val_;
114 
119  double adj_;
120 
133  vari(const double x):
134  val_(x),
135  adj_(0.0) {
136  var_stack_.push_back(this);
137  }
138 
146  ~vari() {
147  throw std::logic_error("vari destruction handled automatically");
148  }
149 
155  virtual void init_dependent() {
156  adj_ = 1.0; // droot/droot = 1
157  }
158 
162  virtual void set_zero_adjoint() {
163  adj_ = 0.0;
164  }
165 
175  friend std::ostream& operator<<(std::ostream& os, const vari* v) {
176  return os << v << " " << v->val_ << " : " << v->adj_;
177  }
178 
179  };
180 
190  inline void print_stack(std::ostream& o) {
191  o << "STACK, size=" << var_stack_.size() << std::endl;
192  for (size_t i = 0; i < var_stack_.size(); ++i)
193  o << i
194  << " " << var_stack_[i]
195  << " " << (static_cast<vari*>(var_stack_[i]))->val_
196  << " : " << (static_cast<vari*>(var_stack_[i]))->adj_
197  << std::endl;
198  }
199 
200 
214  class var {
215  public:
216 
217  // FIXME: doc what this is for
218  typedef double Scalar;
219 
228 
239  return (vi_ == static_cast<vari*>(0U));
240  }
241 
247  explicit var(vari* vi)
248  : vi_(vi)
249  { }
250 
258  var()
259  : vi_(static_cast<vari*>(0U))
260  { }
261 
268  var(bool b) :
269  vi_(new vari(static_cast<double>(b))) {
270  }
271 
278  var(char c) :
279  vi_(new vari(static_cast<double>(c))) {
280  }
281 
288  var(short n) :
289  vi_(new vari(static_cast<double>(n))) {
290  }
291 
298  var(unsigned short n) :
299  vi_(new vari(static_cast<double>(n))) {
300  }
301 
308  var(int n) :
309  vi_(new vari(static_cast<double>(n))) {
310  }
311 
318  var(unsigned int n) :
319  vi_(new vari(static_cast<double>(n))) {
320  }
321 
328  var(long int n) :
329  vi_(new vari(static_cast<double>(n))) {
330  }
331 
338  var(unsigned long int n) :
339  vi_(new vari(static_cast<double>(n))) {
340  }
341 
348  var(unsigned long long n) :
349  vi_(new vari(static_cast<double>(n))) {
350  }
351 
358  var(long long n) :
359  vi_(new vari(static_cast<double>(n))) {
360  }
361 
368  var(float x) :
369  vi_(new vari(static_cast<double>(x))) {
370  }
371 
377  var(double x) :
378  vi_(new vari(x)) {
379  }
380 
387  var(long double x) :
388  vi_(new vari(static_cast<double>(x))) {
389  }
390 
396  inline double val() const {
397  return vi_->val_;
398  }
399 
408  inline double adj() const {
409  return vi_->adj_;
410  }
411 
424  void grad(std::vector<var>& x,
425  std::vector<double>& g) {
427  g.resize(x.size());
428  for (size_t i = 0; i < x.size(); ++i)
429  g[i] = x[i].vi_->adj_;
430  recover_memory();
431  }
432 
449  void grad() {
451  recover_memory();
452  }
453 
454  // POINTER OVERRIDES
455 
468  inline vari& operator*() {
469  return *vi_;
470  }
471 
482  inline vari* operator->() {
483  return vi_;
484  }
485 
486  // COMPOUND ASSIGNMENT OPERATORS
487 
498  inline var& operator+=(const var& b);
499 
510  inline var& operator+=(const double b);
511 
523  inline var& operator-=(const var& b);
524 
536  inline var& operator-=(const double b);
537 
549  inline var& operator*=(const var& b);
550 
562  inline var& operator*=(const double b);
563 
574  inline var& operator/=(const var& b);
575 
587  inline var& operator/=(const double b);
588 
597  friend std::ostream& operator<<(std::ostream& os, const var& v) {
598  return os << v.val() << ':' << v.adj();
599  }
600  };
601 
602  namespace {
603 
604  class op_v_vari : public vari {
605  protected:
606  vari* avi_;
607  public:
608  op_v_vari(double f, vari* avi) :
609  vari(f),
610  avi_(avi) {
611  }
612  };
613 
614  class op_vv_vari : public vari {
615  protected:
616  vari* avi_;
617  vari* bvi_;
618  public:
619  op_vv_vari(double f, vari* avi, vari* bvi):
620  vari(f),
621  avi_(avi),
622  bvi_(bvi) {
623  }
624  };
625 
626  class op_vd_vari : public vari {
627  protected:
628  vari* avi_;
629  double bd_;
630  public:
631  op_vd_vari(double f, vari* avi, double b) :
632  vari(f),
633  avi_(avi),
634  bd_(b) {
635  }
636  };
637 
638  class op_dv_vari : public vari {
639  protected:
640  double ad_;
641  vari* bvi_;
642  public:
643  op_dv_vari(double f, double a, vari* bvi) :
644  vari(f),
645  ad_(a),
646  bvi_(bvi) {
647  }
648  };
649 
650  class op_vvv_vari : public vari {
651  protected:
652  vari* avi_;
653  vari* bvi_;
654  vari* cvi_;
655  public:
656  op_vvv_vari(double f, vari* avi, vari* bvi, vari* cvi) :
657  vari(f),
658  avi_(avi),
659  bvi_(bvi),
660  cvi_(cvi) {
661  }
662  };
663 
664  class op_vvd_vari : public vari {
665  protected:
666  vari* avi_;
667  vari* bvi_;
668  double cd_;
669  public:
670  op_vvd_vari(double f, vari* avi, vari* bvi, double c) :
671  vari(f),
672  avi_(avi),
673  bvi_(bvi),
674  cd_(c) {
675  }
676  };
677 
678  class op_vdv_vari : public vari {
679  protected:
680  vari* avi_;
681  double bd_;
682  vari* cvi_;
683  public:
684  op_vdv_vari(double f, vari* avi, double b, vari* cvi) :
685  vari(f),
686  avi_(avi),
687  bd_(b),
688  cvi_(cvi) {
689  }
690  };
691 
692  class op_vdd_vari : public vari {
693  protected:
694  vari* avi_;
695  double bd_;
696  double cd_;
697  public:
698  op_vdd_vari(double f, vari* avi, double b, double c) :
699  vari(f),
700  avi_(avi),
701  bd_(b),
702  cd_(c) {
703  }
704  };
705 
706  class op_dvv_vari : public vari {
707  protected:
708  double ad_;
709  vari* bvi_;
710  vari* cvi_;
711  public:
712  op_dvv_vari(double f, double a, vari* bvi, vari* cvi) :
713  vari(f),
714  ad_(a),
715  bvi_(bvi),
716  cvi_(cvi) {
717  }
718  };
719 
720  class op_dvd_vari : public vari {
721  protected:
722  double ad_;
723  vari* bvi_;
724  double cd_;
725  public:
726  op_dvd_vari(double f, double a, vari* bvi, double c) :
727  vari(f),
728  ad_(a),
729  bvi_(bvi),
730  cd_(c) {
731  }
732  };
733 
734  class op_ddv_vari : public vari {
735  protected:
736  double ad_;
737  double bd_;
738  vari* cvi_;
739  public:
740  op_ddv_vari(double f, double a, double b, vari* cvi) :
741  vari(f),
742  ad_(a),
743  bd_(b),
744  cvi_(cvi) {
745  }
746  };
747 
748  // FIXME: memory leak -- copy vector to local memory
749  class op_vector_vari : public vari {
750  protected:
751  const size_t size_;
752  vari** vis_;
753  public:
754  op_vector_vari(double f, const std::vector<stan::agrad::var>& vs) :
755  vari(f),
756  size_(vs.size()) {
757  vis_ = (vari**) operator new(sizeof(vari*[vs.size()]));
758  for (size_t i = 0; i < vs.size(); ++i)
759  vis_[i] = vs[i].vi_;
760  }
761  vari* operator[](size_t n) const {
762  return vis_[n];
763  }
764  size_t size() {
765  return size_;
766  }
767  };
768 
769  class neg_vari : public op_v_vari {
770  public:
771  neg_vari(vari* avi) :
772  op_v_vari(-(avi->val_), avi) {
773  }
774  void chain() {
775  avi_->adj_ -= adj_;
776  }
777  };
778 
779 
780  class add_vv_vari : public op_vv_vari {
781  public:
782  add_vv_vari(vari* avi, vari* bvi) :
783  op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
784  }
785  void chain() {
786  avi_->adj_ += adj_;
787  bvi_->adj_ += adj_;
788  }
789  };
790 
791  class add_vd_vari : public op_vd_vari {
792  public:
793  add_vd_vari(vari* avi, double b) :
794  op_vd_vari(avi->val_ + b, avi, b) {
795  }
796  void chain() {
797  avi_->adj_ += adj_;
798  }
799  };
800 
801  class increment_vari : public op_v_vari {
802  public:
803  increment_vari(vari* avi) :
804  op_v_vari(avi->val_ + 1.0, avi) {
805  }
806  void chain() {
807  avi_->adj_ += adj_;
808  }
809  };
810 
811  class decrement_vari : public op_v_vari {
812  public:
813  decrement_vari(vari* avi) :
814  op_v_vari(avi->val_ - 1.0, avi) {
815  }
816  void chain() {
817  avi_->adj_ += adj_;
818  }
819  };
820 
821  class subtract_vv_vari : public op_vv_vari {
822  public:
823  subtract_vv_vari(vari* avi, vari* bvi) :
824  op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
825  }
826  void chain() {
827  avi_->adj_ += adj_;
828  bvi_->adj_ -= adj_;
829  }
830  };
831 
832  class subtract_vd_vari : public op_vd_vari {
833  public:
834  subtract_vd_vari(vari* avi, double b) :
835  op_vd_vari(avi->val_ - b, avi, b) {
836  }
837  void chain() {
838  avi_->adj_ += adj_;
839  }
840  };
841 
842  class subtract_dv_vari : public op_dv_vari {
843  public:
844  subtract_dv_vari(double a, vari* bvi) :
845  op_dv_vari(a - bvi->val_, a, bvi) {
846  }
847  void chain() {
848  bvi_->adj_ -= adj_;
849  }
850  };
851 
852  class multiply_vv_vari : public op_vv_vari {
853  public:
854  multiply_vv_vari(vari* avi, vari* bvi) :
855  op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
856  }
857  void chain() {
858  avi_->adj_ += bvi_->val_ * adj_;
859  bvi_->adj_ += avi_->val_ * adj_;
860  }
861  };
862 
863  class multiply_vd_vari : public op_vd_vari {
864  public:
865  multiply_vd_vari(vari* avi, double b) :
866  op_vd_vari(avi->val_ * b, avi, b) {
867  }
868  void chain() {
869  avi_->adj_ += adj_ * bd_;
870  }
871  };
872 
873  // (a/b)' = a' * (1 / b) - b' * (a / [b * b])
874  class divide_vv_vari : public op_vv_vari {
875  public:
876  divide_vv_vari(vari* avi, vari* bvi) :
877  op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
878  }
879  void chain() {
880  avi_->adj_ += adj_ / bvi_->val_;
881  bvi_->adj_ -= adj_ * avi_->val_ / (bvi_->val_ * bvi_->val_);
882  }
883  };
884 
885  class divide_vd_vari : public op_vd_vari {
886  public:
887  divide_vd_vari(vari* avi, double b) :
888  op_vd_vari(avi->val_ / b, avi, b) {
889  }
890  void chain() {
891  avi_->adj_ += adj_ / bd_;
892  }
893  };
894 
895  class divide_dv_vari : public op_dv_vari {
896  public:
897  divide_dv_vari(double a, vari* bvi) :
898  op_dv_vari(a / bvi->val_, a, bvi) {
899  }
900  void chain() {
901  bvi_->adj_ -= adj_ * ad_ / (bvi_->val_ * bvi_->val_);
902  }
903  };
904 
905  class exp_vari : public op_v_vari {
906  public:
907  exp_vari(vari* avi) :
908  op_v_vari(std::exp(avi->val_),avi) {
909  }
910  void chain() {
911  avi_->adj_ += adj_ * val_;
912  }
913  };
914 
915  class log_vari : public op_v_vari {
916  public:
917  log_vari(vari* avi) :
918  op_v_vari(std::log(avi->val_),avi) {
919  }
920  void chain() {
921  avi_->adj_ += adj_ / avi_->val_;
922  }
923  };
924 
925  double LOG_10 = std::log(10.0);
926 
927  class log10_vari : public op_v_vari {
928  public:
929  const double exp_val_;
930  log10_vari(vari* avi) :
931  op_v_vari(std::log10(avi->val_),avi),
932  exp_val_(avi->val_) {
933  }
934  void chain() {
935  avi_->adj_ += adj_ / (LOG_10 * exp_val_);
936  }
937  };
938 
939  class sqrt_vari : public op_v_vari {
940  public:
941  sqrt_vari(vari* avi) :
942  op_v_vari(std::sqrt(avi->val_),avi) {
943  }
944  void chain() {
945  avi_->adj_ += adj_ / (2.0 * val_);
946  }
947  };
948 
949  class pow_vv_vari : public op_vv_vari {
950  public:
951  pow_vv_vari(vari* avi, vari* bvi) :
952  op_vv_vari(std::pow(avi->val_,bvi->val_),avi,bvi) {
953  }
954  void chain() {
955  if (avi_->val_ == 0.0) return; // partials zero, avoids 0 & log(0)
956  avi_->adj_ += adj_ * bvi_->val_ * val_ / avi_->val_;
957  bvi_->adj_ += adj_ * std::log(avi_->val_) * val_;
958  }
959  };
960 
961  class pow_vd_vari : public op_vd_vari {
962  public:
963  pow_vd_vari(vari* avi, double b) :
964  op_vd_vari(std::pow(avi->val_,b),avi,b) {
965  }
966  void chain() {
967  if (avi_->val_ == 0.0) return; // partials zero, avoids 0 & log(0)
968  avi_->adj_ += adj_ * bd_ * val_ / avi_->val_;
969  }
970  };
971 
972  class pow_dv_vari : public op_dv_vari {
973  public:
974  pow_dv_vari(double a, vari* bvi) :
975  op_dv_vari(std::pow(a,bvi->val_),a,bvi) {
976  }
977  void chain() {
978  if (ad_ == 0.0) return; // partials zero, avoids 0 & log(0)
979  bvi_->adj_ += adj_ * std::log(ad_) * val_;
980  }
981  };
982 
983  class cos_vari : public op_v_vari {
984  public:
985  cos_vari(vari* avi) :
986  op_v_vari(std::cos(avi->val_),avi) {
987  }
988  void chain() {
989  avi_->adj_ -= adj_ * std::sin(avi_->val_);
990  }
991  };
992 
993  class sin_vari : public op_v_vari {
994  public:
995  sin_vari(vari* avi) :
996  op_v_vari(std::sin(avi->val_),avi) {
997  }
998  void chain() {
999  avi_->adj_ += adj_ * std::cos(avi_->val_);
1000  }
1001  };
1002 
1003  class tan_vari : public op_v_vari {
1004  public:
1005  tan_vari(vari* avi) :
1006  op_v_vari(std::tan(avi->val_),avi) {
1007  }
1008  void chain() {
1009  avi_->adj_ += adj_ * (1.0 + val_ * val_);
1010  }
1011  };
1012 
1013  class acos_vari : public op_v_vari {
1014  public:
1015  acos_vari(vari* avi) :
1016  op_v_vari(std::acos(avi->val_),avi) {
1017  }
1018  void chain() {
1019  avi_->adj_ -= adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
1020  }
1021  };
1022 
1023  class asin_vari : public op_v_vari {
1024  public:
1025  asin_vari(vari* avi) :
1026  op_v_vari(std::asin(avi->val_),avi) {
1027  }
1028  void chain() {
1029  avi_->adj_ += adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
1030  }
1031  };
1032 
1033  class atan_vari : public op_v_vari {
1034  public:
1035  atan_vari(vari* avi) :
1036  op_v_vari(std::atan(avi->val_),avi) {
1037  }
1038  void chain() {
1039  avi_->adj_ += adj_ / (1.0 + (avi_->val_ * avi_->val_));
1040  }
1041  };
1042 
1043  class atan2_vv_vari : public op_vv_vari {
1044  public:
1045  atan2_vv_vari(vari* avi, vari* bvi) :
1046  op_vv_vari(std::atan2(avi->val_,bvi->val_),avi,bvi) {
1047  }
1048  void chain() {
1049  double a_sq_plus_b_sq = (avi_->val_ * avi_->val_) + (bvi_->val_ * bvi_->val_);
1050  avi_->adj_ += bvi_->val_ / a_sq_plus_b_sq;
1051  bvi_->adj_ -= avi_->val_ / a_sq_plus_b_sq;
1052  }
1053  };
1054 
1055  class atan2_vd_vari : public op_vd_vari {
1056  public:
1057  atan2_vd_vari(vari* avi, double b) :
1058  op_vd_vari(std::atan2(avi->val_,b),avi,b) {
1059  }
1060  void chain() {
1061  double a_sq_plus_b_sq = (avi_->val_ * avi_->val_) + (bd_ * bd_);
1062  avi_->adj_ += bd_ / a_sq_plus_b_sq;
1063  }
1064  };
1065 
1066  class atan2_dv_vari : public op_dv_vari {
1067  public:
1068  atan2_dv_vari(double a, vari* bvi) :
1069  op_dv_vari(std::atan2(a,bvi->val_),a,bvi) {
1070  }
1071  void chain() {
1072  double a_sq_plus_b_sq = (ad_ * ad_) + (bvi_->val_ * bvi_->val_);
1073  bvi_->adj_ -= ad_ / a_sq_plus_b_sq;
1074  }
1075  };
1076 
1077  class cosh_vari : public op_v_vari {
1078  public:
1079  cosh_vari(vari* avi) :
1080  op_v_vari(std::cosh(avi->val_),avi) {
1081  }
1082  void chain() {
1083  avi_->adj_ += adj_ * std::sinh(avi_->val_);
1084  }
1085  };
1086 
1087  class sinh_vari : public op_v_vari {
1088  public:
1089  sinh_vari(vari* avi) :
1090  op_v_vari(std::sinh(avi->val_),avi) {
1091  }
1092  void chain() {
1093  avi_->adj_ += adj_ * std::cosh(avi_->val_);
1094  }
1095  };
1096 
1097  class tanh_vari : public op_v_vari {
1098  public:
1099  tanh_vari(vari* avi) :
1100  op_v_vari(std::tanh(avi->val_),avi) {
1101  }
1102  void chain() {
1103  double cosh = std::cosh(avi_->val_);
1104  avi_->adj_ += adj_ / (cosh * cosh);
1105  }
1106  };
1107 
1108 
1109  class floor_vari : public vari {
1110  public:
1111  floor_vari(vari* avi) :
1112  vari(std::floor(avi->val_)) {
1113  }
1114  };
1115 
1116  class ceil_vari : public vari {
1117  public:
1118  ceil_vari(vari* avi) :
1119  vari(std::ceil(avi->val_)) {
1120  }
1121  };
1122 
1123  class fmod_vv_vari : public op_vv_vari {
1124  public:
1125  fmod_vv_vari(vari* avi, vari* bvi) :
1126  op_vv_vari(std::fmod(avi->val_,bvi->val_),avi,bvi) {
1127  }
1128  void chain() {
1129  avi_->adj_ += adj_;
1130  bvi_->adj_ -= adj_ * static_cast<int>(avi_->val_ / bvi_->val_);
1131  }
1132  };
1133 
1134  class fmod_vd_vari : public op_v_vari {
1135  public:
1136  fmod_vd_vari(vari* avi, double b) :
1137  op_v_vari(std::fmod(avi->val_,b),avi) {
1138  }
1139  void chain() {
1140  avi_->adj_ += adj_;
1141  }
1142  };
1143 
1144  class fmod_dv_vari : public op_dv_vari {
1145  public:
1146  fmod_dv_vari(double a, vari* bvi) :
1147  op_dv_vari(std::fmod(a,bvi->val_),a,bvi) {
1148  }
1149  void chain() {
1150  int d = static_cast<int>(ad_ / bvi_->val_);
1151  bvi_->adj_ -= adj_ * d;
1152  }
1153  };
1154  }
1155 
1156 
1157  // COMPARISON OPERATORS
1158 
1167  inline bool operator==(const var& a, const var& b) {
1168  return a.val() == b.val();
1169  }
1170 
1180  inline bool operator==(const var& a, const double b) {
1181  return a.val() == b;
1182  }
1183 
1192  inline bool operator==(const double a, const var& b) {
1193  return a == b.val();
1194  }
1195 
1204  inline bool operator!=(const var& a, const var& b) {
1205  return a.val() != b.val();
1206  }
1207 
1217  inline bool operator!=(const var& a, const double b) {
1218  return a.val() != b;
1219  }
1220 
1230  inline bool operator!=(const double a, const var& b) {
1231  return a != b.val();
1232  }
1233 
1241  inline bool operator<(const var& a, const var& b) {
1242  return a.val() < b.val();
1243  }
1244 
1253  inline bool operator<(const var& a, const double b) {
1254  return a.val() < b;
1255  }
1256 
1265  inline bool operator<(const double a, const var& b) {
1266  return a < b.val();
1267  }
1268 
1276  inline bool operator>(const var& a, const var& b) {
1277  return a.val() > b.val();
1278  }
1279 
1288  inline bool operator>(const var& a, const double b) {
1289  return a.val() > b;
1290  }
1291 
1300  inline bool operator>(const double a, const var& b) {
1301  return a > b.val();
1302  }
1303 
1313  inline bool operator<=(const var& a, const var& b) {
1314  return a.val() <= b.val();
1315  }
1316 
1326  inline bool operator<=(const var& a, const double b) {
1327  return a.val() <= b;
1328  }
1329 
1339  inline bool operator<=(const double a, const var& b) {
1340  return a <= b.val();
1341  }
1342 
1352  inline bool operator>=(const var& a, const var& b) {
1353  return a.val() >= b.val();
1354  }
1355 
1365  inline bool operator>=(const var& a, const double b) {
1366  return a.val() >= b;
1367  }
1368 
1378  inline bool operator>=(const double a, const var& b) {
1379  return a >= b.val();
1380  }
1381 
1382  // LOGICAL OPERATORS
1383 
1397  inline bool operator!(const var& a) {
1398  return !a.val();
1399  }
1400 
1401 
1402  // ARITHMETIC OPERATORS
1403 
1419  inline var operator+(const var& a) {
1420  return a;
1421  }
1422 
1431  inline var operator-(const var& a) {
1432  return var(new neg_vari(a.vi_));
1433  }
1434 
1448  inline var operator+(const var& a, const var& b) {
1449  return var(new add_vv_vari(a.vi_,b.vi_));
1450  }
1451 
1452 
1464  inline var operator+(const var& a, const double b) {
1465  if (b == 0.0)
1466  return a;
1467  return var(new add_vd_vari(a.vi_,b));
1468  }
1469 
1481  inline var operator+(const double a, const var& b) {
1482  if (a == 0.0)
1483  return b;
1484  return var(new add_vd_vari(b.vi_,a)); // by symmetry
1485  }
1486 
1501  inline var operator-(const var& a, const var& b) {
1502  return var(new subtract_vv_vari(a.vi_,b.vi_));
1503  }
1504 
1516  inline var operator-(const var& a, const double b) {
1517  if (b == 0.0)
1518  return a;
1519  return var(new subtract_vd_vari(a.vi_,b));
1520  }
1521 
1533  inline var operator-(const double a, const var& b) {
1534  return var(new subtract_dv_vari(a,b.vi_));
1535  }
1536 
1550  inline var operator*(const var& a, const var& b) {
1551  return var(new multiply_vv_vari(a.vi_,b.vi_));
1552  }
1553 
1565  inline var operator*(const var& a, const double b) {
1566  if (b == 1.0)
1567  return a;
1568  return var(new multiply_vd_vari(a.vi_,b));
1569  }
1570 
1582  inline var operator*(const double a, const var& b) {
1583  if (a == 1.0)
1584  return b;
1585  return var(new multiply_vd_vari(b.vi_,a)); // by symmetry
1586  }
1587 
1602  inline var operator/(const var& a, const var& b) {
1603  return var(new divide_vv_vari(a.vi_,b.vi_));
1604  }
1605 
1617  inline var operator/(const var& a, const double b) {
1618  if (b == 1.0)
1619  return a;
1620  return var(new divide_vd_vari(a.vi_,b));
1621  }
1622 
1634  inline var operator/(const double a, const var& b) {
1635  return var(new divide_dv_vari(a,b.vi_));
1636  }
1637 
1647  inline var& operator++(var& a) {
1648  a.vi_ = new increment_vari(a.vi_);
1649  return a;
1650  }
1651 
1665  inline var operator++(var& a, int dummy) {
1666  var temp(a);
1667  a.vi_ = new increment_vari(a.vi_);
1668  return temp;
1669  }
1670 
1684  inline var& operator--(var& a) {
1685  a.vi_ = new decrement_vari(a.vi_);
1686  return a;
1687  }
1688 
1702  inline var operator--(var& a, int dummy) {
1703  var temp(a);
1704  a.vi_ = new decrement_vari(a.vi_);
1705  return temp;
1706  }
1707 
1708  // CMATH EXP AND LOG
1709 
1716  inline var exp(const var& a) {
1717  return var(new exp_vari(a.vi_));
1718  }
1719 
1730  inline var log(const var& a) {
1731  return var(new log_vari(a.vi_));
1732  }
1733 
1744  inline var log10(const var& a) {
1745  return var(new log10_vari(a.vi_));
1746  }
1747 
1748 
1749  // POWER FUNCTIONS
1750 
1761  inline var sqrt(const var& a) {
1762  return var(new sqrt_vari(a.vi_));
1763  }
1764 
1778  inline var pow(const var& base, const var& exponent) {
1779  return var(new pow_vv_vari(base.vi_,exponent.vi_));
1780  }
1781 
1794  inline var pow(const var& base, const double exponent) {
1795  if (exponent == 0.5)
1796  return sqrt(base);
1797  if (exponent == 1.0)
1798  return base;
1799  if (exponent == 2.0)
1800  return base * base; // FIXME: square() functionality from special_functions
1801  return var(new pow_vd_vari(base.vi_,exponent));
1802  }
1803 
1816  inline var pow(const double base, const var& exponent) {
1817  return var(new pow_dv_vari(base,exponent.vi_));
1818  }
1819 
1820 
1821  // TRIG FUNCTIONS
1822 
1833  inline var cos(const var& a) {
1834  return var(new cos_vari(a.vi_));
1835  }
1836 
1847  inline var sin(const var& a) {
1848  return var(new sin_vari(a.vi_));
1849  }
1850 
1861  inline var tan(const var& a) {
1862  return var(new tan_vari(a.vi_));
1863  }
1864 
1876  inline var acos(const var& a) {
1877  return var(new acos_vari(a.vi_));
1878  }
1879 
1891  inline var asin(const var& a) {
1892  return var(new asin_vari(a.vi_));
1893  }
1894 
1906  inline var atan(const var& a) {
1907  return var(new atan_vari(a.vi_));
1908  }
1909 
1924  inline var atan2(const var& a, const var& b) {
1925  return var(new atan2_vv_vari(a.vi_,b.vi_));
1926  }
1927 
1940  inline var atan2(const var& a, const double b) {
1941  return var(new atan2_vd_vari(a.vi_,b));
1942  }
1943 
1956  inline var atan2(const double a, const var& b) {
1957  return var(new atan2_dv_vari(a,b.vi_));
1958  }
1959 
1960  // HYPERBOLIC FUNCTIONS
1961 
1972  inline var cosh(const var& a) {
1973  return var(new cosh_vari(a.vi_));
1974  }
1975 
1986  inline var sinh(const var& a) {
1987  return var(new sinh_vari(a.vi_));
1988  }
1989 
2000  inline var tanh(const var& a) {
2001  return var(new tanh_vari(a.vi_));
2002  }
2003 
2004 
2005  // ROUNDING FUNCTIONS
2006 
2023  inline var fabs(const var& a) {
2024  // cut-and-paste from abs()
2025  if (a.val() > 0.0)
2026  return a;
2027  if (a.val() < 0.0)
2028  return var(new neg_vari(a.vi_));
2029  // FIXME: is this right? breaks connection to a
2030  return var(new vari(0.0));
2031  }
2032 
2051  inline var floor(const var& a) {
2052  return var(new floor_vari(a.vi_));
2053  }
2054 
2073  inline var ceil(const var& a) {
2074  return var(new ceil_vari(a.vi_));
2075  }
2076 
2094  inline var fmod(const var& a, const var& b) {
2095  return var(new fmod_vv_vari(a.vi_,b.vi_));
2096  }
2097 
2111  inline var fmod(const var& a, const double b) {
2112  return var(new fmod_vd_vari(a.vi_,b));
2113  }
2114 
2128  inline var fmod(const double a, const var& b) {
2129  return var(new fmod_dv_vari(a,b.vi_));
2130  }
2131 
2132 
2133  // STD LIB FUNCTIONS
2134 
2149  inline var abs(const var& a) {
2150  // cut-and-paste from fabs()
2151  if (a.val() > 0.0)
2152  return a;
2153  if (a.val() < 0.0)
2154  return var(new neg_vari(a.vi_));
2155  // FIXME: same as fabs() -- is this right?
2156  return var(new vari(0.0));
2157  }
2158 
2159 
2163  static void recover_memory() {
2164  var_stack_.resize(0);
2166  }
2167 
2171  static void free_memory() {
2172  memalloc_.free_all();
2173  }
2174 
2187  static void grad(chainable* vi) {
2188  // old with subtle *2 bug
2189  // std::vector<chainable*>::iterator begin = var_stack_.begin();
2190  // std::vector<chainable*>::iterator it = var_stack_.end();
2191  // for (; (it >= begin) && (*it != vi); --it) ;
2192 
2193  std::vector<chainable*>::iterator begin = var_stack_.begin();
2194  std::vector<chainable*>::iterator it = var_stack_.end();
2195  if (begin == it) return; // nothing on stack
2196  for (--it; (it >= begin) && (*it != vi); --it) ;
2197 
2198  vi->init_dependent();
2199  // propagate derivates for remaining vars
2200  for (; it >= begin; --it)
2201  (*it)->chain();
2202  }
2203 
2207  static void set_zero_all_adjoints() {
2208  for (size_t i = 0; i < var_stack_.size(); ++i)
2209  var_stack_[i]->set_zero_adjoint();
2210  }
2211 
2238  inline void jacobian(std::vector<var>& dependents,
2239  std::vector<var>& independents,
2240  std::vector<std::vector<double> >& jacobian) {
2241  jacobian.resize(dependents.size());
2242  for (size_t i = 0; i < dependents.size(); ++i) {
2243  jacobian[i].resize(independents.size());
2244  if (i > 0)
2246  jacobian.push_back(std::vector<double>(0));
2247  grad(dependents[i].vi_);
2248  for (size_t j = 0; j < independents.size(); ++j)
2249  jacobian[i][j] = independents[j].adj();
2250  }
2251  }
2252 
2253 
2254  inline var& var::operator+=(const var& b) {
2255  vi_ = new add_vv_vari(vi_,b.vi_);
2256  return *this;
2257  }
2258 
2259  inline var& var::operator+=(const double b) {
2260  if (b == 0.0)
2261  return *this;
2262  vi_ = new add_vd_vari(vi_,b);
2263  return *this;
2264  }
2265 
2266  inline var& var::operator-=(const var& b) {
2267  vi_ = new subtract_vv_vari(vi_,b.vi_);
2268  return *this;
2269  }
2270 
2271  inline var& var::operator-=(const double b) {
2272  if (b == 0.0)
2273  return *this;
2274  vi_ = new subtract_vd_vari(vi_,b);
2275  return *this;
2276  }
2277 
2278  inline var& var::operator*=(const var& b) {
2279  vi_ = new multiply_vv_vari(vi_,b.vi_);
2280  return *this;
2281  }
2282 
2283  inline var& var::operator*=(const double b) {
2284  if (b == 1.0)
2285  return *this;
2286  vi_ = new multiply_vd_vari(vi_,b);
2287  return *this;
2288  }
2289 
2290  inline var& var::operator/=(const var& b) {
2291  vi_ = new divide_vv_vari(vi_,b.vi_);
2292  return *this;
2293  }
2294 
2295  inline var& var::operator/=(const double b) {
2296  if (b == 1.0)
2297  return *this;
2298  vi_ = new divide_vd_vari(vi_,b);
2299  return *this;
2300  }
2301  }
2302 
2303 }
2304 
2305 
2306 namespace std {
2307 
2314  template<>
2315  struct numeric_limits<stan::agrad::var> {
2316  static const bool is_specialized = true;
2319  static const int digits = numeric_limits<double>::digits;
2320  static const int digits10 = numeric_limits<double>::digits10;
2321  static const bool is_signed = numeric_limits<double>::is_signed;
2322  static const bool is_integer = numeric_limits<double>::is_integer;
2323  static const bool is_exact = numeric_limits<double>::is_exact;
2324  static const int radix = numeric_limits<double>::radix;
2326  static stan::agrad::var round_error() { return numeric_limits<double>::round_error(); }
2327 
2328  static const int min_exponent = numeric_limits<double>::min_exponent;
2329  static const int min_exponent10 = numeric_limits<double>::min_exponent10;
2330  static const int max_exponent = numeric_limits<double>::max_exponent;
2331  static const int max_exponent10 = numeric_limits<double>::max_exponent10;
2332 
2333  static const bool has_infinity = numeric_limits<double>::has_infinity;
2334  static const bool has_quiet_NaN = numeric_limits<double>::has_quiet_NaN;
2335  static const bool has_signaling_NaN = numeric_limits<double>::has_signaling_NaN;
2336  static const float_denorm_style has_denorm = numeric_limits<double>::has_denorm;
2337  static const bool has_denorm_loss = numeric_limits<double>::has_denorm_loss;
2338  static stan::agrad::var infinity() { return numeric_limits<double>::infinity(); }
2339  static stan::agrad::var quiet_NaN() { return numeric_limits<double>::quiet_NaN(); }
2340  static stan::agrad::var signaling_NaN() { return numeric_limits<double>::signaling_NaN(); }
2341  static stan::agrad::var denorm_min() { return numeric_limits<double>::denorm_min(); }
2342 
2343  static const bool is_iec559 = numeric_limits<double>::is_iec559;
2344  static const bool is_bounded = numeric_limits<double>::is_bounded;
2345  static const bool is_modulo = numeric_limits<double>::is_modulo;
2346 
2347  static const bool traps = numeric_limits<double>::traps;
2348  static const bool tinyness_before = numeric_limits<double>::tinyness_before;
2349  static const float_round_style round_style = numeric_limits<double>::round_style;
2350  };
2351 
2361  inline int isnan(const stan::agrad::var& a) {
2362  return isnan(a.val());
2363  }
2364 
2374  inline int isinf(const stan::agrad::var& a) {
2375  return isinf(a.val());
2376  }
2377 
2378 }
2379 
2380 #endif
double cd_
Definition: agrad.hpp:668
vari ** vis_
Definition: agrad.hpp:752
double bd_
Definition: agrad.hpp:629
double ad_
Definition: agrad.hpp:640
vari * avi_
Definition: agrad.hpp:606
vari * cvi_
Definition: agrad.hpp:654
const size_t size_
Definition: agrad.hpp:751
const double exp_val_
Definition: agrad.hpp:929
vari * bvi_
Definition: agrad.hpp:617
Abstract base class for variable implementations that handles memory management and applying the chai...
Definition: agrad.hpp:34
~chainable()
Throws a logic exception.
Definition: agrad.hpp:49
virtual void chain()
Apply the chain rule to this variable based on the variables on which it depends.
Definition: agrad.hpp:58
virtual void set_zero_adjoint()
Set the value of the adjoint for this chainable to its initial value.
Definition: agrad.hpp:72
chainable()
Construct a chainable object.
Definition: agrad.hpp:42
virtual void init_dependent()
Initialize this chainable's adjoint value to make it the dependent variable in a gradient calculation...
Definition: agrad.hpp:65
Independent (input) and dependent (output) variables for gradients.
Definition: agrad.hpp:214
var(unsigned long long n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:348
var(long long n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:358
double adj() const
Return the derivative of the root expression with respect to this expression.
Definition: agrad.hpp:408
void grad()
Compute gradients of this dependent variable with respect to all variables on which it depends.
Definition: agrad.hpp:449
var(unsigned int n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:318
var()
Construct a variable for later assignment.
Definition: agrad.hpp:258
friend std::ostream & operator<<(std::ostream &os, const var &v)
Write the value of this auto-dif variable and its adjoint to the specified output stream.
Definition: agrad.hpp:597
var(char c)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:278
var(float x)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:368
var(long int n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:328
var(double x)
Construct a variable with the specified value.
Definition: agrad.hpp:377
var & operator+=(const var &b)
The compound add/assignment operator for variables (C++).
Definition: agrad.hpp:2254
var(int n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:308
var & operator*=(const var &b)
The compound multiply/assignment operator for variables (C++).
Definition: agrad.hpp:2278
double Scalar
Definition: agrad.hpp:218
var & operator/=(const var &b)
The compound divide/assignment operator for variables (C++).
Definition: agrad.hpp:2290
var(long double x)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:387
bool is_uninitialized()
Return true if this variable has been declared, but not been defined.
Definition: agrad.hpp:238
vari & operator*()
Return a reference to underlying implementation of this variable.
Definition: agrad.hpp:468
var(unsigned short n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:298
var(unsigned long int n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:338
void grad(std::vector< var > &x, std::vector< double > &g)
Compute the gradient of this (dependent) variable with respect to the specified vector of (independen...
Definition: agrad.hpp:424
double val() const
Return the value of this variable.
Definition: agrad.hpp:396
var & operator-=(const var &b)
The compound subtract/assignment operator for variables (C++).
Definition: agrad.hpp:2266
var(short n)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:288
var(vari *vi)
Construct a variable from a pointer to a variable implementation.
Definition: agrad.hpp:247
var(bool b)
Construct a variable by static casting the specified value to double.
Definition: agrad.hpp:268
vari * vi_
Pointer to the implementation of this variable.
Definition: agrad.hpp:227
vari * operator->()
Return a pointer to the underlying implementation of this variable.
Definition: agrad.hpp:482
The variable implementation base class.
Definition: agrad.hpp:104
const double val_
The value of this variable.
Definition: agrad.hpp:113
virtual void set_zero_adjoint()
Set the adjoint value of this variable to 0.
Definition: agrad.hpp:162
double adj_
The adjoint of this variable, which is the partial derivative of this variable with respect to the ro...
Definition: agrad.hpp:119
~vari()
Throw an illegal argument exception.
Definition: agrad.hpp:146
friend std::ostream & operator<<(std::ostream &os, const vari *v)
Insertion operator for vari.
Definition: agrad.hpp:175
virtual void init_dependent()
Initialize the adjoint for this (dependent) variable to 1.
Definition: agrad.hpp:155
vari(const double x)
Construct a variable implementation from a value.
Definition: agrad.hpp:133
void free_all()
Free all memory used by the stack allocator other than the initial block allocation back to the syste...
void recover_all()
Recover all the memory used by the stack allocator.
void * alloc(size_t len)
Return a newly allocated block of memory of the appropriate size managed by the stack allocator.
std::vector< chainable * > var_stack_
Definition: agrad.cpp:10
var cosh(const var &a)
Return the hyperbolic cosine of the specified variable (cmath).
Definition: agrad.hpp:1972
var atan(const var &a)
Return the principal value of the arc tangent, in radians, of the specified variable (cmath).
Definition: agrad.hpp:1906
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
Definition: agrad.hpp:1761
var fabs(const var &a)
Return the absolute value of the variable (cmath).
Definition: agrad.hpp:2023
bool operator<=(const var &a, const var &b)
Less than or equal operator comparing two variables' values (C++).
Definition: agrad.hpp:1313
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
Definition: agrad.hpp:2187
var abs(const var &a)
Return the absolute value of the variable (std).
Definition: agrad.hpp:2149
bool operator>=(const var &a, const var &b)
Greater than or equal operator comparing two variables' values (C++).
Definition: agrad.hpp:1352
var operator-(const var &a)
Unary negation operator for variables (C++).
Definition: agrad.hpp:1431
static void recover_memory()
Recover memory used for all variables for reuse.
Definition: agrad.hpp:2163
var asin(const var &a)
Return the principal value of the arc sine, in radians, of the specified variable (cmath).
Definition: agrad.hpp:1891
var operator+(const var &a)
Unary plus operator for variables (C++).
Definition: agrad.hpp:1419
var acos(const var &a)
Return the principal value of the arc cosine of a variable, in radians (cmath).
Definition: agrad.hpp:1876
var ceil(const var &a)
Return the ceiling of the specified variable (cmath).
Definition: agrad.hpp:2073
void jacobian(std::vector< var > &dependents, std::vector< var > &independents, std::vector< std::vector< double > > &jacobian)
Return the Jacobian of the function producing the specified dependent variables with respect to the s...
Definition: agrad.hpp:2238
bool operator!=(const var &a, const var &b)
Inequality operator comparing two variables' values (C++).
Definition: agrad.hpp:1204
var log10(const var &a)
Return the base 10 log of the specified variable (cmath).
Definition: agrad.hpp:1744
var cos(const var &a)
Return the cosine of a radian-scaled variable (cmath).
Definition: agrad.hpp:1833
var operator/(const var &a, const var &b)
Division operator for two variables (C++).
Definition: agrad.hpp:1602
void print_stack(std::ostream &o)
Prints the auto-dif variable stack.
Definition: agrad.hpp:190
var sin(const var &a)
Return the sine of a radian-scaled variable (cmath).
Definition: agrad.hpp:1847
bool operator<(const var &a, const var &b)
Less than operator comparing variables' values (C++).
Definition: agrad.hpp:1241
bool operator==(const var &a, const var &b)
Equality operator comparing two variables' values (C++).
Definition: agrad.hpp:1167
static void free_memory()
Return all memory used for gradients back to the system.
Definition: agrad.hpp:2171
var tanh(const var &a)
Return the hyperbolic tangent of the specified variable (cmath).
Definition: agrad.hpp:2000
var floor(const var &a)
Return the floor of the specified variable (cmath).
Definition: agrad.hpp:2051
memory::stack_alloc memalloc_
Definition: agrad.cpp:11
static void set_zero_all_adjoints()
Reset all adjoint values in the stack to zero.
Definition: agrad.hpp:2207
var pow(const var &base, const var &exponent)
Return the base raised to the power of the exponent (cmath).
Definition: agrad.hpp:1778
var operator*(const var &a, const var &b)
Multiplication operator for two variables (C++).
Definition: agrad.hpp:1550
var & operator++(var &a)
Prefix increment operator for variables (C++).
Definition: agrad.hpp:1647
var log(const var &a)
Return the natural log of the specified variable (cmath).
Definition: agrad.hpp:1730
var fmod(const var &a, const var &b)
Return the floating point remainder after dividing the first variable by the second (cmath).
Definition: agrad.hpp:2094
bool operator>(const var &a, const var &b)
Greater than operator comparing variables' values (C++).
Definition: agrad.hpp:1276
var tan(const var &a)
Return the tangent of a radian-scaled variable (cmath).
Definition: agrad.hpp:1861
var atan2(const var &a, const var &b)
Return the principal value of the arc tangent, in radians, of the first variable divided by the secon...
Definition: agrad.hpp:1924
bool operator!(const var &a)
Prefix logical negation for the value of variables (C++).
Definition: agrad.hpp:1397
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
Definition: agrad.hpp:1716
var sinh(const var &a)
Return the hyperbolic sine of the specified variable (cmath).
Definition: agrad.hpp:1986
var & operator--(var &a)
Prefix decrement operator for variables (C++).
Definition: agrad.hpp:1684
double epsilon()
Return minimum positive number representable.
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
Definition: matrix.hpp:917
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
Definition: matrix.hpp:968
const double LOG_10
The natural logarithm of 10, .
Definition: constants.hpp:32
Probability, optimization and sampling library.
Definition: agrad.cpp:6
Template specification of functions in std for Stan.
Definition: agrad.hpp:2306
int isnan(const stan::agrad::var &a)
Checks if the given number is NaN.
Definition: agrad.hpp:2361
int isinf(const stan::agrad::var &a)
Checks if the given number is infinite.
Definition: agrad.hpp:2374
static stan::agrad::var max()
Definition: agrad.hpp:2318
static stan::agrad::var epsilon()
Definition: agrad.hpp:2325
static stan::agrad::var min()
Definition: agrad.hpp:2317
static stan::agrad::var signaling_NaN()
Definition: agrad.hpp:2340
static stan::agrad::var round_error()
Definition: agrad.hpp:2326
static stan::agrad::var quiet_NaN()
Definition: agrad.hpp:2339
static stan::agrad::var infinity()
Definition: agrad.hpp:2338
static stan::agrad::var denorm_min()
Definition: agrad.hpp:2341

     [ Stan Home Page ] © 2011–2012, Stan Development Team.