mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
ns_model.hpp
Go to the documentation of this file.
1 
15 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
16 #define MLPACK_METHODS_NEIGHBOR_SEARCH_NS_MODEL_HPP
17 
23 #include <boost/variant.hpp>
24 #include "neighbor_search.hpp"
25 
26 namespace mlpack {
27 namespace neighbor {
28 
32 template<typename SortPolicy,
33  template<typename TreeMetricType,
34  typename TreeStatType,
35  typename TreeMatType> class TreeType>
36 using NSType = NeighborSearch<SortPolicy,
38  arma::mat,
39  TreeType,
41  NeighborSearchStat<SortPolicy>,
42  arma::mat>::template DualTreeTraverser>;
43 
44 template<typename SortPolicy>
46 {
47  static const std::string Name() { return "neighbor_search_model"; }
48 };
49 
50 template<>
52 {
53  static const std::string Name() { return "nearest_neighbor_search_model"; }
54 };
55 
56 template<>
58 {
59  static const std::string Name() { return "furthest_neighbor_search_model"; }
60 };
61 
66 class MonoSearchVisitor : public boost::static_visitor<void>
67 {
68  private:
70  const size_t k;
72  arma::Mat<size_t>& neighbors;
74  arma::mat& distances;
75 
76  public:
78  template<typename NSType>
79  void operator()(NSType* ns) const;
80 
82  MonoSearchVisitor(const size_t k,
83  arma::Mat<size_t>& neighbors,
84  arma::mat& distances) :
85  k(k),
86  neighbors(neighbors),
87  distances(distances)
88  {};
89 };
90 
97 template<typename SortPolicy>
98 class BiSearchVisitor : public boost::static_visitor<void>
99 {
100  private:
102  const arma::mat& querySet;
104  const size_t k;
106  arma::Mat<size_t>& neighbors;
108  arma::mat& distances;
110  const size_t leafSize;
112  const double tau;
114  const double rho;
115 
117  template<typename NSType>
118  void SearchLeaf(NSType* ns) const;
119 
120  public:
122  template<template<typename TreeMetricType,
123  typename TreeStatType,
124  typename TreeMatType> class TreeType>
126 
128  template<template<typename TreeMetricType,
129  typename TreeStatType,
130  typename TreeMatType> class TreeType>
131  void operator()(NSTypeT<TreeType>* ns) const;
132 
134  void operator()(NSTypeT<tree::KDTree>* ns) const;
135 
137  void operator()(NSTypeT<tree::BallTree>* ns) const;
138 
140  void operator()(SpillKNN* ns) const;
141 
143  void operator()(NSTypeT<tree::Octree>* ns) const;
144 
146  BiSearchVisitor(const arma::mat& querySet,
147  const size_t k,
148  arma::Mat<size_t>& neighbors,
149  arma::mat& distances,
150  const size_t leafSize,
151  const double tau,
152  const double rho);
153 };
154 
161 template<typename SortPolicy>
162 class TrainVisitor : public boost::static_visitor<void>
163 {
164  private:
166  arma::mat&& referenceSet;
168  size_t leafSize;
170  const double tau;
172  const double rho;
173 
175  template<typename NSType>
176  void TrainLeaf(NSType* ns) const;
177 
178  public:
180  template<template<typename TreeMetricType,
181  typename TreeStatType,
182  typename TreeMatType> class TreeType>
184 
186  template<template<typename TreeMetricType,
187  typename TreeStatType,
188  typename TreeMatType> class TreeType>
189  void operator()(NSTypeT<TreeType>* ns) const;
190 
192  void operator()(NSTypeT<tree::KDTree>* ns) const;
193 
195  void operator()(NSTypeT<tree::BallTree>* ns) const;
196 
198  void operator()(SpillKNN* ns) const;
199 
201  void operator()(NSTypeT<tree::Octree>* ns) const;
202 
205  TrainVisitor(arma::mat&& referenceSet,
206  const size_t leafSize,
207  const double tau,
208  const double rho);
209 };
210 
214 class SearchModeVisitor : public boost::static_visitor<NeighborSearchMode>
215 {
216  public:
218  template<typename NSType>
220 };
221 
225 class SetSearchModeVisitor : public boost::static_visitor<void>
226 {
227  NeighborSearchMode searchMode;
228  public:
231  searchMode(searchMode)
232  {};
233 
235  template<typename NSType>
236  void operator()(NSType* ns) const;
237 };
238 
242 class EpsilonVisitor : public boost::static_visitor<double&>
243 {
244  public:
246  template<typename NSType>
247  double& operator()(NSType *ns) const;
248 };
249 
253 class ReferenceSetVisitor : public boost::static_visitor<const arma::mat&>
254 {
255  public:
257  template<typename NSType>
258  const arma::mat& operator()(NSType *ns) const;
259 };
260 
264 class DeleteVisitor : public boost::static_visitor<void>
265 {
266  public:
268  template<typename NSType>
269  void operator()(NSType *ns) const;
270 };
271 
282 template<typename SortPolicy>
283 class NSModel
284 {
285  public:
288  {
304  };
305 
306  private:
308  TreeTypes treeType;
309 
311  size_t leafSize;
312 
314  double tau;
316  double rho;
317 
319  bool randomBasis;
321  arma::mat q;
322 
328  boost::variant<NSType<SortPolicy, tree::KDTree>*,
340  SpillKNN*,
343 
344  public:
349  NSModel(TreeTypes treeType = TreeTypes::KD_TREE, bool randomBasis = false);
350 
352  ~NSModel();
353 
355  template<typename Archive>
356  void Serialize(Archive& ar, const unsigned int /* version */);
357 
359  const arma::mat& Dataset() const;
360 
364  void SetSearchMode(const NeighborSearchMode mode);
365 
367  double Epsilon() const;
368  double& Epsilon();
369 
371  size_t LeafSize() const { return leafSize; }
372  size_t& LeafSize() { return leafSize; }
373 
375  double Tau() const { return tau; }
376  double& Tau() { return tau; }
377 
379  double Rho() const { return rho; }
380  double& Rho() { return rho; }
381 
383  TreeTypes TreeType() const { return treeType; }
384  TreeTypes& TreeType() { return treeType; }
385 
387  bool RandomBasis() const { return randomBasis; }
388  bool& RandomBasis() { return randomBasis; }
389 
391  void BuildModel(arma::mat&& referenceSet,
392  const size_t leafSize,
393  const NeighborSearchMode searchMode,
394  const double epsilon = 0);
395 
397  void Search(arma::mat&& querySet,
398  const size_t k,
399  arma::Mat<size_t>& neighbors,
400  arma::mat& distances);
401 
403  void Search(const size_t k,
404  arma::Mat<size_t>& neighbors,
405  arma::mat& distances);
406 
408  std::string TreeName() const;
409 };
410 
411 } // namespace neighbor
412 } // namespace mlpack
413 
415 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
417 
418 // Include implementation.
419 #include "ns_model_impl.hpp"
420 
421 #endif
TreeTypes TreeType() const
Expose treeType.
Definition: ns_model.hpp:383
double Epsilon() const
Expose Epsilon.
#define BOOST_TEMPLATE_CLASS_VERSION(SIGNATURE, T, N)
Use this like BOOST_CLASS_VERSION(), but for templated classes.
MonoSearchVisitor(const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Construct the MonoSearchVisitor object with the given parameters.
Definition: ns_model.hpp:82
EpsilonVisitor exposes the Epsilon method of the given NSType.
Definition: ns_model.hpp:242
std::string TreeName() const
Return a string representation of the current tree type.
TrainVisitor(arma::mat &&referenceSet, const size_t leafSize, const double tau, const double rho)
Construct the TrainVisitor object with the given reference set, leafSize for BinarySpaceTrees, and tau and rho for spill trees.
void operator()(NSType *ns) const
Delete the NSType object.
TreeTypes
Enum type to identify each accepted tree type.
Definition: ns_model.hpp:287
SetSearchModeVisitor modifies the SearchMode method of the given NSType.
Definition: ns_model.hpp:225
ReferenceSetVisitor exposes the referenceSet of the given NSType.
Definition: ns_model.hpp:253
SearchModeVisitor exposes the SearchMode() method of the given NSType.
Definition: ns_model.hpp:214
The NeighborSearch class is a template class for performing distance-based neighbor searches...
void operator()(NSTypeT< TreeType > *ns) const
Default Train on the given NSType instance.
double Rho() const
Expose rho.
Definition: ns_model.hpp:379
NeighborSearchMode SearchMode() const
Access the search mode.
NeighborSearch< SortPolicy, metric::EuclideanDistance, arma::mat, TreeType, TreeType< metric::EuclideanDistance, NeighborSearchStat< SortPolicy >, arma::mat >::template DualTreeTraverser > NSType
Alias template for euclidean neighbor search.
Definition: ns_model.hpp:42
void operator()(NSType *ns) const
Perform monochromatic nearest neighbor search.
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
~NSModel()
Clean memory, if necessary.
This class implements the necessary methods for the SortPolicy template parameter of the NeighborSear...
static const std::string Name()
Definition: ns_model.hpp:47
BiSearchVisitor executes a bichromatic neighbor search on the given NSType.
Definition: ns_model.hpp:98
const arma::mat & operator()(NSType *ns) const
Return the reference set.
double & operator()(NSType *ns) const
Return epsilon, the approximation parameter.
void Serialize(Archive &ar, const unsigned int)
Serialize the neighbor search model.
TreeTypes & TreeType()
Definition: ns_model.hpp:384
The NSModel class provides an easy way to serialize a model, abstracts away the different types of tr...
Definition: ns_model.hpp:283
NSModel(TreeTypes treeType=TreeTypes::KD_TREE, bool randomBasis=false)
Initialize the NSModel with the given type and whether or not a random basis should be used...
SetSearchModeVisitor(const NeighborSearchMode searchMode)
Construct the SetSearchModeVisitor object with the given mode.
Definition: ns_model.hpp:230
double Tau() const
Expose tau.
Definition: ns_model.hpp:375
TrainVisitor sets the reference set to a new reference set on the given NSType.
void operator()(NSTypeT< TreeType > *ns) const
Default Bichromatic neighbor search on the given NSType instance.
void Search(arma::mat &&querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances)
Perform neighbor search. The query set will be reordered.
BiSearchVisitor(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &neighbors, arma::mat &distances, const size_t leafSize, const double tau, const double rho)
Construct the BiSearchVisitor.
NeighborSearchMode operator()(NSType *ns) const
Return the search mode.
MonoSearchVisitor executes a monochromatic neighbor search on the given NSType.
Definition: ns_model.hpp:66
void BuildModel(arma::mat &&referenceSet, const size_t leafSize, const NeighborSearchMode searchMode, const double epsilon=0)
Build the reference tree.
size_t LeafSize() const
Expose leafSize.
Definition: ns_model.hpp:371
DeleteVisitor deletes the given NSType instance.
Definition: ns_model.hpp:264
LMetric< 2, true > EuclideanDistance
The Euclidean (L2) distance.
Definition: lmetric.hpp:112
void SetSearchMode(const NeighborSearchMode mode)
Modify the search mode.
NeighborSearchMode
NeighborSearchMode represents the different neighbor search modes available.
const arma::mat & Dataset() const
Expose the dataset.
bool RandomBasis() const
Expose randomBasis.
Definition: ns_model.hpp:387
void operator()(NSType *ns) const
Set the search mode.