diff --git a/backend/grpcKNN/faissTest.py b/backend/grpcKNN/faissTest.py new file mode 100644 index 0000000000000000000000000000000000000000..91b5a4e4048e501d432b4243e61e49bc0783d594 --- /dev/null +++ b/backend/grpcKNN/faissTest.py @@ -0,0 +1,58 @@ +import numpy as np +import faiss +from tabulate import tabulate +fileName = "word2vec_reddit_300_10000.txt" + + +#Read in data +with open(fileName) as f: + content = f.readlines() + content = [x.split(' ') for x in content] + ids = [x[0] for x in content] + wordToID = {k: v for v,k in enumerate(ids)} + db = [x[1:] for x in content] + db = [[np.float32(y) for y in x] for x in db] + allData = np.array(db) + +index = faiss.IndexFlatL2(len(allData[0])) # build the index +index.add(allData) # add vectors to the index + + +def printClosest(indices, distances): + words = [ids[x] for x in indices[0]] + tableForm = tabulate(list(zip(words,distances[0])),headers=['Word','Distance']) + print(tableForm) + +def knn(word,k): + if word not in wordToID: + print(word, "not in the set") + return + + wordIndex = wordToID[word] + vector = allData[wordIndex] + knnVector(vector,k) + + + +def knnVector(vector,k): + D, I = index.search(np.asarray([vector]), k) # sanity check + printClosest(I, D) + +#knn("woman",5) + + + + +def knnSemantic(word1,word2,word3,k): + if word1 not in wordToID or word2 not in wordToID or word3 not in wordToID: + print("a word is not in the set") + return + vector1 = allData[wordToID[word1]] + vector2 = allData[wordToID[word2]] + vector3 = allData[wordToID[word3]] + + knnVector(vector1 - vector2 + vector3,k) + +knnSemantic("king","man","woman",10) + +