Skip to content
2 changes: 2 additions & 0 deletions pyrit/memory/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@

from pyrit.memory.azure_sql_memory import AzureSQLMemory
from pyrit.memory.central_memory import CentralMemory
from pyrit.memory.identifier_filters import IdentifierFilter
from pyrit.memory.memory_embedding import MemoryEmbedding
from pyrit.memory.memory_exporter import MemoryExporter
from pyrit.memory.memory_interface import MemoryInterface
Expand All @@ -26,4 +27,5 @@
"MemoryExporter",
"PromptMemoryEntry",
"SeedEntry",
"IdentifierFilter",
]
266 changes: 118 additions & 148 deletions pyrit/memory/azure_sql_memory.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,22 +250,6 @@ def _get_message_pieces_memory_label_conditions(self, *, memory_labels: dict[str
condition = text(conditions).bindparams(**{key: str(value) for key, value in memory_labels.items()})
return [condition]

def _get_message_pieces_attack_conditions(self, *, attack_id: str) -> Any:
"""
Generate SQL condition for filtering message pieces by attack ID.

Uses JSON_VALUE() function specific to SQL Azure to query the attack identifier.

Args:
attack_id (str): The attack identifier to filter by.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text("ISJSON(attack_identifier) = 1 AND JSON_VALUE(attack_identifier, '$.hash') = :json_id").bindparams(
json_id=str(attack_id)
)

def _get_metadata_conditions(self, *, prompt_metadata: dict[str, Union[str, int]]) -> list[TextClause]:
"""
Generate SQL conditions for filtering by prompt metadata.
Expand Down Expand Up @@ -321,6 +305,109 @@ def _get_seed_metadata_conditions(self, *, metadata: dict[str, Union[str, int]])
"""
return self._get_metadata_conditions(prompt_metadata=metadata)[0]

def _get_condition_json_property_match(
self,
*,
json_column: Any,
property_path: str,
value_to_match: str,
partial_match: bool = False,
) -> Any:
uid = self._uid()
table_name = json_column.class_.__tablename__
column_name = json_column.key
pp_param = f"pp_{uid}"
mv_param = f"mv_{uid}"

return text(
f"""ISJSON("{table_name}".{column_name}) = 1
AND LOWER(JSON_VALUE("{table_name}".{column_name}, :{pp_param})) {"LIKE" if partial_match else "="} :{mv_param}""" # noqa: E501
).bindparams(
**{
pp_param: property_path,
mv_param: f"%{value_to_match.lower()}%" if partial_match else value_to_match.lower(),
}
)

def _get_condition_json_array_match(
self,
*,
json_column: Any,
property_path: str,
sub_path: str | None = None,
array_to_match: Sequence[str],
) -> Any:
uid = self._uid()
table_name = json_column.class_.__tablename__
column_name = json_column.key
pp_param = f"pp_{uid}"
sp_param = f"sp_{uid}"

if len(array_to_match) == 0:
return text(
f"""("{table_name}".{column_name} IS NULL
OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) IS NULL
OR JSON_QUERY("{table_name}".{column_name}, :{pp_param}) = '[]')"""
).bindparams(**{pp_param: property_path})

value_expression = f"LOWER(JSON_VALUE(value, :{sp_param}))" if sub_path else "LOWER(value)"

conditions = []
bindparams_dict: dict[str, str] = {pp_param: property_path}
if sub_path:
bindparams_dict[sp_param] = sub_path

for index, match_value in enumerate(array_to_match):
mv_param = f"mv_{uid}_{index}"
conditions.append(
f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("{table_name}".{column_name},
:{pp_param}))
WHERE {value_expression} = :{mv_param})"""
)
bindparams_dict[mv_param] = match_value.lower()

combined = " AND ".join(conditions)
return text(f"""ISJSON("{table_name}".{column_name}) = 1 AND {combined}""").bindparams(**bindparams_dict)

