Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 16 additions & 9 deletions drivers/python/age/age.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,7 +144,7 @@ def setUpAge(conn:psycopg.connection, graphName:str, load_from_plugins:bool=Fals
else:
cursor.execute("LOAD 'age';")

cursor.execute("SET search_path = ag_catalog, '$user', public;")
cursor.execute('SET search_path = ag_catalog, "$user", public;')

ag_info = TypeInfo.fetch(conn, 'agtype')

Expand Down Expand Up @@ -208,14 +208,19 @@ def _validate_column(col: str) -> str:
name, type_name = parts
validate_identifier(name, "Column name")
validate_identifier(type_name, "Column type")
return f"{name} {type_name}"
# Only the column name is double-quoted. The type name is left
# unquoted so PostgreSQL applies its default identifier folding
# for type names in column definitions. Double-quoting would
# make the type name case-sensitive and could change type
# resolution in surprising ways for user-defined types.
return f'"{name}" {type_name}'
else:
validate_identifier(col, "Column name")
return f"{col} agtype"
return f'"{col}" agtype'


def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
if graphName == None:
if graphName is None:
raise _EXCEPTION_GraphNotSet

columnExp=[]
Expand All @@ -225,16 +230,18 @@ def buildCypher(graphName:str, cypherStmt:str, columns:list) ->str:
if validated:
columnExp.append(validated)
else:
columnExp.append('v agtype')
columnExp.append('"v" agtype')

# Design note: String concatenation is used here instead of
# psycopg.sql.Identifier() because column specifications are
# "name type" pairs (e.g. "v agtype") that don't map directly to
# "name type" pairs (e.g. '"v" agtype') that don't map directly to
# sql.Identifier(). Each component has already been validated by
# _validate_column() → validate_identifier(), which restricts
# names to ^[A-Za-z_][A-Za-z0-9_]*$ and max 63 chars. The
# graphName and cypherStmt are NOT embedded here — this template
# only contains the validated column list and static SQL keywords.
# names to ^[A-Za-z_][A-Za-z0-9_]*$ and max 63 chars. Column names
# are always double-quoted to avoid conflicts with PostgreSQL
# reserved words (e.g. "count", "order", "type"). The graphName
# and cypherStmt are NOT embedded here — this template only
# contains the validated column list and static SQL keywords.
stmtArr = []
stmtArr.append("SELECT * from cypher(NULL,NULL) as (")
stmtArr.append(','.join(columnExp))
Expand Down
72 changes: 72 additions & 0 deletions drivers/python/test_age_py.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,17 @@
# specific language governing permissions and limitations
# under the License.
import json
import re

from age.models import Vertex
import unittest
import decimal
import age
# _validate_column is private but tested directly because its quoting
# behavior is security-relevant and the public surface (buildCypher)
# makes it difficult to isolate quoting assertions.
from age.age import buildCypher, _validate_column
from age.exceptions import InvalidIdentifier
import argparse

TEST_HOST = "localhost"
Expand All @@ -28,6 +34,72 @@
TEST_GRAPH_NAME = "test_graph"


class TestBuildCypher(unittest.TestCase):
"""Unit tests for buildCypher() and _validate_column() — no DB required."""

def test_simple_column(self):
result = buildCypher("g", "MATCH (n) RETURN n", ["n"])
self.assertIn('"n" agtype', result)

def test_column_with_type(self):
result = buildCypher("g", "MATCH (n) RETURN n", ["n agtype"])
self.assertIn('"n" agtype', result)

def test_reserved_word_count(self):
"""Issue #2370: 'count' is a PostgreSQL reserved word."""
result = buildCypher("g", "MATCH (n) RETURN count(n)", ["count"])
self.assertIn('"count" agtype', result)
# Verify 'count' never appears unquoted as a column name
self.assertIsNone(
re.search(r'(?<!")\bcount\s+agtype\b', result),
f"'count' must be quoted in: {result}"
)

def test_reserved_word_order(self):
"""Issue #2370: 'order' is a PostgreSQL reserved word."""
result = buildCypher("g", "MATCH (n) RETURN n.order", ["order"])
self.assertIn('"order" agtype', result)

def test_reserved_word_type(self):
"""Issue #2370: 'type' is a PostgreSQL reserved word."""
result = buildCypher("g", "MATCH ()-[r]->() RETURN type(r)", ["type"])
self.assertIn('"type" agtype', result)

def test_reserved_word_select(self):
"""Issue #2370: 'select' is a PostgreSQL reserved word."""
result = buildCypher("g", "MATCH (n) RETURN n", ["select"])
self.assertIn('"select" agtype', result)

def test_reserved_word_group(self):
"""Issue #2370: 'group' is a PostgreSQL reserved word."""
result = buildCypher("g", "MATCH (n) RETURN n", ["group"])
self.assertIn('"group" agtype', result)

def test_multiple_columns(self):
result = buildCypher("g", "MATCH (n) RETURN n.name, count(n)", ["name", "count"])
self.assertIn('"name" agtype', result)
self.assertIn('"count" agtype', result)

def test_default_column(self):
result = buildCypher("g", "MATCH (n) RETURN n", None)
self.assertIn('"v" agtype', result)

def test_invalid_column_rejected(self):
with self.assertRaises(InvalidIdentifier):
buildCypher("g", "MATCH (n) RETURN n", ["invalid;col"])

def test_reserved_word_in_name_type_pair(self):
"""Quoting applies even when the column is specified as 'name type'."""
result = buildCypher("g", "MATCH (n) RETURN n.order", ["order agtype"])
self.assertIn('"order" agtype', result)

def test_validate_column_quoting(self):
self.assertEqual(_validate_column("v"), '"v" agtype')
self.assertEqual(_validate_column("v agtype"), '"v" agtype')
self.assertEqual(_validate_column("count"), '"count" agtype')
self.assertEqual(_validate_column("my_col"), '"my_col" agtype')


class TestAgeBasic(unittest.TestCase):
ag = None
args: argparse.Namespace = argparse.Namespace(
Expand Down
14 changes: 7 additions & 7 deletions drivers/python/test_security.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,10 +167,10 @@ class TestColumnValidation(unittest.TestCase):
"""Test _validate_column prevents injection through column specs."""

def test_plain_column_name(self):
self.assertEqual(_validate_column('v'), 'v agtype')
self.assertEqual(_validate_column('v'), '"v" agtype')

def test_column_with_type(self):
self.assertEqual(_validate_column('n agtype'), 'n agtype')
self.assertEqual(_validate_column('n agtype'), '"n" agtype')

def test_empty_column(self):
self.assertEqual(_validate_column(''), '')
Expand Down Expand Up @@ -198,20 +198,20 @@ class TestBuildCypher(unittest.TestCase):

def test_default_column(self):
result = buildCypher('test_graph', 'MATCH (n) RETURN n', None)
self.assertIn('v agtype', result)
self.assertIn('"v" agtype', result)

def test_single_column(self):
result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n'])
self.assertIn('n agtype', result)
self.assertIn('"n" agtype', result)

def test_typed_column(self):
result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['n agtype'])
self.assertIn('n agtype', result)
self.assertIn('"n" agtype', result)

def test_multiple_columns(self):
result = buildCypher('test_graph', 'MATCH (n) RETURN n', ['a', 'b'])
self.assertIn('a agtype', result)
self.assertIn('b agtype', result)
self.assertIn('"a" agtype', result)
self.assertIn('"b" agtype', result)

def test_rejects_injection_in_column(self):
with self.assertRaises(InvalidIdentifier):
Expand Down