A picture of an android, hooked up to a series of tubes and wires

Similarity search is better than most people give it credit for

Created: Jul 13, 2023

Last edited: Jul 16, 2023

Tags: #machine-learning #math #randomized-algorithms

On k-nearest neighbors

If you ever read an introductory machine learning textbook or take a course on the subject, one of the first classification algorithms that you are likely to learn about is k-nearest neighbors (kNN). The idea behind it is pretty straightforward: suppose that you have a dataset split into two different classes, and you are given a new point that you want to classify. To do so, you would find the \(k\) points that are closest to it, and classify the new point as belonging the most common class among those points.

kNN is a completely respectable algorithm, and still an active area of research. But in most ML education and discussions, kNN gets written off pretty quickly, because although training a kNN classifier is extremely cheap, classification takes (naively) \(O(Nd)\) time, where \(N\) is the size of the dataset and \(d\) is the dimensionality1. So it pleased me to see this paper, “‘Low-Resource’ Text Classification: A Parameter-Free Classification Method with Compressors”, making the rounds recently. The authors of this paper construct a new string metric, defined as

\[ \begin{aligned} d(x,y) = \frac{C(xy) - \min{\{C(x), C(y)\}}}{\max{\{C(x), C(y)\}}} \end{aligned} \]

where \(C(\cdot)\) is the length of its input after compression with gzip. The idea, roughly, is to figure out the difference in the Kolmogorov complexity of the concatenated string \(xy\) from the strings \(x\) or \(y\). Since Kolmogorov complexity isn’t actually computable, we use a lossless compression algorithm as a stand-in.

It turns out that kNN, with this metric, is an extremely solid classifier on a range of tasks, even compared to many state-of-the-art BERT-based models. The authors demonstrate that it is quite competitive on the test set to most of the baselines it compares against, and in fact beats all of them on out-of-domain datasets2.

One downside of the paper is that it never really addresses its original promise, which is specifically low-resource classification. The methods that are presented still take \(O(Nd)\) time to execute, which for a reasonably-sized dataset is still much more expensive than a neural network. There are, however, a few different tricks that can be used to accelerate similarity search. k-d trees and ball trees reduce the impact of \(N\), while dimensionality reduction through the likes of PCA and randomized projections reduces the impact of \(d\).

My personal favorite trick in this genre is locality-sensitive hashing, or LSH. An LSH family for a given similarity function is a family of randomized hash functions with the property that, for two inputs and a randomly-sampled hash function, the probability of a hash collision between those inputs increases the more similar they are to one another. Families of locality-sensitive hash functions are known to exist for a ton of different similarity functions, including:

  • Cosine similarity3
  • \(\ell^p\) distance (in particular, Manhattan and Euclidean distance)4
  • Jaccard similarity5
  • Inner product similarity6
  • Hamming distance7

It’s also possible to use a hybrid approach to construct hash functions for other notions of similarity. For instance, you can use contrastive learning to train an embedding that maps inputs to some vector space, and then use SimHash to hash those vectors on their cosine similarity. You can also apply LSH even when your metric doesn’t have its own hash family, as long as you have another metric that’s correlated to the original. In that case, you’d use LSH with the second metric to reduce your original search space, and then run kNN with the original metric over the remaining points. A classic example of this is using Jaccard similarity with shingling as an initial, coarse-grained string similarity metric, and then Levenshtein distance as a fine(r)-grained metric for kNN.

The way that LSH works in practice is that you randomly generate a bunch of hash functions, construct a few hash tables (say, 8 tables with 16 hash functions each), and insert your database into each of those tables8. Then, to use kNN to classify a point, you would hash it with all of the hash functions you’ve constructed. At the end, your search space is reduced to just those points with which you got a hash collision in each of the tables you queried, and you just search against those remaining candidates.

Comparing kNN with other learning algorithms

The resulting classifier that you get with kNN can be much faster and require less specialized hardware than what you’d get with e.g. a neural network. Moreover, the outputs of kNN are a lot more easily interpretable than what you’d get with other models; it’s easier to understand why your classifier came to a particular conclusion, as well as to characterize the failure conditions of your classifier.

The main downside, really, is that choosing a good similarity metric can be challenging for some datasets. It requires a lot more domain expertise than the more exotic classifiers, which are in large part plug-and-play. Choosing metrics that are easily compatible with LSH can be even more of a challenge, and if you’re at the point where you’re reaching for something like contrastive learning to generate an embedding where LSH would be feasible, you might (reasonably) want to try learning a neural net anyways.

But at a fundamental level, you aren’t losing anything by using kNN; kNN is, at least, theoretically equivalent (contingent on the choice of similarity function) to any other classifier in the machine learning grab-bag.

  1. In fact it can be much worse than that, depending on your choice of metric. ↩︎

  2. See Tables 3 and 5 of the paper. ↩︎

  3. Moses S. Charikar. Similarity estimation techniques from rounding algorithms. In Proceedings of the Thiry-Fourth Annual ACM Symposium on Theory of Computing, STOC ‘02, page 380–388, New York, NY, USA, 2002. Association for Computing Machinery. 10.1145/509907.509965. ↩︎

  4. Datar, Mayur and Indyk, Piotr and Immorlica, Nicole and Mirrokni, Vahab. (2004). Locality-sensitive hashing scheme based on p-stable distributions. Proceedings of the Annual Symposium on Computational Geometry. 10.1145/997817.997857. ↩︎

  5. Broder, A. On the resemblance and containment of documents. Compression and Complexity of Sequences: Proceedings, Positano, Amalfitan Coast, Salerno, Italy, June 11-13, 1997. doi:10.1109/SEQUEN.1997.666900. ↩︎

  6. Shrivastava, Anshumali, and Ping Li. “Improved asymmetric locality sensitive hashing (ALSH) for maximum inner product search (MIPS).” arXiv preprint arXiv:1410.5410 (2014). https://arxiv.org/abs/1410.5410 ↩︎

  7. Indyk, Piotr.; Motwani, Rajeev. (1998). “Approximate Nearest Neighbors: Towards Removing the Curse of Dimensionality.”. Proceedings of 30th Symposium on Theory of Computing. ↩︎

  8. The number of hash functions/tables you construct will depend on what you want the cutoff probability to look like for which points you filter in. Each LSH family has some characterization of collision probability versus similarity which you can use to make this determination. For instance, for SimHash (the standard LSH family for cosine similarity), that probability for a single randomly-sampled hash function is \(1 - \theta/\pi\), where \(\theta = \arccos{(\text{cossim}(x,y))}\) is the angle between \(x\) and \(y\). ↩︎