12 #ifndef MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
13 #define MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
21 namespace regression {
78 const size_t numClasses,
79 const bool fitIntercept =
false);
95 const arma::Row<size_t>& labels,
96 const size_t numClasses,
97 const double lambda = 0.0001,
98 const bool fitIntercept =
false);
123 arma::Row<size_t>& predictions)
const;
134 void Classify(
const arma::mat& dataset, arma::Row<size_t>& labels)
const;
145 const arma::Row<size_t>& labels)
const;
155 double Train(OptimizerType<SoftmaxRegressionFunction>& optimizer);
164 double Train(
const arma::mat &data,
const arma::Row<size_t>& labels,
165 const size_t numClasses);
187 {
return fitIntercept ? parameters.n_cols - 1 :
193 template<
typename Archive>
198 ar &
CreateNVP(parameters,
"parameters");
199 ar &
CreateNVP(numClasses,
"numClasses");
201 ar &
CreateNVP(fitIntercept,
"fitIntercept");
206 arma::mat parameters;
219 #include "softmax_regression_impl.hpp"
void Serialize(Archive &ar, const unsigned int)
Serialize the SoftmaxRegression model.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
The core includes that mlpack expects; standard C++ includes and Armadillo.
#define mlpack_deprecated
bool FitIntercept() const
Gets the intercept term flag. We can't change this after training.
mlpack_deprecated void Predict(const arma::mat &testData, arma::Row< size_t > &predictions) const
Predict the class labels for the provided feature points.
size_t FeatureSize() const
Gets the features size of the training data.
Softmax Regression is a classifier which can be used for classification when the data available can t...
const arma::mat & Parameters() const
Get the model parameters.
arma::mat & Parameters()
Get the model parameters.
double ComputeAccuracy(const arma::mat &testData, const arma::Row< size_t > &labels) const
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
size_t NumClasses() const
Gets the number of classes.
double Lambda() const
Gets the regularization parameter.
double & Lambda()
Sets the regularization parameter.
SoftmaxRegression(const size_t inputSize, const size_t numClasses, const bool fitIntercept=false)
Initialize the SoftmaxRegression without performing training.
size_t & NumClasses()
Sets the number of classes.
The generic L-BFGS optimizer, which uses a back-tracking line search algorithm to minimize a function...
double Train(OptimizerType< SoftmaxRegressionFunction > &optimizer)
Train the softmax regression model with the given optimizer.
void Classify(const arma::mat &dataset, arma::Row< size_t > &labels) const
Classify the given points, returning the predicted labels for each point.