MLPACK  1.0.7
em_fit.hpp
Go to the documentation of this file.
1 
23 #ifndef __MLPACK_METHODS_GMM_EM_FIT_HPP
24 #define __MLPACK_METHODS_GMM_EM_FIT_HPP
25 
26 #include <mlpack/core.hpp>
27 
28 // Default clustering mechanism.
30 // Default covariance matrix constraint.
32 
33 namespace mlpack {
34 namespace gmm {
35 
49 template<typename InitialClusteringType = kmeans::KMeans<>,
50  typename CovarianceConstraintPolicy = PositiveDefiniteConstraint>
51 class EMFit
52 {
53  public:
71  EMFit(const size_t maxIterations = 300,
72  const double tolerance = 1e-10,
73  InitialClusteringType clusterer = InitialClusteringType(),
74  CovarianceConstraintPolicy constraint = CovarianceConstraintPolicy());
75 
86  void Estimate(const arma::mat& observations,
87  std::vector<arma::vec>& means,
88  std::vector<arma::mat>& covariances,
89  arma::vec& weights);
90 
103  void Estimate(const arma::mat& observations,
104  const arma::vec& probabilities,
105  std::vector<arma::vec>& means,
106  std::vector<arma::mat>& covariances,
107  arma::vec& weights);
108 
110  const InitialClusteringType& Clusterer() const { return clusterer; }
112  InitialClusteringType& Clusterer() { return clusterer; }
113 
115  const CovarianceConstraintPolicy& Constraint() const { return constraint; }
117  CovarianceConstraintPolicy& Constraint() { return constraint; }
118 
120  size_t MaxIterations() const { return maxIterations; }
122  size_t& MaxIterations() { return maxIterations; }
123 
125  double Tolerance() const { return tolerance; }
127  double& Tolerance() { return tolerance; }
128 
129  private:
140  void InitialClustering(const arma::mat& observations,
141  std::vector<arma::vec>& means,
142  std::vector<arma::mat>& covariances,
143  arma::vec& weights);
144 
155  double LogLikelihood(const arma::mat& data,
156  const std::vector<arma::vec>& means,
157  const std::vector<arma::mat>& covariances,
158  const arma::vec& weights) const;
159 
163  double tolerance;
165  InitialClusteringType clusterer;
167  CovarianceConstraintPolicy constraint;
168 };
169 
170 }; // namespace gmm
171 }; // namespace mlpack
172 
173 // Include implementation.
174 #include "em_fit_impl.hpp"
175 
176 #endif