1 #ifndef __STAN__MCMC__NUTS_H__
2 #define __STAN__MCMC__NUTS_H__
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
31 template <
class BaseRNG = boost::mt19937>
36 const double _maxchange;
51 inline static bool compute_criterion(std::vector<double>& xplus,
52 std::vector<double>& xminus,
53 std::vector<double>& mplus,
54 std::vector<double>& mminus) {
55 std::vector<double> total_direction;
92 double epsilon_pm = 0.0,
93 bool epsilon_adapt =
true,
96 BaseRNG base_rng = BaseRNG(std::time(0)),
97 const std::vector<double>* params_r = 0,
98 const std::vector<int>* params_i = 0)
131 for (
size_t i = 0; i < mminus.size(); ++i)
133 std::vector<double> mplus(mminus);
138 std::vector<double> gradminus(this->
_g);
139 std::vector<double> gradplus(this->
_g);
140 std::vector<double> xminus(this->
_x);
141 std::vector<double> xplus(this->
_x);
147 bool criterion =
true;
150 std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
152 double prob_sum = -1;
154 int n_considered = 0;
163 double range = high - low;
168 while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
171 build_tree(xminus, mminus, gradminus, u, direction, depth,
172 H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
173 newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
176 build_tree(xplus, mplus, gradplus, u, direction, depth,
177 H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
178 newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
183 criterion = compute_criterion(xplus, xminus, mplus, mminus);
189 this->
_logp = newlogp;
198 double adapt_stat = prob_sum / float(n_considered);
200 double adapt_g = adapt_stat - this->
_delta;
201 std::vector<double> gvec(1, -adapt_g);
202 std::vector<double> result;
206 std::vector<double> result;
209 double avg_eta = 1.0 / this->
n_steps();
223 o << _lastdepth <<
',';
230 names.push_back(
"treedepth__");
232 names.push_back(
"stepsize__");
236 values.push_back(_lastdepth);
276 const std::vector<double>& m,
277 const std::vector<double>&
grad,
282 std::vector<double>& xminus,
283 std::vector<double>& mminus,
284 std::vector<double>& gradminus,
285 std::vector<double>& xplus,
286 std::vector<double>& mplus,
287 std::vector<double>& gradplus,
288 std::vector<double>& newx,
289 std::vector<double>& newgrad,
307 gradplus = gradminus;
310 newH = -std::numeric_limits<double>::infinity();
312 criterion = newH - u > _maxchange;
317 build_tree(x, m, grad, u, direction, depth-1, H0, xminus, mminus,
318 gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
319 nvalid, criterion, prob_sum, n_considered);
321 std::vector<double> dummy1, dummy2, dummy3;
322 std::vector<double> newx2;
323 std::vector<double> newgrad2;
330 build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
331 xminus, mminus, gradminus, dummy1, dummy2, dummy3,
332 newx2, newgrad2, newlogp2, nvalid2, criterion2,
333 prob_sum2, n_considered2);
335 build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
336 dummy1, dummy2, dummy3, xplus, mplus, gradplus,
337 newx2, newgrad2, newlogp2, nvalid2, criterion2,
338 prob_sum2, n_considered2);
341 <
float(nvalid2) /
float(nvalid+nvalid2))) {
346 n_considered += n_considered2;
347 prob_sum += prob_sum2;
348 criterion &= criterion2;
351 criterion &= compute_criterion(xplus, xminus, mplus, mminus);