1 #ifndef __STAN__GM__AST_HPP__
2 #define __STAN__GM__AST_HPP__
7 #include <boost/variant/recursive_variant.hpp>
92 void add(
const std::string& name,
94 const std::vector<expr_type>& arg_types);
95 void add(
const std::string& name,
97 void add(
const std::string& name,
100 void add(
const std::string& name,
104 void add(
const std::string& name,
109 void add(
const std::string& name,
115 void add(
const std::string& name,
123 void add_unary(const::std::string& name);
128 const std::vector<expr_type>& sig_args);
130 const std::vector<expr_type>& args,
131 std::ostream& error_msgs);
135 std::map<std::string, std::vector<function_signature_t> > sigs_map_;
143 statements(
const std::vector<var_decl>& local_decl,
144 const std::vector<statement>& stmts);
172 typedef boost::variant<boost::recursive_wrapper<nil>,
173 boost::recursive_wrapper<int_literal>,
174 boost::recursive_wrapper<double_literal>,
175 boost::recursive_wrapper<array_literal>,
176 boost::recursive_wrapper<variable>,
177 boost::recursive_wrapper<fun>,
178 boost::recursive_wrapper<index_op>,
179 boost::recursive_wrapper<binary_op>,
180 boost::recursive_wrapper<unary_op> >
223 typedef boost::variant<boost::recursive_wrapper<std::string>,
224 boost::recursive_wrapper<expression> >
258 std::vector<expression>
const& dims);
302 fun(std::string
const& name,
303 std::vector<expression>
const& args);
307 size_t total_dims(
const std::vector<std::vector<expression> >& dimss);
310 size_t num_expr_dims,
311 size_t num_index_dims);
314 size_t num_index_dims);
319 std::vector<std::vector<expression> >
dimss_;
324 const std::vector<std::vector<expression> >& dimss);
336 const std::string&
op,
377 const std::vector<expression>& dims,
382 typedef std::pair<base_var_decl,var_origin>
range_t;
383 std::map<std::string, range_t>
map_;
384 bool exists(
const std::string& name)
const;
389 void add(
const std::string& name,
392 void remove(
const std::string& name);
399 std::string
const& name,
400 std::vector<expression>
const& dims);
408 std::string
const& name,
409 std::vector<expression>
const& dims);
416 std::string
const& name,
417 std::vector<expression>
const& dims);
424 std::string
const& name,
425 std::vector<expression>
const& dims);
432 std::string
const& name,
433 std::vector<expression>
const& dims);
442 std::string
const& name,
443 std::vector<expression>
const& dims);
452 std::string
const& name,
453 std::vector<expression>
const& dims);
464 std::string
const& name,
465 std::vector<expression>
const& dims);
476 std::string
const& name,
477 std::vector<expression>
const& dims);
486 std::string
const& name,
487 std::vector<expression>
const& dims);
492 struct name_vis :
public boost::static_visitor<std::string> {
511 typedef boost::variant<boost::recursive_wrapper<nil>,
512 boost::recursive_wrapper<int_var_decl>,
513 boost::recursive_wrapper<double_var_decl>,
514 boost::recursive_wrapper<vector_var_decl>,
515 boost::recursive_wrapper<row_vector_var_decl>,
516 boost::recursive_wrapper<matrix_var_decl>,
517 boost::recursive_wrapper<simplex_var_decl>,
518 boost::recursive_wrapper<ordered_var_decl>,
519 boost::recursive_wrapper<positive_ordered_var_decl>,
520 boost::recursive_wrapper<cov_matrix_var_decl>,
521 boost::recursive_wrapper<corr_matrix_var_decl> >
543 std::string
name()
const;
547 typedef boost::variant<boost::recursive_wrapper<nil>,
548 boost::recursive_wrapper<assignment>,
549 boost::recursive_wrapper<sample>,
550 boost::recursive_wrapper<statements>,
551 boost::recursive_wrapper<for_statement>,
552 boost::recursive_wrapper<conditional_statement>,
553 boost::recursive_wrapper<while_statement>,
554 boost::recursive_wrapper<print_statement>,
555 boost::recursive_wrapper<no_op_statement> >
619 std::pair<std::vector<var_decl>,std::vector<statement> >
622 std::pair<std::vector<var_decl>,std::vector<statement> >
627 program(
const std::vector<var_decl>& data_decl,
628 const std::pair<std::vector<var_decl>,
629 std::vector<statement> >& derived_data_decl,
630 const std::vector<var_decl>& parameter_decl,
631 const std::pair<std::vector<var_decl>,
632 std::vector<statement> >& derived_decl,
634 const std::pair<std::vector<var_decl>,
635 std::vector<statement> >& generated_decl);
void add_ternary(const ::std::string &name)
void add(const std::string &name, const expr_type &result_type, const std::vector< expr_type > &arg_types)
void add_binary(const ::std::string &name)
void add_nullary(const ::std::string &name)
void add_quaternary(const ::std::string &name)
static function_signatures & instance()
int num_promotions(const std::vector< expr_type > &call_args, const std::vector< expr_type > &sig_args)
void add_unary(const ::std::string &name)
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
size_t total_dims(const std::vector< std::vector< expression > > &dimss)
const int parameter_origin
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
expr_type promote_primitive(const expr_type &et)
void generate_expression(const expression &e, std::ostream &o)
const int transformed_data_origin
bool is_nil(const expression &e)
std::pair< expr_type, std::vector< expr_type > > function_signature_t
const int transformed_parameter_origin
std::ostream & operator<<(std::ostream &o, const expr_type &et)
void print_var_origin(std::ostream &o, const var_origin &vo)
double e()
Return the base of the natural logarithm.
double dist(const std::vector< double > &x, const std::vector< double > &y)
Probability, optimization and sampling library.
std::vector< expression > args_
array_literal & operator=(const array_literal &al)
std::vector< expression > dims_
base_expr_type base_type_
std::vector< expression > conditions_
std::vector< statement > bodies_
std::vector< expression > args_
double_literal & operator=(const double_literal &dl)
bool is_primitive() const
bool operator==(const expr_type &et) const
bool operator!=(const expr_type &et) const
bool is_primitive_int() const
base_expr_type type() const
bool is_primitive_double() const
bool is_ill_formed() const
base_expr_type base_type_
expr_type operator()(const nil &e) const
expression & operator*=(const expression &rhs)
expr_type expression_type() const
expression & operator+=(const expression &rhs)
expression & operator/=(const expression &rhs)
expression & operator-=(const expression &rhs)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_literal >, boost::recursive_wrapper< double_literal >, boost::recursive_wrapper< array_literal >, boost::recursive_wrapper< variable >, boost::recursive_wrapper< fun >, boost::recursive_wrapper< index_op >, boost::recursive_wrapper< binary_op >, boost::recursive_wrapper< unary_op > > expression_t
std::vector< expression > args_
std::vector< std::vector< expression > > dimss_
int_literal & operator=(const int_literal &il)
bool operator()(const nil &x) const
std::string operator()(const nil &x) const
Placeholder struct for boost::variant default ctors.
positive_ordered_var_decl()
std::vector< printable > printables_
boost::variant< boost::recursive_wrapper< std::string >, boost::recursive_wrapper< expression > > printable_t
std::pair< std::vector< var_decl >, std::vector< statement > > derived_decl_
std::vector< var_decl > data_decl_
std::pair< std::vector< var_decl >, std::vector< statement > > generated_decl_
std::vector< var_decl > parameter_decl_
std::pair< std::vector< var_decl >, std::vector< statement > > derived_data_decl_
bool is_ill_formed() const
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< assignment >, boost::recursive_wrapper< sample >, boost::recursive_wrapper< statements >, boost::recursive_wrapper< for_statement >, boost::recursive_wrapper< conditional_statement >, boost::recursive_wrapper< while_statement >, boost::recursive_wrapper< print_statement >, boost::recursive_wrapper< no_op_statement > > statement_t
std::vector< var_decl > local_decl_
std::vector< statement > statements_
unary_op(char op, expression const &subject)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_var_decl >, boost::recursive_wrapper< double_var_decl >, boost::recursive_wrapper< vector_var_decl >, boost::recursive_wrapper< row_vector_var_decl >, boost::recursive_wrapper< matrix_var_decl >, boost::recursive_wrapper< simplex_var_decl >, boost::recursive_wrapper< ordered_var_decl >, boost::recursive_wrapper< positive_ordered_var_decl >, boost::recursive_wrapper< cov_matrix_var_decl >, boost::recursive_wrapper< corr_matrix_var_decl > > var_decl_t
std::vector< expression > dims_
base_expr_type get_base_type(const std::string &name) const
std::map< std::string, range_t > map_
void add(const std::string &name, const base_var_decl &base_decl, const var_origin &vo)
std::pair< base_var_decl, var_origin > range_t
base_var_decl get(const std::string &name) const
var_origin get_origin(const std::string &name) const
void remove(const std::string &name)
bool exists(const std::string &name) const
size_t get_num_dims(const std::string &name) const
void set_type(const base_expr_type &base_type, size_t num_dims)