From e7eb657ea6bbff5e140b0fde03414014b2cf96fa Mon Sep 17 00:00:00 2001 From: IsolatedSushi <simen.vanherpt@gmail.com> Date: Sat, 27 Feb 2021 23:42:39 +0100 Subject: [PATCH] Almost finished chunked knn --- backend/grpcKNN/knnServer.py | 17 +++--- backend/grpcKNN/knn_pb2.py | 61 ++++++++++++++++--- backend/grpcKNN/knn_pb2_grpc.py | 6 +- backend/webSocketGateway/protos/v3/knn.proto | 8 ++- .../webSocketGateway/src/webSocketGateway.js | 20 ++++-- protos/protos/knn.proto | 8 ++- 6 files changed, 94 insertions(+), 26 deletions(-) diff --git a/backend/grpcKNN/knnServer.py b/backend/grpcKNN/knnServer.py index 63298ca..632b4fe 100644 --- a/backend/grpcKNN/knnServer.py +++ b/backend/grpcKNN/knnServer.py @@ -16,7 +16,7 @@ class Data(): self.index = None class KNNService(rpc.KNNServicer): - def __init__(self): + def __init__(self): self.clientList = [] self.pointID = 0 @@ -84,7 +84,6 @@ class KNNService(rpc.KNNServicer): #Perform actual KNN def knnVector(self,vector,k,dataObject): - print("vector",vector) D, I = dataObject.index.search(np.asarray([vector]), k) words = [dataObject.ids[x] for x in I[0]] distances = D[0] @@ -94,12 +93,16 @@ class KNNService(rpc.KNNServicer): def sendProjectionPoints(self, request_iterator, context): print("Connected") - for message in request_iterator: - self.storePoint(message,context) + try: + for trainingChunk in request_iterator: + for row in trainingChunk.rows: + self.storePoint(row,context) - if self.pointID % 1000 == 0: - print("Received {} points!".format(self.pointID)) - self.pointID+= 1 + if self.pointID % 1000 == 0: + print("Received {} points!".format(self.pointID)) + self.pointID+= 1 + except Exception as e: + print("Error",e) def serveServer(): port = '[::]:50052' diff --git a/backend/grpcKNN/knn_pb2.py b/backend/grpcKNN/knn_pb2.py index 4f830a6..55189de 100644 --- a/backend/grpcKNN/knn_pb2.py +++ b/backend/grpcKNN/knn_pb2.py @@ -20,7 +20,7 @@ DESCRIPTOR = _descriptor.FileDescriptor( syntax='proto3', serialized_options=b'\n\017nl.uuvig.proveeB\022ProveProjectorGRCPP\001\242\002\006PROVEE', create_key=_descriptor._internal_create_key, - serialized_pb=b'\n\tknn.proto\x12\x06provee\x1a\x1bgoogle/protobuf/empty.proto\"\x10\n\x02ID\x12\n\n\x02id\x18\x01 \x01(\t\"&\n\nknnRequest\x12\t\n\x01k\x18\x01 \x01(\x05\x12\r\n\x05words\x18\x02 \x01(\t\"\'\n\nNeighbours\x12\x19\n\x04rows\x18\x01 \x03(\x0b\x32\x0b.provee.Row\"#\n\x03Row\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08\x64istance\x18\x02 \x01(\x02\"2\n\x0eTrainingSetRow\x12\n\n\x02id\x18\x01 \x01(\t\x12\x14\n\x08hdvector\x18\x02 \x03(\x01\x42\x02\x10\x01\x32\xc5\x01\n\x03KNN\x12J\n\x14sendProjectionPoints\x12\x16.provee.TrainingSetRow\x1a\x16.google.protobuf.Empty\"\x00(\x01\x12\x39\n\rgetKNNRequest\x12\x12.provee.knnRequest\x1a\x12.provee.Neighbours\"\x00\x12\x37\n\x0fgetIDfromServer\x12\x16.google.protobuf.Empty\x1a\n.provee.ID\"\x00\x42\x30\n\x0fnl.uuvig.proveeB\x12ProveProjectorGRCPP\x01\xa2\x02\x06PROVEEb\x06proto3' + serialized_pb=b'\n\tknn.proto\x12\x06provee\x1a\x1bgoogle/protobuf/empty.proto\"\x10\n\x02ID\x12\n\n\x02id\x18\x01 \x01(\t\"&\n\nknnRequest\x12\t\n\x01k\x18\x01 \x01(\x05\x12\r\n\x05words\x18\x02 \x01(\t\"5\n\rTrainingChunk\x12$\n\x04rows\x18\x02 \x03(\x0b\x32\x16.provee.TrainingSetRow\"\'\n\nNeighbours\x12\x19\n\x04rows\x18\x01 \x03(\x0b\x32\x0b.provee.Row\"#\n\x03Row\x12\n\n\x02id\x18\x01 \x01(\t\x12\x10\n\x08\x64istance\x18\x02 \x01(\x02\"2\n\x0eTrainingSetRow\x12\n\n\x02id\x18\x01 \x01(\t\x12\x14\n\x08hdvector\x18\x02 \x03(\x01\x42\x02\x10\x01\x32\xc4\x01\n\x03KNN\x12I\n\x14sendProjectionPoints\x12\x15.provee.TrainingChunk\x1a\x16.google.protobuf.Empty\"\x00(\x01\x12\x39\n\rgetKNNRequest\x12\x12.provee.knnRequest\x1a\x12.provee.Neighbours\"\x00\x12\x37\n\x0fgetIDfromServer\x12\x16.google.protobuf.Empty\x1a\n.provee.ID\"\x00\x42\x30\n\x0fnl.uuvig.proveeB\x12ProveProjectorGRCPP\x01\xa2\x02\x06PROVEEb\x06proto3' , dependencies=[google_dot_protobuf_dot_empty__pb2.DESCRIPTOR,]) @@ -98,6 +98,38 @@ _KNNREQUEST = _descriptor.Descriptor( ) +_TRAININGCHUNK = _descriptor.Descriptor( + name='TrainingChunk', + full_name='provee.TrainingChunk', + filename=None, + file=DESCRIPTOR, + containing_type=None, + create_key=_descriptor._internal_create_key, + fields=[ + _descriptor.FieldDescriptor( + name='rows', full_name='provee.TrainingChunk.rows', index=0, + number=2, type=11, cpp_type=10, label=3, + has_default_value=False, default_value=[], + message_type=None, enum_type=None, containing_type=None, + is_extension=False, extension_scope=None, + serialized_options=None, file=DESCRIPTOR, create_key=_descriptor._internal_create_key), + ], + extensions=[ + ], + nested_types=[], + enum_types=[ + ], + serialized_options=None, + is_extendable=False, + syntax='proto3', + extension_ranges=[], + oneofs=[ + ], + serialized_start=108, + serialized_end=161, +) + + _NEIGHBOURS = _descriptor.Descriptor( name='Neighbours', full_name='provee.Neighbours', @@ -125,8 +157,8 @@ _NEIGHBOURS = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=108, - serialized_end=147, + serialized_start=163, + serialized_end=202, ) @@ -164,8 +196,8 @@ _ROW = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=149, - serialized_end=184, + serialized_start=204, + serialized_end=239, ) @@ -203,13 +235,15 @@ _TRAININGSETROW = _descriptor.Descriptor( extension_ranges=[], oneofs=[ ], - serialized_start=186, - serialized_end=236, + serialized_start=241, + serialized_end=291, ) +_TRAININGCHUNK.fields_by_name['rows'].message_type = _TRAININGSETROW _NEIGHBOURS.fields_by_name['rows'].message_type = _ROW DESCRIPTOR.message_types_by_name['ID'] = _ID DESCRIPTOR.message_types_by_name['knnRequest'] = _KNNREQUEST +DESCRIPTOR.message_types_by_name['TrainingChunk'] = _TRAININGCHUNK DESCRIPTOR.message_types_by_name['Neighbours'] = _NEIGHBOURS DESCRIPTOR.message_types_by_name['Row'] = _ROW DESCRIPTOR.message_types_by_name['TrainingSetRow'] = _TRAININGSETROW @@ -229,6 +263,13 @@ knnRequest = _reflection.GeneratedProtocolMessageType('knnRequest', (_message.Me }) _sym_db.RegisterMessage(knnRequest) +TrainingChunk = _reflection.GeneratedProtocolMessageType('TrainingChunk', (_message.Message,), { + 'DESCRIPTOR' : _TRAININGCHUNK, + '__module__' : 'knn_pb2' + # @@protoc_insertion_point(class_scope:provee.TrainingChunk) + }) +_sym_db.RegisterMessage(TrainingChunk) + Neighbours = _reflection.GeneratedProtocolMessageType('Neighbours', (_message.Message,), { 'DESCRIPTOR' : _NEIGHBOURS, '__module__' : 'knn_pb2' @@ -261,15 +302,15 @@ _KNN = _descriptor.ServiceDescriptor( index=0, serialized_options=None, create_key=_descriptor._internal_create_key, - serialized_start=239, - serialized_end=436, + serialized_start=294, + serialized_end=490, methods=[ _descriptor.MethodDescriptor( name='sendProjectionPoints', full_name='provee.KNN.sendProjectionPoints', index=0, containing_service=None, - input_type=_TRAININGSETROW, + input_type=_TRAININGCHUNK, output_type=google_dot_protobuf_dot_empty__pb2._EMPTY, serialized_options=None, create_key=_descriptor._internal_create_key, diff --git a/backend/grpcKNN/knn_pb2_grpc.py b/backend/grpcKNN/knn_pb2_grpc.py index 84f4bb8..3ebf004 100644 --- a/backend/grpcKNN/knn_pb2_grpc.py +++ b/backend/grpcKNN/knn_pb2_grpc.py @@ -18,7 +18,7 @@ class KNNStub(object): """ self.sendProjectionPoints = channel.stream_unary( '/provee.KNN/sendProjectionPoints', - request_serializer=knn__pb2.TrainingSetRow.SerializeToString, + request_serializer=knn__pb2.TrainingChunk.SerializeToString, response_deserializer=google_dot_protobuf_dot_empty__pb2.Empty.FromString, ) self.getKNNRequest = channel.unary_unary( @@ -60,7 +60,7 @@ def add_KNNServicer_to_server(servicer, server): rpc_method_handlers = { 'sendProjectionPoints': grpc.stream_unary_rpc_method_handler( servicer.sendProjectionPoints, - request_deserializer=knn__pb2.TrainingSetRow.FromString, + request_deserializer=knn__pb2.TrainingChunk.FromString, response_serializer=google_dot_protobuf_dot_empty__pb2.Empty.SerializeToString, ), 'getKNNRequest': grpc.unary_unary_rpc_method_handler( @@ -96,7 +96,7 @@ class KNN(object): timeout=None, metadata=None): return grpc.experimental.stream_unary(request_iterator, target, '/provee.KNN/sendProjectionPoints', - knn__pb2.TrainingSetRow.SerializeToString, + knn__pb2.TrainingChunk.SerializeToString, google_dot_protobuf_dot_empty__pb2.Empty.FromString, options, channel_credentials, insecure, call_credentials, compression, wait_for_ready, timeout, metadata) diff --git a/backend/webSocketGateway/protos/v3/knn.proto b/backend/webSocketGateway/protos/v3/knn.proto index 2bdc51b..9189ac2 100644 --- a/backend/webSocketGateway/protos/v3/knn.proto +++ b/backend/webSocketGateway/protos/v3/knn.proto @@ -14,7 +14,7 @@ option objc_class_prefix = "PROVEE"; service KNN { - rpc sendProjectionPoints(stream TrainingSetRow) returns (google.protobuf.Empty) {} + rpc sendProjectionPoints(stream TrainingChunk) returns (google.protobuf.Empty) {} rpc getKNNRequest(knnRequest) returns (Neighbours) {} rpc getIDfromServer(google.protobuf.Empty) returns (ID) {} } @@ -28,6 +28,12 @@ message knnRequest{ string words = 2; } + +// Needs description +message TrainingChunk { + repeated TrainingSetRow rows = 2; +} + message Neighbours { repeated Row rows = 1; } diff --git a/backend/webSocketGateway/src/webSocketGateway.js b/backend/webSocketGateway/src/webSocketGateway.js index 09c06ce..d8dfca5 100644 --- a/backend/webSocketGateway/src/webSocketGateway.js +++ b/backend/webSocketGateway/src/webSocketGateway.js @@ -42,6 +42,13 @@ class GrpcConnection{ wsServer.on('connection', function connection(ws) { console.log(`Client connected with websocket`); var currConnection = new ProjectionRequest(ws); + + + console.log("setup KNN"); + currConnection.knnConn = getGRPCClient(knnTarget,KNNPackage,"KNN"); + getKNNConnection(currConnection) + + ws.on('message', function incoming(message) { parseMessage(message, currConnection); }); @@ -103,7 +110,12 @@ function getProjectorConnection(ws,client){ function getKNNConnection(connection){ connection.knnConn.client.getIDfromServer({},function(error,response){ + if(error){ + console.log("error") + console.log(error) + } if(!response){ + return; } console.log(error); @@ -116,6 +128,7 @@ function getKNNConnection(connection){ }); connection.knnConn.calls=[call] + console.log("Setup knn call") }); } @@ -168,9 +181,7 @@ function parseMessage(message, connection) { KNNNeighbourRequest(connection,words,k); break; case "getKNN": - console.log("setup KNN"); - connection.knnConn = getGRPCClient(knnTarget,KNNPackage,"KNN"); - getKNNConnection(connection) + break; case "setProjectorAmount": var amount = parseInt(jsonMessage["amount"]); @@ -226,7 +237,8 @@ function sendRowToServer(allRows, lineIndex, connection) { } var trainingChunk = {rows: trainingRows}; //Linearly distribute rows to the projectors - connection.projectorConn.calls[0].write(trainingChunk) + connection.projectorConn.calls[0].write(trainingChunk); + connection.knnConn.calls[0].write(trainingChunk); } //Send client to browser diff --git a/protos/protos/knn.proto b/protos/protos/knn.proto index 2bdc51b..9189ac2 100644 --- a/protos/protos/knn.proto +++ b/protos/protos/knn.proto @@ -14,7 +14,7 @@ option objc_class_prefix = "PROVEE"; service KNN { - rpc sendProjectionPoints(stream TrainingSetRow) returns (google.protobuf.Empty) {} + rpc sendProjectionPoints(stream TrainingChunk) returns (google.protobuf.Empty) {} rpc getKNNRequest(knnRequest) returns (Neighbours) {} rpc getIDfromServer(google.protobuf.Empty) returns (ID) {} } @@ -28,6 +28,12 @@ message knnRequest{ string words = 2; } + +// Needs description +message TrainingChunk { + repeated TrainingSetRow rows = 2; +} + message Neighbours { repeated Row rows = 1; } -- GitLab