From e7eb657ea6bbff5e140b0fde03414014b2cf96fa Mon Sep 17 00:00:00 2001
From: IsolatedSushi <simen.vanherpt@gmail.com>
Date: Sat, 27 Feb 2021 23:42:39 +0100
Subject: [PATCH] Almost finished chunked knn

---
 backend/grpcKNN/knnServer.py                  | 17 +++---
 backend/grpcKNN/knn_pb2.py                    | 61 ++++++++++++++++---
 backend/grpcKNN/knn_pb2_grpc.py               |  6 +-
 backend/webSocketGateway/protos/v3/knn.proto  |  8 ++-
 .../webSocketGateway/src/webSocketGateway.js  | 20 ++++--
 protos/protos/knn.proto                       |  8 ++-
 6 files changed, 94 insertions(+), 26 deletions(-)

diff --git a/backend/grpcKNN/knnServer.py b/backend/grpcKNN/knnServer.py
index 63298ca..632b4fe 100644
--- a/backend/grpcKNN/knnServer.py
+++ b/backend/grpcKNN/knnServer.py
@@ -16,7 +16,7 @@ class Data():
         self.index = None
 
 class KNNService(rpc.KNNServicer):
-        def __init__(self):
+    def __init__(self):
         self.clientList = []
         self.pointID = 0
 
@@ -84,7 +84,6 @@ class KNNService(rpc.KNNServicer):
 
     #Perform actual KNN
     def knnVector(self,vector,k,dataObject):
-        print("vector",vector)
         D, I = dataObject.index.search(np.asarray([vector]), k)
         words = [dataObject.ids[x] for x in I[0]]
         distances = D[0]
@@ -94,12 +93,16 @@ class KNNService(rpc.KNNServicer):
 
     def sendProjectionPoints(self, request_iterator, context):
         print("Connected")
-        for message in request_iterator:
-            self.storePoint(message,context)
+        try:
+            for trainingChunk in request_iterator:
+                for row in trainingChunk.rows:
+                    self.storePoint(row,context)
 
-            if self.pointID % 1000 == 0:
-                print("Received {} points!".format(self.pointID))
-            self.pointID+= 1
+                    if self.pointID % 1000 == 0:
+                        print("Received {} points!".format(self.pointID))
+                    self.pointID+= 1
+        except Exception as e:
+            print("Error",e)
 
 def serveServer():
     port = '[::]:50052'
