mlpack  3.4.2
linear_regression.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
14 #define MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace regression {
20 
27 {
28  public:
37  LinearRegression(const arma::mat& predictors,
38  const arma::rowvec& responses,
39  const double lambda = 0,
40  const bool intercept = true);
41 
51  LinearRegression(const arma::mat& predictors,
52  const arma::rowvec& responses,
53  const arma::rowvec& weights,
54  const double lambda = 0,
55  const bool intercept = true);
56 
62  LinearRegression() : lambda(0.0), intercept(true) { }
63 
76  double Train(const arma::mat& predictors,
77  const arma::rowvec& responses,
78  const bool intercept = true);
79 
93  double Train(const arma::mat& predictors,
94  const arma::rowvec& responses,
95  const arma::rowvec& weights,
96  const bool intercept = true);
97 
104  void Predict(const arma::mat& points, arma::rowvec& predictions) const;
105 
123  double ComputeError(const arma::mat& points,
124  const arma::rowvec& responses) const;
125 
127  const arma::vec& Parameters() const { return parameters; }
129  arma::vec& Parameters() { return parameters; }
130 
132  double Lambda() const { return lambda; }
134  double& Lambda() { return lambda; }
135 
137  bool Intercept() const { return intercept; }
138 
142  template<typename Archive>
143  void serialize(Archive& ar, const unsigned int /* version */)
144  {
145  ar & BOOST_SERIALIZATION_NVP(parameters);
146  ar & BOOST_SERIALIZATION_NVP(lambda);
147  ar & BOOST_SERIALIZATION_NVP(intercept);
148  }
149 
150  private:
155  arma::vec parameters;
156 
161  double lambda;
162 
164  bool intercept;
165 };
166 
167 } // namespace regression
168 } // namespace mlpack
169 
170 #endif // MLPACK_METHODS_LINEAR_REGRESSION_HPP
prereqs.hpp
The core includes that mlpack expects; standard C++ includes and Armadillo.
mlpack::regression::LinearRegression::Parameters
arma::vec & Parameters()
Modify the parameters (the b vector).
Definition: linear_regression.hpp:129
mlpack::regression::LinearRegression::Parameters
const arma::vec & Parameters() const
Return the parameters (the b vector).
Definition: linear_regression.hpp:127
mlpack
Linear algebra utility functions, generally performed on matrices or vectors.
Definition: add_to_cli11.hpp:21
mlpack::regression::LinearRegression::Lambda
double Lambda() const
Return the Tikhonov regularization parameter for ridge regression.
Definition: linear_regression.hpp:132
mlpack::regression::LinearRegression::Train
double Train(const arma::mat &predictors, const arma::rowvec &responses, const arma::rowvec &weights, const bool intercept=true)
Train the LinearRegression model on the given data and weights.
mlpack::regression::LinearRegression::Intercept
bool Intercept() const
Return whether or not an intercept term is used in the model.
Definition: linear_regression.hpp:137
mlpack::regression::LinearRegression::serialize
void serialize(Archive &ar, const unsigned int)
Serialize the model.
Definition: linear_regression.hpp:143
mlpack::regression::LinearRegression::LinearRegression
LinearRegression(const arma::mat &predictors, const arma::rowvec &responses, const arma::rowvec &weights, const double lambda=0, const bool intercept=true)
Creates the model with weighted learning.
mlpack::regression::LinearRegression::Lambda
double & Lambda()
Modify the Tikhonov regularization parameter for ridge regression.
Definition: linear_regression.hpp:134
mlpack::regression::LinearRegression::Train
double Train(const arma::mat &predictors, const arma::rowvec &responses, const bool intercept=true)
Train the LinearRegression model on the given data.
mlpack::regression::LinearRegression::Predict
void Predict(const arma::mat &points, arma::rowvec &predictions) const
Calculate y_i for each data point in points.
mlpack::regression::LinearRegression::LinearRegression
LinearRegression(const arma::mat &predictors, const arma::rowvec &responses, const double lambda=0, const bool intercept=true)
Creates the model.
mlpack::regression::LinearRegression::ComputeError
double ComputeError(const arma::mat &points, const arma::rowvec &responses) const
Calculate the L2 squared error on the given predictors and responses using this linear regression mod...
mlpack::regression::LinearRegression
A simple linear regression algorithm using ordinary least squares.
Definition: linear_regression.hpp:27
mlpack::regression::LinearRegression::LinearRegression
LinearRegression()
Empty constructor.
Definition: linear_regression.hpp:62