diff --git a/docs/api/datasets.rst b/docs/api/datasets.rst index 8d9a59d21..faf58d760 100644 --- a/docs/api/datasets.rst +++ b/docs/api/datasets.rst @@ -246,3 +246,4 @@ Available Datasets datasets/pyhealth.datasets.TCGAPRADDataset datasets/pyhealth.datasets.splitter datasets/pyhealth.datasets.utils + datasets/pyhealth.datasets.medlingo diff --git a/docs/api/datasets/pyhealth.datasets.medlingo.rst b/docs/api/datasets/pyhealth.datasets.medlingo.rst new file mode 100644 index 000000000..157fbb2be --- /dev/null +++ b/docs/api/datasets/pyhealth.datasets.medlingo.rst @@ -0,0 +1,22 @@ +pyhealth.datasets.medlingo +========================== + +Overview +-------- + +MedLingo-style dataset for clinical abbreviation expansion. + +API Reference +------------- + +.. autoclass:: pyhealth.datasets.MedLingoDataset + :members: + :undoc-members: + :show-inheritance: + + +References +---------- + +MedLingo: A dataset for clinical abbreviation expansion +https://arxiv.org/abs/2505.15024 \ No newline at end of file diff --git a/docs/api/models/pyhealth.models.abbreviation_lookup.rst b/docs/api/models/pyhealth.models.abbreviation_lookup.rst new file mode 100644 index 000000000..d0de62a53 --- /dev/null +++ b/docs/api/models/pyhealth.models.abbreviation_lookup.rst @@ -0,0 +1,7 @@ +pyhealth.models.abbreviation_lookup +=================================== + +.. automodule:: pyhealth.models.abbreviation_lookup + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 23a4e06e5..80149d0f3 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -230,3 +230,5 @@ Available Tasks Mutation Pathogenicity (COSMIC) Cancer Survival Prediction (TCGA) Cancer Mutation Burden (TCGA) + Clinical Abbreviation Expansion + MedLingo Task \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.clinical_abbreviation.rst b/docs/api/tasks/pyhealth.tasks.clinical_abbreviation.rst new file mode 100644 index 000000000..faf36fbf3 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.clinical_abbreviation.rst @@ -0,0 +1,7 @@ +pyhealth.tasks.clinical_abbreviation +==================================== + +.. automodule:: pyhealth.tasks.clinical_abbreviation + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/docs/api/tasks/pyhealth.tasks.medlingo_task.rst b/docs/api/tasks/pyhealth.tasks.medlingo_task.rst new file mode 100644 index 000000000..ad646ae78 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.medlingo_task.rst @@ -0,0 +1,15 @@ +pyhealth.tasks.medlingo_task +============================ + +Overview +-------- +Task wrapper for MedLingo dataset that converts structured records +into model-ready input/target pairs. + +API Reference +------------- + +.. autoclass:: pyhealth.tasks.medlingo_task.MedLingoTask + :members: + :undoc-members: + :show-inheritance: \ No newline at end of file diff --git a/examples/medlingo_clinical_abbreviation_abbreviation_lookup.py b/examples/medlingo_clinical_abbreviation_abbreviation_lookup.py new file mode 100644 index 000000000..c4624f7c0 --- /dev/null +++ b/examples/medlingo_clinical_abbreviation_abbreviation_lookup.py @@ -0,0 +1,60 @@ +""" +This example demonstrates an evaluation of the Clinical Abbreviation Task using the MedLingo dataset. +We will perform ablation studies to understand the impact of different input modifications on the model's performance. + +The ablation studies include: +1. Base abbreviation-only input. +2. Ablation with Lowercase formatting. +3. Ablation with short clinical context. +4. Ablation with Noisy formatting. + +Paper: + Diagnosing Our Datasets: How Does My Language Model Learn Clinical Information? + https://arxiv.org/abs/2505.15024 + +""" + +from pyhealth.datasets.medlingo import MedLingoDataset +from pyhealth.tasks.clinical_abbreviation import ClinicalAbbreviationTask + +def main() -> None: + dataset = MedLingoDataset(root="test-resources") + records = dataset.process() + + samples = [] + for record in records: + for sample in record["medlingo"]: + samples.append(sample) + + print("=== Base Results: Abbreviation-Only ===") + base_task = ClinicalAbbreviationTask(use_context=False) + for sample in samples: + print(base_task(sample)) + + print("\n=== Ablation 1: Lowercase Input ===") + for sample in samples: + modified = { + **sample, + "abbr": sample["abbr"].lower(), + } + print(base_task(modified)) + + print("\n=== Ablation 2: Short Clinical Context ===") + context_task = ClinicalAbbreviationTask(use_context=True) + for sample in samples: + print(context_task(sample)) + + print("\n=== Ablation 3: Noisy Formatting ===") + noise_variants = ["!!!", "???", "..."] + + for sample in samples: + for noise in noise_variants: + noisy = { + **sample, + "abbr": sample["abbr"] + noise, + } + print(base_task(noisy)) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/medlingo_demo.py b/examples/medlingo_demo.py new file mode 100644 index 000000000..518061967 --- /dev/null +++ b/examples/medlingo_demo.py @@ -0,0 +1,5 @@ +from pyhealth.datasets.medlingo import MedLingoDataset + +dataset = MedLingoDataset(root="test-resources") + +print("Dataset loaded") \ No newline at end of file diff --git a/examples/medlingo_gpt_vs_lookup.py b/examples/medlingo_gpt_vs_lookup.py new file mode 100644 index 000000000..22c718975 --- /dev/null +++ b/examples/medlingo_gpt_vs_lookup.py @@ -0,0 +1,233 @@ + +""" +GPT vs Lookup baseline evaluation for MedLingo-style clinical abbreviation interpretation. + +This script compares: +1. AbbreviationLookupModel baseline +2. GPT-based abbreviation expansion + +Evaluation conditions: +- abbreviation only +- lowercase abbreviation +- punctuation noise +- short clinical context + +IMPORTANT: +This script is intended to run only on cleaned, derived benchmark samples +(e.g., test-resources/medlingo_samples.json), not on raw MIMIC notes. + +This script requires an OpenAI API key and is an optional secondary modern LLM evaluation conducted. +""" + +from __future__ import annotations + +import os +import re +from pathlib import Path +from typing import Any + +from dotenv import load_dotenv +from openai import OpenAI + +from pyhealth.datasets.medlingo import MedLingoDataset +from pyhealth.models.abbreviation_lookup import AbbreviationLookupModel +from pyhealth.tasks.clinical_abbreviation import ClinicalAbbreviationTask + + +def normalize_label(text: str) -> str: + """ + Normalize text for exact-match scoring. + + This removes common punctuation and markdown-like formatting so GPT outputs + such as '**Shortness of breath**.' still score correctly. + """ + text = text.strip().lower() + text = re.sub(r"[*_`]+", "", text) + text = re.sub(r"[^\w\s/]+", "", text) + return " ".join(text.split()) + + +def score_prediction(pred: str, gold: str) -> int: + """Return 1 if prediction matches gold after normalization, else 0.""" + return int(normalize_label(pred) == normalize_label(gold)) + + +def build_prompt(input_text: str, use_context: bool) -> str: + """ + Build a constrained prompt for GPT evaluation. + + This version explicitly tells the model to identify and expand + the abbreviation, improving performance under contextual input. + """ + if use_context: + return ( + f"Sentence: {input_text}\n" + "What does the abbreviation in this sentence stand for?\n" + "Return only the expansion in plain text. " + "No explanation, no markdown, no punctuation." + ) + + return ( + "Expand the following clinical abbreviation into its medical meaning.\n" + f"Abbreviation: {input_text}\n" + "Return only the expansion in plain text. " + "No explanation, no markdown, no punctuation." + ) + + +def accuracy_lookup( + samples: list[dict[str, Any]], + model: AbbreviationLookupModel, + task: ClinicalAbbreviationTask, +) -> float: + """Evaluate lookup baseline accuracy.""" + correct = 0 + total = 0 + + for sample in samples: + processed = task(sample) + pred = model.predict(processed["input"]) + gold = processed["label"] + + correct += score_prediction(pred, gold) + total += 1 + + return correct / total if total > 0 else 0.0 + + +def query_gpt( + client: OpenAI, + prompt: str, + model_name: str = "gpt-4.1-mini", +) -> str: + """ + Query GPT and return stripped text output. + + If a request fails, return 'unknown' so the evaluation can continue. + """ + try: + response = client.responses.create( + model=model_name, + input=prompt, + ) + return response.output_text.strip() + except Exception as exc: + print(f"GPT query failed: {exc}") + return "unknown" + + +def accuracy_gpt( + samples: list[dict[str, Any]], + task: ClinicalAbbreviationTask, + client: OpenAI, + model_name: str = "gpt-4.1-mini", + max_samples: int | None = None, +) -> float: + """Evaluate GPT accuracy on the benchmark.""" + correct = 0 + total = 0 + + eval_samples = samples[:max_samples] if max_samples is not None else samples + + for sample in eval_samples: + processed = task(sample) + prompt = build_prompt( + input_text=processed["input"], + use_context=task.use_context, + ) + pred = query_gpt(client, prompt, model_name=model_name) + gold = processed["label"] + + correct += score_prediction(pred, gold) + total += 1 + + return correct / total if total > 0 else 0.0 + + +def main() -> None: + load_dotenv() + + api_key = os.getenv("OPENAI_API_KEY") + if not api_key: + print("Skipping GPT example:a OPENAI_API_KEY not found.") + return + + client = OpenAI(api_key=api_key) + + dataset = MedLingoDataset(root="test-resources") + records = dataset.process() + + samples = [] + for record in records: + for s in record["medlingo"]: + samples.append(s) + + # Baseline model + lookup_model = AbbreviationLookupModel(normalize=True) + lookup_model.fit(samples) + + # Tasks + base_task = ClinicalAbbreviationTask(use_context=False) + context_task = ClinicalAbbreviationTask(use_context=True) + + lowercase_samples = [{**s, "abbr": s["abbr"].lower()} for s in samples] + noisy_samples = [{**s, "abbr": s["abbr"] + "!!!"} for s in samples] + + # Keep GPT run small at first for cost/control + max_gpt_samples = min(10, len(samples)) + print(f"Using {max_gpt_samples} samples for GPT evaluation.\n") + + results = { + "lookup_base_abbreviation_only": accuracy_lookup( + samples, lookup_model, base_task + ), + "lookup_lowercase_abbreviation": accuracy_lookup( + lowercase_samples, lookup_model, base_task + ), + "lookup_short_clinical_context": accuracy_lookup( + samples, lookup_model, context_task + ), + "lookup_punctuation_noise": accuracy_lookup( + noisy_samples, lookup_model, base_task + ), + "gpt_base_abbreviation_only": accuracy_gpt( + samples, base_task, client, max_samples=max_gpt_samples + ), + "gpt_lowercase_abbreviation": accuracy_gpt( + lowercase_samples, base_task, client, max_samples=max_gpt_samples + ), + "gpt_short_clinical_context": accuracy_gpt( + samples, context_task, client, max_samples=max_gpt_samples + ), + "gpt_punctuation_noise": accuracy_gpt( + noisy_samples, base_task, client, max_samples=max_gpt_samples + ), + } + + print("=== GPT vs Lookup Results ===") + print(f"{'Condition':30} {'Lookup':>8} {'GPT':>8}") + print("-" * 50) + print( + f"{'Abbreviation only':30} " + f"{results['lookup_base_abbreviation_only']:.3f} " + f"{results['gpt_base_abbreviation_only']:.3f}" + ) + print( + f"{'Lowercase abbreviation':30} " + f"{results['lookup_lowercase_abbreviation']:.3f} " + f"{results['gpt_lowercase_abbreviation']:.3f}" + ) + print( + f"{'Short clinical context':30} " + f"{results['lookup_short_clinical_context']:.3f} " + f"{results['gpt_short_clinical_context']:.3f}" + ) + print( + f"{'Punctuation noise':30} " + f"{results['lookup_punctuation_noise']:.3f} " + f"{results['gpt_punctuation_noise']:.3f}" + ) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/examples/my_replication.py b/examples/my_replication.py new file mode 100644 index 000000000..d4a06305e --- /dev/null +++ b/examples/my_replication.py @@ -0,0 +1,52 @@ +from pyhealth.datasets.medlingo import MedLingoDataset +from pyhealth.tasks.medlingo_task import MedLingoTask +from pyhealth.models.abbreviation_lookup import AbbreviationLookupModel + +""" +This script demonstrates a replication of the MedLingo clinical abbreviation expansion task. +It loads the MedLingo dataset, processes it into task-ready format, and evaluates a simple rule-based abbreviation lookup model. +Contributors: + Tedra Birch (tbirch2@illinois.edu) + +Paper: + Diagnosing Our Datasets: How Does My Language Model Learn Clinical Information? + https://arxiv.org/abs/2505.15024 + +""" +def main() -> None: + dataset = MedLingoDataset(root="test-resources") + records = dataset.process() + + task = MedLingoTask() + processed = task.process(records) + + model = AbbreviationLookupModel(normalize=True) + model.fit( + [ + {"abbr": item["input"], "label": item["target"]} + for item in processed + ] + ) + + correct = 0 + total = len(processed) + + for item in processed: + pred = model.predict(item["input"]) + if pred == item["target"]: + correct += 1 + + accuracy = correct / total if total > 0 else 0.0 + + print("=== MedLingo Replication Pipeline ===") + print(f"Loaded {len(records)} records") + print(f"Processed {len(processed)} task samples") + print(f"Accuracy: {accuracy:.3f}") + print("Example sample:") + print(processed[0]) + print("Example prediction:") + print(model.predict(processed[0]['input'])) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/pyhealth/datasets/__init__.py b/pyhealth/datasets/__init__.py index 50b1b3887..491588e30 100644 --- a/pyhealth/datasets/__init__.py +++ b/pyhealth/datasets/__init__.py @@ -91,3 +91,4 @@ def __init__(self, *args, **kwargs): save_processors, ) from .collate import collate_temporal +from .medlingo import MedLingoDataset \ No newline at end of file diff --git a/pyhealth/datasets/configs/medlingo.yaml b/pyhealth/datasets/configs/medlingo.yaml new file mode 100644 index 000000000..d49cacdf6 --- /dev/null +++ b/pyhealth/datasets/configs/medlingo.yaml @@ -0,0 +1,13 @@ +dataset_name: medlingo +task: abbreviation_expansion +modality: text + +tables: + - medlingo + +fields: + - abbr + - context + - label + +label_field: label \ No newline at end of file diff --git a/pyhealth/datasets/medlingo.py b/pyhealth/datasets/medlingo.py new file mode 100644 index 000000000..1371cafa2 --- /dev/null +++ b/pyhealth/datasets/medlingo.py @@ -0,0 +1,89 @@ +import json +from pathlib import Path +from typing import Dict, List + +from pyhealth.datasets import BaseDataset + +class MedLingoDataset(BaseDataset): + """ + MedLingo Dataset for clinical abbreviation interpretation. + + Contributor: + Tedra Birch(tbirch2@illinois.edu) + + Paper: + Diagnosing Our Datasets: How Does My Language Model Learn Clinical Information? + https://arxiv.org/abs/2505.15024 + + This dataset is inspired by the MedLingo benchmark and is constructed from + cleaned, curated clinical abbreviation samples. + + Each sample contains: + - abbr: clinical abbreviation string + - context: short clinical text snippet + - label: ground truth expanded meaning + - source: source of the sample (e.g. "mimic_iv", "synthetic_demo") + + Args: + root: Root directory containing medlingo_samples.json (e.g., "test-resources" for demo usage) + config_path: Optional path to dataset config yaml. + + Example: + >>> dataset = MedLingoDataset(root="data") + >>> records = dataset.process() + """ + + def __init__( + self, + root: str = "", + config_path: str | None = None, + ) -> None: + tables = ["medlingo"] # single table dataset + super().__init__( + root=root, + tables=tables, + dataset_name="medlingo", + config_path=config_path, + ) + + @classmethod + def from_json(cls, filepath: str | Path) -> "MedLingoDataset": + dataset = cls(root=str(Path(filepath).parent)) + return dataset + + + def process(self) -> List[Dict]: + """ + Load MedLingo JSON samples and convert them into PyHealth-style records. + + Returns: + A list of patient/visit records with a medlingo table. + """ + file_path = Path(self.root) / "medlingo_samples.json" + + + # Check if the file exists + if not file_path.exists(): + raise FileNotFoundError(f"{file_path} not found.") + + + with open(file_path, "r", encoding="utf-8") as f: + samples = json.load(f) + + data = [] + + # Convert each sample into the standardized format + for i, sample in enumerate(samples): + data.append({ + "patient_id": f"patient_{i}", + "visit_id": f"visit_{i}", + "medlingo": [ + { + "abbr": sample["abbr"], + "context": sample["context"], + "label": sample["label"], + } + ] + }) + + return data \ No newline at end of file diff --git a/pyhealth/models/abbreviation_lookup.py b/pyhealth/models/abbreviation_lookup.py new file mode 100644 index 000000000..05d89220f --- /dev/null +++ b/pyhealth/models/abbreviation_lookup.py @@ -0,0 +1,73 @@ +import re +from typing import Any + + +class AbbreviationLookupModel: + """ + Simple rule-based model for clinical abbreviation interpretation. + + Contributor: + Tedra Birch (tbirch2@illinois.edu) + + Paper: + Diagnosing Our Datasets: How Does My Language Model Learn Clinical Information? + https://arxiv.org/abs/2505.15024 + + This model builds a dictionary mapping clinical abbreviations to their + expanded labels. It optionally normalizes input by stripping punctuation + and converting text to uppercase. + + Args: + normalize: Whether to normalize inputs before lookup. + + Example: + >>> model = AbbreviationLookupModel() + >>> model.fit(samples) + >>> model.predict("SOB") + """ + + def __init__(self, normalize: bool = True) -> None: + self.normalize = normalize + self.lookup: dict[str, str] = {} + + def _normalize_text(self, text: str) -> str: + """ + Normalize text for rule-based lookup. + + Args: + text: Input abbreviation string. + + Returns: + Normalized abbreviation string. + """ + text = text.strip() + if self.normalize: + text = re.sub(r"[^A-Za-z0-9/]+", "", text) + text = text.upper() + return text + + def fit(self, samples: list[dict[str, Any]]) -> None: + """ + Fit the lookup table from abbreviation-label pairs. + + Args: + samples: List of samples containing 'abbr' and 'label' keys. + """ + for sample in samples: + key = self._normalize_text(sample["abbr"]) + self.lookup[key] = sample["label"] + + def predict(self, input_text: str) -> str: + """ + Predict the expanded meaning of an abbreviation. + + Args: + input_text: Abbreviation string. + + Returns: + Expanded abbreviation label if found, otherwise 'unknown'. + """ + key = self._normalize_text(input_text) + return self.lookup.get(key, "unknown") + + \ No newline at end of file diff --git a/pyhealth/tasks/clinical_abbreviation.py b/pyhealth/tasks/clinical_abbreviation.py new file mode 100644 index 000000000..3d9fbed75 --- /dev/null +++ b/pyhealth/tasks/clinical_abbreviation.py @@ -0,0 +1,78 @@ +import re +from typing import Any, Dict + +class ClinicalAbbreviationTask: + """ + Task for clinical abbreviation interpretation. + + Contributor: + Tedra Birch (tbirch2@illinois.edu) + + Paper: + Diagnosing Our Datasets: How Does My Language Model Learn Clinical Information? + https://arxiv.org/abs/2505.15024 + + + This task converts Medlingo samples into model-ready input/label pairs. + + If `use_context` is False, the task uses the abbreviation directly. + If `use_context` is True, the task attempts to extract a likely + abbreviation from the clinical context. + + """ + + task_name: str = "clinical_abbreviation" + input_schema = {"input": "str"} + output_schema = {"label": "str"} + + def __init__(self, use_context: bool = False) -> None: + self.use_context = use_context + + def extract_abbreviation(self, text: str) -> str: + """ + Extract a likely clinical abbreviation from text. + + Priority: + 1. uppercase abbreviations like SOB, BP, CHF, HTN + 2. mixed-case shorthand like Hx, Dx, Rx + + Args: + text: The clinical context from which to extract an abbreviation. + + Returns: + The extracted abbreviation, or an empty string if none is found. + """ + # First, try to find uppercase abbreviations (2+ letters) + upper_match = re.search(r"\b([A-Z]{2,})\b", text) + if upper_match: + return upper_match.group(0) + + # Then, try to find mixed-case shorthand (2+ letters) + mixed_match = re.search(r"\b([A-Z][a-z]{1,})\b", text) + if mixed_match: + return mixed_match.group(0) + + return "" + + def __call__(self, sample: Dict[str, Any]) -> Dict[str, str]: + """ + Convert a MedLingo sample into task-ready format. + + Args: + sample: A MedLingo sample containing abbreviation, context, and label. + + Returns: + A dictionary with model input and expected label. + """ + context = sample.get("context", "").strip() + + if self.use_context and context: + extracted = self.extract_abbreviation(context) + model_input = extracted if extracted else sample["abbr"] + else: + model_input = sample["abbr"] + + return { + "input": model_input, + "label": sample["label"], + } \ No newline at end of file diff --git a/pyhealth/tasks/medlingo_task.py b/pyhealth/tasks/medlingo_task.py new file mode 100644 index 000000000..a28d1aa1a --- /dev/null +++ b/pyhealth/tasks/medlingo_task.py @@ -0,0 +1,58 @@ +from typing import Any + +from pyhealth.tasks import BaseTask + + +class MedLingoTask(BaseTask): + """ + Task for MedLingo-style clinical abbreviation expansion. + + Contributor: + Tedra Birch (tbirch2@illinois.edu) + + Paper: + Diagnosing Our Datasets: How Does My Language Model Learn Clinical Information? + https://arxiv.org/abs/2505.15024 + + This task converts MedLingo dataset records into model-ready input/target pairs. + """ + + task_name: str = "medlingo_task" + input_schema = {"input": "str"} + output_schema = {"target": "str"} + + def __init__(self) -> None: + super().__init__() + + def __call__(self, sample: dict[str, Any]) -> dict[str, str]: + """ + Convert a single MedLingo sample into task-ready format. + + Args: + sample: A dictionary containing the fields 'context' and 'label'. + + Returns: + A dictionary with the processed input and target fields. + """ + return { + "input": sample["context"], + "target": sample["label"], + } + + def process(self, dataset): + """ + Convert processed MedLingo records into task-ready samples. + + Args: + dataset: Output of MedLingoDataset.process(). + + Returns: + A list of dictionaries with input and target fields. + """ + output = [] + + for record in dataset: + for sample in record["medlingo"]: + output.append(self(sample)) + + return output \ No newline at end of file diff --git a/test-resources/medlingo_samples.json b/test-resources/medlingo_samples.json new file mode 100644 index 000000000..cc39858dd --- /dev/null +++ b/test-resources/medlingo_samples.json @@ -0,0 +1,92 @@ +[ + { + "abbr": "SOB", + "context": "SOB, worsening abd distension and discomfort", + "label": "shortness of breath", + "source": "mimic_iv_discharge" + }, + { + "abbr": "COPD", + "context": "persistent cough with mucus.", + "label": "chronic obstructive pulmonary disease", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Hx", + "context": "Hx of cocaine and heroin abuse", + "label": "history", + "source": "mimic_iv_discharge" + }, + { + "abbr": "COPD", + "context": "hepatic encephalopathy", + "label": "chronic obstructive pulmonary disease", + "source": "mimic_iv_discharge" + }, + { + "abbr": "CHF", + "context": "known CHF with fluid overload", + "label": "congestive heart failure", + "source": "mimic_iv_discharge" + }, + { + "abbr": "COPD", + "context": "wheezing on admission", + "label": "chronic obstructive pulmonary disease", + "source": "mimic_iv_discharge" + }, + { + "abbr": "CP", + "context": "chest tightness persistent", + "label": "chest pain", + "source": "mimic_iv_discharge" + }, + { + "abbr": "CHF", + "context": "chronic decompensated heart failure", + "label": "congestive heart failure", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Hx", + "context": "Hx of sigmoid diverticulitis", + "label": "history", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Hx", + "context": "Hx of bladder cancer", + "label": "history", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Hx", + "context": "Hx of cervical CA", + "label": "history", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Dx", + "context": "Dx: lung cancer 2", + "label": "diagnosis", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Dx", + "context": "Dx: depression", + "label": "diagnosis", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Dx", + "context": "Dx: Seizure Disorder", + "label": "diagnosis", + "source": "mimic_iv_discharge" + }, + { + "abbr": "Dx", + "context": "Dx: Meningioma", + "label": "diagnosis", + "source": "mimic_iv_discharge" + } +] \ No newline at end of file diff --git a/tests/test_abbreviation_lookup.py b/tests/test_abbreviation_lookup.py new file mode 100644 index 000000000..45197eb79 --- /dev/null +++ b/tests/test_abbreviation_lookup.py @@ -0,0 +1,15 @@ +from pyhealth.models.abbreviation_lookup import AbbreviationLookupModel + + +def test_lookup_model_predicts_known_abbreviation() -> None: + samples = [ + {"abbr": "SOB", "label": "shortness of breath"}, + {"abbr": "BP", "label": "blood pressure"}, + ] + + model = AbbreviationLookupModel(normalize=True) + model.fit(samples) + + assert model.predict("SOB") == "shortness of breath" + assert model.predict("sob") == "shortness of breath" + assert model.predict("BP!!!") == "blood pressure" \ No newline at end of file diff --git a/tests/test_clinical_abbreviation.py b/tests/test_clinical_abbreviation.py new file mode 100644 index 000000000..38d1a7c75 --- /dev/null +++ b/tests/test_clinical_abbreviation.py @@ -0,0 +1,31 @@ +from pyhealth.tasks.clinical_abbreviation import ClinicalAbbreviationTask + + +def test_task_without_context() -> None: + task = ClinicalAbbreviationTask(use_context=False) + + sample = { + "abbr": "SOB", + "context": "Patient presents with SOB.", + "label": "shortness of breath", + } + + result = task(sample) + + assert result["input"] == "SOB" + assert result["label"] == "shortness of breath" + + +def test_task_with_context() -> None: + task = ClinicalAbbreviationTask(use_context=True) + + sample = { + "abbr": "SOB", + "context": "Patient presents with SOB.", + "label": "shortness of breath", + } + + result = task(sample) + + assert result["input"] == "SOB" + assert result["label"] == "shortness of breath" \ No newline at end of file diff --git a/tests/test_medlingo.py b/tests/test_medlingo.py new file mode 100644 index 000000000..f939f4623 --- /dev/null +++ b/tests/test_medlingo.py @@ -0,0 +1,39 @@ +from pyhealth.datasets.medlingo import MedLingoDataset + + +def test_medlingo_dataset_structure(): + dataset = MedLingoDataset(root=".") + + # synthetic raw samples (what your JSON would contain) + samples = [ + { + "abbr": "SOB", + "context": "Patient presents with SOB.", + "label": "shortness of breath", + }, + { + "abbr": "BP", + "context": "BP remained stable overnight.", + "label": "blood pressure", + }, + ] + + # monkey patch process input + def mock_process(): + data = [] + for i, sample in enumerate(samples): + data.append({ + "patient_id": f"patient_{i}", + "visit_id": f"visit_{i}", + "medlingo": [sample], + }) + return data + + dataset.process = mock_process # override + + output = dataset.process() + + # actual assertions + assert len(output) == 2 + assert output[0]["medlingo"][0]["abbr"] == "SOB" + assert output[1]["medlingo"][0]["label"] == "blood pressure" \ No newline at end of file