diff --git a/backend/grpcKNN/knn_pb2.py b/backend/grpcKNN/knn_pb2.py
index 4f830a6..55189de 100644
--- a/backend/grpcKNN/knn_pb2.py
+++ b/backend/grpcKNN/knn_pb2.py
@@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor(
   syntax='proto3',
   serialized_options=b'\n\017nl.uuvig.proveeB\022ProveProjectorGRCPP\001\242\002\006PROVEE',
   create_key=_descriptor._internal_create_key,
-  serialized_pb=b'\n\tknn.proto\x12\x06provee\x1a\x1bgoogle/protobuf/empty.proto\"\x10\n\x02ID\x12\n\n\x02id\x18\x01 \x01(\t\"&\n\nknnRequest\x12\t\n\x01k\x18\x01 \x01(\x05\x12\r\n\x05words\x18\x02 \x01(\t\"\'\n\nNeighbours\x12\x19\n\x04rows\x18\x01 \x03(\x0b\x32\x0b.provee.Row\"#\n\x03Row\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08\x64istance\x18\x02 \x01(\x02\"2\n\x0eTrainingSetRow\x12\n\n\x02id\x18\x01 \x01(\t\x12\x14\n\x08hdvector\x18\x02 \x03(\x01\x42\x02\x10\x01\x32\xc5\x01\n\x03KNN\x12J\n\x14sendProjectionPoints\x12\x16.provee.TrainingSetRow\x1a\x16.google.protobuf.Empty\"\x00(\x01\x12\x39\n\rgetKNNRequest\x12\x12.provee.knnRequest\x1a\x12.provee.Neighbours\"\x00\x12\x37\n\x0fgetIDfromServer\x12\x16.google.protobuf.Empty\x1a\n.provee.ID\"\x00\x42\x30\n\x0fnl.uuvig.proveeB\x12ProveProjectorGRCPP\x01\xa2\x02\x06PROVEEb\x06proto3'
+  serialized_pb=b'\n\tknn.proto\x12\x06provee\x1a\x1bgoogle/protobuf/empty.proto\"\x10\n\x02ID\x12\n\n\x02id\x18\x01 \x01(\t\"&\n\nknnRequest\x12\t\n\x01k\x18\x01 \x01(\x05\x12\r\n\x05words\x18\x02 \x01(\t\"5\n\rTrainingChunk\x12$\n\x04rows\x18\x02 \x03(\x0b\x32\x16.provee.TrainingSetRow\"\'\n\nNeighbours\x12\x19\n\x04rows\x18\x01 \x03(\x0b\x32\x0b.provee.Row\"#\n\x03Row\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08\x64istance\x18\x02 \x01(\x02\"2\n\x0eTrainingSetRow\x12\n\n\x02id\x18\x01 \x01(\t\x12\x14\n\x08hdvector\x18\x02 \x03(\x01\x42\x02\x10\x01\x32\xc4\x01\n\x03KNN\x12I\n\x14sendProjectionPoints\x12\x15.provee.TrainingChunk\x1a\x16.google.protobuf.Empty\"\x00(\x01\x12\x39\n\rgetKNNRequest\x12\x12.provee.knnRequest\x1a\x12.provee.Neighbours\"\x00\x12\x37\n\x0fgetIDfromServer\x12\x16.google.protobuf.Empty\x1a\n.provee.ID\"\x00\x42\x30\n\x0fnl.uuvig.proveeB\x12ProveProjectorGRCPP\x01\xa2\x02\x06PROVEEb\x06proto3'
   ,
   dependencies=[google_dot_protobuf_dot_empty__pb2.DESCRIPTOR,])
 
@@ -98,6 +98,38 @@ _KNNREQUEST = _descriptor.Descriptor(
 )
 
 
+_TRAININGCHUNK = _descriptor.Descriptor(
+  name='TrainingChunk',
+  full_name='provee.TrainingChunk',
+  filename=None,
+  file=DESCRIPTOR,
+  containing_type=None,
+  create_key=_descriptor._internal_create_key,
+  fields=[
+    _descriptor.FieldDescriptor(
+      name='rows', full_name='provee.TrainingChunk.rows', index=0,
+      number=2, type=11, cpp_type=10, label=3,
+      has_default_value=False, default_value=[],
+      message_type=None, enum_type=None, containing_type=None,
+      is_extension=False, extension_scope=None,
+      serialized_options=None, file=DESCRIPTOR,  create_key=_descriptor._internal_create_key),
+  ],
+  extensions=[
+  ],
+  nested_types=[],
+  enum_types=[
+  ],
+  serialized_options=None,
+  is_extendable=False,
+  syntax='proto3',
+  extension_ranges=[],
+  oneofs=[
+  ],
+  serialized_start=108,
+  serialized_end=161,
+)
+
+
 _NEIGHBOURS = _descriptor.Descriptor(
   name='Neighbours',
   full_name='provee.Neighbours',
@@ -125,8 +157,8 @@ _NEIGHBOURS = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=108,
-  serialized_end=147,
+  serialized_start=163,
+  serialized_end=202,
 )
 
 
@@ -164,8 +196,8 @@ _ROW = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=149,
-  serialized_end=184,
+  serialized_start=204,
+  serialized_end=239,
 )
 
 
@@ -203,13 +235,15 @@ _TRAININGSETROW = _descriptor.Descriptor(
   extension_ranges=[],
   oneofs=[
   ],
-  serialized_start=186,
-  serialized_end=236,
+  serialized_start=241,
+  serialized_end=291,
 )
 
