Stan  1.0
probability, sampling & optimization
 All Classes Namespaces Files Functions Variables Typedefs Enumerator Friends Macros Pages
nuts_diag.hpp
Go to the documentation of this file.
1 #ifndef __STAN__MCMC__NUTS_DIAG_H__
2 #define __STAN__MCMC__NUTS_DIAG_H__
3 
4 #include <ctime>
5 #include <cstddef>
6 #include <iostream>
7 #include <vector>
8 
9 #include <boost/random/normal_distribution.hpp>
10 #include <boost/random/mersenne_twister.hpp>
11 #include <boost/random/variate_generator.hpp>
12 #include <boost/random/uniform_01.hpp>
13 
14 #include <stan/math/util.hpp>
17 #include <stan/mcmc/hmc_base.hpp>
18 #include <stan/mcmc/util.hpp>
19 
20 namespace stan {
21 
22  namespace mcmc {
23 
32  template <class BaseRNG = boost::mt19937>
33  class nuts_diag : public hmc_base<BaseRNG> {
34  private:
35 
36  // Stop immediately if H < u - _maxchange
37  const double _maxchange;
38 
39  // Limit tree depth
40  const int _maxdepth;
41 
42  // Depth of last sample taken (-1 before any samples)
43  int _lastdepth;
44 
45  // Vector of per-parameter step sizes.
46  std::vector<double> _step_sizes;
47  // Running statistics to estimate per-coordinate std. deviations.
48  std::vector<double> _x_sum;
49  std::vector<double> _xsq_sum;
50  int _x_sum_n;
51  // Next time we should adapt the per-parameter step sizes.
52  int _next_diag_adapt;
53 
61  inline static bool compute_criterion(std::vector<double>& xplus,
62  std::vector<double>& xminus,
63  std::vector<double>& mplus,
64  std::vector<double>& mminus) {
65  std::vector<double> total_direction;
66  stan::math::sub(xplus, xminus, total_direction);
67  return stan::math::dot(total_direction, mminus) > 0
68  && stan::math::dot(total_direction, mplus) > 0;
69  }
70 
71  public:
72 
100  int maxdepth = 10,
101  double epsilon = -1,
102  double epsilon_pm = 0.0,
103  bool epsilon_adapt = true,
104  double delta = 0.6,
105  double gamma = 0.05,
106  BaseRNG base_rng = BaseRNG(std::time(0)),
107  const std::vector<double>* params_r = 0,
108  const std::vector<int>* params_i = 0)
109  : hmc_base<BaseRNG>(model,
110  epsilon,
111  epsilon_pm,
112  epsilon_adapt,
113  delta,
114  gamma,
115  base_rng,
116  params_r,
117  params_i),
118 
119  _maxchange(-1000),
120  _maxdepth(maxdepth),
121  _lastdepth(-1),
122 
123  _step_sizes(model.num_params_r(), 1.0),
124  _x_sum(model.num_params_r(), 0),
125  _xsq_sum(model.num_params_r(), 0),
126  _x_sum_n(0),
127  _next_diag_adapt(10)
128  {
129  // start at 10 * epsilon because NUTS cheaper for larger epsilon
130  this->adaptation_init(10.0);
131  }
132 
139 
145  virtual sample next_impl() {
146  // Initialize the algorithm
147  std::vector<double> mminus(this->_model.num_params_r());
148  for (size_t i = 0; i < mminus.size(); ++i)
149  mminus[i] = this->_rand_unit_norm();
150  std::vector<double> mplus(mminus);
151  // The log-joint probability of the momentum and position terms, i.e.
152  // -(kinetic energy + potential energy)
153  double H0 = -0.5 * stan::math::dot_self(mminus) + this->_logp;
154 
155  std::vector<double> gradminus(this->_g);
156  std::vector<double> gradplus(this->_g);
157  std::vector<double> xminus(this->_x);
158  std::vector<double> xplus(this->_x);
159 
160  // Sample the slice variable
161  double u = log(this->_rand_uniform_01()) + H0;
162  int nvalid = 1;
163  int direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
164  bool criterion = true;
165 
166  // Repeatedly double the set of points we've visited
167  std::vector<double> newx, newgrad, dummy1, dummy2, dummy3;
168  double newlogp = -1;
169  double prob_sum = -1;
170  int newnvalid = -1;
171  int n_considered = 0;
172  // for-loop with depth outside to set lastdepth
173  int depth = 0;
174 
175  double epsilon = this->_epsilon;
176  // only vary epsilon after done adapting
177  if (!this->adapting() && this->varying_epsilon()) {
178  double low = epsilon * (1.0 - this->_epsilon_pm);
179  double high = epsilon * (1.0 + this->_epsilon_pm);
180  double range = high - low;
181  epsilon = low + (range * this->_rand_uniform_01());
182  }
183  this->_epsilon_last = epsilon; // use epsilon_last in tree build
184 
185  while (criterion && (_maxdepth < 0 || depth <= _maxdepth)) {
186  direction = 2 * (this->_rand_uniform_01() > 0.5) - 1;
187  if (direction == -1)
188  build_tree(xminus, mminus, gradminus, u, direction, depth,
189  H0, xminus, mminus, gradminus, dummy1, dummy2, dummy3,
190  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
191  n_considered);
192  else
193  build_tree(xplus, mplus, gradplus, u, direction, depth,
194  H0, dummy1, dummy2, dummy3, xplus, mplus, gradplus,
195  newx, newgrad, newlogp, newnvalid, criterion, prob_sum,
196  n_considered);
197  // We can't look at the results of this last doubling if criterion==false
198  if (!criterion)
199  break;
200  criterion = compute_criterion(xplus, xminus, mplus, mminus);
201  // Metropolis-Hastings to determine if we can jump to a point in
202  // the new half-tree
203  if (this->_rand_uniform_01() < float(newnvalid) / (1e-100+float(nvalid))) {
204  this->_x = newx;
205  this->_g = newgrad;
206  this->_logp = newlogp;
207  }
208  nvalid += newnvalid;
209  ++depth;
210  }
211  _lastdepth = depth;
212 
213  // Now we just have to update global (epsilon) and local
214  // (step_sizes) step sizes, if adaptation is on.
215  double adapt_stat = prob_sum / float(n_considered);
216  if (this->adapting()) {
217  // epsilon.
218  double adapt_g = adapt_stat - this->_delta;
219  std::vector<double> gvec(1, -adapt_g);
220  std::vector<double> result;
221  this->_da.update(gvec, result);
222  this->_epsilon = exp(result[0]);
223  // step_sizes. Doesn't happen every step.
224  if (this->_n_adapt_steps == _next_diag_adapt) {
225  _next_diag_adapt *= 2;
226  double step_size_sq_sum = 0;
227  for (size_t i = 0; i < _step_sizes.size(); i++) {
228  double Ex = _x_sum[i] / _x_sum_n;
229  double Exsq = _xsq_sum[i] / _x_sum_n;
230  _x_sum[i] = 0;
231  _xsq_sum[i] = 0;
232  _step_sizes[i] = sqrt(Exsq - Ex*Ex);
233  step_size_sq_sum += _step_sizes[i] * _step_sizes[i];
234  }
235  if (step_size_sq_sum > 0.0) {
236  _x_sum_n = 0;
237  double normalizer = sqrt((double)_step_sizes.size())
238  / sqrt(step_size_sq_sum);
239  for (size_t i = 0; i < _step_sizes.size(); i++)
240  _step_sizes[i] *= normalizer;
241  } else {
242  for (size_t i = 0; i < _step_sizes.size(); i++)
243  _step_sizes[i] = 1.0;
244  }
245  }
246  }
247  std::vector<double> result;
248  this->_da.xbar(result);
249  double avg_eta = 1.0 / this->n_steps();
250  this->update_mean_stat(avg_eta,adapt_stat);
251 
252  return mcmc::sample(this->_x, this->_z, this->_logp);
253  }
254 
255  virtual void write_sampler_param_names(std::ostream& o) {
256  o << "treedepth__,";
257  if (this->_epsilon_adapt || this->varying_epsilon())
258  o << "stepsize__,";
259  }
260 
261  virtual void write_sampler_params(std::ostream& o) {
262  o << _lastdepth << ',';
263  if (this->_epsilon_adapt || this->varying_epsilon())
264  o << this->_epsilon_last << ',';
265  }
266 
267  virtual void write_adaptation_params(std::ostream& o) {
268  o << "# (mcmc::nuts_diag) adaptation finished" << '\n';
269  o << "# step size=" << this->_epsilon << '\n';
270  o << "# parameter step size multipliers:\n"; // FIXME: names/delineation requires access to model
271  o << "# ";
272  for (size_t k = 0; k < _step_sizes.size(); ++k) {
273  if (k > 0) o << ',';
274  o << _step_sizes[k];
275  }
276  o << '\n';
277  }
278 
279  virtual void get_sampler_param_names(std::vector<std::string>& names) {
280  names.clear();
281  names.push_back("treedepth__");
282  if (this->_epsilon_adapt || this->varying_epsilon())
283  names.push_back("stepsize__");
284  }
285 
286  virtual void get_sampler_params(std::vector<double>& values) {
287  values.clear();
288  values.push_back(_lastdepth);
289  if (this->_epsilon_adapt || this->varying_epsilon())
290  values.push_back(this->_epsilon_last);
291  }
292 
327  void build_tree(const std::vector<double>& x,
328  const std::vector<double>& m,
329  const std::vector<double>& grad,
330  double u,
331  int direction,
332  int depth,
333  double H0,
334  std::vector<double>& xminus,
335  std::vector<double>& mminus,
336  std::vector<double>& gradminus,
337  std::vector<double>& xplus,
338  std::vector<double>& mplus,
339  std::vector<double>& gradplus,
340  std::vector<double>& newx,
341  std::vector<double>& newgrad,
342  double& newlogp,
343  int& nvalid,
344  bool& criterion,
345  double& prob_sum,
346  int& n_considered) {
347  if (depth == 0) { // base case
348  xminus = x;
349  gradminus = grad;
350  mminus = m;
351  newlogp = rescaled_leapfrog(this->_model, this->_z, _step_sizes,
352  xminus, mminus, gradminus,
353  direction * this->_epsilon_last,
354  this->_error_msgs, this->_output_msgs);
355  newx = xminus;
356  newgrad = gradminus;
357  xplus = xminus;
358  mplus = mminus;
359  gradplus = gradminus;
360  double newH = newlogp - 0.5 * stan::math::dot_self(mminus);
361  if (newH != newH) // treat nan as -inf
362  newH = -std::numeric_limits<double>::infinity();
363  nvalid = newH > u;
364  criterion = newH - u > _maxchange;
365  prob_sum = stan::math::min(1, exp(newH - H0));
366  n_considered = 1;
367  this->nfevals_plus_eq(1);
368  // Update running statistics if point is in slice
369  if (nvalid) {
370  _x_sum_n++;
371  for (size_t i = 0; i < newx.size(); i++) {
372  _x_sum[i] += newx[i];
373  _xsq_sum[i] += newx[i] * newx[i];
374  }
375  }
376  } else { // depth >= 1
377  build_tree(x, m, grad, u, direction, depth-1, H0, xminus, mminus,
378  gradminus, xplus, mplus, gradplus, newx, newgrad, newlogp,
379  nvalid, criterion, prob_sum, n_considered);
380  if (criterion) {
381  std::vector<double> dummy1, dummy2, dummy3;
382  std::vector<double> newx2;
383  std::vector<double> newgrad2;
384  double newlogp2;
385  int nvalid2;
386  bool criterion2;
387  double prob_sum2;
388  int n_considered2;
389  if (direction == -1)
390  build_tree(xminus, mminus, gradminus, u, direction, depth-1, H0,
391  xminus, mminus, gradminus, dummy1, dummy2, dummy3,
392  newx2, newgrad2, newlogp2, nvalid2, criterion2,
393  prob_sum2, n_considered2);
394  else
395  build_tree(xplus, mplus, gradplus, u, direction, depth-1, H0,
396  dummy1, dummy2, dummy3, xplus, mplus, gradplus,
397  newx2, newgrad2, newlogp2, nvalid2, criterion2,
398  prob_sum2, n_considered2);
399  if (criterion &&
400  (this->_rand_uniform_01()
401  < float(nvalid2) / float(nvalid+nvalid2))){
402  newx = newx2;
403  newgrad = newgrad2;
404  newlogp = newlogp2;
405  }
406  n_considered += n_considered2;
407  prob_sum += prob_sum2;
408  criterion &= criterion2;
409  nvalid += nvalid2;
410  }
411  criterion &= compute_criterion(xplus, xminus, mplus, mminus);
412  }
413  }
414 
415 
416  };
417 
418 
419  }
420 
421 }
422 
423 #endif

     [ Stan Home Page ] © 2011–2012, Stan Development Team.