mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
linear_regression.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
14 #define MLPACK_METHODS_LINEAR_REGRESSION_LINEAR_REGRESSION_HPP
15 
16 #include <mlpack/prereqs.hpp>
17 
18 namespace mlpack {
19 namespace regression {
20 
27 {
28  public:
38  LinearRegression(const arma::mat& predictors,
39  const arma::vec& responses,
40  const double lambda = 0,
41  const bool intercept = true,
42  const arma::vec& weights = arma::vec());
43 
49  LinearRegression(const LinearRegression& linearRegression);
50 
56  LinearRegression() : lambda(0.0), intercept(true) { }
57 
70  void Train(const arma::mat& predictors,
71  const arma::vec& responses,
72  const bool intercept = true,
73  const arma::vec& weights = arma::vec());
74 
81  void Predict(const arma::mat& points, arma::vec& predictions) const;
82 
100  double ComputeError(const arma::mat& points,
101  const arma::vec& responses) const;
102 
104  const arma::vec& Parameters() const { return parameters; }
106  arma::vec& Parameters() { return parameters; }
107 
109  double Lambda() const { return lambda; }
111  double& Lambda() { return lambda; }
112 
114  bool Intercept() const { return intercept; }
115 
119  template<typename Archive>
120  void Serialize(Archive& ar, const unsigned int /* version */)
121  {
122  ar & data::CreateNVP(parameters, "parameters");
123  ar & data::CreateNVP(lambda, "lambda");
124  ar & data::CreateNVP(intercept, "intercept");
125  }
126 
127  private:
132  arma::vec parameters;
133 
138  double lambda;
139 
141  bool intercept;
142 };
143 
144 } // namespace linear_regression
145 } // namespace mlpack
146 
147 #endif // MLPACK_METHODS_LINEAR_REGRESSION_HPP
const arma::vec & Parameters() const
Return the parameters (the b vector).
arma::vec & Parameters()
Modify the parameters (the b vector).
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
A simple linear regression algorithm using ordinary least squares.
The core includes that mlpack expects; standard C++ includes and Armadillo.
void Serialize(Archive &ar, const unsigned int)
Serialize the model.
double Lambda() const
Return the Tikhonov regularization parameter for ridge regression.
double & Lambda()
Modify the Tikhonov regularization parameter for ridge regression.
void Train(const arma::mat &predictors, const arma::vec &responses, const bool intercept=true, const arma::vec &weights=arma::vec())
Train the LinearRegression model on the given data.
bool Intercept() const
Return whether or not an intercept term is used in the model.
void Predict(const arma::mat &points, arma::vec &predictions) const
Calculate y_i for each data point in points.
double ComputeError(const arma::mat &points, const arma::vec &responses) const
Calculate the L2 squared error on the given predictors and responses using this linear regression mod...