Stan  1.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Friends Macros Pages
agrad.hpp
Go to the documentation of this file.
1 #ifndef __STAN__AGRAD__AGRAD_HPP__
2 #define __STAN__AGRAD__AGRAD_HPP__
3 
4 #include <cmath>
5 #include <cstddef>
6 #include <limits>
7 #include <stdexcept>
8 #include <vector>
9 #include <ostream>
10 #include <iostream>
11 
13 
14 namespace stan {
15 
16  namespace agrad {
17 
18  class chainable;
19  class vari;
20  class var;
21 
22  // FIXME: manage all this as a single singleton (thread local)
23  extern std::vector<chainable*> var_stack_;
24  extern memory::stack_alloc memalloc_;
25 
26  static void recover_memory();
27 
28  static void grad(chainable* vi);
29 
34  class chainable {
35 
36  public:
37 
42  chainable() { }
43 
50  // handled automatically
51  }
52 
58  virtual void chain() {
59  }
60 
65  virtual void init_dependent() {
66  }
67 
72  virtual void set_zero_adjoint() {
73  }
74 
75 
85  static inline void* operator new(size_t nbytes) {
86  return memalloc_.alloc(nbytes);
87  }
88 
89  };
90 
104  class vari : public chainable {
105  private:
106  friend class var;
107 
108  public:
109 
113  const double val_;
114 
119  double adj_;
120 
133  vari(const double x):
134  val_(x),
135  adj_(0.0) {
136  var_stack_.push_back(this);
137  }
138 
146  ~vari() {
147  throw std::logic_error("vari destruction handled automatically");
148  }
149 
155  virtual void init_dependent() {
156  adj_ = 1.0; // droot/droot = 1
157  }
158 
162  virtual void set_zero_adjoint() {
163  adj_ = 0.0;
164  }
165 
175  friend std::ostream& operator<<(std::ostream& os, const vari* v) {
176  return os << v << " " << v->val_ << " : " << v->adj_;
177  }
178 
179  };
180 
190  inline void print_stack(std::ostream& o) {
191  o << "STACK, size=" << var_stack_.size() << std::endl;
192  for (size_t i = 0; i < var_stack_.size(); ++i)
193  o << i
194  << " " << var_stack_[i]
195  << " " << (static_cast<vari*>(var_stack_[i]))->val_
196  << " : " << (static_cast<vari*>(var_stack_[i]))->adj_
197  << std::endl;
198  }
199 
200 
214  class var {
215  public:
216 
217  // FIXME: doc what this is for
218  typedef double Scalar;
219 
228 
239  return (vi_ == static_cast<vari*>(0U));
240  }
241 
247  explicit var(vari* vi)
248  : vi_(vi)
249  { }
250 
258  var()
259  : vi_(static_cast<vari*>(0U))
260  { }
261 
268  var(bool b) :
269  vi_(new vari(static_cast<double>(b))) {
270  }
271 
278  var(char c) :
279  vi_(new vari(static_cast<double>(c))) {
280  }
281 
288  var(short n) :
289  vi_(new vari(static_cast<double>(n))) {
290  }
291 
298  var(unsigned short n) :
299  vi_(new vari(static_cast<double>(n))) {
300  }
301 
308  var(int n) :
309  vi_(new vari(static_cast<double>(n))) {
310  }
311 
318  var(unsigned int n) :
319  vi_(new vari(static_cast<double>(n))) {
320  }
321 
328  var(long int n) :
329  vi_(new vari(static_cast<double>(n))) {
330  }
331 
338  var(unsigned long int n) :
339  vi_(new vari(static_cast<double>(n))) {
340  }
341 
348  var(unsigned long long n) :
349  vi_(new vari(static_cast<double>(n))) {
350  }
351 
358  var(long long n) :
359  vi_(new vari(static_cast<double>(n))) {
360  }
361 
368  var(float x) :
369  vi_(new vari(static_cast<double>(x))) {
370  }
371 
377  var(double x) :
378  vi_(new vari(x)) {
379  }
380 
387  var(long double x) :
388  vi_(new vari(static_cast<double>(x))) {
389  }
390 
396  inline double val() const {
397  return vi_->val_;
398  }
399 
408  inline double adj() const {
409  return vi_->adj_;
410  }
411 
424  void grad(std::vector<var>& x,
425  std::vector<double>& g) {
427  g.resize(x.size());
428  for (size_t i = 0; i < x.size(); ++i)
429  g[i] = x[i].vi_->adj_;
430  recover_memory();
431  }
432 
449  void grad() {
451  recover_memory();
452  }
453 
454  // POINTER OVERRIDES
455 
468  inline vari& operator*() {
469  return *vi_;
470  }
471 
482  inline vari* operator->() {
483  return vi_;
484  }
485 
486  // COMPOUND ASSIGNMENT OPERATORS
487 
498  inline var& operator+=(const var& b);
499 
510  inline var& operator+=(const double b);
511 
523  inline var& operator-=(const var& b);
524 
536  inline var& operator-=(const double b);
537 
549  inline var& operator*=(const var& b);
550 
562  inline var& operator*=(const double b);
563 
574  inline var& operator/=(const var& b);
575 
587  inline var& operator/=(const double b);
588 
597  friend std::ostream& operator<<(std::ostream& os, const var& v) {
598  return os << v.val() << ':' << v.adj();
599  }
600  };
601 
602  namespace {
603 
604  class op_v_vari : public vari {
605  protected:
606  vari* avi_;
607  public:
608  op_v_vari(double f, vari* avi) :
609  vari(f),
610  avi_(avi) {
611  }
612  };
613 
614  class op_vv_vari : public vari {
615  protected:
616  vari* avi_;
617  vari* bvi_;
618  public:
619  op_vv_vari(double f, vari* avi, vari* bvi):
620  vari(f),
621  avi_(avi),
622  bvi_(bvi) {
623  }
624  };
625 
626  class op_vd_vari : public vari {
627  protected:
628  vari* avi_;
629  double bd_;
630  public:
631  op_vd_vari(double f, vari* avi, double b) :
632  vari(f),
633  avi_(avi),
634  bd_(b) {
635  }
636  };
637 
638  class op_dv_vari : public vari {
639  protected:
640  double ad_;
641  vari* bvi_;
642  public:
643  op_dv_vari(double f, double a, vari* bvi) :
644  vari(f),
645  ad_(a),
646  bvi_(bvi) {
647  }
648  };
649 
650  class op_vvv_vari : public vari {
651  protected:
652  vari* avi_;
653  vari* bvi_;
654  vari* cvi_;
655  public:
656  op_vvv_vari(double f, vari* avi, vari* bvi, vari* cvi) :
657  vari(f),
658  avi_(avi),
659  bvi_(bvi),
660  cvi_(cvi) {
661  }
662  };
663 
664  class op_vvd_vari : public vari {
665  protected:
666  vari* avi_;
667  vari* bvi_;
668  double cd_;
669  public:
670  op_vvd_vari(double f, vari* avi, vari* bvi, double c) :
671  vari(f),
672  avi_(avi),
673  bvi_(bvi),
674  cd_(c) {
675  }
676  };
677 
678  class op_vdv_vari : public vari {
679  protected:
680  vari* avi_;
681  double bd_;
682  vari* cvi_;
683  public:
684  op_vdv_vari(double f, vari* avi, double b, vari* cvi) :
685  vari(f),
686  avi_(avi),
687  bd_(b),
688  cvi_(cvi) {
689  }
690  };
691 
692  class op_vdd_vari : public vari {
693  protected:
694  vari* avi_;
695  double bd_;
696  double cd_;
697  public:
698  op_vdd_vari(double f, vari* avi, double b, double c) :
699  vari(f),
700  avi_(avi),
701  bd_(b),
702  cd_(c) {
703  }
704  };
705 
706  class op_dvv_vari : public vari {
707  protected:
708  double ad_;
709  vari* bvi_;
710  vari* cvi_;
711  public:
712  op_dvv_vari(double f, double a, vari* bvi, vari* cvi) :
713  vari(f),
714  ad_(a),
715  bvi_(bvi),
716  cvi_(cvi) {
717  }
718  };
719 
720  class op_dvd_vari : public vari {
721  protected:
722  double ad_;
723  vari* bvi_;
724  double cd_;
725  public:
726  op_dvd_vari(double f, double a, vari* bvi, double c) :
727  vari(f),
728  ad_(a),
729  bvi_(bvi),
730  cd_(c) {
731  }
732  };
733 
734  class op_ddv_vari : public vari {
735  protected:
736  double ad_;
737  double bd_;
738  vari* cvi_;
739  public:
740  op_ddv_vari(double f, double a, double b, vari* cvi) :
741  vari(f),
742  ad_(a),
743  bd_(b),
744  cvi_(cvi) {
745  }
746  };
747 
748  // FIXME: memory leak -- copy vector to local memory
749  class op_vector_vari : public vari {
750  protected:
751  const size_t size_;
752  vari** vis_;
753  public:
754  op_vector_vari(double f, const std::vector<stan::agrad::var>& vs) :
755  vari(f),
756  size_(vs.size()) {
757  vis_ = (vari**) operator new(sizeof(vari*[vs.size()]));
758  for (size_t i = 0; i < vs.size(); ++i)
759  vis_[i] = vs[i].vi_;
760  }
761  vari* operator[](size_t n) const {
762  return vis_[n];
763  }
764  size_t size() {
765  return size_;
766  }
767  };
768 
769  class neg_vari : public op_v_vari {
770  public:
771  neg_vari(vari* avi) :
772  op_v_vari(-(avi->val_), avi) {
773  }
774  void chain() {
775  avi_->adj_ -= adj_;
776  }
777  };
778 
779 
780  class add_vv_vari : public op_vv_vari {
781  public:
782  add_vv_vari(vari* avi, vari* bvi) :
783  op_vv_vari(avi->val_ + bvi->val_, avi, bvi) {
784  }
785  void chain() {
786  avi_->adj_ += adj_;
787  bvi_->adj_ += adj_;
788  }
789  };
790 
791  class add_vd_vari : public op_vd_vari {
792  public:
793  add_vd_vari(vari* avi, double b) :
794  op_vd_vari(avi->val_ + b, avi, b) {
795  }
796  void chain() {
797  avi_->adj_ += adj_;
798  }
799  };
800 
801  class increment_vari : public op_v_vari {
802  public:
803  increment_vari(vari* avi) :
804  op_v_vari(avi->val_ + 1.0, avi) {
805  }
806  void chain() {
807  avi_->adj_ += adj_;
808  }
809  };
810 
811  class decrement_vari : public op_v_vari {
812  public:
813  decrement_vari(vari* avi) :
814  op_v_vari(avi->val_ - 1.0, avi) {
815  }
816  void chain() {
817  avi_->adj_ += adj_;
818  }
819  };
820 
821  class subtract_vv_vari : public op_vv_vari {
822  public:
823  subtract_vv_vari(vari* avi, vari* bvi) :
824  op_vv_vari(avi->val_ - bvi->val_, avi, bvi) {
825  }
826  void chain() {
827  avi_->adj_ += adj_;
828  bvi_->adj_ -= adj_;
829  }
830  };
831 
832  class subtract_vd_vari : public op_vd_vari {
833  public:
834  subtract_vd_vari(vari* avi, double b) :
835  op_vd_vari(avi->val_ - b, avi, b) {
836  }
837  void chain() {
838  avi_->adj_ += adj_;
839  }
840  };
841 
842  class subtract_dv_vari : public op_dv_vari {
843  public:
844  subtract_dv_vari(double a, vari* bvi) :
845  op_dv_vari(a - bvi->val_, a, bvi) {
846  }
847  void chain() {
848  bvi_->adj_ -= adj_;
849  }
850  };
851 
852  class multiply_vv_vari : public op_vv_vari {
853  public:
854  multiply_vv_vari(vari* avi, vari* bvi) :
855  op_vv_vari(avi->val_ * bvi->val_, avi, bvi) {
856  }
857  void chain() {
858  avi_->adj_ += bvi_->val_ * adj_;
859  bvi_->adj_ += avi_->val_ * adj_;
860  }
861  };
862 
863  class multiply_vd_vari : public op_vd_vari {
864  public:
865  multiply_vd_vari(vari* avi, double b) :
866  op_vd_vari(avi->val_ * b, avi, b) {
867  }
868  void chain() {
869  avi_->adj_ += adj_ * bd_;
870  }
871  };
872 
873  // (a/b)' = a' * (1 / b) - b' * (a / [b * b])
874  class divide_vv_vari : public op_vv_vari {
875  public:
876  divide_vv_vari(vari* avi, vari* bvi) :
877  op_vv_vari(avi->val_ / bvi->val_, avi, bvi) {
878  }
879  void chain() {
880  avi_->adj_ += adj_ / bvi_->val_;
881  bvi_->adj_ -= adj_ * avi_->val_ / (bvi_->val_ * bvi_->val_);
882  }
883  };
884 
885  class divide_vd_vari : public op_vd_vari {
886  public:
887  divide_vd_vari(vari* avi, double b) :
888  op_vd_vari(avi->val_ / b, avi, b) {
889  }
890  void chain() {
891  avi_->adj_ += adj_ / bd_;
892  }
893  };
894 
895  class divide_dv_vari : public op_dv_vari {
896  public:
897  divide_dv_vari(double a, vari* bvi) :
898  op_dv_vari(a / bvi->val_, a, bvi) {
899  }
900  void chain() {
901  bvi_->adj_ -= adj_ * ad_ / (bvi_->val_ * bvi_->val_);
902  }
903  };
904 
905  class exp_vari : public op_v_vari {
906  public:
907  exp_vari(vari* avi) :
908  op_v_vari(std::exp(avi->val_),avi) {
909  }
910  void chain() {
911  avi_->adj_ += adj_ * val_;
912  }
913  };
914 
915  class log_vari : public op_v_vari {
916  public:
917  log_vari(vari* avi) :
918  op_v_vari(std::log(avi->val_),avi) {
919  }
920  void chain() {
921  avi_->adj_ += adj_ / avi_->val_;
922  }
923  };
924 
925  double LOG_10 = std::log(10.0);
926 
927  class log10_vari : public op_v_vari {
928  public:
929  const double exp_val_;
930  log10_vari(vari* avi) :
931  op_v_vari(std::log10(avi->val_),avi),
932  exp_val_(avi->val_) {
933  }
934  void chain() {
935  avi_->adj_ += adj_ / (LOG_10 * exp_val_);
936  }
937  };
938 
939  class sqrt_vari : public op_v_vari {
940  public:
941  sqrt_vari(vari* avi) :
942  op_v_vari(std::sqrt(avi->val_),avi) {
943  }
944  void chain() {
945  avi_->adj_ += adj_ / (2.0 * val_);
946  }
947  };
948 
949  class pow_vv_vari : public op_vv_vari {
950  public:
951  pow_vv_vari(vari* avi, vari* bvi) :
952  op_vv_vari(std::pow(avi->val_,bvi->val_),avi,bvi) {
953  }
954  void chain() {
955  if (avi_->val_ == 0.0) return; // partials zero, avoids 0 & log(0)
956  avi_->adj_ += adj_ * bvi_->val_ * val_ / avi_->val_;
957  bvi_->adj_ += adj_ * std::log(avi_->val_) * val_;
958  }
959  };
960 
961  class pow_vd_vari : public op_vd_vari {
962  public:
963  pow_vd_vari(vari* avi, double b) :
964  op_vd_vari(std::pow(avi->val_,b),avi,b) {
965  }
966  void chain() {
967  if (avi_->val_ == 0.0) return; // partials zero, avoids 0 & log(0)
968  avi_->adj_ += adj_ * bd_ * val_ / avi_->val_;
969  }
970  };
971 
972  class pow_dv_vari : public op_dv_vari {
973  public:
974  pow_dv_vari(double a, vari* bvi) :
975  op_dv_vari(std::pow(a,bvi->val_),a,bvi) {
976  }
977  void chain() {
978  if (ad_ == 0.0) return; // partials zero, avoids 0 & log(0)
979  bvi_->adj_ += adj_ * std::log(ad_) * val_;
980  }
981  };
982 
983  class cos_vari : public op_v_vari {
984  public:
985  cos_vari(vari* avi) :
986  op_v_vari(std::cos(avi->val_),avi) {
987  }
988  void chain() {
989  avi_->adj_ -= adj_ * std::sin(avi_->val_);
990  }
991  };
992 
993  class sin_vari : public op_v_vari {
994  public:
995  sin_vari(vari* avi) :
996  op_v_vari(std::sin(avi->val_),avi) {
997  }
998  void chain() {
999  avi_->adj_ += adj_ * std::cos(avi_->val_);
1000  }
1001  };
1002 
1003  class tan_vari : public op_v_vari {
1004  public:
1005  tan_vari(vari* avi) :
1006  op_v_vari(std::tan(avi->val_),avi) {
1007  }
1008  void chain() {
1009  avi_->adj_ += adj_ * (1.0 + val_ * val_);
1010  }
1011  };
1012 
1013  class acos_vari : public op_v_vari {
1014  public:
1015  acos_vari(vari* avi) :
1016  op_v_vari(std::acos(avi->val_),avi) {
1017  }
1018  void chain() {
1019  avi_->adj_ -= adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
1020  }
1021  };
1022 
1023  class asin_vari : public op_v_vari {
1024  public:
1025  asin_vari(vari* avi) :
1026  op_v_vari(std::asin(avi->val_),avi) {
1027  }
1028  void chain() {
1029  avi_->adj_ += adj_ / std::sqrt(1.0 - (avi_->val_ * avi_->val_));
1030  }
1031  };
1032 
1033  class atan_vari : public op_v_vari {
1034  public:
1035  atan_vari(vari* avi) :
1036  op_v_vari(std::atan(avi->val_),avi) {
1037  }
1038  void chain() {
1039  avi_->adj_ += adj_ / (1.0 + (avi_->val_ * avi_->val_));
1040  }
1041  };
1042 
1043  class atan2_vv_vari : public op_vv_vari {
1044  public:
1045  atan2_vv_vari(vari* avi, vari* bvi) :
1046  op_vv_vari(std::atan2(avi->val_,bvi->val_),avi,bvi) {
1047  }
1048  void chain() {
1049  double a_sq_plus_b_sq = (avi_->val_ * avi_->val_) + (bvi_->val_ * bvi_->val_);
1050  avi_->adj_ += bvi_->val_ / a_sq_plus_b_sq;
1051  bvi_->adj_ -= avi_->val_ / a_sq_plus_b_sq;
1052  }
1053  };
1054 
1055  class atan2_vd_vari : public op_vd_vari {
1056  public:
1057  atan2_vd_vari(vari* avi, double b) :
1058  op_vd_vari(std::atan2(avi->val_,b),avi,b) {
1059  }
1060  void chain() {
1061  double a_sq_plus_b_sq = (avi_->val_ * avi_->val_) + (bd_ * bd_);
1062  avi_->adj_ += bd_ / a_sq_plus_b_sq;
1063  }
1064  };
1065 
1066  class atan2_dv_vari : public op_dv_vari {
1067  public:
1068  atan2_dv_vari(double a, vari* bvi) :
1069  op_dv_vari(std::atan2(a,bvi->val_),a,bvi) {
1070  }
1071  void chain() {
1072  double a_sq_plus_b_sq = (ad_ * ad_) + (bvi_->val_ * bvi_->val_);
1073  bvi_->adj_ -= ad_ / a_sq_plus_b_sq;
1074  }
1075  };
1076 
1077  class cosh_vari : public op_v_vari {
1078  public:
1079  cosh_vari(vari* avi) :
1080  op_v_vari(std::cosh(avi->val_),avi) {
1081  }
1082  void chain() {
1083  avi_->adj_ += adj_ * std::sinh(avi_->val_);
1084  }
1085  };
1086 
1087  class sinh_vari : public op_v_vari {
1088  public:
1089  sinh_vari(vari* avi) :
1090  op_v_vari(std::sinh(avi->val_),avi) {
1091  }
1092  void chain() {
1093  avi_->adj_ += adj_ * std::cosh(avi_->val_);
1094  }
1095  };
1096 
1097  class tanh_vari : public op_v_vari {
1098  public:
1099  tanh_vari(vari* avi) :
1100  op_v_vari(std::tanh(avi->val_),avi) {
1101  }
1102  void chain() {
1103  double cosh = std::cosh(avi_->val_);
1104  avi_->adj_ += adj_ / (cosh * cosh);
1105  }
1106  };
1107 
1108 
1109  class floor_vari : public vari {
1110  public:
1111  floor_vari(vari* avi) :
1112  vari(std::floor(avi->val_)) {
1113  }
1114  };
1115 
1116  class ceil_vari : public vari {
1117  public:
1118  ceil_vari(vari* avi) :
1119  vari(std::ceil(avi->val_)) {
1120  }
1121  };
1122 
1123  class fmod_vv_vari : public op_vv_vari {
1124  public:
1125  fmod_vv_vari(vari* avi, vari* bvi) :
1126  op_vv_vari(std::fmod(avi->val_,bvi->val_),avi,bvi) {
1127  }
1128  void chain() {
1129  avi_->adj_ += adj_;
1130  bvi_->adj_ -= adj_ * static_cast<int>(avi_->val_ / bvi_->val_);
1131  }
1132  };
1133 
1134  class fmod_vd_vari : public op_v_vari {
1135  public:
1136  fmod_vd_vari(vari* avi, double b) :
1137  op_v_vari(std::fmod(avi->val_,b),avi) {
1138  }
1139  void chain() {
1140  avi_->adj_ += adj_;
1141  }
1142  };
1143 
1144  class fmod_dv_vari : public op_dv_vari {
1145  public:
1146  fmod_dv_vari(double a, vari* bvi) :
1147  op_dv_vari(std::fmod(a,bvi->val_),a,bvi) {
1148  }
1149  void chain() {
1150  int d = static_cast<int>(ad_ / bvi_->val_);
1151  bvi_->adj_ -= adj_ * d;
1152  }
1153  };
1154  }
1155 
1156 
1157  // COMPARISON OPERATORS
1158 
1167  inline bool operator==(const var& a, const var& b) {
1168  return a.val() == b.val();
1169  }
1170 
1180  inline bool operator==(const var& a, const double b) {
1181  return a.val() == b;
1182  }
1183 
1192  inline bool operator==(const double a, const var& b) {
1193  return a == b.val();
1194  }
1195 
1204  inline bool operator!=(const var& a, const var& b) {
1205  return a.val() != b.val();
1206  }
1207 
1217  inline bool operator!=(const var& a, const double b) {
1218  return a.val() != b;
1219  }
1220 
1230  inline bool operator!=(const double a, const var& b) {
1231  return a != b.val();
1232  }
1233 
1241  inline bool operator<(const var& a, const var& b) {
1242  return a.val() < b.val();
1243  }
1244 
1253  inline bool operator<(const var& a, const double b) {
1254  return a.val() < b;
1255  }
1256 
1265  inline bool operator<(const double a, const var& b) {
1266  return a < b.val();
1267  }
1268 
1276  inline bool operator>(const var& a, const var& b) {
1277  return a.val() > b.val();
1278  }
1279 
1288  inline bool operator>(const var& a, const double b) {
1289  return a.val() > b;
1290  }
1291 
1300  inline bool operator>(const double a, const var& b) {
1301  return a > b.val();
1302  }
1303 
1313  inline bool operator<=(const var& a, const var& b) {
1314  return a.val() <= b.val();
1315  }
1316 
1326  inline bool operator<=(const var& a, const double b) {
1327  return a.val() <= b;
1328  }
1329 
1339  inline bool operator<=(const double a, const var& b) {
1340  return a <= b.val();
1341  }
1342 
1352  inline bool operator>=(const var& a, const var& b) {
1353  return a.val() >= b.val();
1354  }
1355 
1365  inline bool operator>=(const var& a, const double b) {
1366  return a.val() >= b;
1367  }
1368 
1378  inline bool operator>=(const double a, const var& b) {
1379  return a >= b.val();
1380  }
1381 
1382  // LOGICAL OPERATORS
1383 
1397  inline bool operator!(const var& a) {
1398  return !a.val();
1399  }
1400 
1401 
1402  // ARITHMETIC OPERATORS
1403 
1419  inline var operator+(const var& a) {
1420  return a;
1421  }
1422 
1431  inline var operator-(const var& a) {
1432  return var(new neg_vari(a.vi_));
1433  }
1434 
1448  inline var operator+(const var& a, const var& b) {
1449  return var(new add_vv_vari(a.vi_,b.vi_));
1450  }
1451 
1452 
1464  inline var operator+(const var& a, const double b) {
1465  if (b == 0.0)
1466  return a;
1467  return var(new add_vd_vari(a.vi_,b));
1468  }
1469 
1481  inline var operator+(const double a, const var& b) {
1482  if (a == 0.0)
1483  return b;
1484  return var(new add_vd_vari(b.vi_,a)); // by symmetry
1485  }
1486 
1501  inline var operator-(const var& a, const var& b) {
1502  return var(new subtract_vv_vari(a.vi_,b.vi_));
1503  }
1504 
1516  inline var operator-(const var& a, const double b) {
1517  if (b == 0.0)
1518  return a;
1519  return var(new subtract_vd_vari(a.vi_,b));
1520  }
1521 
1533  inline var operator-(const double a, const var& b) {
1534  return var(new subtract_dv_vari(a,b.vi_));
1535  }
1536 
1550  inline var operator*(const var& a, const var& b) {
1551  return var(new multiply_vv_vari(a.vi_,b.vi_));
1552  }
1553 
1565  inline var operator*(const var& a, const double b) {
1566  if (b == 1.0)
1567  return a;
1568  return var(new multiply_vd_vari(a.vi_,b));
1569  }
1570 
1582  inline var operator*(const double a, const var& b) {
1583  if (a == 1.0)
1584  return b;
1585  return var(new multiply_vd_vari(b.vi_,a)); // by symmetry
1586  }
1587 
1602  inline var operator/(const var& a, const var& b) {
1603  return var(new divide_vv_vari(a.vi_,b.vi_));
1604  }
1605 
1617  inline var operator/(const var& a, const double b) {
1618  if (b == 1.0)
1619  return a;
1620  return var(new divide_vd_vari(a.vi_,b));
1621  }
1622 
1634  inline var operator/(const double a, const var& b) {
1635  return var(new divide_dv_vari(a,b.vi_));
1636  }
1637 
1647  inline var& operator++(var& a) {
1648  a.vi_ = new increment_vari(a.vi_);
1649  return a;
1650  }
1651 
1665  inline var operator++(var& a, int dummy) {
1666  var temp(a);
1667  a.vi_ = new increment_vari(a.vi_);
1668  return temp;
1669  }
1670 
1684  inline var& operator--(var& a) {
1685  a.vi_ = new decrement_vari(a.vi_);
1686  return a;
1687  }
1688 
1702  inline var operator--(var& a, int dummy) {
1703  var temp(a);
1704  a.vi_ = new decrement_vari(a.vi_);
1705  return temp;
1706  }
1707 
1708  // CMATH EXP AND LOG
1709 
1716  inline var exp(const var& a) {
1717  return var(new exp_vari(a.vi_));
1718  }
1719 
1730  inline var log(const var& a) {
1731  return var(new log_vari(a.vi_));
1732  }
1733 
1744  inline var log10(const var& a) {
1745  return var(new log10_vari(a.vi_));
1746  }
1747 
1748 
1749  // POWER FUNCTIONS
1750 
1761  inline var sqrt(const var& a) {
1762  return var(new sqrt_vari(a.vi_));
1763  }
1764 
1778  inline var pow(const var& base, const var& exponent) {
1779  return var(new pow_vv_vari(base.vi_,exponent.vi_));
1780  }
1781 
1794  inline var pow(const var& base, const double exponent) {
1795  if (exponent == 0.5)
1796  return sqrt(base);
1797  if (exponent == 1.0)
1798  return base;
1799  if (exponent == 2.0)
1800  return base * base; // FIXME: square() functionality from special_functions
1801  return var(new pow_vd_vari(base.vi_,exponent));
1802  }
1803 
1816  inline var pow(const double base, const var& exponent) {
1817  return var(new pow_dv_vari(base,exponent.vi_));
1818  }
1819 
1820 
1821  // TRIG FUNCTIONS
1822 
1833  inline var cos(const var& a) {
1834  return var(new cos_vari(a.vi_));
1835  }
1836 
1847  inline var sin(const var& a) {
1848  return var(new sin_vari(a.vi_));
1849  }
1850 
1861  inline var tan(const var& a) {
1862  return var(new tan_vari(a.vi_));
1863  }
1864 
1876  inline var acos(const var& a) {
1877  return var(new acos_vari(a.vi_));
1878  }
1879 
1891  inline var asin(const var& a) {
1892  return var(new asin_vari(a.vi_));
1893  }
1894 
1906  inline var atan(const var& a) {
1907  return var(new atan_vari(a.vi_));
1908  }
1909 
1924  inline var atan2(const var& a, const var& b) {
1925  return var(new atan2_vv_vari(a.vi_,b.vi_));
1926  }
1927 
1940  inline var atan2(const var& a, const double b) {
1941  return var(new atan2_vd_vari(a.vi_,b));
1942  }
1943 
1956  inline var atan2(const double a, const var& b) {
1957  return var(new atan2_dv_vari(a,b.vi_));
1958  }
1959 
1960  // HYPERBOLIC FUNCTIONS
1961 
1972  inline var cosh(const var& a) {
1973  return var(new cosh_vari(a.vi_));
1974  }
1975 
1986  inline var sinh(const var& a) {
1987  return var(new sinh_vari(a.vi_));
1988  }
1989 
2000  inline var tanh(const var& a) {
2001  return var(new tanh_vari(a.vi_));
2002  }
2003 
2004 
2005  // ROUNDING FUNCTIONS
2006 
2023  inline var fabs(const var& a) {
2024  // cut-and-paste from abs()
2025  if (a.val() > 0.0)
2026  return a;
2027  if (a.val() < 0.0)
2028  return var(new neg_vari(a.vi_));
2029  // FIXME: is this right? breaks connection to a
2030  return var(new vari(0.0));
2031  }
2032 
2051  inline var floor(const var& a) {
2052  return var(new floor_vari(a.vi_));
2053  }
2054 
2073  inline var ceil(const var& a) {
2074  return var(new ceil_vari(a.vi_));
2075  }
2076 
2094  inline var fmod(const var& a, const var& b) {
2095  return var(new fmod_vv_vari(a.vi_,b.vi_));
2096  }
2097 
2111  inline var fmod(const var& a, const double b) {
2112  return var(new fmod_vd_vari(a.vi_,b));
2113  }
2114 
2128  inline var fmod(const double a, const var& b) {
2129  return var(new fmod_dv_vari(a,b.vi_));
2130  }
2131 
2132 
2133  // STD LIB FUNCTIONS
2134 
2149  inline var abs(const var& a) {
2150  // cut-and-paste from fabs()
2151  if (a.val() > 0.0)
2152  return a;
2153  if (a.val() < 0.0)
2154  return var(new neg_vari(a.vi_));
2155  // FIXME: same as fabs() -- is this right?
2156  return var(new vari(0.0));
2157  }
2158 
2159 
2163  static void recover_memory() {
2164  var_stack_.resize(0);
2166  }
2167 
2171  static void free_memory() {
2172  memalloc_.free_all();
2173  }
2174 
2187  static void grad(chainable* vi) {
2188  // old with subtle *2 bug
2189  // std::vector<chainable*>::iterator begin = var_stack_.begin();
2190  // std::vector<chainable*>::iterator it = var_stack_.end();
2191  // for (; (it >= begin) && (*it != vi); --it) ;
2192 
2193  std::vector<chainable*>::iterator begin = var_stack_.begin();
2194  std::vector<chainable*>::iterator it = var_stack_.end();
2195  if (begin == it) return; // nothing on stack
2196  for (--it; (it >= begin) && (*it != vi); --it) ;
2197 
2198  vi->init_dependent();
2199  // propagate derivates for remaining vars
2200  for (; it >= begin; --it)
2201  (*it)->chain();
2202  }
2203 
2207  static void set_zero_all_adjoints() {
2208  for (size_t i = 0; i < var_stack_.size(); ++i)
2209  var_stack_[i]->set_zero_adjoint();
2210  }
2211 
2238  inline void jacobian(std::vector<var>& dependents,
2239  std::vector<var>& independents,
2240  std::vector<std::vector<double> >& jacobian) {
2241  jacobian.resize(dependents.size());
2242  for (size_t i = 0; i < dependents.size(); ++i) {
2243  jacobian[i].resize(independents.size());
2244  if (i > 0)
2246  jacobian.push_back(std::vector<double>(0));
2247  grad(dependents[i].vi_);
2248  for (size_t j = 0; j < independents.size(); ++j)
2249  jacobian[i][j] = independents[j].adj();
2250  }
2251  }
2252 
2253 
2254  inline var& var::operator+=(const var& b) {
2255  vi_ = new add_vv_vari(vi_,b.vi_);
2256  return *this;
2257  }
2258 
2259  inline var& var::operator+=(const double b) {
2260  if (b == 0.0)
2261  return *this;
2262  vi_ = new add_vd_vari(vi_,b);
2263  return *this;
2264  }
2265 
2266  inline var& var::operator-=(const var& b) {
2267  vi_ = new subtract_vv_vari(vi_,b.vi_);
2268  return *this;
2269  }
2270 
2271  inline var& var::operator-=(const double b) {
2272  if (b == 0.0)
2273  return *this;
2274  vi_ = new subtract_vd_vari(vi_,b);
2275  return *this;
2276  }
2277 
2278  inline var& var::operator*=(const var& b) {
2279  vi_ = new multiply_vv_vari(vi_,b.vi_);
2280  return *this;
2281  }
2282 
2283  inline var& var::operator*=(const double b) {
2284  if (b == 1.0)
2285  return *this;
2286  vi_ = new multiply_vd_vari(vi_,b);
2287  return *this;
2288  }
2289 
2290  inline var& var::operator/=(const var& b) {
2291  vi_ = new divide_vv_vari(vi_,b.vi_);
2292  return *this;
2293  }
2294 
2295  inline var& var::operator/=(const double b) {
2296  if (b == 1.0)
2297  return *this;
2298  vi_ = new divide_vd_vari(vi_,b);
2299  return *this;
2300  }
2301  }
2302 
2303 }
2304 
2305 
2306 namespace std {
2307 
2314  template<>
2315  struct numeric_limits<stan::agrad::var> {
2316  static const bool is_specialized = true;
2319  static const int digits = numeric_limits<double>::digits;
2320  static const int digits10 = numeric_limits<double>::digits10;
2321  static const bool is_signed = numeric_limits<double>::is_signed;
2322  static const bool is_integer = numeric_limits<double>::is_integer;
2323  static const bool is_exact = numeric_limits<double>::is_exact;
2324  static const int radix = numeric_limits<double>::radix;
2326  static stan::agrad::var round_error() { return numeric_limits<double>::round_error(); }
2327 
2328  static const int min_exponent = numeric_limits<double>::min_exponent;
2329  static const int min_exponent10 = numeric_limits<double>::min_exponent10;
2330  static const int max_exponent = numeric_limits<double>::max_exponent;
2331  static const int max_exponent10 = numeric_limits<double>::max_exponent10;
2332 
2333  static const bool has_infinity = numeric_limits<double>::has_infinity;
2334  static const bool has_quiet_NaN = numeric_limits<double>::has_quiet_NaN;
2335  static const bool has_signaling_NaN = numeric_limits<double>::has_signaling_NaN;
2336  static const float_denorm_style has_denorm = numeric_limits<double>::has_denorm;
2337  static const bool has_denorm_loss = numeric_limits<double>::has_denorm_loss;
2338  static stan::agrad::var infinity() { return numeric_limits<double>::infinity(); }
2339  static stan::agrad::var quiet_NaN() { return numeric_limits<double>::quiet_NaN(); }
2340  static stan::agrad::var signaling_NaN() { return numeric_limits<double>::signaling_NaN(); }
2341  static stan::agrad::var denorm_min() { return numeric_limits<double>::denorm_min(); }
2342 
2343  static const bool is_iec559 = numeric_limits<double>::is_iec559;
2344  static const bool is_bounded = numeric_limits<double>::is_bounded;
2345  static const bool is_modulo = numeric_limits<double>::is_modulo;
2346 
2347  static const bool traps = numeric_limits<double>::traps;
2348  static const bool tinyness_before = numeric_limits<double>::tinyness_before;
2349  static const float_round_style round_style = numeric_limits<double>::round_style;
2350  };
2351 
2361  inline int isnan(const stan::agrad::var& a) {
2362  return isnan(a.val());
2363  }
2364 
2374  inline int isinf(const stan::agrad::var& a) {
2375  return isinf(a.val());
2376  }
2377 
2378 }
2379 
2380 #endif

     [ Stan Home Page ] © 2011–2012, Stan Development Team.