+_TRAININGCHUNK.fields_by_name['rows'].message_type = _TRAININGSETROW
 _NEIGHBOURS.fields_by_name['rows'].message_type = _ROW
 DESCRIPTOR.message_types_by_name['ID'] = _ID
 DESCRIPTOR.message_types_by_name['knnRequest'] = _KNNREQUEST
+DESCRIPTOR.message_types_by_name['TrainingChunk'] = _TRAININGCHUNK
 DESCRIPTOR.message_types_by_name['Neighbours'] = _NEIGHBOURS
 DESCRIPTOR.message_types_by_name['Row'] = _ROW
 DESCRIPTOR.message_types_by_name['TrainingSetRow'] = _TRAININGSETROW
@@ -229,6 +263,13 @@ knnRequest = _reflection.GeneratedProtocolMessageType('knnRequest', (_message.Me
   })
 _sym_db.RegisterMessage(knnRequest)
 
+TrainingChunk = _reflection.GeneratedProtocolMessageType('TrainingChunk', (_message.Message,), {
+  'DESCRIPTOR' : _TRAININGCHUNK,
+  '__module__' : 'knn_pb2'
+  # @@protoc_insertion_point(class_scope:provee.TrainingChunk)
+  })
+_sym_db.RegisterMessage(TrainingChunk)
+
 Neighbours = _reflection.GeneratedProtocolMessageType('Neighbours', (_message.Message,), {
   'DESCRIPTOR' : _NEIGHBOURS,
   '__module__' : 'knn_pb2'
@@ -261,15 +302,15 @@ _KNN = _descriptor.ServiceDescriptor(
   index=0,
   serialized_options=None,
   create_key=_descriptor._internal_create_key,
-  serialized_start=239,
-  serialized_end=436,
+  serialized_start=294,
+  serialized_end=490,
   methods=[
   _descriptor.MethodDescriptor(
     name='sendProjectionPoints',
     full_name='provee.KNN.sendProjectionPoints',
     index=0,
     containing_service=None,
-    input_type=_TRAININGSETROW,
+    input_type=_TRAININGCHUNK,
     output_type=google_dot_protobuf_dot_empty__pb2._EMPTY,
     serialized_options=None,
     create_key=_descriptor._internal_create_key,
diff --git a/backend/grpcKNN/knn_pb2_grpc.py b/backend/grpcKNN/knn_pb2_grpc.py
index 84f4bb8..3ebf004 100644
--- a/backend/grpcKNN/knn_pb2_grpc.py
+++ b/backend/grpcKNN/knn_pb2_grpc.py
@@ -18,7 +18,7 @@ class KNNStub(object):
         """
         self.sendProjectionPoints = channel.stream_unary(
                 '/provee.KNN/sendProjectionPoints',
-                request_serializer=knn__pb2.TrainingSetRow.SerializeToString,
+                request_serializer=knn__pb2.TrainingChunk.SerializeToString,
                 response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString,
                 )
         self.getKNNRequest = channel.unary_unary(
@@ -60,7 +60,7 @@ def add_KNNServicer_to_server(servicer, server):
     rpc_method_handlers = {
             'sendProjectionPoints': grpc.stream_unary_rpc_method_handler(
                     servicer.sendProjectionPoints,
-                    request_deserializer=knn__pb2.TrainingSetRow.FromString,
+                    request_deserializer=knn__pb2.TrainingChunk.FromString,
                     response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString,
             ),
             'getKNNRequest': grpc.unary_unary_rpc_method_handler(
@@ -96,7 +96,7 @@ class KNN(object):
             timeout=None,
             metadata=None):
         return grpc.experimental.stream_unary(request_iterator, target, '/provee.KNN/sendProjectionPoints',
-            knn__pb2.TrainingSetRow.SerializeToString,
+            knn__pb2.TrainingChunk.SerializeToString,
             google_dot_protobuf_dot_empty__pb2.Empty.FromString,
             options, channel_credentials,
             insecure, call_credentials, compression, wait_for_ready, timeout, metadata)
diff --git a/backend/webSocketGateway/protos/v3/knn.proto b/backend/webSocketGateway/protos/v3/knn.proto
index 2bdc51b..9189ac2 100644
--- a/backend/webSocketGateway/protos/v3/knn.proto
+++ b/backend/webSocketGateway/protos/v3/knn.proto
@@ -14,7 +14,7 @@ option objc_class_prefix = "PROVEE";
 service KNN {
 
 
-  rpc sendProjectionPoints(stream TrainingSetRow) returns (google.protobuf.Empty) {}
+  rpc sendProjectionPoints(stream TrainingChunk) returns (google.protobuf.Empty) {}
   rpc getKNNRequest(knnRequest) returns (Neighbours) {}
   rpc getIDfromServer(google.protobuf.Empty) returns (ID) {}
 }
@@ -28,6 +28,12 @@ message knnRequest{
   string words = 2;
 }
 
+
+// Needs description
+message TrainingChunk {
+  repeated TrainingSetRow rows = 2;
+}
+
 message Neighbours {
   repeated Row rows = 1;
 }
diff --git a/backend/webSocketGateway/src/webSocketGateway.js b/backend/webSocketGateway/src/webSocketGateway.js
index 09c06ce..d8dfca5 100644
--- a/backend/webSocketGateway/src/webSocketGateway.js
+++ b/backend/webSocketGateway/src/webSocketGateway.js
@@ -42,6 +42,13 @@ class GrpcConnection{
 wsServer.on('connection', function connection(ws) {
   console.log(`Client connected with websocket`);
   var currConnection = new ProjectionRequest(ws);
+
+
+  console.log("setup KNN");
+  currConnection.knnConn = getGRPCClient(knnTarget,KNNPackage,"KNN");      
+  getKNNConnection(currConnection)
+
+
   ws.on('message', function incoming(message) {
     parseMessage(message, currConnection);
   });
@@ -103,7 +110,12 @@ function getProjectorConnection(ws,client){
 
 function getKNNConnection(connection){
   connection.knnConn.client.getIDfromServer({},function(error,response){
+    if(error){
+      console.log("error")
+      console.log(error)
+    }
     if(!response){
+      
       return;
     }
     console.log(error);
@@ -116,6 +128,7 @@ function getKNNConnection(connection){
     });
 
     connection.knnConn.calls=[call]
+    console.log("Setup knn call")
   });
 
 }
@@ -168,9 +181,7 @@ function parseMessage(message, connection) {
       KNNNeighbourRequest(connection,words,k);
       break;
     case "getKNN":
-      console.log("setup KNN");
-      connection.knnConn = getGRPCClient(knnTarget,KNNPackage,"KNN");      
-      getKNNConnection(connection)
+      
       break;
     case "setProjectorAmount":
       var amount =  parseInt(jsonMessage["amount"]);
@@ -226,7 +237,8 @@ function sendRowToServer(allRows, lineIndex, connection) {
   }
   var trainingChunk = {rows: trainingRows};
   //Linearly distribute rows to the projectors
-  connection.projectorConn.calls[0].write(trainingChunk)
+  connection.projectorConn.calls[0].write(trainingChunk);
+  connection.knnConn.calls[0].write(trainingChunk);
 }
 
 //Send client to browser
diff --git a/protos/protos/knn.proto b/protos/protos/knn.proto
index 2bdc51b..9189ac2 100644
--- a/protos/protos/knn.proto
+++ b/protos/protos/knn.proto
@@ -14,7 +14,7 @@ option objc_class_prefix = "PROVEE";
 service KNN {
 
 
-  rpc sendProjectionPoints(stream TrainingSetRow) returns (google.protobuf.Empty) {}
+  rpc sendProjectionPoints(stream TrainingChunk) returns (google.protobuf.Empty) {}
   rpc getKNNRequest(knnRequest) returns (Neighbours) {}
   rpc getIDfromServer(google.protobuf.Empty) returns (ID) {}
 }
@@ -28,6 +28,12 @@ message knnRequest{
   string words = 2;
 }
 
+
+// Needs description
+message TrainingChunk {
+  repeated TrainingSetRow rows = 2;
+}
+
 message Neighbours {
   repeated Row rows = 1;
 }
-- 
GitLab