Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 40 additions & 2 deletions src/autointent/context/data_handler/_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,36 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s
train_labels = [lab for lab in train_labels if lab is not None]
yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc]

def _has_oos_samples(self, split_name: str) -> bool:
"""Return True if the given split contains OOS (label is None) samples."""
if split_name not in self.dataset:
return False
hf_split = self.dataset[split_name]
label_feature = self.dataset.label_feature
oos_samples = hf_split.filter(lambda sample: sample[label_feature] is None)
return len(oos_samples) > 0

def _duplicate_split_for_scoring_and_decision(self, split_name: str) -> None:
"""Duplicate split into _0/_1 where _0 is in-domain only.

Intended for hold-out mode when OOS is present but separation_ratio is not set:
- scoring uses `{split_name}_0` (no OOS)
- decision uses `{split_name}_1` (full, may include OOS)
"""
if split_name not in self.dataset:
return
hf_split = self.dataset[split_name]
label_feature = self.dataset.label_feature

in_domain = hf_split.filter(lambda sample: sample[label_feature] is not None)
if len(in_domain) == 0:
msg = f"Split '{split_name}' contains only OOS samples; cannot prepare scoring split."
raise ValueError(msg)

self.dataset[f"{split_name}_0"] = in_domain
self.dataset[f"{split_name}_1"] = hf_split
self.dataset.pop(split_name)

def _split_ho(
self,
separation_ratio: FloatFromZeroToOne | None,
Expand All @@ -185,8 +215,16 @@ def _split_ho(
) -> None:
has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset)

if separation_ratio is not None and Split.TRAIN in self.dataset:
self._split_train(separation_ratio)
if Split.TRAIN in self.dataset:
if separation_ratio is not None:
self._split_train(separation_ratio)
elif self._has_oos_samples(Split.TRAIN):
# When OOS exists and separation_ratio is not set, keep the same in-domain pool
# for scoring and decision, but exclude OOS from scoring split.
self._duplicate_split_for_scoring_and_decision(Split.TRAIN)
# If user provided a single validation split containing OOS, make scoring validation OOS-free.
if Split.VALIDATION in self.dataset and self._has_oos_samples(Split.VALIDATION):
self._duplicate_split_for_scoring_and_decision(Split.VALIDATION)

if not has_validation_split:
self._split_validation_from_train(validation_size, is_few_shot, examples_per_intent)
Expand Down
64 changes: 64 additions & 0 deletions tests/data/test_data_handler.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from autointent import Dataset
from autointent.configs import DataConfig
from autointent.context.data_handler import DataHandler
from autointent.custom_types import Split
from autointent.schemas import Sample


Expand Down Expand Up @@ -246,3 +247,66 @@ def test_few_shot_split(dataset):
assert Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split], (
f"Failed for {data_split}"
)


def _make_multiclass_mapping_with_oos(*, with_validation: bool) -> dict:
# Ensure enough samples per class so stratified splitting doesn't fail.
in_domain = [{"utterance": f"c0_{i}", "label": 0} for i in range(50)] + [
{"utterance": f"c1_{i}", "label": 1} for i in range(50)
]

oos = [{"utterance": f"oos_{i}"} for i in range(20)]

mapping: dict = {
"train": [*in_domain, *oos],
"intents": [{"id": 0}, {"id": 1}],
}

if with_validation:
mapping["validation"] = [
{"utterance": "val_c0_0", "label": 0},
{"utterance": "val_c0_1", "label": 0},
{"utterance": "val_c1_0", "label": 1},
{"utterance": "val_c1_1", "label": 1},
{"utterance": "val_oos_0"},
{"utterance": "val_oos_1"},
]

return mapping


def _split_has_oos_labels(dh: DataHandler, split_name: str) -> bool:
return any(lab is None for lab in dh.dataset[split_name][dh.dataset.label_feature])


def test_ho_oos_without_separation_ratio_duplicates_and_filters_scoring_splits():
"""If OOS exists and separation_ratio is None, scoring splits must be OOS-free."""
dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=False))
dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42)

assert "train_0" in dh.dataset
assert "train_1" in dh.dataset
assert "validation_0" in dh.dataset
assert "validation_1" in dh.dataset
assert Split.TRAIN not in dh.dataset
assert Split.VALIDATION not in dh.dataset

assert _split_has_oos_labels(dh, "train_0") is False
assert _split_has_oos_labels(dh, "validation_0") is False
assert _split_has_oos_labels(dh, "train_1") is True
assert _split_has_oos_labels(dh, "validation_1") is True


def test_ho_oos_with_user_validation_duplicates_validation_when_needed():
"""If user provides validation with OOS, it should be duplicated and filtered for scoring."""
dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=True))
dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42)

assert "train_0" in dh.dataset
assert "train_1" in dh.dataset
assert "validation_0" in dh.dataset
assert "validation_1" in dh.dataset
assert Split.VALIDATION not in dh.dataset

assert _split_has_oos_labels(dh, "validation_0") is False
assert _split_has_oos_labels(dh, "validation_1") is True
Loading