MLPACK  1.0.7
gmm.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_METHODS_MOG_MOG_EM_HPP
24 #define __MLPACK_METHODS_MOG_MOG_EM_HPP
25 
26 #include <mlpack/core.hpp>
27 
28 // This is the default fitting method class.
29 #include "em_fit.hpp"
30 
31 namespace mlpack {
32 namespace gmm {
33 
88 template<typename FittingType = EMFit<> >
89 class GMM
90 {
91  private:
93  size_t gaussians;
97  std::vector<arma::vec> means;
99  std::vector<arma::mat> covariances;
101  arma::vec weights;
102 
103  public:
107  GMM() :
108  gaussians(0),
109  dimensionality(0),
110  localFitter(FittingType()),
112  {
113  // Warn the user. They probably don't want to do this. If this constructor
114  // is being used (because it is required by some template classes), the user
115  // should know that it is potentially dangerous.
116  Log::Debug << "GMM::GMM(): no parameters given; Estimate() may fail "
117  << "unless parameters are set." << std::endl;
118  }
119 
127  GMM(const size_t gaussians, const size_t dimensionality) :
128  gaussians(gaussians),
129  dimensionality(dimensionality),
130  means(gaussians, arma::vec(dimensionality)),
131  covariances(gaussians, arma::mat(dimensionality, dimensionality)),
132  weights(gaussians),
133  localFitter(FittingType()),
134  fitter(localFitter) { /* Nothing to do. */ }
135 
146  GMM(const size_t gaussians,
147  const size_t dimensionality,
148  FittingType& fitter) :
149  gaussians(gaussians),
150  dimensionality(dimensionality),
151  means(gaussians, arma::vec(dimensionality)),
152  covariances(gaussians, arma::mat(dimensionality, dimensionality)),
153  weights(gaussians),
154  fitter(fitter) { /* Nothing to do. */ }
155 
163  GMM(const std::vector<arma::vec>& means,
164  const std::vector<arma::mat>& covariances,
165  const arma::vec& weights) :
166  gaussians(means.size()),
167  dimensionality((!means.empty()) ? means[0].n_elem : 0),
168  means(means),
169  covariances(covariances),
170  weights(weights),
171  localFitter(FittingType()),
172  fitter(localFitter) { /* Nothing to do. */ }
173 
183  GMM(const std::vector<arma::vec>& means,
184  const std::vector<arma::mat>& covariances,
185  const arma::vec& weights,
186  FittingType& fitter) :
187  gaussians(means.size()),
188  dimensionality((!means.empty()) ? means[0].n_elem : 0),
189  means(means),
190  covariances(covariances),
191  weights(weights),
192  fitter(fitter) { /* Nothing to do. */ }
193 
197  template<typename OtherFittingType>
198  GMM(const GMM<OtherFittingType>& other);
199 
204  GMM(const GMM& other);
205 
209  template<typename OtherFittingType>
210  GMM& operator=(const GMM<OtherFittingType>& other);
211 
216  GMM& operator=(const GMM& other);
217 
224  void Load(const std::string& filename);
225 
231  void Save(const std::string& filename) const;
232 
234  size_t Gaussians() const { return gaussians; }
237  size_t& Gaussians() { return gaussians; }
238 
240  size_t Dimensionality() const { return dimensionality; }
243  size_t& Dimensionality() { return dimensionality; }
244 
246  const std::vector<arma::vec>& Means() const { return means; }
248  std::vector<arma::vec>& Means() { return means; }
249 
251  const std::vector<arma::mat>& Covariances() const { return covariances; }
253  std::vector<arma::mat>& Covariances() { return covariances; }
254 
256  const arma::vec& Weights() const { return weights; }
258  arma::vec& Weights() { return weights; }
259 
261  const FittingType& Fitter() const { return fitter; }
263  FittingType& Fitter() { return fitter; }
264 
271  double Probability(const arma::vec& observation) const;
272 
280  double Probability(const arma::vec& observation,
281  const size_t component) const;
282 
289  arma::vec Random() const;
290 
306  double Estimate(const arma::mat& observations,
307  const size_t trials = 1);
308 
326  double Estimate(const arma::mat& observations,
327  const arma::vec& probabilities,
328  const size_t trials = 1);
329 
346  void Classify(const arma::mat& observations,
347  arma::Col<size_t>& labels) const;
348 
349  private:
359  double LogLikelihood(const arma::mat& dataPoints,
360  const std::vector<arma::vec>& means,
361  const std::vector<arma::mat>& covars,
362  const arma::vec& weights) const;
363 
365  FittingType localFitter;
366 
368  FittingType& fitter;
369 };
370 
371 }; // namespace gmm
372 }; // namespace mlpack
373 
374 // Include implementation.
375 #include "gmm_impl.hpp"
376 
377 #endif