From e9b98acf3654050be68924b6815c3887c2e45fd1 Mon Sep 17 00:00:00 2001 From: Leonardo <leomilho@gmail.com> Date: Tue, 25 Jun 2024 16:03:57 +0200 Subject: [PATCH] fix: cypher query where clause on grouping --- cypherv2/convertQuery.go | 35 ++++++++++++++++----- cypherv2/convertQuery_test.go | 59 +++++++++++++++++++++++++++++++++-- 2 files changed, 84 insertions(+), 10 deletions(-) diff --git a/cypherv2/convertQuery.go b/cypherv2/convertQuery.go index a08c94e..df7e9ed 100755 --- a/cypherv2/convertQuery.go +++ b/cypherv2/convertQuery.go @@ -183,12 +183,31 @@ func extractFilterCypher(JSONQuery *entityv2.NodeStruct) ([]string, error) { return filters, nil } -func createWhereLogic(op string, left string, whereLogic string) (interface{}, string, error) { +func createWhereLogic(op string, left string, whereLogic string, cacheData queryCacheData) (interface{}, string, error) { newWhereLogic := fmt.Sprintf("%s_%s", strings.ReplaceAll(left, ".", "_"), op) if whereLogic != "" { whereLogic += ", " } - return newWhereLogic, fmt.Sprintf("%s%s(%s) AS %s", whereLogic, op, left, newWhereLogic), nil + remainingNodes := []string{} + + for _, entity := range cacheData.entities { + if entity.id != left { + remainingNodes = append(remainingNodes, entity.id) + } + } + for _, relation := range cacheData.relations { + if relation.id != left { + remainingNodes = append(remainingNodes, relation.id) + } + } + + remainingNodesStr := "" + if len(remainingNodes) > 0 { + remainingNodesStr = ", " + strings.Join(remainingNodes, ", ") + } + newWithLogic := fmt.Sprintf("%s%s(%s) AS %s%s", whereLogic, op, left, newWhereLogic, remainingNodesStr) + + return newWhereLogic, newWithLogic, nil } func extractLogicCypher(LogicQuery interface{}, cacheData queryCacheData) (interface{}, string, error) { @@ -213,15 +232,15 @@ func extractLogicCypher(LogicQuery interface{}, cacheData queryCacheData) (inter case "upper": return fmt.Sprintf("toUpper(%s)", left), whereLogic, nil case "avg": - return createWhereLogic(op, left.(string), whereLogic) + return createWhereLogic(op, left.(string), whereLogic, cacheData) case "count": - return createWhereLogic(op, left.(string), whereLogic) + return createWhereLogic(op, left.(string), whereLogic, cacheData) case "max": - return createWhereLogic(op, left.(string), whereLogic) + return createWhereLogic(op, left.(string), whereLogic, cacheData) case "min": - return createWhereLogic(op, left.(string), whereLogic) + return createWhereLogic(op, left.(string), whereLogic, cacheData) case "sum": - return createWhereLogic(op, left.(string), whereLogic) + return createWhereLogic(op, left.(string), whereLogic, cacheData) } if len(v) > 2 { right, whereLogicRight, err := extractLogicCypher(v[2], cacheData) @@ -354,7 +373,7 @@ func formQuery(JSONQuery *entityv2.IncomingQueryJSON) (*string, error) { } totalQueryWithLogic := totalQuery if whereLogic != "" { - totalQueryWithLogic += fmt.Sprintf("WITH %s\n", strings.Replace(whereLogic, ", ", "", 1)) + totalQueryWithLogic += fmt.Sprintf("WITH %s\n", whereLogic) totalQuery = totalQueryWithLogic + totalQuery } totalQuery += fmt.Sprintf("WHERE %s\n", logic) diff --git a/cypherv2/convertQuery_test.go b/cypherv2/convertQuery_test.go index 0da5530..bf2495c 100755 --- a/cypherv2/convertQuery_test.go +++ b/cypherv2/convertQuery_test.go @@ -377,7 +377,7 @@ func TestV2WithAverage(t *testing.T) { t.Log(*cypher) answer := `MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie)) - WITH avg(p1.age) AS p1_age_avg + WITH avg(p1.age) AS p1_age_avg, p1, movie, acted MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie)) WHERE (p1.age < p1_age_avg) RETURN * LIMIT 5000` @@ -449,7 +449,7 @@ func TestV2WithAverage2Paths(t *testing.T) { answer := `MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie)) MATCH path2 = ((p2:Person)-[acted:ACTED_IN*1..1]->(movie:Movie)) - WITH avg(p1.age) AS p1_age_avg + WITH avg(p1.age) AS p1_age_avg, p2, movie, acted MATCH path1 = ((p1:Person)-[acted:ACTED_IN*1..1]->(movie:Movie)) MATCH path2 = ((p2:Person)-[acted:ACTED_IN*1..1]->(movie:Movie)) WHERE (p1.age < p1_age_avg) @@ -743,3 +743,58 @@ func TestV2RelationLogic2(t *testing.T) { assert.Equal(t, trimmedAnswer, trimmedCypher) } + +func TestV2Count(t *testing.T) { + query := []byte(`{ + "databaseName": "Movies3", + "return": ["*"], + "logic": [">", ["Count", "@p1"], "1"], + "query": [ + { + "id": "path1", + "node": { + "label": "Person", + "id": "p1", + "relation": { + "label": "DIRECTED", + "direction": "TO", + "depth": { "min": 1, "max": 1 }, + "node": { + "label": "Movie", + "id": "m1" + } + } + } + } + ], + "limit": 5000 + } + `) + + var JSONQuery entityv2.IncomingQueryJSON + err := json.Unmarshal(query, &JSONQuery) + if err != nil { + fmt.Println(err) + t.Log(err) + } + + s := NewService() + cypher, _, err := s.ConvertQuery(&JSONQuery) + if err != nil { + fmt.Println(err) + t.Log(err) + } + t.Log(*cypher) + + answer := `MATCH path1 = ((p1:Person)-[:DIRECTED*1..1]->(m1:Movie)) + WITH count(p1) AS p1_count, m1 + MATCH path1 = ((p1:Person)-[:DIRECTED*1..1]->(m1:Movie)) + WHERE (p1_count > 1) + RETURN * LIMIT 5000` + + fmt.Printf("Cypher: %s\n", answer) + trimmedCypher := fixCypherSpaces(cypher) + trimmedAnswer := fixCypherSpaces(&answer) + + assert.Equal(t, trimmedAnswer, trimmedCypher) +} -- GitLab