mlpack  2.0.1
softmax_regression.hpp
Go to the documentation of this file.
1 
14 #ifndef __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
15 #define __MLPACK_METHODS_SOFTMAX_REGRESSION_SOFTMAX_REGRESSION_HPP
16 
17 #include <mlpack/core.hpp>
19 
21 
22 namespace mlpack {
23 namespace regression {
24 
64 template<
65  template<typename> class OptimizerType = mlpack::optimization::L_BFGS
66 >
68 {
69  public:
79  SoftmaxRegression(const size_t inputSize,
80  const size_t numClasses,
81  const bool fitIntercept = false);
82 
96  SoftmaxRegression(const arma::mat& data,
97  const arma::Row<size_t>& labels,
98  const size_t numClasses,
99  const double lambda = 0.0001,
100  const bool fitIntercept = false);
101 
111  SoftmaxRegression(OptimizerType<SoftmaxRegressionFunction>& optimizer);
112 
121  void Predict(const arma::mat& testData, arma::Row<size_t>& predictions) const;
122 
131  double ComputeAccuracy(const arma::mat& testData, const arma::Row<size_t>& labels);
132 
141  double Train(OptimizerType<SoftmaxRegressionFunction>& optimizer);
142 
150  double Train(const arma::mat &data, const arma::Row<size_t>& labels,
151  const size_t numClasses);
152 
154  size_t& NumClasses() { return numClasses; }
156  size_t NumClasses() const { return numClasses; }
157 
159  double& Lambda() { return lambda; }
161  double Lambda() const { return lambda; }
162 
164  bool FitIntercept() const { return fitIntercept; }
165 
167  arma::mat& Parameters() { return parameters; }
169  const arma::mat& Parameters() const { return parameters; }
170 
172  size_t FeatureSize() const
173  { return fitIntercept ? parameters.n_cols - 1 :
174  parameters.n_cols; }
175 
179  template<typename Archive>
180  void Serialize(Archive& ar, const unsigned int /* version */)
181  {
183 
184  ar & CreateNVP(parameters, "parameters");
185  ar & CreateNVP(numClasses, "numClasses");
186  ar & CreateNVP(lambda, "lambda");
187  ar & CreateNVP(fitIntercept, "fitIntercept");
188  }
189 
190  private:
192  arma::mat parameters;
194  size_t numClasses;
196  double lambda;
199 };
200 
201 } // namespace regression
202 } // namespace mlpack
203 
204 // Include implementation.
205 #include "softmax_regression_impl.hpp"
206 
207 #endif
double & Lambda()
Sets the regularization parameter.
Linear algebra utility functions, generally performed on matrices or vectors.
FirstShim< T > CreateNVP(T &t, const std::string &name, typename boost::enable_if< HasSerialize< T >>::type *=0)
Call this function to produce a name-value pair; this is similar to BOOST_SERIALIZATION_NVP(), but should be used for types that have a Serialize() function (or contain a type that has a Serialize() function) instead of a serialize() function.
arma::mat & Parameters()
Get the model parameters.
double ComputeAccuracy(const arma::mat &testData, const arma::Row< size_t > &labels)
Computes accuracy of the learned model given the feature data and the labels associated with each dat...
const arma::mat & Parameters() const
Get the model parameters.
double Train(OptimizerType< SoftmaxRegressionFunction > &optimizer)
Train the softmax regression model with the given optimizer.
Softmax Regression is a classifier which can be used for classification when the data available can t...
size_t & NumClasses()
Sets the number of classes.
size_t NumClasses() const
Gets the number of classes.
double Lambda() const
Gets the regularization parameter.
bool FitIntercept() const
Gets the intercept term flag. We can't change this after training.
Include all of the base components required to write MLPACK methods, and the main MLPACK Doxygen docu...
SoftmaxRegression(const size_t inputSize, const size_t numClasses, const bool fitIntercept=false)
Initialize the SoftmaxRegression without performing training.
void Serialize(Archive &ar, const unsigned int)
Serialize the SoftmaxRegression model.
double lambda
L2-regularization constant.
size_t FeatureSize() const
Gets the features size of the training data.
void Predict(const arma::mat &testData, arma::Row< size_t > &predictions) const
Predict the class labels for the provided feature points.
arma::mat parameters
Parameters after optimization.
The generic L-BFGS optimizer, which uses a back-tracking line search algorithm to minimize a function...
Definition: lbfgs.hpp:36