mlpack  2.2.5
 All Classes Namespaces Files Functions Variables Typedefs Enumerations Enumerator Friends Macros Pages
lsh_search.hpp
Go to the documentation of this file.
1 
43 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
44 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
45 
46 #include <mlpack/prereqs.hpp>
47 
50 
51 namespace mlpack {
52 namespace neighbor {
53 
61 template<typename SortPolicy = NearestNeighborSort>
62 class LSHSearch
63 {
64  public:
85  LSHSearch(const arma::mat& referenceSet,
86  const arma::cube& projections,
87  const double hashWidth = 0.0,
88  const size_t secondHashSize = 99901,
89  const size_t bucketSize = 500);
90 
112  LSHSearch(const arma::mat& referenceSet,
113  const size_t numProj,
114  const size_t numTables,
115  const double hashWidth = 0.0,
116  const size_t secondHashSize = 99901,
117  const size_t bucketSize = 500);
118 
123  LSHSearch();
124 
128  ~LSHSearch();
129 
154  void Train(const arma::mat& referenceSet,
155  const size_t numProj,
156  const size_t numTables,
157  const double hashWidth = 0.0,
158  const size_t secondHashSize = 99901,
159  const size_t bucketSize = 500,
160  const arma::cube& projection = arma::cube());
161 
183  void Search(const arma::mat& querySet,
184  const size_t k,
185  arma::Mat<size_t>& resultingNeighbors,
186  arma::mat& distances,
187  const size_t numTablesToSearch = 0,
188  const size_t T = 0);
189 
208  void Search(const size_t k,
209  arma::Mat<size_t>& resultingNeighbors,
210  arma::mat& distances,
211  const size_t numTablesToSearch = 0,
212  size_t T = 0);
213 
223  static double ComputeRecall(const arma::Mat<size_t>& foundNeighbors,
224  const arma::Mat<size_t>& realNeighbors);
225 
231  template<typename Archive>
232  void Serialize(Archive& ar, const unsigned int version);
233 
235  size_t DistanceEvaluations() const { return distanceEvaluations; }
237  size_t& DistanceEvaluations() { return distanceEvaluations; }
238 
240  const arma::mat& ReferenceSet() const { return *referenceSet; }
241 
243  size_t NumProjections() const { return projections.n_slices; }
244 
246  const arma::mat& Offsets() const { return offsets; }
247 
249  const arma::vec& SecondHashWeights() const { return secondHashWeights; }
250 
252  size_t BucketSize() const { return bucketSize; }
253 
255  const std::vector<arma::Col<size_t>>& SecondHashTable() const
256  { return secondHashTable; }
257 
259  const arma::cube& Projections() { return projections; }
260 
262  void Projections(const arma::cube& projTables)
263  {
264  // Simply call Train() with the given projection tables.
265  Train(*referenceSet, numProj, numTables, hashWidth, secondHashSize,
266  bucketSize, projTables);
267  }
268 
269  private:
285  template<typename VecType>
286  void ReturnIndicesFromTable(const VecType& queryPoint,
287  arma::uvec& referenceIndices,
288  size_t numTablesToSearch,
289  const size_t T) const;
290 
304  void BaseCase(const size_t queryIndex,
305  const arma::uvec& referenceIndices,
306  const size_t k,
307  arma::Mat<size_t>& neighbors,
308  arma::mat& distances) const;
309 
324  void BaseCase(const size_t queryIndex,
325  const arma::uvec& referenceIndices,
326  const size_t k,
327  const arma::mat& querySet,
328  arma::Mat<size_t>& neighbors,
329  arma::mat& distances) const;
330 
345  void GetAdditionalProbingBins(const arma::vec& queryCode,
346  const arma::vec& queryCodeNotFloored,
347  const size_t T,
348  arma::mat& additionalProbingBins) const;
349 
357  double PerturbationScore(const std::vector<bool>& A,
358  const arma::vec& scores) const;
359 
366  bool PerturbationShift(std::vector<bool>& A) const;
367 
375  bool PerturbationExpand(std::vector<bool>& A) const;
376 
383  bool PerturbationValid(const std::vector<bool>& A) const;
384 
385 
386 
388  const arma::mat* referenceSet;
390  bool ownsSet;
391 
393  size_t numProj;
395  size_t numTables;
396 
398  arma::cube projections; // should be [numProj x dims] x numTables slices
399 
401  arma::mat offsets; // should be numProj x numTables
402 
404  double hashWidth;
405 
407  size_t secondHashSize;
408 
410  arma::vec secondHashWeights;
411 
413  size_t bucketSize;
414 
417  std::vector<arma::Col<size_t>> secondHashTable;
418 
421  arma::Col<size_t> bucketContentSize;
422 
425  arma::Col<size_t> bucketRowInHashTable;
426 
428  size_t distanceEvaluations;
429 
431  typedef std::pair<double, size_t> Candidate;
432 
434  struct CandidateCmp {
435  bool operator()(const Candidate& c1, const Candidate& c2)
436  {
437  return !SortPolicy::IsBetter(c2.first, c1.first);
438  };
439  };
440 
442  typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
443  CandidateList;
444 
445 }; // class LSHSearch
446 
447 } // namespace neighbor
448 } // namespace mlpack
449 
451 BOOST_TEMPLATE_CLASS_VERSION(template<typename SortPolicy>,
453 
454 // Include implementation.
455 #include "lsh_search_impl.hpp"
456 
457 #endif
void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
#define BOOST_TEMPLATE_CLASS_VERSION(SIGNATURE, T, N)
Use this like BOOST_CLASS_VERSION(), but for templated classes.
~LSHSearch()
Clean memory.
void Train(const arma::mat &referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
const arma::cube & Projections()
Get the projection tables.
Definition: lsh_search.hpp:259
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
Definition: lsh_search.hpp:255
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
Definition: lsh_search.hpp:62
size_t BucketSize() const
Get the bucket size of the second hash.
Definition: lsh_search.hpp:252
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
Definition: lsh_search.hpp:235
size_t NumProjections() const
Get the number of projections.
Definition: lsh_search.hpp:243
const arma::mat & Offsets() const
Get the offsets &#39;b&#39; for each of the projections. (One &#39;b&#39; per column.)
Definition: lsh_search.hpp:246
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a &quot;gr...
const arma::mat & ReferenceSet() const
Return the reference dataset.
Definition: lsh_search.hpp:240
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
Definition: lsh_search.hpp:262
void Serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
Definition: lsh_search.hpp:249
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.
Definition: lsh_search.hpp:237