Stan  1.0
probability, sampling & optimization
statement_grammar_def.hpp
Go to the documentation of this file.
1 #ifndef __STAN__GM__PARSER__STATEMENT_GRAMMAR_DEF__HPP__
2 #define __STAN__GM__PARSER__STATEMENT_GRAMMAR_DEF__HPP__
3 
4 #include <cstddef>
5 #include <iomanip>
6 #include <iostream>
7 #include <istream>
8 #include <map>
9 #include <set>
10 #include <sstream>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 #include <stdexcept>
15 
16 #include <boost/spirit/include/qi.hpp>
17 // FIXME: get rid of unused include
18 #include <boost/spirit/include/phoenix_core.hpp>
19 #include <boost/spirit/include/phoenix_function.hpp>
20 #include <boost/spirit/include/phoenix_fusion.hpp>
21 #include <boost/spirit/include/phoenix_object.hpp>
22 #include <boost/spirit/include/phoenix_operator.hpp>
23 #include <boost/spirit/include/phoenix_stl.hpp>
24 
25 #include <boost/lexical_cast.hpp>
26 #include <boost/fusion/include/adapt_struct.hpp>
27 #include <boost/fusion/include/std_pair.hpp>
28 #include <boost/config/warning_disable.hpp>
29 #include <boost/spirit/include/qi.hpp>
30 #include <boost/spirit/include/qi_numeric.hpp>
31 #include <boost/spirit/include/classic_position_iterator.hpp>
32 #include <boost/spirit/include/phoenix_core.hpp>
33 #include <boost/spirit/include/phoenix_function.hpp>
34 #include <boost/spirit/include/phoenix_fusion.hpp>
35 #include <boost/spirit/include/phoenix_object.hpp>
36 #include <boost/spirit/include/phoenix_operator.hpp>
37 #include <boost/spirit/include/phoenix_stl.hpp>
38 #include <boost/spirit/include/support_multi_pass.hpp>
39 #include <boost/tuple/tuple.hpp>
40 #include <boost/variant/apply_visitor.hpp>
41 #include <boost/variant/recursive_variant.hpp>
42 
43 #include <stan/gm/ast.hpp>
49 
51  (stan::gm::variable_dims, var_dims_)
52  (stan::gm::expression, expr_) )
53 
55  (std::string, name_)
56  (std::vector<stan::gm::expression>, dims_) )
57 
59  (std::string, family_)
60  (std::vector<stan::gm::expression>, args_) )
61 
63  (std::string, variable_)
64  (stan::gm::range, range_)
65  (stan::gm::statement, statement_) )
66 
68  (std::vector<stan::gm::printable>, printables_) )
69 
71  (stan::gm::expression, expr_)
72  (stan::gm::distribution, dist_)
73  (stan::gm::range, truncation_) )
74 
76  (std::vector<stan::gm::var_decl>, local_decl_)
77  (std::vector<stan::gm::statement>, statements_) )
78 
79 namespace stan {
80 
81  namespace gm {
82 
83  struct validate_assignment {
84  template <typename T1, typename T2, typename T3, typename T4>
85  struct result { typedef bool type; };
86 
87  bool operator()(assignment& a,
88  const var_origin& origin_allowed,
89  variable_map& vm,
90  std::stringstream& error_msgs) const {
91 
92  // validate existence
93  std::string name = a.var_dims_.name_;
94  if (!vm.exists(name)) {
95  error_msgs << "unknown variable in assignment"
96  << "; lhs variable=" << a.var_dims_.name_
97  << std::endl;
98  return false;
99  }
100 
101  // validate origin
102  var_origin lhs_origin = vm.get_origin(name);
103  if (lhs_origin != local_origin
104  && lhs_origin != origin_allowed) {
105  error_msgs << "attempt to assign variable in wrong block."
106  << " left-hand-side variable origin=";
107  print_var_origin(error_msgs,lhs_origin);
108  error_msgs << std::endl;
109  return false;
110  }
111 
112  // validate types
113  a.var_type_ = vm.get(name);
114  size_t lhs_var_num_dims = a.var_type_.dims_.size();
115  size_t num_index_dims = a.var_dims_.dims_.size();
116 
117  expr_type lhs_type = infer_type_indexing(a.var_type_.base_type_,
118  lhs_var_num_dims,
119  num_index_dims);
120 
121  if (lhs_type.is_ill_formed()) {
122  error_msgs << "too many indexes for variable "
123  << "; variable name = " << name
124  << "; num dimensions given = " << num_index_dims
125  << "; variable array dimensions = " << lhs_var_num_dims;
126  return false;
127  }
128  if (lhs_type.num_dims_ != a.expr_.expression_type().num_dims_) {
129  error_msgs << "mismatched dimensions on left- and right-hand side of assignment"
130  << "; left dims=" << lhs_type.num_dims_
131  << "; right dims=" << a.expr_.expression_type().num_dims_
132  << std::endl;
133  return false;
134  }
135 
136  base_expr_type lhs_base_type = lhs_type.base_type_;
137  base_expr_type rhs_base_type = a.expr_.expression_type().base_type_;
138  // int -> double promotion
139  bool types_compatible
140  = lhs_base_type == rhs_base_type
141  || ( lhs_base_type == DOUBLE_T && rhs_base_type == INT_T );
142  if (!types_compatible) {
143  error_msgs << "base type mismatch in assignment"
144  << "; left variable=" << a.var_dims_.name_
145  << "; left base type=";
146  write_base_expr_type(error_msgs,lhs_base_type);
147  error_msgs << "; right base type=";
148  write_base_expr_type(error_msgs,rhs_base_type);
149  error_msgs << std::endl;
150  return false;
151  }
152  return true;
153  }
154  };
155  boost::phoenix::function<validate_assignment> validate_assignment_f;
156 
157  struct validate_sample {
158  template <typename T1, typename T2>
159  struct result { typedef bool type; };
160 
161  bool is_double_return(const std::string& function_name,
162  const std::vector<expr_type>& arg_types,
163  std::ostream& error_msgs) const {
165  .get_result_type(function_name,arg_types,error_msgs)
167  }
168  bool operator()(const sample& s,
169  std::ostream& error_msgs) const {
170  std::vector<expr_type> arg_types;
171  arg_types.push_back(s.expr_.expression_type());
172  for (size_t i = 0; i < s.dist_.args_.size(); ++i)
173  arg_types.push_back(s.dist_.args_[i].expression_type());
174  std::string function_name(s.dist_.family_);
175  function_name += "_log";
176  // expr_type result_type
177  // = function_signatures::instance()
178  // .get_result_type(function_name,arg_types,error_msgs);
179  // if (!result_type.is_primitive_double()) {
180  if (!is_double_return(function_name,arg_types,error_msgs)) {
181  error_msgs << "unknown distribution=" << s.dist_.family_ << std::endl;
182  return false;
183  }
184  if (s.truncation_.has_low()) {
185  std::vector<expr_type> arg_types_trunc(arg_types);
186  arg_types_trunc[0] = s.truncation_.low_.expression_type();
187  std::string function_name_cdf(s.dist_.family_);
188  function_name_cdf += "_cdf";
189  if (!is_double_return(function_name_cdf,arg_types_trunc,error_msgs)) {
190  error_msgs << "lower truncation not defined for specified arguments to "
191  << s.dist_.family_ << std::endl;
192  return false;
193  }
194  if (!is_double_return(function_name_cdf,arg_types,error_msgs)) {
195  error_msgs << "lower bound in truncation type does not match"
196  << " sampled variate in distribution's type"
197  << std::endl;
198  return false;
199  }
200  }
201  if (s.truncation_.has_high()) {
202  std::vector<expr_type> arg_types_trunc(arg_types);
203  arg_types_trunc[0] = s.truncation_.high_.expression_type();
204  std::string function_name_cdf(s.dist_.family_);
205  function_name_cdf += "_cdf";
206  if (!is_double_return(function_name_cdf,arg_types_trunc,error_msgs)) {
207  error_msgs << "upper truncation not defined for specified arguments to "
208  << s.dist_.family_ << std::endl;
209  return false;
210  }
211  if (!is_double_return(function_name_cdf,arg_types,error_msgs)) {
212  error_msgs << "upper bound in truncation type does not match"
213  << " sampled variate in distribution's type"
214  << std::endl;
215  return false;
216  }
217  }
218  return true;
219 
220  }
221  };
222  boost::phoenix::function<validate_sample> validate_sample_f;
223 
224  struct unscope_locals {
225  template <typename T1, typename T2>
226  struct result { typedef void type; };
227  void operator()(const std::vector<var_decl>& var_decls,
228  variable_map& vm) const {
229  for (size_t i = 0; i < var_decls.size(); ++i)
230  vm.remove(var_decls[i].name());
231  }
232  };
233  boost::phoenix::function<unscope_locals> unscope_locals_f;
234 
235  // struct add_conditional_condition {
236  // template <typename T1, typename T2, typename T3>
237  // struct result { typedef bool type; };
238  // bool operator()(conditional_statement& cs,
239  // const expression& e,
240  // std::stringstream& error_msgs) const {
241  // if (!e.expression_type().is_primitive()) {
242  // error_msgs << "conditions in if-else statement must be primitive int or real;"
243  // << " found type=" << e.expression_type() << std::endl;
244  // return false;
245  // }
246  // cs.conditions_.push_back(e);
247  // return true;
248  // }
249  // };
250  // boost::phoenix::function<add_conditional_condition> add_conditional_condition_f;
251 
252  // struct add_conditional_body {
253  // template <typename T1, typename T2>
254  // struct result { typedef void type; };
255  // void operator()(conditional_statement& cs,
256  // const statement& s) const {
257  // cs.bodies_.push_back(s);
258  // }
259  // };
260  // boost::phoenix::function<add_conditional_body> add_conditional_body_f;
261 
262  struct add_while_condition {
263  template <typename T1, typename T2, typename T3>
264  struct result { typedef bool type; };
265  bool operator()(while_statement& ws,
266  const expression& e,
267  std::stringstream& error_msgs) const {
268  if (!e.expression_type().is_primitive()) {
269  error_msgs << "conditions in while statement must be primitive int or real;"
270  << " found type=" << e.expression_type() << std::endl;
271  return false;
272  }
273  ws.condition_ = e;
274  return true;
275  }
276  };
277  boost::phoenix::function<add_while_condition> add_while_condition_f;
278 
279  struct add_while_body {
280  template <typename T1, typename T2>
281  struct result { typedef void type; };
282  void operator()(while_statement& ws,
283  const statement& s) const {
284  ws.body_ = s;
285  }
286  };
287  boost::phoenix::function<add_while_body> add_while_body_f;
288 
289  struct add_loop_identifier {
290  template <typename T1, typename T2, typename T3, typename T4>
291  struct result { typedef bool type; };
292  bool operator()(const std::string& name,
293  std::string& name_local,
294  variable_map& vm,
295  std::stringstream& error_msgs) const {
296  name_local = name;
297  if (vm.exists(name)) {
298  error_msgs << "ERROR: loop variable already declared."
299  << " variable name=\"" << name << "\"" << std::endl;
300  return false; // variable exists
301  }
302  vm.add(name,
303  base_var_decl(name,std::vector<expression>(),
304  INT_T),
305  local_origin); // loop var acts like local
306  return true;
307  }
308  };
309  boost::phoenix::function<add_loop_identifier> add_loop_identifier_f;
310 
311  struct remove_loop_identifier {
312  template <typename T1, typename T2>
313  struct result { typedef void type; };
314  void operator()(const std::string& name,
315  variable_map& vm) const {
316  vm.remove(name);
317  }
318  };
319  boost::phoenix::function<remove_loop_identifier> remove_loop_identifier_f;
320 
321  struct validate_int_expr2 {
322  template <typename T1, typename T2>
323  struct result { typedef bool type; };
324 
325  bool operator()(const expression& expr,
326  std::stringstream& error_msgs) const {
327  if (!expr.expression_type().is_primitive_int()) {
328  error_msgs << "expression denoting integer required; found type="
329  << expr.expression_type() << std::endl;
330  return false;
331  }
332  return true;
333  }
334  };
335  boost::phoenix::function<validate_int_expr2> validate_int_expr2_f;
336 
337  struct validate_allow_sample {
338  template <typename T1, typename T2>
339  struct result { typedef bool type; };
340 
341  bool operator()(const bool& allow_sample,
342  std::stringstream& error_msgs) const {
343  if (!allow_sample) {
344  error_msgs << "ERROR: sampling only allowed in model."
345  << std::endl;
346  return false;
347  }
348  return true;
349  }
350  };
351  boost::phoenix::function<validate_allow_sample> validate_allow_sample_f;
352 
353 
354  template <typename Iterator>
356  std::stringstream& error_msgs)
357  : statement_grammar::base_type(statement_r),
358  var_map_(var_map),
359  error_msgs_(error_msgs),
360  expression_g(var_map,error_msgs),
361  var_decls_g(var_map,error_msgs),
362  statement_2_g(var_map,error_msgs,*this)
363  {
364  using boost::spirit::qi::_1;
365  using boost::spirit::qi::char_;
366  using boost::spirit::qi::eps;
367  using boost::spirit::qi::lexeme;
368  using boost::spirit::qi::lit;
369  using boost::spirit::qi::_pass;
370  using boost::spirit::qi::_val;
371 
372  using boost::spirit::qi::labels::_a;
373  using boost::spirit::qi::labels::_r1;
374  using boost::spirit::qi::labels::_r2;
375 
376  // _r1 true if sample_r allowed (inherited)
377  // _r2 source of variables allowed for assignments
378  // set to true if sample_r are allowed
379  statement_r.name("statement");
380  statement_r
381  %= statement_seq_r(_r1,_r2)
382  | for_statement_r(_r1,_r2)
383  | while_statement_r(_r1,_r2)
384  | statement_2_g(_r1,_r2)
385  | print_statement_r
386  | assignment_r
387  [_pass
388  = validate_assignment_f(_1,_r2,boost::phoenix::ref(var_map_),
389  boost::phoenix::ref(error_msgs_))]
390  | sample_r(_r1) [_pass = validate_sample_f(_1,
391  boost::phoenix::ref(error_msgs_))]
392  | no_op_statement_r
393  ;
394 
395  // _r1, _r2 same as statement_r
396  statement_seq_r.name("sequence of statements");
397  statement_seq_r
398  %= lit('{')
399  > local_var_decls_r[_a = _1]
400  > *statement_r(_r1,_r2)
401  > lit('}')
402  > eps[unscope_locals_f(_a,boost::phoenix::ref(var_map_))]
403  ;
404 
405  local_var_decls_r
406  %= var_decls_g(false,local_origin); // - constants
407 
408  while_statement_r.name("while statement");
409  while_statement_r
410  = lit("while")
411  > lit('(')
412  > expression_g
413  [_pass = add_while_condition_f(_val,_1,
414  boost::phoenix::ref(error_msgs_))]
415  > lit(')')
416  > statement_r(_r1,_r2)
417  [add_while_body_f(_val,_1)]
418  ;
419 
420  // conditional_statement_r.name("if-else statement");
421  // conditional_statement_r
422  // = lit("if")
423  // > lit('(')
424  // > expression_g
425  // [_pass = add_conditional_condition_f(_val,_1,
426  // boost::phoenix::ref(error_msgs_))]
427  // > lit(')')
428  // > statement_r(_r1,_r2)
429  // [add_conditional_body_f(_val,_1)]
430  // > * (lit("else")
431  // >> lit("if")
432  // > lit('(')
433  // > expression_g
434  // [_pass = add_conditional_condition_f(_val,_1,
435  // boost::phoenix::ref(error_msgs_))]
436  // > lit(')')
437  // > statement_r(_r1,_r2)
438  // [add_conditional_body_f(_val,_1)]
439  // )
440  // > - (lit("else")
441  // > statement_r(_r1,_r2)
442  // [add_conditional_body_f(_val,_1)]
443  // )
444  // ;
445 
446  // _r1, _r2 same as statement_r
447  for_statement_r.name("for statement");
448  for_statement_r
449  %= lit("for")
450  > lit('(')
451  > identifier_r [_pass
452  = add_loop_identifier_f(_1,_a,
453  boost::phoenix::ref(var_map_),
454  boost::phoenix::ref(error_msgs_))]
455  > lit("in")
456  > range_r
457  > lit(')')
458  > statement_r(_r1,_r2)
459  > eps
460  [remove_loop_identifier_f(_a,boost::phoenix::ref(var_map_))];
461  ;
462 
463  print_statement_r.name("print statement");
464  print_statement_r
465  %= lit("print")
466  > lit('(')
467  > (printable_r % ',')
468  // > (expression_g % ',')
469  > lit(')');
470 
471  printable_r.name("printable");
472  printable_r
473  %= printable_string_r
474  | expression_g;
475 
476  printable_string_r.name("printable quoted string");
477  printable_string_r
478  %= lit('"')
479  > lexeme[*char_("a-zA-Z0-9/~!@#$%^&*()_+`-={}|[]:;'<>?,./ ")]
480  > lit('"');
481 
482  identifier_r.name("identifier");
483  identifier_r
484  %= (lexeme[char_("a-zA-Z")
485  >> *char_("a-zA-Z0-9_.")]);
486 
487  range_r.name("range expression pair, colon");
488  range_r
489  %= expression_g
490  [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(error_msgs_))]
491  >> lit(':')
492  >> expression_g
493  [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(error_msgs_))];
494 
495  assignment_r.name("variable assignment by expression");
496  assignment_r
497  %= var_lhs_r
498  >> lit("<-")
499  > expression_g
500  > lit(';')
501  ;
502 
503  var_lhs_r.name("variable and array dimensions");
504  var_lhs_r
505  %= identifier_r
506  >> opt_dims_r;
507 
508  opt_dims_r.name("array dimensions (optional)");
509  opt_dims_r
510  %= - dims_r;
511 
512  dims_r.name("array dimensions");
513  dims_r
514  %= lit('[')
515  > (expression_g
516  [_pass = validate_int_expr2_f(_1,boost::phoenix::ref(error_msgs_))]
517  % ',')
518  > lit(']')
519  ;
520 
521  // inherited _r1 = true if samples allowed as statements
522  sample_r.name("distribution of expression");
523  sample_r
524  %= expression_g
525  >> lit('~')
526  > eps
527  [_pass
528  = validate_allow_sample_f(_r1,boost::phoenix::ref(error_msgs_))]
529  > distribution_r
530  > -truncation_range_r
531  > lit(';');
532 
533  distribution_r.name("distribution and parameters");
534  distribution_r
535  %= identifier_r
536  >> lit('(')
537  >> -(expression_g % ',')
538  > lit(')');
539 
540  truncation_range_r.name("range pair");
541  truncation_range_r
542  %= lit('T')
543  > lit('[')
544  > -expression_g
545  > lit(',')
546  > -expression_g
547  > lit(']');
548 
549  no_op_statement_r.name("no op statement");
550  no_op_statement_r
551  %= lit(';') [_val = no_op_statement()]; // ok to re-use instance
552 
553  }
554 
555  }
556 }
557 #endif
static function_signatures & instance()
Definition: ast_def.cpp:112
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
Definition: ast_def.cpp:218
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
Definition: ast_def.cpp:479
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
Definition: ast_def.cpp:23
const int INT_T
Definition: ast.hpp:54
int var_origin
Definition: ast.hpp:358
const int DOUBLE_T
Definition: ast.hpp:55
int base_expr_type
Definition: ast.hpp:50
const int local_origin
Definition: ast.hpp:364
void print_var_origin(std::ostream &o, const var_origin &vo)
Definition: ast_def.cpp:550
double e()
Return the base of the natural logarithm.
Probability, optimization and sampling library.
Definition: agrad.cpp:6
BOOST_FUSION_ADAPT_STRUCT(stan::gm::program,(std::vector< stan::gm::var_decl >, data_decl_)(DUMMY_STRUCT::type, derived_data_decl_)(std::vector< stan::gm::var_decl >, parameter_decl_)(DUMMY_STRUCT::type, derived_decl_)(stan::gm::statement, statement_)(DUMMY_STRUCT::type, generated_decl_)) namespace stan
bool is_primitive_double() const
Definition: ast_def.cpp:78
statement_grammar(variable_map &var_map, std::stringstream &error_msgs)

     [ Stan Home Page ] © 2011–2012, Stan Development Team.