From e9b98acf3654050be68924b6815c3887c2e45fd1 Mon Sep 17 00:00:00 2001
From: Leonardo <leomilho@gmail.com>
Date: Tue, 25 Jun 2024 16:03:57 +0200
Subject: [PATCH] fix: cypher query where clause on grouping

---
 cypherv2/convertQuery.go      | 35 ++++++++++++++++-----
 cypherv2/convertQuery_test.go | 59 +++++++++++++++++++++++++++++++++--
 2 files changed, 84 insertions(+), 10 deletions(-)

diff --git a/cypherv2/convertQuery.go b/cypherv2/convertQuery.go
index a08c94e..df7e9ed 100755
--- a/cypherv2/convertQuery.go
+++ b/cypherv2/convertQuery.go
@@ -183,12 +183,31 @@ func extractFilterCypher(JSONQuery *entityv2.NodeStruct) ([]string, error) {
 	return filters, nil
 }
 
-func createWhereLogic(op string, left string, whereLogic string) (interface{}, string, error) {
+func createWhereLogic(op string, left string, whereLogic string, cacheData queryCacheData) (interface{}, string, error) {
 	newWhereLogic := fmt.Sprintf("%s_%s", strings.ReplaceAll(left, ".", "_"), op)
 	if whereLogic != "" {
 		whereLogic += ", "
 	}
-	return newWhereLogic, fmt.Sprintf("%s%s(%s) AS %s", whereLogic, op, left, newWhereLogic), nil
+	remainingNodes := []string{}
+
+	for _, entity := range cacheData.entities {
+		if entity.id != left {
+			remainingNodes = append(remainingNodes, entity.id)
+		}
+	}
+	for _, relation := range cacheData.relations {
+		if relation.id != left {
+			remainingNodes = append(remainingNodes, relation.id)
+		}
+	}
+
+	remainingNodesStr := ""
+	if len(remainingNodes) > 0 {
+		remainingNodesStr = ", " + strings.Join(remainingNodes, ", ")
+	}
+	newWithLogic := fmt.Sprintf("%s%s(%s) AS %s%s", whereLogic, op, left, newWhereLogic, remainingNodesStr)
+
+	return newWhereLogic, newWithLogic, nil
 }
 
 func extractLogicCypher(LogicQuery interface{}, cacheData queryCacheData) (interface{}, string, error) {
@@ -213,15 +232,15 @@ func extractLogicCypher(LogicQuery interface{}, cacheData queryCacheData) (inter
 		case "upper":
 			return fmt.Sprintf("toUpper(%s)", left), whereLogic, nil
 		case "avg":
-			return createWhereLogic(op, left.(string), whereLogic)
+			return createWhereLogic(op, left.(string), whereLogic, cacheData)
 		case "count":
-			return createWhereLogic(op, left.(string), whereLogic)
+			return createWhereLogic(op, left.(string), whereLogic, cacheData)
 		case "max":
-			return createWhereLogic(op, left.(string), whereLogic)
+			return createWhereLogic(op, left.(string), whereLogic, cacheData)
 		case "min":
-			return createWhereLogic(op, left.(string), whereLogic)
+			return createWhereLogic(op, left.(string), whereLogic, cacheData)
 		case "sum":
-			return createWhereLogic(op, left.(string), whereLogic)
+			return createWhereLogic(op, left.(string), whereLogic, cacheData)
 		}
 		if len(v) > 2 {
 			right, whereLogicRight, err := extractLogicCypher(v[2], cacheData)
@@ -354,7 +373,7 @@ func formQuery(JSONQuery *entityv2.IncomingQueryJSON) (*string, error) {
 		}
 		totalQueryWithLogic := totalQuery
 		if whereLogic != "" {
-			totalQueryWithLogic += fmt.Sprintf("WITH %s\n", strings.Replace(whereLogic, ", ", "", 1))
+			totalQueryWithLogic += fmt.Sprintf("WITH %s\n", whereLogic)
 			totalQuery = totalQueryWithLogic + totalQuery
 		}
 		totalQuery += fmt.Sprintf("WHERE %s\n", logic)
diff --git a/cypherv2/convertQuery_test.go b/cypherv2/convertQuery_test.go
index 0da5530..bf2495c 100755
--- a/cypherv2/convertQuery_test.go
+++ b/cypherv2/convertQuery_test.go
@@ -377,7 +377,7 @@ func TestV2WithAverage(t *testing.T) {
 	t.Log(*cypher)
 
 	answer := `MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie))
-	WITH avg(p1.age) AS p1_age_avg
+	WITH avg(p1.age) AS p1_age_avg, p1, movie, acted
 	MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie))
 	WHERE (p1.age <  p1_age_avg)
 	RETURN * LIMIT 5000`
@@ -449,7 +449,7 @@ func TestV2WithAverage2Paths(t *testing.T) {
 
 	answer := `MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie))
 	MATCH path2 = ((p2:Person)-[acted:ACTED_IN*1..1]->(movie:Movie))
-	WITH avg(p1.age) AS p1_age_avg
+	WITH avg(p1.age) AS p1_age_avg, p2, movie, acted
 	MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie))
 	MATCH path2 = ((p2:Person)-[acted:ACTED_IN*1..1]->(movie:Movie))
 	WHERE (p1.age <  p1_age_avg)
@@ -743,3 +743,58 @@ func TestV2RelationLogic2(t *testing.T) {
 
 	assert.Equal(t, trimmedAnswer, trimmedCypher)
 }
+
+func TestV2Count(t *testing.T) {
+	query := []byte(`{
+		"databaseName": "Movies3",
+		"return": ["*"],
+		"logic": [">", ["Count", "@p1"], "1"],
+		"query": [
+		  {
+			"id": "path1",
+			"node": {
+			  "label": "Person",
+			  "id": "p1",
+			  "relation": {
+				"label": "DIRECTED",
+				"direction": "TO",
+				"depth": { "min": 1, "max": 1 },
+				"node": {
+				  "label": "Movie",
+				  "id": "m1"
+				}
+			  }
+			}
+		  }
+		],
+		"limit": 5000
+	  }
+	  `)
+
+	var JSONQuery entityv2.IncomingQueryJSON
+	err := json.Unmarshal(query, &JSONQuery)
+	if err != nil {
+		fmt.Println(err)
+		t.Log(err)
+	}
+
+	s := NewService()
+	cypher, _, err := s.ConvertQuery(&JSONQuery)
+	if err != nil {
+		fmt.Println(err)
+		t.Log(err)
+	}
+	t.Log(*cypher)
+
+	answer := `MATCH path1 = ((p1:Person)-[:DIRECTED*1..1]->(m1:Movie))
+	WITH count(p1) AS p1_count, m1
+	MATCH path1 = ((p1:Person)-[:DIRECTED*1..1]->(m1:Movie))
+	WHERE (p1_count > 1)
+	RETURN * LIMIT 5000`
+
+	fmt.Printf("Cypher: %s\n", answer)
+	trimmedCypher := fixCypherSpaces(cypher)
+	trimmedAnswer := fixCypherSpaces(&answer)
+
+	assert.Equal(t, trimmedAnswer, trimmedCypher)
+}
-- 
GitLab