1 #ifndef __STAN__MCMC__NUTS_DIAG_H__
2 #define __STAN__MCMC__NUTS_DIAG_H__
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
32 template <
class BaseRNG = boost::mt19937>
37 const double _maxchange;
46 std::vector<double> _step_sizes;
48 std::vector<double> _x_sum;
49 std::vector<double> _xsq_sum;
61 inline static bool compute_criterion(std::vector<double>& xplus,
62 std::vector<double>& xminus,
63 std::vector<double>& mplus,
64 std::vector<double>& mminus) {
65 std::vector<double> total_direction;
102 double epsilon_pm = 0.0,
103 bool epsilon_adapt =
true,
106 BaseRNG base_rng = BaseRNG(std::time(0)),
107 const std::vector<double>* params_r = 0,
108 const std::vector<int>* params_i = 0)
123 _step_sizes(model.num_params_r(), 1.0),
124 _x_sum(model.num_params_r(), 0),
125 _xsq_sum(model.num_params_r(), 0),
148 for (
size_t i = 0; i < mminus.size(); ++i)
150 std::vector<double> mplus(mminus);
155 std::vector<double> gradminus(this->
_g);
156 std::vector<double> gradplus(this->
_g);
157 std::vector<double> xminus(this->
_x);
158 std::vector<double> xplus(this->
_x);
164 bool criterion =
true;
167 std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
169 double prob_sum = -1;
171 int n_considered = 0;
180 double range = high - low;
185 while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
188 build_tree(xminus, mminus, gradminus, u, direction, depth,
189 H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
190 newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
193 build_tree(xplus, mplus, gradplus, u, direction, depth,
194 H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
195 newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
200 criterion = compute_criterion(xplus, xminus, mplus, mminus);
206 this->
_logp = newlogp;
215 double adapt_stat = prob_sum / float(n_considered);
218 double adapt_g = adapt_stat - this->
_delta;
219 std::vector<double> gvec(1, -adapt_g);
220 std::vector<double> result;
225 _next_diag_adapt *= 2;
226 double step_size_sq_sum = 0;
227 for (
size_t i = 0; i < _step_sizes.size(); i++) {
228 double Ex = _x_sum[i] / _x_sum_n;
229 double Exsq = _xsq_sum[i] / _x_sum_n;
232 _step_sizes[i] =
sqrt(Exsq - Ex*Ex);
233 step_size_sq_sum += _step_sizes[i] * _step_sizes[i];
235 if (step_size_sq_sum > 0.0) {
237 double normalizer =
sqrt((
double)_step_sizes.size())
238 /
sqrt(step_size_sq_sum);
239 for (
size_t i = 0; i < _step_sizes.size(); i++)
240 _step_sizes[i] *= normalizer;
242 for (
size_t i = 0; i < _step_sizes.size(); i++)
243 _step_sizes[i] = 1.0;
247 std::vector<double> result;
249 double avg_eta = 1.0 / this->
n_steps();
262 o << _lastdepth <<
',';
268 o <<
"# (mcmc::nuts_diag) adaptation finished" <<
'\n';
269 o <<
"# step size=" << this->
_epsilon <<
'\n';
270 o <<
"# parameter step size multipliers:\n";
272 for (
size_t k = 0; k < _step_sizes.size(); ++k) {
281 names.push_back(
"treedepth__");
283 names.push_back(
"stepsize__");
288 values.push_back(_lastdepth);
328 const std::vector<double>& m,
329 const std::vector<double>&
grad,
334 std::vector<double>& xminus,
335 std::vector<double>& mminus,
336 std::vector<double>& gradminus,
337 std::vector<double>& xplus,
338 std::vector<double>& mplus,
339 std::vector<double>& gradplus,
340 std::vector<double>& newx,
341 std::vector<double>& newgrad,
352 xminus, mminus, gradminus,
359 gradplus = gradminus;
362 newH = -std::numeric_limits<double>::infinity();
364 criterion = newH - u > _maxchange;
371 for (
size_t i = 0; i < newx.size(); i++) {
372 _x_sum[i] += newx[i];
373 _xsq_sum[i] += newx[i] * newx[i];
377 build_tree(x, m, grad, u, direction, depth-1, H0, xminus, mminus,
378 gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
379 nvalid, criterion, prob_sum, n_considered);
381 std::vector<double> dummy1, dummy2, dummy3;
382 std::vector<double> newx2;
383 std::vector<double> newgrad2;
390 build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
391 xminus, mminus, gradminus, dummy1, dummy2, dummy3,
392 newx2, newgrad2, newlogp2, nvalid2, criterion2,
393 prob_sum2, n_considered2);
395 build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
396 dummy1, dummy2, dummy3, xplus, mplus, gradplus,
397 newx2, newgrad2, newlogp2, nvalid2, criterion2,
398 prob_sum2, n_considered2);
401 <
float(nvalid2) /
float(nvalid+nvalid2))){
406 n_considered += n_considered2;
407 prob_sum += prob_sum2;
408 criterion &= criterion2;
411 criterion &= compute_criterion(xplus, xminus, mplus, mminus);