1 #ifndef __STAN__GM__AST_DEF_HPP__
2 #define __STAN__GM__AST_DEF_HPP__
4 #include <boost/variant/apply_visitor.hpp>
5 #include <boost/variant/recursive_variant.hpp>
55 : base_type_(base_type),
60 : base_type_(base_type),
68 return !(*
this == et);
120 const std::vector<expr_type>& arg_types) {
126 std::vector<expr_type> arg_types;
127 add(name,result_type,arg_types);
132 std::vector<expr_type> arg_types;
133 arg_types.push_back(arg_type);
134 add(name,result_type,arg_types);
140 std::vector<expr_type> arg_types;
141 arg_types.push_back(arg_type1);
142 arg_types.push_back(arg_type2);
143 add(name,result_type,arg_types);
150 std::vector<expr_type> arg_types;
151 arg_types.push_back(arg_type1);
152 arg_types.push_back(arg_type2);
153 arg_types.push_back(arg_type3);
154 add(name,result_type,arg_types);
162 std::vector<expr_type> arg_types;
163 arg_types.push_back(arg_type1);
164 arg_types.push_back(arg_type2);
165 arg_types.push_back(arg_type3);
166 arg_types.push_back(arg_type4);
167 add(name,result_type,arg_types);
176 std::vector<expr_type> arg_types;
177 arg_types.push_back(arg_type1);
178 arg_types.push_back(arg_type2);
179 arg_types.push_back(arg_type3);
180 arg_types.push_back(arg_type4);
181 arg_types.push_back(arg_type5);
182 add(name,result_type,arg_types);
200 const std::vector<expr_type>& call_args,
201 const std::vector<expr_type>& sig_args) {
202 if (call_args.size() != sig_args.size()) {
206 for (
size_t i = 0; i < call_args.size(); ++i) {
207 if (call_args[i] == sig_args[i]) {
209 }
else if (call_args[i].is_primitive_int()
210 && sig_args[i].is_primitive_double()) {
219 const std::string& name,
220 const std::vector<expr_type>& args,
221 std::ostream& error_msgs) {
222 std::vector<function_signature_t> signatures = sigs_map_[name];
223 size_t match_index = 0;
225 size_t num_matches = 0;
227 for (
size_t i = 0; i < signatures.size(); ++i) {
229 if (promotions < 0)
continue;
230 size_t promotions_ui =
static_cast<size_t>(promotions);
231 if (promotions_ui < min_promotions) {
232 min_promotions = promotions_ui;
235 }
else if (promotions_ui == min_promotions) {
240 if (num_matches == 1) {
241 return signatures[match_index].first;
242 }
else if (num_matches == 0) {
243 error_msgs <<
"no matches for function name=\"" << name <<
"\""
246 error_msgs << num_matches <<
" matches with "
247 << min_promotions <<
" integer promotions "
248 <<
"for function name=\"" << name <<
"\"" << std::endl;
250 for (
size_t i = 0; i < args.size(); ++i)
251 error_msgs <<
" arg " << i <<
" type=" << args[i] << std::endl;
254 function_signatures::function_signatures() {
257 function_signatures* function_signatures::sigs_ = 0;
263 const std::vector<statement>& stmts)
264 : local_decl_(local_decl),
308 return boost::apply_visitor(vis,
expr_);
394 return boost::apply_visitor(ino,
e.expr_);
399 std::vector<expression>
const& dims)
461 std::vector<expression>
const& args)
471 size_t total_dims(
const std::vector<std::vector<expression> >& dimss) {
473 for (
size_t i = 0; i < dimss.size(); ++i)
474 total += dimss[i].size();
480 size_t num_expr_dims,
481 size_t num_index_dims) {
482 if (num_index_dims <= num_expr_dims)
483 return expr_type(expr_base_type,num_expr_dims - num_index_dims);
484 if (num_index_dims == (num_expr_dims + 1)) {
490 if (num_index_dims == (num_expr_dims + 2))
499 size_t num_index_dims) {
508 const std::vector<std::vector<expression> >& dimss)
519 const std::string& op,
525 right.expression_type())) {
554 o <<
"transformed data";
558 o <<
"transformed parameter";
564 o <<
"UNKNOWN ORIGIN";
570 : base_type_(base_type) {
573 const std::vector<expression>& dims,
577 base_type_(base_type) {
581 return map_.find(name) !=
map_.end();
585 throw std::invalid_argument(
"variable does not exist");
586 return map_.find(name)->second.first;
596 throw std::invalid_argument(
"variable does not exist");
597 return map_.find(name)->second.second;
613 std::string
const& name,
614 std::vector<expression>
const& dims)
626 std::string
const& name,
627 std::vector<expression>
const& dims)
637 std::string
const& name,
638 std::vector<expression>
const& dims)
648 std::string
const& name,
649 std::vector<expression>
const& dims)
659 std::string
const& name,
660 std::vector<expression>
const& dims)
669 std::string
const& name,
670 std::vector<expression>
const& dims)
679 std::string
const& name,
680 std::vector<expression>
const& dims)
690 std::string
const& name,
691 std::vector<expression>
const& dims)
702 std::string
const& name,
703 std::vector<expression>
const& dims)
710 std::string
const& name,
711 std::vector<expression>
const& dims)
818 : condition_(condition),
826 const std::vector<statement>& bodies)
827 : conditions_(conditions),
834 : printables_(printables) {
839 const std::pair<std::vector<var_decl>,
840 std::vector<statement> >& derived_data_decl,
841 const std::vector<var_decl>& parameter_decl,
842 const std::pair<std::vector<var_decl>,
843 std::vector<statement> >& derived_decl,
845 const std::pair<std::vector<var_decl>,
846 std::vector<statement> >& generated_decl)
847 : data_decl_(data_decl),
848 derived_data_decl_(derived_data_decl),
849 parameter_decl_(parameter_decl),
850 derived_decl_(derived_decl),
852 generated_decl_(generated_decl) {
876 : var_dims_(var_dims),
void add_ternary(const ::std::string &name)
void add(const std::string &name, const expr_type &result_type, const std::vector< expr_type > &arg_types)
void add_binary(const ::std::string &name)
void add_nullary(const ::std::string &name)
void add_quaternary(const ::std::string &name)
static function_signatures & instance()
int num_promotions(const std::vector< expr_type > &call_args, const std::vector< expr_type > &sig_args)
void add_unary(const ::std::string &name)
expr_type get_result_type(const std::string &name, const std::vector< expr_type > &args, std::ostream &error_msgs)
size_t total_dims(const std::vector< std::vector< expression > > &dimss)
const int parameter_origin
expr_type infer_type_indexing(const base_expr_type &expr_base_type, size_t num_expr_dims, size_t num_index_dims)
std::ostream & write_base_expr_type(std::ostream &o, base_expr_type type)
expr_type promote_primitive(const expr_type &et)
const int transformed_data_origin
bool is_nil(const expression &e)
std::pair< expr_type, std::vector< expr_type > > function_signature_t
const int transformed_parameter_origin
std::ostream & operator<<(std::ostream &o, const expr_type &et)
void print_var_origin(std::ostream &o, const var_origin &vo)
double e()
Return the base of the natural logarithm.
int max(const std::vector< int > &x)
Returns the maximum coefficient in the specified column vector.
double dist(const std::vector< double > &x, const std::vector< double > &y)
Probability, optimization and sampling library.
std::vector< expression > args_
array_literal & operator=(const array_literal &al)
std::vector< expression > dims_
base_expr_type base_type_
double_literal & operator=(const double_literal &dl)
bool is_primitive() const
bool operator==(const expr_type &et) const
bool operator!=(const expr_type &et) const
bool is_primitive_int() const
base_expr_type type() const
bool is_primitive_double() const
bool is_ill_formed() const
base_expr_type base_type_
expr_type operator()(const nil &e) const
expression & operator*=(const expression &rhs)
expr_type expression_type() const
expression & operator+=(const expression &rhs)
expression & operator/=(const expression &rhs)
expression & operator-=(const expression &rhs)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_literal >, boost::recursive_wrapper< double_literal >, boost::recursive_wrapper< array_literal >, boost::recursive_wrapper< variable >, boost::recursive_wrapper< fun >, boost::recursive_wrapper< index_op >, boost::recursive_wrapper< binary_op >, boost::recursive_wrapper< unary_op > > expression_t
std::vector< std::vector< expression > > dimss_
int_literal & operator=(const int_literal &il)
bool operator()(const nil &x) const
std::string operator()(const nil &x) const
Placeholder struct for boost::variant default ctors.
positive_ordered_var_decl()
boost::variant< boost::recursive_wrapper< std::string >, boost::recursive_wrapper< expression > > printable_t
bool is_ill_formed() const
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< assignment >, boost::recursive_wrapper< sample >, boost::recursive_wrapper< statements >, boost::recursive_wrapper< for_statement >, boost::recursive_wrapper< conditional_statement >, boost::recursive_wrapper< while_statement >, boost::recursive_wrapper< print_statement >, boost::recursive_wrapper< no_op_statement > > statement_t
unary_op(char op, expression const &subject)
boost::variant< boost::recursive_wrapper< nil >, boost::recursive_wrapper< int_var_decl >, boost::recursive_wrapper< double_var_decl >, boost::recursive_wrapper< vector_var_decl >, boost::recursive_wrapper< row_vector_var_decl >, boost::recursive_wrapper< matrix_var_decl >, boost::recursive_wrapper< simplex_var_decl >, boost::recursive_wrapper< ordered_var_decl >, boost::recursive_wrapper< positive_ordered_var_decl >, boost::recursive_wrapper< cov_matrix_var_decl >, boost::recursive_wrapper< corr_matrix_var_decl > > var_decl_t
base_expr_type get_base_type(const std::string &name) const
std::map< std::string, range_t > map_
void add(const std::string &name, const base_var_decl &base_decl, const var_origin &vo)
std::pair< base_var_decl, var_origin > range_t
base_var_decl get(const std::string &name) const
var_origin get_origin(const std::string &name) const
void remove(const std::string &name)
bool exists(const std::string &name) const
size_t get_num_dims(const std::string &name) const
void set_type(const base_expr_type &base_type, size_t num_dims)