Better than KNN: Approximate Nearest Neighbor (Introduction)

Goal

I assume you have heard of the k-Nearest Neighbor algorithm for classification problem (see Tutorial: K-Nearest Neighbor Model).

  • It’s one of the simplest classification algorithm
  • It does not require training time
  • It’s very slow at prediction time
  • It does not scale well for large training size because every prediction requires scanning the entire data set.

There are very advanced algorithms for searching for nearest neighbor such as Google’s ScaNN and Facebook’s Faiss.

These are approximate nearest neighbor (ANN) search algorithm because they speed up the search performance in time by sacrificing some margin of accuracy. The general approach is usually to building up an index to speed up the searching so that the system is never scanning the entire data set.

This article will go over an example of simple idea of implementing ANN search.

Hash tables

First idea you need to understand is hash tables. Hash tables are a kind of data structure where you can map a high dimension record (rich information) to a hash value.

For example, if you are familiar with the Iris Flow Dataset, then the input dimension is 4 because there are 4 features:

  • sepal length in cm
  • sepal width in cm
  • petal length in cm
  • petal width in cm

If we can have a function f(x1, x2, x3, x4) that returns a a single int32. Then this is consider a hash function.

Separating Hyperplanes

Next idea we need is the separating hyper-place. If the Iris Flow example, we can think of hyper-place in the 4-dimensional space that can separate the space into 2.

Each hyperplane can be represented by its normal (orthogonal) vector.

By taking the dot product of a vector with the normal vector and checking the signal, we can tell whether the vector falls on the positive or negative side of the hyperplane

  • dot_product(normal vector, positive vector) > 0. The blue vector falls on the positive side of the hyperplane
  • dot_product(normal vector, negative vector) < 0. The red vector falls on the positive side of the hyperplane
  • dot_product(normal vector, zero vector) = 0. The green vector falls on top of the hyperplane

The take away here is that checking which side of a hyper place that a vector falls on is quite straight forward by taking dot product.

Using hyperplanes as a hashing mechanism

Since each hyperplane can return a binary value that take a vector in 4-dimension. We can use multiple hyperplanes to encode multiple digits of an integer. So an int32 can represent 32 hyperplanes.

Further, we can use multiple sets of hyperplanes to represent multiple hash functions

Each region that is carved out by the separating hyperplane is a called a bucket under a given hash function.

Locality Sensitive Hashing (LSH)

Based on the hashing scheme above, points in the 4-dimensional input space that are close to each other are likely fall in the same buckets. It means the hashing scheme is locality sensitive. We can leverage this kind of hashing scheme to build an approximate nearest neighbor search index. Note that the separating hyperplane is just one type of locality sensitive hashing but there can be much more complicated ones for more sophisticated systems.

Approximate Nearest Neighbor Search

Combining all the ideas above, we can devise the algorithm by randomly generating a set of normal vector of these hyperplanes, and thus constructing a set of hash functions.

To search for the nearest neighbor from a point in the input high dimensional space, we can calculate the hash value for all the hash functions and check all the other points in the data set that share the same hash value.

Note that these other points that share the same hash value are already precomputed in some index (ie. like a python dictionary), so that you don’t iterate over all data set to find them.

Since the number of vectors that shares the same hash value is relatively small, this is a great speed up of the nearest neighbor search.

However, it’s not guaranteed to find the best neighbor because the hash functions are generally randomly and they are not perfect. But we should not be too far off when we use multiple hashing functions by randomization.

Conclusion

We showed the fundamental idea of how ANN speed the performance of nearest neighbor search. This has wide application when working with embedding space in ML and production ML systems to need to serve at real time.

Related Posts

Leave a Reply

Your email address will not be published. Required fields are marked *