def _get_unique_json_array_values(
self,
*,
json_column: Any,
path_to_array: str,
sub_path: str | None = None,
) -> list[str]:
uid = self._uid()
pa_param = f"pa_{uid}"
sp_param = f"sp_{uid}"
table_name = json_column.class_.__tablename__
column_name = json_column.key
with closing(self.get_session()) as session:
if sub_path is None:
rows = session.execute(
text(
f"""SELECT DISTINCT JSON_VALUE("{table_name}".{column_name}, :{pa_param}) AS value
FROM "{table_name}"
WHERE ISJSON("{table_name}".{column_name}) = 1
AND JSON_VALUE("{table_name}".{column_name}, :{pa_param}) IS NOT NULL"""
).bindparams(**{pa_param: path_to_array})
).fetchall()
else:
rows = session.execute(
text(
f"""SELECT DISTINCT JSON_VALUE(items.value, :{sp_param}) AS value
FROM "{table_name}"
CROSS APPLY OPENJSON(JSON_QUERY("{table_name}".{column_name}, :{pa_param})) AS items
WHERE ISJSON("{table_name}".{column_name}) = 1
AND JSON_VALUE(items.value, :{sp_param}) IS NOT NULL"""
).bindparams(
**{
pa_param: path_to_array,
sp_param: sub_path,
}
)
).fetchall()
return sorted(row[0] for row in rows)

def _get_attack_result_harm_category_condition(self, *, targeted_harm_categories: Sequence[str]) -> Any:
"""
Get the SQL Azure implementation for filtering AttackResults by targeted harm categories.
Expand Down Expand Up @@ -388,110 +475,6 @@ def _get_attack_result_label_condition(self, *, labels: dict[str, str]) -> Any:
)
)

def _get_attack_result_attack_class_condition(self, *, attack_class: str) -> Any:
"""
Azure SQL implementation for filtering AttackResults by attack class.
Uses JSON_VALUE() on the atomic_attack_identifier JSON column.

Args:
attack_class (str): Exact attack class name to match.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1
AND JSON_VALUE("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.class_name') = :attack_class"""
).bindparams(attack_class=attack_class)

def _get_attack_result_converter_classes_condition(self, *, converter_classes: Sequence[str]) -> Any:
"""
Azure SQL implementation for filtering AttackResults by converter classes.

Uses JSON_VALUE()/JSON_QUERY()/OPENJSON() on the atomic_attack_identifier
JSON column.

When converter_classes is empty, matches attacks with no converters.
When non-empty, uses OPENJSON() to check all specified classes are present
(AND logic, case-insensitive).

Args:
converter_classes (Sequence[str]): List of converter class names. Empty list means no converters.

Returns:
Any: SQLAlchemy combined condition with bound parameters.
"""
if len(converter_classes) == 0:
# Explicitly "no converters": match attacks where the converter list
# is absent, null, or empty in the stored JSON.
return text(
"""("AttackResultEntries".atomic_attack_identifier IS NULL
OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters') IS NULL
OR JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters') = '[]')"""
)

conditions = []
bindparams_dict: dict[str, str] = {}
for i, cls in enumerate(converter_classes):
param_name = f"conv_cls_{i}"
conditions.append(
f"""EXISTS(SELECT 1 FROM OPENJSON(JSON_QUERY("AttackResultEntries".atomic_attack_identifier,
'$.children.attack.children.request_converters'))
WHERE LOWER(JSON_VALUE(value, '$.class_name')) = :{param_name})"""
)
bindparams_dict[param_name] = cls.lower()

combined = " AND ".join(conditions)
return text(f"""ISJSON("AttackResultEntries".atomic_attack_identifier) = 1 AND {combined}""").bindparams(
**bindparams_dict
)

def get_unique_attack_class_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique class_name values from
the atomic_attack_identifier JSON column.

