43 #ifndef MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
44 #define MLPACK_METHODS_NEIGHBOR_SEARCH_LSH_SEARCH_HPP
61 template<
typename SortPolicy = NearestNeighborSort>
86 const arma::cube& projections,
87 const double hashWidth = 0.0,
88 const size_t secondHashSize = 99901,
89 const size_t bucketSize = 500);
113 const size_t numProj,
114 const size_t numTables,
115 const double hashWidth = 0.0,
116 const size_t secondHashSize = 99901,
117 const size_t bucketSize = 500);
154 void Train(
const arma::mat& referenceSet,
155 const size_t numProj,
156 const size_t numTables,
157 const double hashWidth = 0.0,
158 const size_t secondHashSize = 99901,
159 const size_t bucketSize = 500,
160 const arma::cube& projection = arma::cube());
183 void Search(
const arma::mat& querySet,
185 arma::Mat<size_t>& resultingNeighbors,
186 arma::mat& distances,
187 const size_t numTablesToSearch = 0,
208 void Search(
const size_t k,
209 arma::Mat<size_t>& resultingNeighbors,
210 arma::mat& distances,
211 const size_t numTablesToSearch = 0,
223 static double ComputeRecall(
const arma::Mat<size_t>& foundNeighbors,
224 const arma::Mat<size_t>& realNeighbors);
231 template<
typename Archive>
232 void Serialize(Archive& ar,
const unsigned int version);
246 const arma::mat&
Offsets()
const {
return offsets; }
256 {
return secondHashTable; }
265 Train(*referenceSet, numProj, numTables, hashWidth, secondHashSize,
266 bucketSize, projTables);
285 template<
typename VecType>
286 void ReturnIndicesFromTable(
const VecType& queryPoint,
287 arma::uvec& referenceIndices,
288 size_t numTablesToSearch,
289 const size_t T)
const;
304 void BaseCase(
const size_t queryIndex,
305 const arma::uvec& referenceIndices,
307 arma::Mat<size_t>& neighbors,
308 arma::mat& distances)
const;
324 void BaseCase(
const size_t queryIndex,
325 const arma::uvec& referenceIndices,
327 const arma::mat& querySet,
328 arma::Mat<size_t>& neighbors,
329 arma::mat& distances)
const;
345 void GetAdditionalProbingBins(
const arma::vec& queryCode,
346 const arma::vec& queryCodeNotFloored,
348 arma::mat& additionalProbingBins)
const;
357 double PerturbationScore(
const std::vector<bool>& A,
358 const arma::vec& scores)
const;
366 bool PerturbationShift(std::vector<bool>& A)
const;
375 bool PerturbationExpand(std::vector<bool>& A)
const;
383 bool PerturbationValid(
const std::vector<bool>& A)
const;
388 const arma::mat* referenceSet;
398 arma::cube projections;
407 size_t secondHashSize;
410 arma::vec secondHashWeights;
417 std::vector<arma::Col<size_t>> secondHashTable;
421 arma::Col<size_t> bucketContentSize;
425 arma::Col<size_t> bucketRowInHashTable;
428 size_t distanceEvaluations;
431 typedef std::pair<double, size_t> Candidate;
434 struct CandidateCmp {
435 bool operator()(
const Candidate& c1,
const Candidate& c2)
437 return !SortPolicy::IsBetter(c2.first, c1.first);
442 typedef std::priority_queue<Candidate, std::vector<Candidate>, CandidateCmp>
455 #include "lsh_search_impl.hpp"
void Search(const arma::mat &querySet, const size_t k, arma::Mat< size_t > &resultingNeighbors, arma::mat &distances, const size_t numTablesToSearch=0, const size_t T=0)
Compute the nearest neighbors of the points in the given query set and store the output in the given ...
#define BOOST_TEMPLATE_CLASS_VERSION(SIGNATURE, T, N)
Use this like BOOST_CLASS_VERSION(), but for templated classes.
~LSHSearch()
Clean memory.
void Train(const arma::mat &referenceSet, const size_t numProj, const size_t numTables, const double hashWidth=0.0, const size_t secondHashSize=99901, const size_t bucketSize=500, const arma::cube &projection=arma::cube())
Train the LSH model on the given dataset.
const arma::cube & Projections()
Get the projection tables.
The core includes that mlpack expects; standard C++ includes and Armadillo.
LSHSearch()
Create an untrained LSH model.
const std::vector< arma::Col< size_t > > & SecondHashTable() const
Get the second hash table.
The LSHSearch class; this class builds a hash on the reference set and uses this hash to compute the ...
size_t BucketSize() const
Get the bucket size of the second hash.
size_t DistanceEvaluations() const
Return the number of distance evaluations performed.
size_t NumProjections() const
Get the number of projections.
const arma::mat & Offsets() const
Get the offsets 'b' for each of the projections. (One 'b' per column.)
static double ComputeRecall(const arma::Mat< size_t > &foundNeighbors, const arma::Mat< size_t > &realNeighbors)
Compute the recall (% of neighbors found) given the neighbors returned by LSHSearch::Search and a "gr...
const arma::mat & ReferenceSet() const
Return the reference dataset.
void Projections(const arma::cube &projTables)
Change the projection tables (this retrains the LSH model).
void Serialize(Archive &ar, const unsigned int version)
Serialize the LSH model.
const arma::vec & SecondHashWeights() const
Get the weights of the second hash.
size_t & DistanceEvaluations()
Modify the number of distance evaluations performed.