From cbe200d3968a87b59520b1cfa391d0bfec5712f8 Mon Sep 17 00:00:00 2001
From: Leonardo <lchristino@graphpolaris.com>
Date: Wed, 19 Mar 2025 15:28:22 +0100
Subject: [PATCH] feat: split count and normal query and handle abort

---
 src/readers/insightProcessor.ts      |   2 +-
 src/readers/queryService.ts          | 202 +++++++++++++++++++++------
 src/utils/queryPublisher.ts          |  37 ++++-
 src/utils/reactflow/query2backend.ts |   2 +-
 4 files changed, 189 insertions(+), 54 deletions(-)

diff --git a/src/readers/insightProcessor.ts b/src/readers/insightProcessor.ts
index 71014ea..d8fdc09 100644
--- a/src/readers/insightProcessor.ts
+++ b/src/readers/insightProcessor.ts
@@ -94,7 +94,7 @@ export const insightProcessor = async () => {
       const cypher = query2Cypher(convertedQuery);
       if (cypher == null) return;
       try {
-        const result = await queryService(ss.dbConnections[0], cypher, true);
+        const result = await queryService(ss.dbConnections[0], cypher.query, true);
 
         insight.status = false;
 
diff --git a/src/readers/queryService.ts b/src/readers/queryService.ts
index 0973c02..d2e384f 100644
--- a/src/readers/queryService.ts
+++ b/src/readers/queryService.ts
@@ -1,4 +1,4 @@
-import { graphQueryBackend2graphQuery, type DbConnection, type QueryRequest } from 'ts-common';
+import { AbortedError, getGraphStatistics, GPError, graphQueryBackend2graphQuery, type DbConnection, type QueryRequest } from 'ts-common';
 
 import { QUERY_CACHE_DURATION, rabbitMq, redis, ums, type QueryExecutionTypes } from '../variables';
 import { log } from '../logger';
@@ -7,16 +7,12 @@ import { query2Cypher } from '../utils/cypher/converter';
 import { parseCountCypherQuery, parseCypherQuery } from '../utils/cypher/queryParser';
 import { formatTimeDifference } from 'ts-common/src/logger/logger';
 import { Query2BackendQuery } from '../utils/reactflow/query2backend';
-import type {
-  CountQueryResultFromBackend,
-  GraphQueryResultFromBackend,
-  GraphQueryResultMetaFromBackend,
-} from 'ts-common/src/model/webSocket/graphResult';
+import type { GraphQueryCountResultFromBackend, GraphQueryResultMetaFromBackend } from 'ts-common/src/model/webSocket/graphResult';
 import { RabbitMqBroker } from 'ts-common/rabbitMq';
 import { Neo4jConnection } from 'ts-common/neo4j';
-import type { QueryCypher } from '../utils/cypher/converter/queryConverter';
+import { Neo4jError } from 'neo4j-driver';
 
-async function cacheCheck(cacheKey: string): Promise<GraphQueryResultMetaFromBackend | undefined> {
+async function cacheCheck<T>(cacheKey: string): Promise<T | undefined> {
   log.debug('Checking cache for query, with cache ttl', QUERY_CACHE_DURATION, 'seconds');
   const cached = await redis.client.get(cacheKey);
   if (cached) {
@@ -24,14 +20,20 @@ async function cacheCheck(cacheKey: string): Promise<GraphQueryResultMetaFromBac
     const buf = Buffer.from(cached, 'base64');
     const inflated = Bun.gunzipSync(new Uint8Array(buf));
     const dec = new TextDecoder();
-    const cachedMessage = JSON.parse(dec.decode(inflated)) as GraphQueryResultMetaFromBackend;
+    const cachedMessage = JSON.parse(dec.decode(inflated)) as T;
     return cachedMessage;
   }
 }
 
-export const queryService = async (db: DbConnection, cypher: QueryCypher, useCached: boolean): Promise<GraphQueryResultMetaFromBackend> => {
+export const queryService = async (
+  db: DbConnection,
+  cypher: string,
+  useCached: boolean,
+  sessionID: string | undefined = undefined,
+  callID: string | undefined = undefined,
+): Promise<GraphQueryResultMetaFromBackend> => {
   let index = 0;
-  const disambiguatedQuery = cypher.query.replace(/\d{13}/g, () => (index++).toString());
+  const disambiguatedQuery = cypher.replace(/\d{13}/g, () => (index++).toString());
   const cacheKey = Bun.hash(JSON.stringify({ db: db, query: disambiguatedQuery })).toString();
 
   if (QUERY_CACHE_DURATION === '') {
@@ -39,7 +41,7 @@ export const queryService = async (db: DbConnection, cypher: QueryCypher, useCac
   } else if (!useCached) {
     log.info('Skipping cache check for query due to parameter', useCached);
   } else {
-    const cachedMessage = await cacheCheck(cacheKey);
+    const cachedMessage = await cacheCheck<GraphQueryResultMetaFromBackend>(cacheKey);
     if (cachedMessage) {
       log.debug('Cache hit for query', disambiguatedQuery);
       return cachedMessage;
@@ -49,20 +51,20 @@ export const queryService = async (db: DbConnection, cypher: QueryCypher, useCac
   // TODO: only neo4j is supported for now
   const connection = new Neo4jConnection(db);
   try {
-    const [neo4jResult, neo4jCountResult] = await connection.run([cypher.query, cypher.countQuery]);
-    const graph = parseCypherQuery(neo4jResult.records);
-    const countGraph = parseCountCypherQuery(neo4jCountResult.records);
+    const [neo4jResult] =
+      sessionID != null && callID != null
+        ? await connection.runWithAbort(redis, [cypher], sessionID, callID, 5000, false)
+        : await connection.run([cypher]);
+    const graph = graphQueryBackend2graphQuery(parseCypherQuery(neo4jResult.records));
 
     // calculate metadata
-    const result = graphQueryBackend2graphQuery(graph, countGraph);
-    result.nodeCounts.updatedAt = Date.now();
+    const result: GraphQueryResultMetaFromBackend = { ...graph, metaData: getGraphStatistics(graph) };
 
     // Force garbage collection
     neo4jResult.records = [];
     Bun.gc(true);
 
     // cache result
-
     if (QUERY_CACHE_DURATION !== '') {
       log.info('Started gzipping...');
       const compressedMessage = Bun.gzipSync(JSON.stringify(result));
@@ -77,8 +79,75 @@ export const queryService = async (db: DbConnection, cypher: QueryCypher, useCac
 
     return result;
   } catch (error) {
-    log.error('Error parsing query result:', cypher, error);
-    throw new Error('Error parsing query result');
+    if (error instanceof AbortedError) {
+      log.info('Query aborted:', cypher);
+      throw error;
+    } else if (error instanceof Neo4jError) {
+      log.error('Error in queryServiceReader', error.message, error);
+      throw new GPError('Error querying neo4j', error);
+    } else {
+      log.error('Error parsing query result:', cypher, error);
+      throw new GPError('Error parsing query result', error);
+    }
+  } finally {
+    connection.close();
+  }
+};
+
+export const queryCountService = async (
+  db: DbConnection,
+  cypher: string,
+  useCached: boolean,
+  sessionID: string,
+  callID: string,
+): Promise<GraphQueryCountResultFromBackend> => {
+  let index = 0;
+  const disambiguatedQuery = cypher.replace(/\d{13}/g, () => (index++).toString());
+  const cacheKey = Bun.hash(JSON.stringify({ db: db, query: disambiguatedQuery })).toString();
+
+  if (QUERY_CACHE_DURATION === '') {
+    log.info('Query cache disabled, skipping cache check');
+  } else if (!useCached) {
+    log.info('Skipping cache check for query due to parameter', useCached);
+  } else {
+    const cachedMessage = await cacheCheck<GraphQueryCountResultFromBackend>(cacheKey);
+    if (cachedMessage) {
+      log.debug('Cache hit for query', disambiguatedQuery);
+      return cachedMessage;
+    }
+  }
+
+  // TODO: only neo4j is supported for now
+  const connection = new Neo4jConnection(db);
+  try {
+    const [neo4jCountResult] = await connection.runWithAbort(redis, [cypher], sessionID, callID, 5000, false);
+    const countGraph = { nodeCounts: parseCountCypherQuery(neo4jCountResult.records) };
+
+    // cache result
+    if (QUERY_CACHE_DURATION !== '') {
+      log.info('Started gzipping...');
+      const compressedMessage = Bun.gzipSync(JSON.stringify(countGraph));
+      log.info('Done gzipping, started encoding to base64...');
+      const base64Message = Buffer.from(compressedMessage).toString('base64');
+      log.info('Done encoding, sending to redis...');
+
+      // if cache enabled, cache the result
+      await redis.setWithExpire(cacheKey, base64Message, QUERY_CACHE_DURATION); // ttl in seconds
+      log.info('cached in redis');
+    }
+
+    return countGraph;
+  } catch (error) {
+    if (error instanceof AbortedError) {
+      log.info('Count Query aborted:', cypher);
+      throw error;
+    } else if (error instanceof Neo4jError) {
+      log.error('Error in queryServiceReader', error.message, error);
+      throw new GPError('Error querying neo4j', error);
+    } else {
+      log.error('Error parsing count query result:', cypher, error);
+      throw new GPError('Error parsing count query result', error);
+    }
   } finally {
     connection.close();
   }
@@ -109,13 +178,11 @@ export const queryServiceReader = async (frontendPublisher: RabbitMqBroker, mlPu
     const publisher = new QueryPublisher(frontendPublisher, mlPublisher, headers, message.queryID);
 
     try {
-      const startTime = Date.now();
       const ss = await ums.getUserSaveState(headers.message.sessionData.userID, message.saveStateID);
 
       if (!ss) {
         log.error('Invalid SaveState received in queryServiceConsumer:', ss);
-        publisher.publishErrorToFrontend('Invalid SaveState');
-        return;
+        throw new GPError('Invalid SaveState');
       }
 
       log.debug('Received query request:', message, headers, ss);
@@ -125,31 +192,27 @@ export const queryServiceReader = async (frontendPublisher: RabbitMqBroker, mlPu
 
       if (ss == null || ss.dbConnections == null || ss.dbConnections[0] == null || ss.dbConnections.length === 0) {
         log.error('Invalid SaveState received in queryServiceConsumer:', ss);
-        publisher.publishErrorToFrontend('Invalid SaveState');
-        return;
+        throw new GPError('Invalid SaveState');
       }
 
       let activeQuery = ss.queryStates.activeQueryId;
       if (message.queryID) {
         if (ss.queryStates.openQueryArray.find(q => q.id === message.queryID) == null) {
           log.error('Query not found in SaveState:', message.queryID, ss.queryStates.openQueryArray);
-          publisher.publishErrorToFrontend('Query not found');
-          return;
+          throw new GPError('Query not found');
         }
         activeQuery = message.queryID;
       }
 
       if (activeQuery == null || activeQuery == -1) {
         log.error('No active query found in SaveState:', ss);
-        publisher.publishErrorToFrontend('No active query found');
-        return;
+        throw new GPError('No active query found');
       }
 
       const activeQueryInfo = ss.queryStates.openQueryArray.find(q => q.id === activeQuery);
       if (activeQueryInfo == null) {
         log.error('Active query not found in SaveState:', ss.queryStates.activeQueryId, ss.queryStates.openQueryArray);
-        publisher.publishErrorToFrontend('Active query not found');
-        return;
+        throw new GPError('Active query not found');
       }
 
       const visualQuery = activeQueryInfo.graph; //ss.queries[0].graph;
@@ -159,19 +222,21 @@ export const queryServiceReader = async (frontendPublisher: RabbitMqBroker, mlPu
         publisher.publishResultToFrontend({
           nodes: [],
           edges: [],
-          nodeCounts: { updatedAt: 0 },
           metaData: {
             topological: { density: 0, self_loops: 0 },
             nodes: { count: 0, labels: [], types: {} },
             edges: { count: 0, labels: [], types: {} },
           },
         });
+        publisher.publishCountResultToFrontend({ nodeCounts: { updatedAt: 0 } });
         return;
       }
 
       const queryBuilderSettings = activeQueryInfo.settings; //ss.queries[0].settings;
       const ml = message.ml;
+      let startTime = Date.now();
       const convertedQuery = Query2BackendQuery(ss.id, visualQuery, queryBuilderSettings, ml);
+      log.info(`Query converted in ${formatTimeDifference(Date.now() - startTime)}`);
 
       log.debug('translating query:', convertedQuery);
       publisher.publishStatusToFrontend('Translating');
@@ -179,9 +244,7 @@ export const queryServiceReader = async (frontendPublisher: RabbitMqBroker, mlPu
       const cypher = query2Cypher(convertedQuery);
       const query = cypher.query;
       if (query == null) {
-        log.error('Error translating query:', convertedQuery);
-        publisher.publishErrorToFrontend('Error translating query');
-        return;
+        throw new GPError(`Error translating query ${convertedQuery}`);
       }
 
       log.debug('Translated query FROM:', convertedQuery);
@@ -190,40 +253,89 @@ export const queryServiceReader = async (frontendPublisher: RabbitMqBroker, mlPu
       publisher.publishTranslationResultToFrontend(query);
 
       for (let i = 0; i < ss.dbConnections.length; i++) {
+        let result: GraphQueryResultMetaFromBackend;
+        try {
+          log.info('Executing query on database');
+          startTime = Date.now();
+          result = await queryService(
+            ss.dbConnections[i],
+            cypher.query,
+            message.useCached,
+            headers.message.sessionData.sessionID,
+            headers.callID,
+          );
+          publisher.publishResultToFrontend(result);
+          log.info(`Query executed in ${formatTimeDifference(Date.now() - startTime)}`);
+        } catch (error) {
+          if (error instanceof AbortedError) {
+            publisher.publishResultToFrontend({} as any, 'aborted');
+            return;
+          } else {
+            throw new GPError('Error querying database', error);
+          }
+        }
+
+        let countResult: GraphQueryCountResultFromBackend;
         try {
-          const result = await queryService(ss.dbConnections[i], cypher, message.useCached);
+          startTime = Date.now();
+          countResult = await queryCountService(
+            ss.dbConnections[i],
+            cypher.countQuery,
+            message.useCached,
+            headers.message.sessionData.sessionID,
+            headers.callID,
+          );
+          publisher.publishCountResultToFrontend(countResult);
+          log.info(`Query Count executed in ${formatTimeDifference(Date.now() - startTime)}`);
+        } catch (error) {
+          if (error instanceof AbortedError) {
+            publisher.publishCountResultToFrontend({ nodeCounts: { updatedAt: 0 } }, 'aborted');
+            return;
+          } else {
+            throw new GPError('Error querying database', error);
+          }
+        }
 
+        try {
           // Cache nodeCounts such that we can display differentiation for each query
           await ums.updateQuery(ss.userId, message.saveStateID, {
             ...activeQueryInfo,
             graph: {
               ...activeQueryInfo.graph,
-              nodeCounts: result.nodeCounts,
             },
+            graphCounts: countResult,
           });
 
-          publisher.publishResultToFrontend(result);
-          log.debug('Query result!');
-          log.info(`Query executed in ${formatTimeDifference(Date.now() - startTime)}`);
-
           if (convertedQuery.machineLearning && convertedQuery.machineLearning.length > 0) {
             for (let i = 0; i < convertedQuery.machineLearning.length; i++) {
               try {
                 publisher.publishMachineLearningRequest(result, convertedQuery.machineLearning[i], headers);
                 log.debug('Published machine learning request', convertedQuery.machineLearning[i]);
               } catch (error) {
-                log.error('Error publishing machine learning request', error);
-                publisher.publishErrorToFrontend('Error publishing machine learning request');
+                throw new GPError('Error publishing machine learning request', error);
               }
             }
           }
         } catch (error) {
-          log.error('Error querying database', error);
-          publisher.publishErrorToFrontend('Error querying database');
+          throw new GPError('Error during database query execution', error);
         }
 
         Bun.gc(true);
       }
+    } catch (e: any) {
+      // Clean resolve
+      if (e instanceof GPError) {
+        log.error('Error in queryServiceReader', e.message, e);
+        e.log();
+        publisher.publishErrorToFrontend(e);
+      }
+      if (e instanceof Neo4jError) {
+        log.error('Error in queryServiceReader', e.message, e);
+        publisher.publishErrorToFrontend(new GPError(e.message, e));
+      } else {
+        log.error('Error in queryServiceReader on an uncaught error!', e.message, e);
+        publisher.publishErrorToFrontend(new GPError(e.message, e));
+      }
     } finally {
       setTimeout(() => Bun.gc(true), 100);
     }
diff --git a/src/utils/queryPublisher.ts b/src/utils/queryPublisher.ts
index d10bea7..94bec42 100644
--- a/src/utils/queryPublisher.ts
+++ b/src/utils/queryPublisher.ts
@@ -1,6 +1,10 @@
-import { wsReturnKey, type BackendMessageHeader, type MachineLearning, type ToMLMessage } from 'ts-common';
+import { GPError, wsReturnKey, type BackendMessageHeader, type MachineLearning, type ToMLMessage } from 'ts-common';
 import { log } from '../logger';
-import type { GraphQueryResultFromBackend, GraphQueryResultMetaFromBackend } from 'ts-common/src/model/webSocket/graphResult';
+import type {
+  GraphQueryCountResultFromBackend,
+  GraphQueryResultFromBackend,
+  GraphQueryResultMetaFromBackend,
+} from 'ts-common/src/model/webSocket/graphResult';
 import type { RabbitMqBroker } from 'ts-common/rabbitMq';
 
 export class QueryPublisher {
@@ -31,13 +35,13 @@ export class QueryPublisher {
     );
   }
 
-  publishErrorToFrontend(reason: string) {
+  publishErrorToFrontend(error: GPError) {
     this.frontendPublisher.publishMessageToFrontend(
       {
         type: wsReturnKey.queryStatusError,
         callID: this.headers.callID,
-        value: this.queryID,
-        status: reason,
+        value: { ...error.toJSON(), extra: { queryID: this.queryID } },
+        status: error.message,
       },
       this.routingKey,
       this.headers,
@@ -60,7 +64,7 @@ export class QueryPublisher {
     );
   }
 
-  publishResultToFrontend(result: GraphQueryResultMetaFromBackend) {
+  publishResultToFrontend(result: GraphQueryResultMetaFromBackend, status: string = 'success') {
     this.frontendPublisher.publishMessageToFrontend(
       {
         type: wsReturnKey.queryStatusResult,
@@ -72,7 +76,26 @@ export class QueryPublisher {
           },
           queryID: this.queryID,
         },
-        status: 'success',
+        status: status,
+      },
+      this.routingKey,
+      this.headers,
+    );
+  }
+
+  publishCountResultToFrontend(result: GraphQueryCountResultFromBackend, status: string = 'success') {
+    this.frontendPublisher.publishMessageToFrontend(
+      {
+        type: wsReturnKey.queryCountResult,
+        callID: this.headers.callID,
+        value: {
+          result: {
+            type: 'count',
+            payload: result,
+          },
+          queryID: this.queryID,
+        },
+        status: status,
       },
       this.routingKey,
       this.headers,
diff --git a/src/utils/reactflow/query2backend.ts b/src/utils/reactflow/query2backend.ts
index 7b797d2..ef71c5c 100644
--- a/src/utils/reactflow/query2backend.ts
+++ b/src/utils/reactflow/query2backend.ts
@@ -44,7 +44,7 @@ const traverseEntityRelationPaths = (
           x: node.attributes.x,
           y: node.attributes.x,
           depth: { min: settings.depth.min, max: settings.depth.max },
-          direction: 'both',
+          direction: QueryRelationDirection.BOTH,
           attributes: [],
         });
       } else {
-- 
GitLab