mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
mean_shift.hpp
Go to the documentation of this file.
1 
13 #ifndef MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
14 #define MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
15 
16 #include <mlpack/prereqs.hpp>
20 #include <boost/utility.hpp>
21 
22 namespace mlpack {
23 namespace meanshift {
24 
46 template<bool UseKernel = false,
47  typename KernelType = kernel::GaussianKernel,
48  typename MatType = arma::mat>
49 class MeanShift
50 {
51  public:
63  MeanShift(const double radius = 0,
64  const size_t maxIterations = 1000,
65  const KernelType kernel = KernelType());
66 
73  double EstimateRadius(const MatType& data, const double ratio = 0.2);
74 
84  void Cluster(const MatType& data,
85  arma::Col<size_t>& assignments,
86  arma::mat& centroids,
87  bool useSeeds = true);
88 
90  size_t MaxIterations() const { return maxIterations; }
92  size_t& MaxIterations() { return maxIterations; }
93 
95  double Radius() const { return radius; }
97  void Radius(double radius);
98 
100  const KernelType& Kernel() const { return kernel; }
102  KernelType& Kernel() { return kernel; }
103 
104  private:
118  void GenSeeds(const MatType& data,
119  const double binSize,
120  const int minFreq,
121  MatType& seeds);
122 
131  template<bool ApplyKernel = UseKernel>
132  typename std::enable_if<ApplyKernel, bool>::type
133  CalculateCentroid(const MatType& data,
134  const std::vector<size_t>& neighbors,
135  const std::vector<double>& distances,
136  arma::colvec& centroid);
137 
146  template<bool ApplyKernel = UseKernel>
147  typename std::enable_if<!ApplyKernel, bool>::type
148  CalculateCentroid(const MatType& data,
149  const std::vector<size_t>& neighbors,
150  const std::vector<double>&, /*unused*/
151  arma::colvec& centroid);
152 
158  double radius;
159 
161  size_t maxIterations;
162 
164  KernelType kernel;
165 };
166 
167 } // namespace meanshift
168 } // namespace mlpack
169 
170 // Include implementation.
171 #include "mean_shift_impl.hpp"
172 
173 #endif // MLPACK_METHODS_MEAN_SHIFT_MEAN_SHIFT_HPP
The core includes that mlpack expects; standard C++ includes and Armadillo.
size_t & MaxIterations()
Set the maximum number of iterations.
Definition: mean_shift.hpp:92
This class implements mean shift clustering.
Definition: mean_shift.hpp:49
double EstimateRadius(const MatType &data, const double ratio=0.2)
Give an estimation of radius based on given dataset.
MeanShift(const double radius=0, const size_t maxIterations=1000, const KernelType kernel=KernelType())
Create a mean shift object and set the parameters which mean shift will be run with.
size_t MaxIterations() const
Get the maximum number of iterations.
Definition: mean_shift.hpp:90
void Cluster(const MatType &data, arma::Col< size_t > &assignments, arma::mat &centroids, bool useSeeds=true)
Perform mean shift clustering on the data, returning a list of cluster assignments and centroids...
const KernelType & Kernel() const
Get the kernel.
Definition: mean_shift.hpp:100
The standard Gaussian kernel.
KernelType & Kernel()
Modify the kernel.
Definition: mean_shift.hpp:102
double Radius() const
Get the radius.
Definition: mean_shift.hpp:95