Stan  1.0
probability, sampling & optimization
term_grammar_def.hpp
Go to the documentation of this file.
1 #ifndef __STAN__GM__PARSER__TERM_GRAMMAR_DEF__HPP__
2 #define __STAN__GM__PARSER__TERM_GRAMMAR_DEF__HPP__
3 
4 #include <cstddef>
5 #include <iomanip>
6 #include <iostream>
7 #include <istream>
8 #include <map>
9 #include <set>
10 #include <sstream>
11 #include <string>
12 #include <utility>
13 #include <vector>
14 #include <stdexcept>
15 
16 #include <boost/spirit/include/qi.hpp>
17 // FIXME: get rid of unused include
18 #include <boost/spirit/include/phoenix_core.hpp>
19 #include <boost/spirit/include/phoenix_function.hpp>
20 #include <boost/spirit/include/phoenix_fusion.hpp>
21 #include <boost/spirit/include/phoenix_object.hpp>
22 #include <boost/spirit/include/phoenix_operator.hpp>
23 #include <boost/spirit/include/phoenix_stl.hpp>
24 
25 #include <boost/lexical_cast.hpp>
26 #include <boost/fusion/include/adapt_struct.hpp>
27 #include <boost/fusion/include/std_pair.hpp>
28 #include <boost/config/warning_disable.hpp>
29 #include <boost/spirit/include/qi.hpp>
30 #include <boost/spirit/include/qi_numeric.hpp>
31 #include <boost/spirit/include/classic_position_iterator.hpp>
32 #include <boost/spirit/include/phoenix_core.hpp>
33 #include <boost/spirit/include/phoenix_function.hpp>
34 #include <boost/spirit/include/phoenix_fusion.hpp>
35 #include <boost/spirit/include/phoenix_object.hpp>
36 #include <boost/spirit/include/phoenix_operator.hpp>
37 #include <boost/spirit/include/phoenix_stl.hpp>
38 #include <boost/spirit/include/support_multi_pass.hpp>
39 #include <boost/tuple/tuple.hpp>
40 #include <boost/variant/apply_visitor.hpp>
41 #include <boost/variant/recursive_variant.hpp>
42 
43 #include <stan/gm/ast.hpp>
47 
49  (stan::gm::expression, expr_)
50  (std::vector<std::vector<stan::gm::expression> >,
51  dimss_) )
52 
54  (std::string, name_)
55  (std::vector<stan::gm::expression>, args_) )
56 
58  (int,val_)
59  (stan::gm::expr_type,type_))
60 
62  (double,val_)
63  (stan::gm::expr_type,type_) )
64 
65 
66 
67 
68 namespace stan {
69 
70  namespace gm {
71 
72 
73  struct set_fun_type {
74  template <typename T1, typename T2>
75  struct result { typedef fun type; };
76 
77  fun operator()(fun& fun,
78  std::ostream& error_msgs) const {
79  std::vector<expr_type> arg_types;
80  for (size_t i = 0; i < fun.args_.size(); ++i)
81  arg_types.push_back(fun.args_[i].expression_type());
82  fun.type_ = function_signatures::instance().get_result_type(fun.name_,
83  arg_types,
84  error_msgs);
85  return fun;
86  }
87  };
88  boost::phoenix::function<set_fun_type> set_fun_type_f;
89 
90 
91 
92  struct multiplication_expr {
93  template <typename T1, typename T2, typename T3>
94  struct result { typedef expression type; };
95 
96  expression operator()(expression& expr1,
97  const expression& expr2,
98  std::ostream& error_msgs) const {
99 
100  if (expr1.expression_type().is_primitive()
101  && expr2.expression_type().is_primitive()) {
102  return expr1 *= expr2;
103  }
104  std::vector<expression> args;
105  args.push_back(expr1);
106  args.push_back(expr2);
107  set_fun_type sft;
108  fun f("multiply",args);
109  sft(f,error_msgs);
110  return expression(f);
111  }
112  };
113  boost::phoenix::function<multiplication_expr> multiplication;
114 
115  void generate_expression(const expression& e, std::ostream& o);
116 
117  struct division_expr {
118  template <typename T1, typename T2, typename T3>
119  struct result { typedef expression type; };
120 
121  expression operator()(expression& expr1,
122  const expression& expr2,
123  std::ostream& error_msgs) const {
124  if (expr1.expression_type().is_primitive_int()
125  && expr2.expression_type().is_primitive_int()) {
126  // getting here, but not printing? only print error if problems?
127  error_msgs << "Warning: integer division implicitly rounds to integer."
128  << " Found int division: ";
129  generate_expression(expr1.expr_,error_msgs);
130  error_msgs << " / ";
131  generate_expression(expr2.expr_,error_msgs);
132  error_msgs << std::endl
133  << " Positive values rounded down, negative values rounded up or down"
134  << " in platform-dependent way."
135  << std::endl;
136  }
137 
138  if (expr1.expression_type().is_primitive()
139  && expr2.expression_type().is_primitive()) {
140  return expr1 /= expr2;
141  }
142  std::vector<expression> args;
143  args.push_back(expr1);
144  args.push_back(expr2);
145  set_fun_type sft;
146  if ((expr1.expression_type().type() == MATRIX_T
147  || expr1.expression_type().type() == ROW_VECTOR_T)
148  && expr2.expression_type().type() == MATRIX_T) {
149  fun f("mdivide_right",args);
150  sft(f,error_msgs);
151  return expression(f);
152  }
153 
154  fun f("divide",args);
155  sft(f,error_msgs);
156  return expression(f);
157  }
158  };
159  boost::phoenix::function<division_expr> division;
160 
161  struct left_division_expr {
162  template <typename T1, typename T2, typename T3>
163  struct result { typedef expression type; };
164 
165  expression operator()(expression& expr1,
166  const expression& expr2,
167  std::ostream& error_msgs) const {
168  if (expr1.expression_type().is_primitive()
169  && expr2.expression_type().is_primitive()) {
170  return expr1 /= expr2;
171  }
172  std::vector<expression> args;
173  args.push_back(expr1);
174  args.push_back(expr2);
175  set_fun_type sft;
176  if (expr1.expression_type().type() == MATRIX_T
177  && (expr2.expression_type().type() == VECTOR_T
178  || expr2.expression_type().type() == MATRIX_T)) {
179  fun f("mdivide_left",args);
180  sft(f,error_msgs);
181  return expression(f);
182  }
183  fun f("divide_left",args);
184  sft(f,error_msgs);
185  return expression(f);
186  }
187  };
188  boost::phoenix::function<left_division_expr> left_division;
189 
190  struct elt_multiplication_expr {
191  template <typename T1, typename T2, typename T3>
192  struct result { typedef expression type; };
193 
194  expression operator()(expression& expr1,
195  const expression& expr2,
196  std::ostream& error_msgs) const {
197 
198  if (expr1.expression_type().is_primitive()
199  && expr2.expression_type().is_primitive()) {
200  return expr1 *= expr2;
201  }
202  std::vector<expression> args;
203  args.push_back(expr1);
204  args.push_back(expr2);
205  set_fun_type sft;
206  fun f("elt_multiply",args);
207  sft(f,error_msgs);
208  return expression(f);
209  return expr1 += expr2;
210  }
211  };
212  boost::phoenix::function<elt_multiplication_expr> elt_multiplication;
213 
214  struct elt_division_expr {
215  template <typename T1, typename T2, typename T3>
216  struct result { typedef expression type; };
217 
218  expression operator()(expression& expr1,
219  const expression& expr2,
220  std::ostream& error_msgs) const {
221 
222  if (expr1.expression_type().is_primitive()
223  && expr2.expression_type().is_primitive()) {
224  return expr1 /= expr2;
225  }
226  std::vector<expression> args;
227  args.push_back(expr1);
228  args.push_back(expr2);
229  set_fun_type sft;
230  fun f("elt_divide",args);
231  sft(f,error_msgs);
232  return expression(f);
233  return expr1 += expr2;
234  }
235  };
236  boost::phoenix::function<elt_division_expr> elt_division;
237 
238  // Cut-and-Paste from Spirit examples, including comment: We
239  // should be using expression::operator-. There's a bug in phoenix
240  // type deduction mechanism that prevents us from doing
241  // so. Phoenix will be switching to BOOST_TYPEOF. In the meantime,
242  // we will use a phoenix::function below:
243  struct negate_expr {
244  template <typename T1, typename T2>
245  struct result { typedef expression type; };
246 
247  expression operator()(const expression& expr,
248  std::ostream& error_msgs) const {
249  if (expr.expression_type().is_primitive()) {
250  return expression(unary_op('-', expr));
251  }
252  std::vector<expression> args;
253  args.push_back(expr);
254  set_fun_type sft;
255  fun f("minus",args);
256  sft(f,error_msgs);
257  return expression(f);
258  }
259  };
260  boost::phoenix::function<negate_expr> negate_expr_f;
261 
262  struct logical_negate_expr {
263  template <typename T1, typename T2>
264  struct result { typedef expression type; };
265 
266  expression operator()(const expression& expr,
267  std::ostream& error_msgs) const {
268  if (!expr.expression_type().is_primitive()) {
269  error_msgs << "logical negation operator ! only applies to int or real types; ";
270  return expression();
271  }
272  std::vector<expression> args;
273  args.push_back(expr);
274  set_fun_type sft;
275  fun f("logical_negation",args);
276  sft(f,error_msgs);
277  return expression(f);
278  }
279  };
280  boost::phoenix::function<logical_negate_expr> logical_negate_expr_f;
281 
282  struct transpose_expr {
283  template <typename T1, typename T2>
284  struct result { typedef expression type; };
285 
286  expression operator()(const expression& expr,
287  std::ostream& error_msgs) const {
288  if (expr.expression_type().is_primitive()) {
289  return expr; // transpose of basic is self -- works?
290  }
291  std::vector<expression> args;
292  args.push_back(expr);
293  set_fun_type sft;
294  fun f("transpose",args);
295  sft(f,error_msgs);
296  return expression(f);
297  }
298  };
299  boost::phoenix::function<transpose_expr> transpose_f;
300 
301  struct add_expression_dimss {
302  template <typename T1, typename T2, typename T3, typename T4>
303  struct result { typedef T1 type; };
304  expression operator()(expression& expression,
305  std::vector<std::vector<stan::gm::expression> >& dimss,
306  bool& pass,
307  std::ostream& error_msgs) const {
308  index_op iop(expression,dimss);
309  iop.infer_type();
310  if (iop.type_.is_ill_formed()) {
311  error_msgs << "indexes inappropriate for expression." << std::endl;
312  pass = false;
313  } else {
314  pass = true;
315  }
316  return iop;
317  }
318  };
319  boost::phoenix::function<add_expression_dimss> add_expression_dimss_f;
320 
321  struct set_var_type {
322  template <typename T1, typename T2, typename T3, typename T4>
323  struct result { typedef variable type; };
324  variable operator()(variable& var_expr,
325  variable_map& vm,
326  std::ostream& error_msgs,
327  bool& pass) const {
328  std::string name = var_expr.name_;
329  if (!vm.exists(name)) {
330  pass = false;
331  error_msgs << "variable \"" << name << '"' << " does not exist."
332  << std::endl;
333  return var_expr;
334  }
335  pass = true;
336  var_expr.set_type(vm.get_base_type(name),vm.get_num_dims(name));
337  return var_expr;
338  }
339  };
340  boost::phoenix::function<set_var_type> set_var_type_f;
341 
342  struct validate_int_expr3 {
343  template <typename T1, typename T2>
344  struct result { typedef bool type; };
345 
346  bool operator()(const expression& expr,
347  std::stringstream& error_msgs) const {
348  if (!expr.expression_type().is_primitive_int()) {
349  error_msgs << "expression denoting integer required; found type="
350  << expr.expression_type() << std::endl;
351  return false;
352  }
353  return true;
354  }
355  };
356  boost::phoenix::function<validate_int_expr3> validate_int_expr3_f;
357 
358 
359  struct validate_expr_type {
360  template <typename T1, typename T2>
361  struct result { typedef bool type; };
362 
363  bool operator()(const expression& expr,
364  std::ostream& error_msgs) const {
365  if (expr.expression_type().is_ill_formed()) {
366  error_msgs << "expression is ill formed" << std::endl;
367  return false;
368  }
369  return true;
370  }
371  };
372  boost::phoenix::function<validate_expr_type> validate_expr_type_f;
373 
374 
375 
376 
377 
378  template <typename Iterator>
379  term_grammar<Iterator>::term_grammar(variable_map& var_map,
380  std::stringstream& error_msgs,
381  expression_grammar<Iterator>& eg)
382  : term_grammar::base_type(term_r),
383  var_map_(var_map),
384  error_msgs_(error_msgs),
385  expression_g(eg)
386  {
387  using boost::spirit::qi::_1;
388  using boost::spirit::qi::char_;
389  using boost::spirit::qi::double_;
390  using boost::spirit::qi::eps;
391  using boost::spirit::qi::int_;
392  using boost::spirit::qi::lexeme;
393  using boost::spirit::qi::lit;
394  using boost::spirit::qi::_pass;
395  using boost::spirit::qi::_val;
396 
397  term_r.name("term");
398  term_r
399  = ( negated_factor_r
400  [_val = _1]
401  >> *( (lit('*') > negated_factor_r
402  [_val = multiplication(_val,_1,
403  boost::phoenix::ref(error_msgs_))])
404  | (lit('/') > negated_factor_r
405  [_val = division(_val,_1,boost::phoenix::ref(error_msgs_))])
406  | (lit('\\') > negated_factor_r
407  [_val = left_division(_val,_1,
408  boost::phoenix::ref(error_msgs_))])
409  | (lit(".*") > negated_factor_r
410  [_val = elt_multiplication(_val,_1,
411  boost::phoenix::ref(error_msgs_))])
412  | (lit("./") > negated_factor_r
413  [_val = elt_division(_val,_1,
414  boost::phoenix::ref(error_msgs_))])
415  )
416  )
417  ;
418 
419 
420  negated_factor_r
421  = lit('-') >> negated_factor_r
422  [_val = negate_expr_f(_1,boost::phoenix::ref(error_msgs_))]
423  | lit('!') >> negated_factor_r
424  [_val = logical_negate_expr_f(_1,boost::phoenix::ref(error_msgs_))]
425  | lit('+') >> negated_factor_r [_val = _1]
426  | indexed_factor_r [_val = _1];
427 
428 
429  indexed_factor_r.name("(optionally) indexed factor [sub]");
430  indexed_factor_r
431  = factor_r [_val = _1]
432  > * (
433  (+dims_r)
434  [_val = add_expression_dimss_f(_val, _1, _pass,
435  boost::phoenix::ref(error_msgs_))]
436  |
437  lit("'")
438  [_val = transpose_f(_val, boost::phoenix::ref(error_msgs_))]
439  )
440  ;
441 
442 
443  factor_r.name("factor");
444  factor_r
445  = int_literal_r [_val = _1]
446  | double_literal_r [_val = _1]
447  | fun_r [_val = set_fun_type_f(_1,boost::phoenix::ref(error_msgs_))]
448  | variable_r
449  [_val = set_var_type_f(_1,boost::phoenix::ref(var_map_),
450  boost::phoenix::ref(error_msgs_),
451  _pass)]
452  | ( lit('(')
453  > expression_g [_val = _1]
454  > lit(')') )
455  ;
456 
457 
458  int_literal_r.name("integer literal");
459  int_literal_r
460  %= int_
461  >> !( lit('.')
462  | lit('e')
463  | lit('E') );
464 
465 
466  double_literal_r.name("real literal");
467  double_literal_r
468  %= double_;
469 
470 
471  fun_r.name("function and argument expressions");
472  fun_r
473  %= identifier_r // no test yet on valid naming
474  >> args_r;
475 
476 
477  identifier_r.name("identifier (expression grammar)");
478  identifier_r
479  %= lexeme[char_("a-zA-Z")
480  >> *char_("a-zA-Z0-9_.")];
481 
482 
483  args_r.name("function argument expressions");
484  args_r
485  %= (lit('(') >> lit(')'))
486  | ( lit('(')
487  >> (expression_g % ',')
488  > lit(')') )
489  ;
490 
491 
492  dims_r.name("array dimensions");
493  dims_r
494  %= lit('[')
495  > (expression_g
496  [_pass = validate_int_expr3_f(_1,boost::phoenix::ref(error_msgs_))]
497  % ',')
498  > lit(']')
499  ;
500 
501 
502  variable_r.name("variable expression");
503  variable_r
504  %= identifier_r;
505 
506  }
507  }
508 }
509 
510 #endif
static function_signatures & instance()
Definition: ast_def.cpp:112
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
Definition: ast_def.cpp:218
const int ROW_VECTOR_T
Definition: ast.hpp:57
void generate_expression(const expression &e, std::ostream &o)
Definition: generator.hpp:191
const int VECTOR_T
Definition: ast.hpp:56
const int MATRIX_T
Definition: ast.hpp:58
double e()
Return the base of the natural logarithm.
Probability, optimization and sampling library.
Definition: agrad.cpp:6
BOOST_FUSION_ADAPT_STRUCT(stan::gm::program,(std::vector< stan::gm::var_decl >, data_decl_)(DUMMY_STRUCT::type, derived_data_decl_)(std::vector< stan::gm::var_decl >, parameter_decl_)(DUMMY_STRUCT::type, derived_decl_)(stan::gm::statement, statement_)(DUMMY_STRUCT::type, generated_decl_)) namespace stan
term_grammar(variable_map &var_map, std::stringstream &error_msgs, expression_grammar< Iterator > &eg)

     [ Stan Home Page ] © 2011–2012, Stan Development Team.