1 #ifndef __STAN__MCMC__NUTS_DIAG_H__
2 #define __STAN__MCMC__NUTS_DIAG_H__
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
32 template <
class BaseRNG = boost::mt19937>
37 const double _maxchange;
46 std::vector<double> _step_sizes;
48 std::vector<double> _x_sum;
49 std::vector<double> _xsq_sum;
61 inline static bool compute_criterion(std::vector<double>& xplus,
62 std::vector<double>& xminus,
63 std::vector<double>& mplus,
64 std::vector<double>& mminus) {
65 std::vector<double> total_direction;
102 double epsilon_pm = 0.0,
103 bool epsilon_adapt =
true,
106 BaseRNG base_rng = BaseRNG(std::time(0)),
107 const std::vector<double>* params_r = 0,
108 const std::vector<int>* params_i = 0)
123 _step_sizes(model.num_params_r(), 1.0),
124 _x_sum(model.num_params_r(), 0),
125 _xsq_sum(model.num_params_r(), 0),
148 for (
size_t i = 0; i < mminus.size(); ++i)
150 std::vector<double> mplus(mminus);
155 std::vector<double> gradminus(this->
_g);
156 std::vector<double> gradplus(this->
_g);
157 std::vector<double> xminus(this->
_x);
158 std::vector<double> xplus(this->
_x);
164 bool criterion =
true;
167 std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
169 double prob_sum = -1;
171 int n_considered = 0;
180 double range = high - low;
185 while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
188 build_tree(xminus, mminus, gradminus, u, direction, depth,
189 H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
190 newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
193 build_tree(xplus, mplus, gradplus, u, direction, depth,
194 H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
195 newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
200 criterion = compute_criterion(xplus, xminus, mplus, mminus);
206 this->
_logp = newlogp;
215 double adapt_stat = prob_sum / float(n_considered);
218 double adapt_g = adapt_stat - this->
_delta;
219 std::vector<double> gvec(1, -adapt_g);
220 std::vector<double> result;
225 _next_diag_adapt *= 2;
226 double step_size_sq_sum = 0;
227 for (
size_t i = 0; i < _step_sizes.size(); i++) {
228 double Ex = _x_sum[i] / _x_sum_n;
229 double Exsq = _xsq_sum[i] / _x_sum_n;
232 _step_sizes[i] =
sqrt(Exsq - Ex*Ex);
233 step_size_sq_sum += _step_sizes[i] * _step_sizes[i];
235 if (step_size_sq_sum > 0.0) {
237 double normalizer =
sqrt((
double)_step_sizes.size())
238 /
sqrt(step_size_sq_sum);
239 for (
size_t i = 0; i < _step_sizes.size(); i++)
240 _step_sizes[i] *= normalizer;
242 for (
size_t i = 0; i < _step_sizes.size(); i++)
243 _step_sizes[i] = 1.0;
247 std::vector<double> result;
249 double avg_eta = 1.0 / this->
n_steps();
262 o << _lastdepth <<
',';
268 o <<
"# (mcmc::nuts_diag) adaptation finished" <<
'\n';
269 o <<
"# step size=" << this->
_epsilon <<
'\n';
270 o <<
"# parameter step size multipliers:\n";
272 for (
size_t k = 0; k < _step_sizes.size(); ++k) {
281 names.push_back(
"treedepth__");
283 names.push_back(
"stepsize__");
288 values.push_back(_lastdepth);
328 const std::vector<double>& m,
329 const std::vector<double>&
grad,
334 std::vector<double>& xminus,
335 std::vector<double>& mminus,
336 std::vector<double>& gradminus,
337 std::vector<double>& xplus,
338 std::vector<double>& mplus,
339 std::vector<double>& gradplus,
340 std::vector<double>& newx,
341 std::vector<double>& newgrad,
352 xminus, mminus, gradminus,
359 gradplus = gradminus;
362 newH = -std::numeric_limits<double>::infinity();
364 criterion = newH - u > _maxchange;
371 for (
size_t i = 0; i < newx.size(); i++) {
372 _x_sum[i] += newx[i];
373 _xsq_sum[i] += newx[i] * newx[i];
378 gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
379 nvalid, criterion, prob_sum, n_considered);
381 std::vector<double> dummy1, dummy2, dummy3;
382 std::vector<double> newx2;
383 std::vector<double> newgrad2;
390 build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
391 xminus, mminus, gradminus, dummy1, dummy2, dummy3,
392 newx2, newgrad2, newlogp2, nvalid2, criterion2,
393 prob_sum2, n_considered2);
395 build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
396 dummy1, dummy2, dummy3, xplus, mplus, gradplus,
397 newx2, newgrad2, newlogp2, nvalid2, criterion2,
398 prob_sum2, n_considered2);
401 <
float(nvalid2) /
float(nvalid+nvalid2))){
406 n_considered += n_considered2;
407 prob_sum += prob_sum2;
408 criterion &= criterion2;
411 criterion &= compute_criterion(xplus, xminus, mplus, mminus);
void update(const std::vector< double > &g, std::vector< double > &xk)
Produces the next iterate xk given the current gradient g.
void xbar(std::vector< double > &xbar)
Get the exponentially weighted moving average of all previous iterates.
std::ostream * _output_msgs
bool adapting()
Return whether or not parameter adaptation is on.
int n_steps()
Return the number of iterations for this sampler.
void update_mean_stat(double avg_eta, double adapt_stat)
Updates the mean statistic given the specified adaptation statistic and weighting.
std::ostream * _error_msgs
void nfevals_plus_eq(int n)
Add the specified number of evaluations to the number of function evaluations.
boost::uniform_01< boost::mt19937 & > _rand_uniform_01
stan::model::prob_grad & _model
void adaptation_init(double epsilon_scale)
boost::variate_generator< boost::mt19937 &, boost::normal_distribution<> > _rand_unit_norm
No-U-Turn Sampler (NUTS) with varying step sizes.
virtual void write_sampler_param_names(std::ostream &o)
Write out any sampler-specific parameter names for output.
void build_tree(const std::vector< double > &x, const std::vector< double > &m, const std::vector< double > &grad, double u, int direction, int depth, double H0, std::vector< double > &xminus, std::vector< double > &mminus, std::vector< double > &gradminus, std::vector< double > &xplus, std::vector< double > &mplus, std::vector< double > &gradplus, std::vector< double > &newx, std::vector< double > &newgrad, double &newlogp, int &nvalid, bool &criterion, double &prob_sum, int &n_considered)
The core recursion in NUTS.
virtual void get_sampler_param_names(std::vector< std::string > &names)
Get any sampler-specific parameter namess.
virtual void write_sampler_params(std::ostream &o)
Write out any sampler-specific parameters for output.
~nuts_diag()
Destroy this sampler.
virtual void get_sampler_params(std::vector< double > &values)
Get any sampler-specific parameters.
virtual void write_adaptation_params(std::ostream &o)
Use this method to write the adaptation parameters into the output.
virtual sample next_impl()
Return the next sample.
nuts_diag(stan::model::prob_grad &model, int maxdepth=10, double epsilon=-1, double epsilon_pm=0.0, bool epsilon_adapt=true, double delta=0.6, double gamma=0.05, BaseRNG base_rng=BaseRNG(std::time(0)), const std::vector< double > *params_r=0, const std::vector< int > *params_i=0)
Construct a No-U-Turn Sampler (NUTS) for the specified model, using the specified step size and numbe...
Representation of a MCMC sample.
The prob_grad class represents densities with fixed numbers of discrete and scalar parameters and the...
virtual size_t num_params_r()
var sqrt(const var &a)
Return the square root of the specified variable (cmath).
static void grad(chainable *vi)
Compute the gradient for all variables starting from the specified root variable implementation.
var log(const var &a)
Return the natural log of the specified variable (cmath).
var exp(const var &a)
Return the exponentiation of the specified variable (cmath).
double epsilon()
Return minimum positive number representable.
double e()
Return the base of the natural logarithm.
int min(const std::vector< int > &x)
Returns the minimum coefficient in the specified column vector.
double dot(std::vector< double > &x, std::vector< double > &y)
void sub(std::vector< double > &x, std::vector< double > &y, std::vector< double > &result)
double dot_self(const Eigen::Matrix< double, R, C > &v)
Returns the dot product of the specified vector with itself.
double rescaled_leapfrog(stan::model::prob_grad &model, std::vector< int > z, const std::vector< double > &step_sizes, std::vector< double > &x, std::vector< double > &m, std::vector< double > &g, double epsilon, std::ostream *error_msgs=0, std::ostream *output_msgs=0)
Probability, optimization and sampling library.