Returns:
Sorted list of unique attack class name strings.
"""
with closing(self.get_session()) as session:
rows = session.execute(
text(
"""SELECT DISTINCT JSON_VALUE(atomic_attack_identifier,
'$.children.attack.class_name') AS cls
FROM "AttackResultEntries"
WHERE ISJSON(atomic_attack_identifier) = 1
AND JSON_VALUE(atomic_attack_identifier,
'$.children.attack.class_name') IS NOT NULL"""
)
).fetchall()
return sorted(row[0] for row in rows)

def get_unique_converter_class_names(self) -> list[str]:
"""
Azure SQL implementation: extract unique converter class_name values
from the children.attack.children.request_converters array
in the atomic_attack_identifier JSON column.

Returns:
Sorted list of unique converter class name strings.
"""
with closing(self.get_session()) as session:
rows = session.execute(
text(
"""SELECT DISTINCT JSON_VALUE(c.value, '$.class_name') AS cls
FROM "AttackResultEntries"
CROSS APPLY OPENJSON(JSON_QUERY(atomic_attack_identifier,
'$.children.attack.children.request_converters')) AS c
WHERE ISJSON(atomic_attack_identifier) = 1
AND JSON_VALUE(c.value, '$.class_name') IS NOT NULL"""
)
).fetchall()
return sorted(row[0] for row in rows)

def get_conversation_stats(self, *, conversation_ids: Sequence[str]) -> dict[str, ConversationStats]:
"""
Azure SQL implementation: lightweight aggregate stats per conversation.
Expand Down Expand Up @@ -593,46 +576,33 @@ def _get_scenario_result_label_condition(self, *, labels: dict[str, str]) -> Any
conditions.append(condition)
return and_(*conditions)

def _get_scenario_result_target_endpoint_condition(self, *, endpoint: str) -> TextClause:
def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
"""
Get the SQL Azure implementation for filtering ScenarioResults by target endpoint.

Uses JSON_VALUE() function specific to SQL Azure.

Args:
endpoint (str): The endpoint URL substring to filter by (case-insensitive).
Insert a list of message pieces into the memory storage.

Returns:
Any: SQLAlchemy text condition with bound parameter.
"""
return text(
"""ISJSON(objective_target_identifier) = 1
AND LOWER(JSON_VALUE(objective_target_identifier, '$.endpoint')) LIKE :endpoint"""
).bindparams(endpoint=f"%{endpoint.lower()}%")
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])

def _get_scenario_result_target_model_condition(self, *, model_name: str) -> TextClause:
def get_unique_attack_class_names(self) -> list[str]:
"""
Get the SQL Azure implementation for filtering ScenarioResults by target model name.

Uses JSON_VALUE() function specific to SQL Azure.

Args:
model_name (str): The model name substring to filter by (case-insensitive).
Azure SQL implementation: extract unique class_name values from
the atomic_attack_identifier JSON column.

Returns:
Any: SQLAlchemy text condition with bound parameter.
Sorted list of unique attack class name strings.
"""
return text(
"""ISJSON(objective_target_identifier) = 1
AND LOWER(JSON_VALUE(objective_target_identifier, '$.model_name')) LIKE :model_name"""
).bindparams(model_name=f"%{model_name.lower()}%")
return super().get_unique_attack_class_names()

def add_message_pieces_to_memory(self, *, message_pieces: Sequence[MessagePiece]) -> None:
def get_unique_converter_class_names(self) -> list[str]:
"""
Insert a list of message pieces into the memory storage.
Azure SQL implementation: extract unique converter class_name values
from the children.attack.children.request_converters array
in the atomic_attack_identifier JSON column.

Returns:
Sorted list of unique converter class name strings.
"""
self._insert_entries(entries=[PromptMemoryEntry(entry=piece) for piece in message_pieces])
return super().get_unique_converter_class_names()

def dispose_engine(self) -> None:
"""
Expand Down
13 changes: 13 additions & 0 deletions pyrit/memory/identifier_filters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
# Copyright (c) Microsoft Corporation.
# Licensed under the MIT license.

from dataclasses import dataclass


@dataclass(frozen=True)
class IdentifierFilter:
"""Immutable filter definition for matching JSON-backed identifier properties."""

property_path: str
value_to_match: str
partial_match: bool = False
Loading
Loading