diff --git a/src/autointent/context/data_handler/_data_handler.py b/src/autointent/context/data_handler/_data_handler.py index b0a6e8ad..c3f607bc 100644 --- a/src/autointent/context/data_handler/_data_handler.py +++ b/src/autointent/context/data_handler/_data_handler.py @@ -176,6 +176,36 @@ def validation_iterator(self) -> Generator[tuple[list[str], ListOfLabels, list[s train_labels = [lab for lab in train_labels if lab is not None] yield train_utterances, train_labels, val_utterances, val_labels # type: ignore[misc] + def _has_oos_samples(self, split_name: str) -> bool: + """Return True if the given split contains OOS (label is None) samples.""" + if split_name not in self.dataset: + return False + hf_split = self.dataset[split_name] + label_feature = self.dataset.label_feature + oos_samples = hf_split.filter(lambda sample: sample[label_feature] is None) + return len(oos_samples) > 0 + + def _duplicate_split_for_scoring_and_decision(self, split_name: str) -> None: + """Duplicate split into _0/_1 where _0 is in-domain only. + + Intended for hold-out mode when OOS is present but separation_ratio is not set: + - scoring uses `{split_name}_0` (no OOS) + - decision uses `{split_name}_1` (full, may include OOS) + """ + if split_name not in self.dataset: + return + hf_split = self.dataset[split_name] + label_feature = self.dataset.label_feature + + in_domain = hf_split.filter(lambda sample: sample[label_feature] is not None) + if len(in_domain) == 0: + msg = f"Split '{split_name}' contains only OOS samples; cannot prepare scoring split." + raise ValueError(msg) + + self.dataset[f"{split_name}_0"] = in_domain + self.dataset[f"{split_name}_1"] = hf_split + self.dataset.pop(split_name) + def _split_ho( self, separation_ratio: FloatFromZeroToOne | None, @@ -185,8 +215,16 @@ def _split_ho( ) -> None: has_validation_split = any(split.startswith(Split.VALIDATION) for split in self.dataset) - if separation_ratio is not None and Split.TRAIN in self.dataset: - self._split_train(separation_ratio) + if Split.TRAIN in self.dataset: + if separation_ratio is not None: + self._split_train(separation_ratio) + elif self._has_oos_samples(Split.TRAIN): + # When OOS exists and separation_ratio is not set, keep the same in-domain pool + # for scoring and decision, but exclude OOS from scoring split. + self._duplicate_split_for_scoring_and_decision(Split.TRAIN) + # If user provided a single validation split containing OOS, make scoring validation OOS-free. + if Split.VALIDATION in self.dataset and self._has_oos_samples(Split.VALIDATION): + self._duplicate_split_for_scoring_and_decision(Split.VALIDATION) if not has_validation_split: self._split_validation_from_train(validation_size, is_few_shot, examples_per_intent) diff --git a/tests/data/test_data_handler.py b/tests/data/test_data_handler.py index d1fbf049..c441428e 100644 --- a/tests/data/test_data_handler.py +++ b/tests/data/test_data_handler.py @@ -5,6 +5,7 @@ from autointent import Dataset from autointent.configs import DataConfig from autointent.context.data_handler import DataHandler +from autointent.custom_types import Split from autointent.schemas import Sample @@ -246,3 +247,66 @@ def test_few_shot_split(dataset): assert Counter(dh.dataset[data_split][dh.dataset.label_feature]) == desired_specs[data_split], ( f"Failed for {data_split}" ) + + +def _make_multiclass_mapping_with_oos(*, with_validation: bool) -> dict: + # Ensure enough samples per class so stratified splitting doesn't fail. + in_domain = [{"utterance": f"c0_{i}", "label": 0} for i in range(50)] + [ + {"utterance": f"c1_{i}", "label": 1} for i in range(50) + ] + + oos = [{"utterance": f"oos_{i}"} for i in range(20)] + + mapping: dict = { + "train": [*in_domain, *oos], + "intents": [{"id": 0}, {"id": 1}], + } + + if with_validation: + mapping["validation"] = [ + {"utterance": "val_c0_0", "label": 0}, + {"utterance": "val_c0_1", "label": 0}, + {"utterance": "val_c1_0", "label": 1}, + {"utterance": "val_c1_1", "label": 1}, + {"utterance": "val_oos_0"}, + {"utterance": "val_oos_1"}, + ] + + return mapping + + +def _split_has_oos_labels(dh: DataHandler, split_name: str) -> bool: + return any(lab is None for lab in dh.dataset[split_name][dh.dataset.label_feature]) + + +def test_ho_oos_without_separation_ratio_duplicates_and_filters_scoring_splits(): + """If OOS exists and separation_ratio is None, scoring splits must be OOS-free.""" + dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=False)) + dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42) + + assert "train_0" in dh.dataset + assert "train_1" in dh.dataset + assert "validation_0" in dh.dataset + assert "validation_1" in dh.dataset + assert Split.TRAIN not in dh.dataset + assert Split.VALIDATION not in dh.dataset + + assert _split_has_oos_labels(dh, "train_0") is False + assert _split_has_oos_labels(dh, "validation_0") is False + assert _split_has_oos_labels(dh, "train_1") is True + assert _split_has_oos_labels(dh, "validation_1") is True + + +def test_ho_oos_with_user_validation_duplicates_validation_when_needed(): + """If user provides validation with OOS, it should be duplicated and filtered for scoring.""" + dataset = Dataset.from_dict(_make_multiclass_mapping_with_oos(with_validation=True)) + dh = DataHandler(dataset, config=DataConfig(scheme="ho", separation_ratio=None), random_seed=42) + + assert "train_0" in dh.dataset + assert "train_1" in dh.dataset + assert "validation_0" in dh.dataset + assert "validation_1" in dh.dataset + assert Split.VALIDATION not in dh.dataset + + assert _split_has_oos_labels(dh, "validation_0") is False + assert _split_has_oos_labels(dh, "validation_1") is True