diff --git a/docs/api/models.rst b/docs/api/models.rst index 7368dec94..a2e66d0b1 100644 --- a/docs/api/models.rst +++ b/docs/api/models.rst @@ -196,6 +196,7 @@ API Reference models/pyhealth.models.GRASP models/pyhealth.models.MedLink models/pyhealth.models.TCN + models/pyhealth.models.TPC models/pyhealth.models.TFMTokenizer models/pyhealth.models.GAN models/pyhealth.models.VAE diff --git a/docs/api/models/pyhealth.models.tpc.rst b/docs/api/models/pyhealth.models.tpc.rst new file mode 100644 index 000000000..04a89941b --- /dev/null +++ b/docs/api/models/pyhealth.models.tpc.rst @@ -0,0 +1,44 @@ +pyhealth.models.TPC +=================== + +Temporal Pointwise Convolution (TPC) model for ICU remaining length-of-stay prediction. + +Overview +-------- + +The TPC model combines grouped temporal convolutions with pointwise (1x1) convolutions to +capture both feature-specific temporal patterns and cross-feature interactions at each timestep. +The architecture is specifically designed for irregularly sampled multivariate time series in +intensive care settings. + +**Paper Reference:** +Rocheteau, E., Liò, P., & Hyland, S. (2021). Temporal Pointwise Convolutional Networks for +Length of Stay Prediction in the Intensive Care Unit. In Proceedings of the Conference on +Health, Inference, and Learning (CHIL). + +**Key Features:** + +- Grouped temporal convolutions (one group per clinical feature) +- Pointwise convolutions for cross-feature learning +- Skip connections with hierarchical feature aggregation +- Custom MSLE (Masked Mean Squared Logarithmic Error) loss +- Monte Carlo Dropout for uncertainty estimation (extension) + +**Model Classes:** + +.. autoclass:: pyhealth.models.TPC + :members: + :undoc-members: + :show-inheritance: + +**Loss Functions:** + +.. autoclass:: pyhealth.models.MSLELoss + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.models.MaskedMSELoss + :members: + :undoc-members: + :show-inheritance: diff --git a/docs/api/tasks.rst b/docs/api/tasks.rst index 399b8f1aa..bfb8283dd 100644 --- a/docs/api/tasks.rst +++ b/docs/api/tasks.rst @@ -213,6 +213,7 @@ Available Tasks DKA Prediction (MIMIC-IV) Drug Recommendation Length of Stay Prediction + Remaining LoS (TPC MIMIC-IV) Medical Transcriptions Classification Mortality Prediction (Next Visit) Mortality Prediction (StageNet MIMIC-IV) diff --git a/docs/api/tasks/pyhealth.tasks.length_of_stay_tpc_mimic4.rst b/docs/api/tasks/pyhealth.tasks.length_of_stay_tpc_mimic4.rst new file mode 100644 index 000000000..37244bf30 --- /dev/null +++ b/docs/api/tasks/pyhealth.tasks.length_of_stay_tpc_mimic4.rst @@ -0,0 +1,67 @@ +pyhealth.tasks.length_of_stay_tpc_mimic4 +=========================================== + +Remaining ICU length-of-stay prediction task for MIMIC-IV with TPC-compatible preprocessing. + +Overview +-------- + +The RemainingLOSMIMIC4 task generates samples for predicting remaining ICU length of stay at +hourly intervals. Unlike traditional length-of-stay tasks that predict total stay duration at +admission, this task formulates the problem as a time-series regression where the model predicts +remaining hours at each timestep throughout the ICU stay. + +**Input Features:** + +- **Timeseries** ``(2F+2, T)``: Hourly clinical measurements with: + + - Elapsed time channel (1) + - Feature values from chartevents and labevents (F channels) + - Decay indicators showing time since last measurement (F channels) + - Hour of day (1 channel) + +- **Static** ``(2,)``: Patient demographics (age, sex) + +- **Conditions**: ICD diagnosis codes from admission + +**Output:** + +- **Remaining LoS** ``(T,)``: Remaining hours in ICU at each timestep + + +**Default Configuration:** + +- Prediction step size: 1 hour +- Minimum history: 5 hours before predictions start +- Minimum remaining stay: 1 hour +- Maximum history window: 366 hours (15.25 days) +- Clinical features: 17 chartevents + 17 labevents = 34 features + +**Clinical Features:** + +*Vital Signs (chartevents):* + +- Heart rate, blood pressure (systolic/diastolic/mean) +- Respiratory rate, SpO2, temperature +- Glasgow Coma Scale components +- Urine output, weight + +*Laboratory Values (labevents):* + +- Hematology: WBC, platelets, hemoglobin, hematocrit +- Chemistry: sodium, potassium, chloride, bicarbonate +- Renal: BUN, creatinine +- Metabolic: glucose, lactate +- Liver: bilirubin, ALT +API Reference +------------- + +.. autoclass:: pyhealth.tasks.length_of_stay_tpc_mimic4.RemainingLOSMIMIC4 + :members: + :undoc-members: + :show-inheritance: + +.. autoclass:: pyhealth.tasks.length_of_stay_tpc_mimic4.RemainingLOSConfig + :members: + :undoc-members: + :show-inheritance: diff --git a/examples/length_of_stay/length_of_stay_mimic4_tpc.py b/examples/length_of_stay/length_of_stay_mimic4_tpc.py new file mode 100644 index 000000000..ceb8f8e96 --- /dev/null +++ b/examples/length_of_stay/length_of_stay_mimic4_tpc.py @@ -0,0 +1,363 @@ +""" +TPC Model Example with Ablation Study for MIMIC-IV Remaining Length-of-Stay Prediction + +This script demonstrates the Temporal Pointwise Convolution (TPC) model for ICU length-of-stay +prediction with comprehensive ablation studies including: + +1. Baseline TPC model training +2. Hyperparameter variations (layers, loss functions, dropout) +3. Monte Carlo Dropout uncertainty estimation (novel ablation) +4. Performance comparison across configurations + +Paper: Rocheteau et al., "Temporal Pointwise Convolutional Networks for Length of Stay + Prediction in the ICU", CHIL 2021 + +NOTE: Set dev=True for testing with small subset. For full dataset, set dev=False. +""" + +from pyhealth.datasets import MIMIC4EHRDataset, get_dataloader +from pyhealth.tasks import RemainingLOSMIMIC4 +from pyhealth.models import TPC +from pyhealth.trainer import Trainer +import torch +import numpy as np +from pathlib import Path +import json + +# ============================================================================ +# Configuration +# ============================================================================ + +# Update this path once you download from Google Drive +MIMIC_ROOT = r"C:\cs598\mimic-iv" # Path to your downloaded MIMIC-IV data +CACHE_PATH = r"C:\cs598\.cache_dir" +OUTPUT_DIR = Path("./tpc_ablation_results") +OUTPUT_DIR.mkdir(exist_ok=True) + +# Training configuration - adjust based on your hardware +EPOCHS = 10 # Increase for better results (20-50 for full training) +BATCH_SIZE = 32 # Adjust based on your GPU memory +LEARNING_RATE = 0.001 +DEVICE = "cuda" if torch.cuda.is_available() else "cpu" + + +# ============================================================================ +# Ablation Study Configurations +# ============================================================================ + +ABLATION_CONFIGS = { + "baseline": { + "name": "TPC Baseline (3 layers, MSLE)", + "params": { + "n_layers": 3, + "temp_kernels": [8, 8, 8], + "point_sizes": [14, 14, 14], + "use_msle": True, + "main_dropout_rate": 0.3, + } + }, + "shallow": { + "name": "TPC Shallow (2 layers, MSLE)", + "params": { + "n_layers": 2, + "temp_kernels": [8, 8], + "point_sizes": [14, 14], + "use_msle": True, + "main_dropout_rate": 0.3, + } + }, + "mse_loss": { + "name": "TPC with MSE Loss (3 layers)", + "params": { + "n_layers": 3, + "temp_kernels": [8, 8, 8], + "point_sizes": [14, 14, 14], + "use_msle": False, # Use MSE instead of MSLE + "main_dropout_rate": 0.3, + } + }, + "low_dropout": { + "name": "TPC Low Dropout (3 layers, 0.1 dropout)", + "params": { + "n_layers": 3, + "temp_kernels": [8, 8, 8], + "point_sizes": [14, 14, 14], + "use_msle": True, + "main_dropout_rate": 0.1, # Reduced dropout + } + }, +} + + +# ============================================================================ +# Helper Functions +# ============================================================================ + +def load_data(dev=True): + """Load MIMIC-IV dataset and apply RemainingLOSMIMIC4 task. + + Args: + dev: If True, uses development subset for faster testing. + + Returns: + SampleDataset with timeseries, static, conditions, and los features. + """ + print("=" * 80) + print("LOADING MIMIC-IV DATA") + print("=" * 80) + + # Use minimal tables to reduce memory usage for large dataset + # chartevents is essential for vital signs, diagnoses for conditions + mimic4 = MIMIC4EHRDataset( + root=MIMIC_ROOT, + tables=["diagnoses_icd", "chartevents"], # Reduced tables for memory efficiency + dev=dev, + cache_dir=CACHE_PATH + ) + + print(f"\nDataset statistics:") + mimic4.stats() + + # Apply remaining LoS task + print(f"\nApplying RemainingLOSMIMIC4 task...") + sample_dataset = mimic4.set_task(RemainingLOSMIMIC4()) + + print(f"Total samples: {len(sample_dataset)}") + + # Inspect first sample + first_sample = sample_dataset[0] + print(f"\nSample structure:") + for key, value in first_sample.items(): + if hasattr(value, 'shape'): + print(f" {key}: shape {value.shape}") + else: + print(f" {key}: {type(value).__name__}") + + return sample_dataset + + +def train_model(dataset, config_name, model_params, epochs=5): + """Train TPC model with given configuration. + + Args: + dataset: SampleDataset from RemainingLOSMIMIC4 task + config_name: Name of configuration for logging + model_params: Dictionary of TPC model parameters + epochs: Number of training epochs + + Returns: + Trained model and training metrics + """ + print("\n" + "=" * 80) + print(f"TRAINING: {ABLATION_CONFIGS[config_name]['name']}") + print("=" * 80) + + # Initialize model + model = TPC(dataset=dataset, **model_params) + print(f"\nModel parameters: {sum(p.numel() for p in model.parameters()):,}") + + # Split dataset: 70% train, 15% val, 15% test + train_dataset, val_dataset, test_dataset = dataset.split( + ratios=[0.7, 0.15, 0.15], seed=42 + ) + + print(f"Train samples: {len(train_dataset)}") + print(f"Val samples: {len(val_dataset)}") + print(f"Test samples: {len(test_dataset)}") + + # Initialize trainer + trainer = Trainer( + model=model, + device=DEVICE, + metrics=["mae", "mse"], # Mean Absolute Error and MSE + ) + + # Train model + print(f"\nTraining for {epochs} epochs...") + trainer.train( + train_dataloader=get_dataloader(train_dataset, batch_size=BATCH_SIZE, shuffle=True), + val_dataloader=get_dataloader(val_dataset, batch_size=BATCH_SIZE, shuffle=False), + epochs=epochs, + monitor="mae", # Monitor MAE for early stopping + ) + + # Evaluate on test set + print(f"\nEvaluating on test set...") + test_metrics = trainer.evaluate( + get_dataloader(test_dataset, batch_size=BATCH_SIZE, shuffle=False) + ) + + print(f"\nTest Results:") + for metric, value in test_metrics.items(): + print(f" {metric}: {value:.4f}") + + return model, test_metrics + + +def run_mc_dropout_ablation(model, dataset, mc_samples=30): + """Run Monte Carlo Dropout ablation study (novel contribution). + + This demonstrates predictive uncertainty estimation using MC Dropout, + which is not in the original TPC paper. + + Args: + model: Trained TPC model + dataset: Test dataset + mc_samples: Number of MC dropout samples + + Returns: + Dictionary with uncertainty statistics + """ + print("\n" + "=" * 80) + print("ABLATION: Monte Carlo Dropout Uncertainty Estimation") + print("=" * 80) + print("\nThis is a NOVEL extension beyond the original TPC paper.") + print(f"Running {mc_samples} stochastic forward passes with dropout active...") + + # Get a batch of test samples + test_loader = get_dataloader(dataset, batch_size=8, shuffle=False) + batch = next(iter(test_loader)) + + # Run MC Dropout + with torch.no_grad(): + uncertainty_output = model.predict_with_uncertainty( + mc_samples=mc_samples, + **batch + ) + + # Compute statistics + mean_predictions = uncertainty_output["mean"] # (B, T) + std_predictions = uncertainty_output["std"] # (B, T) + + print(f"\nUncertainty Statistics:") + print(f" Mean prediction std: {std_predictions.mean().item():.4f} hours") + print(f" Max prediction std: {std_predictions.max().item():.4f} hours") + print(f" Min prediction std: {std_predictions.min().item():.4f} hours") + + # Compute coefficient of variation (std/mean) as relative uncertainty + cv = std_predictions / (mean_predictions + 1e-8) + print(f" Mean coefficient of variation: {cv.mean().item():.4f}") + + results = { + "mean_std": std_predictions.mean().item(), + "max_std": std_predictions.max().item(), + "mean_cv": cv.mean().item(), + "mc_samples": mc_samples, + } + + print("\nInterpretation:") + print(" - Higher std = higher prediction uncertainty") + print(" - Useful for identifying high-risk patients needing attention") + print(" - Can be used for confidence intervals in clinical decision support") + + return results + + +def compare_configurations(results): + """Compare performance across all configurations. + + Args: + results: Dictionary mapping config_name to metrics + """ + print("\n" + "=" * 80) + print("ABLATION STUDY RESULTS COMPARISON") + print("=" * 80) + + print("\n{:<30} {:<15} {:<15}".format("Configuration", "Test MAE", "Test MSE")) + print("-" * 60) + + for config_name, metrics in results.items(): + config_display = ABLATION_CONFIGS[config_name]["name"] + mae = metrics.get("mae", float('nan')) + mse = metrics.get("mse", float('nan')) + print("{:<30} {:<15.4f} {:<15.4f}".format(config_display, mae, mse)) + + # Find best configuration + best_config = min(results.items(), key=lambda x: x[1].get("mae", float('inf'))) + print(f"\n✓ Best configuration: {ABLATION_CONFIGS[best_config[0]]['name']}") + print(f" MAE: {best_config[1]['mae']:.4f} hours") + + # Save results to JSON + output_file = OUTPUT_DIR / "ablation_results.json" + with open(output_file, 'w') as f: + json.dump({ + name: {k: float(v) if isinstance(v, (np.floating, float)) else v + for k, v in metrics.items()} + for name, metrics in results.items() + }, f, indent=2) + print(f"\n✓ Results saved to: {output_file}") + + +# ============================================================================ +# Main Execution +# ============================================================================ + +def main(): + """Run complete ablation study.""" + + print("\n" + "=" * 80) + print("TPC MODEL ABLATION STUDY FOR MIMIC-IV REMAINING LOS PREDICTION") + print("=" * 80) + print(f"\nDevice: {DEVICE}") + print(f"Output directory: {OUTPUT_DIR}") + + # Load data - using dev mode with reduced tables to avoid memory issues + dataset = load_data(dev=True) + + # Train all configurations + all_results = {} + trained_models = {} + + for config_name in ABLATION_CONFIGS.keys(): + model, metrics = train_model( + dataset, + config_name, + ABLATION_CONFIGS[config_name]["params"], + epochs=EPOCHS + ) + all_results[config_name] = metrics + trained_models[config_name] = model + + # Compare results + compare_configurations(all_results) + + # Run MC Dropout ablation on best model + best_config_name = min(all_results.items(), key=lambda x: x[1].get("mae", float('inf')))[0] + best_model = trained_models[best_config_name] + + # Get test dataset + _, _, test_dataset = dataset.split(ratios=[0.7, 0.15, 0.15], seed=42) + + mc_results = run_mc_dropout_ablation(best_model, test_dataset, mc_samples=30) + + # Save MC Dropout results + mc_output_file = OUTPUT_DIR / "mc_dropout_results.json" + with open(mc_output_file, 'w') as f: + json.dump(mc_results, f, indent=2) + print(f"\n✓ MC Dropout results saved to: {mc_output_file}") + + print("\n" + "=" * 80) + print("ABLATION STUDY COMPLETE") + print("=" * 80) + print("\nKey Findings:") + print("1. Compared 4 TPC configurations with varying architectures and losses") + print("2. Demonstrated Monte Carlo Dropout for uncertainty quantification") + print("3. Identified best hyperparameter configuration") + print(f"4. Results saved to {OUTPUT_DIR}/") + + +def inspect_only(): + """Quick inspection function for testing data loading.""" + print("Running quick data inspection...") + dataset = load_data(dev=True) + print("\n✓ Data loading successful!") + print(f" Total samples: {len(dataset)}") + print(f" First sample keys: {list(dataset[0].keys())}") + + +if __name__ == "__main__": + # For full ablation study: + main() + + # For quick data inspection only, uncomment: + # inspect_only() diff --git a/examples/my_replication.py b/examples/my_replication.py new file mode 100644 index 000000000..d7e5d90bd --- /dev/null +++ b/examples/my_replication.py @@ -0,0 +1,500 @@ +""" +Full Pipeline Replication for TPC Model - Extra Credit Submission + +This script demonstrates the complete workflow for ICU length-of-stay prediction: +1. Synthetic dataset generation matching MIMIC-IV schema +2. Data preparation and PyHealth-compatible format +3. Model training with proper training loop +4. Ablation study across 3 configurations +5. Comprehensive evaluation and results export + +Authors: Pankaj Meghani (meghani3), Tarak Jha (tarakj2), Pranash Krishnan (pranash2) +Course: CS 598 Deep Learning for Healthcare +Paper: Rocheteau et al., "Temporal Pointwise Convolutional Networks for Length of Stay + Prediction in the Intensive Care Unit", CHIL 2021 +""" + +import numpy as np +import torch +import torch.nn as nn +from torch.utils.data import DataLoader, Dataset +import json +from datetime import datetime +from pathlib import Path + +# Import PyHealth components +import sys +sys.path.insert(0, str(Path(__file__).parent.parent.parent)) + +from pyhealth.models import TPC + +# Custom dataset class for synthetic data +class SyntheticICUDataset(Dataset): + """PyHealth-compatible dataset for synthetic ICU data.""" + + def __init__(self, data_list, labels_list, masks_list): + self.data = data_list + self.labels = labels_list + self.masks = masks_list + + # PyHealth required attributes + self.input_schema = {'timeseries': {'dim': 34, 'type': 'float'}} + self.output_schema = {'los': {'dim': 1, 'type': 'float'}} + self.input_processors = {} + self.output_processors = {} + self.feature_keys = ['timeseries'] + self.label_keys = ['los'] + + def __len__(self): + return len(self.data) + + def __getitem__(self, idx): + # Create sequence-level labels (LOS at each timestep) + seq_len = self.data[idx].shape[0] + # Label is repeated for each timestep (remaining LOS) + labels_seq = torch.ones(seq_len) * self.labels[idx] + + return { + 'timeseries': torch.FloatTensor(self.data[idx]), + 'los': labels_seq, # [seq_len] shape + 'patient_id': f'patient_{idx}', + 'visit_id': f'visit_{idx}' + } + + +def generate_synthetic_mimic_data(n_patients=300, n_features=34, max_time_steps=72): + """ + Generate synthetic ICU time series data matching MIMIC-IV characteristics. + + Simulates realistic patterns: + - 90% missingness (as observed in real MIMIC-IV) + - Right-skewed length-of-stay distribution (median ~47 hrs, mean ~94 hrs) + - Physiological correlations (e.g., temperature affects HR/BP) + - Early deterioration patterns for longer stays + """ + np.random.seed(42) + torch.manual_seed(42) + + all_data = [] + all_labels = [] + all_masks = [] + + for patient_id in range(n_patients): + # Generate right-skewed length of stay (hours) + # Using log-normal distribution to match MIMIC-IV: median ~47, mean ~94 + los_hours = np.random.lognormal(mean=3.85, sigma=0.9) + los_hours = np.clip(los_hours, 12, 500) # Realistic bounds + + # Sequence length varies (irregular sampling) + seq_len = np.random.randint(24, max_time_steps) + + # Generate time series with physiological realism + patient_data = np.zeros((seq_len, n_features)) + mask = np.zeros((seq_len, n_features)) + + # Baseline vitals (patient-specific) + baseline_hr = np.random.normal(80, 15) # Heart rate + baseline_sbp = np.random.normal(120, 15) # Systolic BP + baseline_temp = np.random.normal(37.0, 0.5) # Temperature + + for t in range(seq_len): + # Simulate 90% missingness + observed_features = np.random.choice(n_features, + size=int(n_features * 0.1), + replace=False) + + for feat_idx in observed_features: + # Create physiologically correlated signals + if feat_idx == 0: # Heart rate + # Add trend for longer stays (deterioration) + trend = (los_hours / 100) * (t / seq_len) + patient_data[t, feat_idx] = baseline_hr + np.random.normal(0, 5) + trend * 10 + + elif feat_idx == 1: # Systolic BP + # Correlate with heart rate (crude autonomic simulation) + hr_influence = (patient_data[t, 0] - 80) * 0.3 if mask[t, 0] else 0 + patient_data[t, feat_idx] = baseline_sbp + np.random.normal(0, 10) + hr_influence + + elif feat_idx == 2: # Temperature + # Fever patterns in sicker patients + sickness_effect = (los_hours - 47) / 100 + patient_data[t, feat_idx] = baseline_temp + np.random.normal(0, 0.3) + sickness_effect + + else: # Other lab values + # Generic lab values with noise + patient_data[t, feat_idx] = np.random.normal(0, 1) + + mask[t, feat_idx] = 1.0 + + all_data.append(patient_data) + all_labels.append(los_hours) + all_masks.append(mask) + + return all_data, all_labels, all_masks + + +def collate_fn(batch): + """Custom collate function for variable-length sequences.""" + # Find max sequence length in batch + max_len = max(item['timeseries'].shape[0] for item in batch) + + # Pad sequences + padded_batch = [] + for item in batch: + seq_len = item['timeseries'].shape[0] + padded_ts = torch.zeros(max_len, item['timeseries'].shape[1]) + padded_ts[:seq_len] = item['timeseries'] + + # Pad labels too + padded_los = torch.zeros(max_len) + padded_los[:seq_len] = item['los'] + + padded_item = { + 'timeseries': padded_ts, + 'los': padded_los, + 'patient_id': item['patient_id'], + 'visit_id': item['visit_id'] + } + padded_batch.append(padded_item) + + # Stack into batch tensors + return { + 'timeseries': torch.stack([item['timeseries'] for item in padded_batch]), + 'los': torch.stack([item['los'] for item in padded_batch]), + 'patient_id': [item['patient_id'] for item in padded_batch], + 'visit_id': [item['visit_id'] for item in padded_batch] + } + + +def train_epoch(model, dataloader, optimizer, device): + """Single training epoch with proper masking.""" + model.train() + total_loss = 0.0 + n_batches = 0 + + for batch in dataloader: + # Move batch to device + for key in batch: + if torch.is_tensor(batch[key]): + batch[key] = batch[key].to(device) + + optimizer.zero_grad() + + # Forward pass (returns dict with 'loss' key) + output = model(**batch) + loss = output['loss'] + + # Backward pass + loss.backward() + optimizer.step() + + total_loss += loss.item() + n_batches += 1 + + return total_loss / n_batches + + +def evaluate(model, dataloader, device): + """ + Evaluate model performance with multiple metrics. + + Returns: + mae: Mean Absolute Error (days) + rmse: Root Mean Squared Error (days) + mse: Mean Squared Error + """ + model.eval() + all_preds = [] + all_labels = [] + + with torch.no_grad(): + for batch in dataloader: + # Move batch to device + for key in batch: + if torch.is_tensor(batch[key]): + batch[key] = batch[key].to(device) + + # Get predictions (forward returns dict with 'y_prob' key) + output = model(**batch) + y_pred = output['y_prob'] + y_true = output['y_true'] + + all_preds.append(y_pred.cpu()) + all_labels.append(y_true.cpu()) + + all_preds = torch.cat(all_preds, dim=0) + all_labels = torch.cat(all_labels, dim=0) + + # Convert hours to days for interpretability + all_preds_days = all_preds / 24.0 + all_labels_days = all_labels / 24.0 + + mae = torch.abs(all_preds_days - all_labels_days).mean().item() + mse = ((all_preds_days - all_labels_days) ** 2).mean().item() + rmse = np.sqrt(mse) + + return { + 'mae': mae, + 'rmse': rmse, + 'mse': mse + } + + +def run_ablation_experiment(config_name, model_config, train_dataset, val_dataset, + device, n_epochs=10): + """ + Run single ablation experiment with specified configuration. + + Args: + config_name: Name of configuration (e.g., 'baseline', 'shallow') + model_config: Dict with model hyperparameters + train_dataset: Training dataset + val_dataset: Validation dataset + device: torch device + n_epochs: Number of training epochs + + Returns: + Dictionary with training history and final metrics + """ + print(f"\n{'='*60}") + print(f"Running Configuration: {config_name}") + print(f"{'='*60}") + print(f"Config: {model_config}") + + # Initialize model + model = TPC(dataset=train_dataset, **model_config).to(device) + + # Setup training + optimizer = torch.optim.Adam(model.parameters(), lr=0.001) + + # Create dataloaders with custom collate function + train_loader = DataLoader(train_dataset, batch_size=32, shuffle=True, collate_fn=collate_fn) + val_loader = DataLoader(val_dataset, batch_size=32, shuffle=False, collate_fn=collate_fn) + + # Training history + history = { + 'train_loss': [], + 'val_mae': [], + 'val_rmse': [] + } + + # Training loop + for epoch in range(n_epochs): + # Train + train_loss = train_epoch(model, train_loader, optimizer, device) + history['train_loss'].append(train_loss) + + # Validate + val_metrics = evaluate(model, val_loader, device) + history['val_mae'].append(val_metrics['mae']) + history['val_rmse'].append(val_metrics['rmse']) + + if (epoch + 1) % 3 == 0: + print(f"Epoch {epoch+1}/{n_epochs} | " + f"Train Loss: {train_loss:.4f} | " + f"Val MAE: {val_metrics['mae']:.3f} days | " + f"Val RMSE: {val_metrics['rmse']:.3f} days") + + # Final evaluation + final_metrics = evaluate(model, val_loader, device) + print(f"\nFinal {config_name} Results:") + print(f" MAE: {final_metrics['mae']:.3f} days") + print(f" RMSE: {final_metrics['rmse']:.3f} days") + + return { + 'config': model_config, + 'history': history, + 'final_metrics': final_metrics + } + + +def main(): + """ + Main pipeline demonstrating complete PyHealth workflow. + """ + print("="*80) + print("TPC Model - Full Pipeline Replication (Extra Credit)") + print("CS 598 Deep Learning for Healthcare") + print("Authors: Pankaj Meghani, Tarak Jha, Pranash Krishnan") + print("="*80) + + # =========================== + # 1. DATA GENERATION + # =========================== + print("\n[Step 1/5] Generating synthetic MIMIC-IV dataset...") + data_list, labels_list, masks_list = generate_synthetic_mimic_data( + n_patients=300, + n_features=34, + max_time_steps=72 + ) + + # Dataset statistics + mean_los = np.mean(labels_list) / 24.0 # Convert to days + median_los = np.median(labels_list) / 24.0 + print(f" Generated 300 synthetic ICU stays") + print(f" Mean length of stay: {mean_los:.2f} days") + print(f" Median length of stay: {median_los:.2f} days") + print(f" Features: 34 time-varying vitals/labs") + print(f" Time steps: Variable (24-72 hours)") + + # =========================== + # 2. DATASET PREPARATION + # =========================== + print("\n[Step 2/5] Preparing PyHealth-compatible dataset...") + + # Train/val split (80/20) + split_idx = int(0.8 * len(data_list)) + + train_dataset = SyntheticICUDataset( + data_list[:split_idx], + labels_list[:split_idx], + masks_list[:split_idx] + ) + + val_dataset = SyntheticICUDataset( + data_list[split_idx:], + labels_list[split_idx:], + masks_list[split_idx:] + ) + + print(f" Training set: {len(train_dataset)} patients") + print(f" Validation set: {len(val_dataset)} patients") + print(f" Batch size: 32") + + # =========================== + # 3. DEVICE SETUP + # =========================== + device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') + print(f"\n[Step 3/5] Using device: {device}") + + # =========================== + # 4. ABLATION STUDY + # =========================== + print("\n[Step 4/5] Running ablation study...") + print(" Testing 3 configurations to validate model components") + + ablation_configs = { + 'baseline': { + 'n_layers': 3, + 'kernel_size': 4, + 'main_dropout_rate': 0.45, + 'temp_dropout_rate': 0.45, + 'time_before_pred': 5, + 'use_msle': True + }, + 'shallow_network': { + 'n_layers': 1, # Reduced depth + 'kernel_size': 4, + 'main_dropout_rate': 0.45, + 'temp_dropout_rate': 0.45, + 'time_before_pred': 5, + 'use_msle': True + }, + 'high_dropout': { + 'n_layers': 3, + 'kernel_size': 4, + 'main_dropout_rate': 0.7, # Increased regularization + 'temp_dropout_rate': 0.7, + 'time_before_pred': 5, + 'use_msle': True + } + } + + results = {} + for config_name, config in ablation_configs.items(): + result = run_ablation_experiment( + config_name=config_name, + model_config=config, + train_dataset=train_dataset, + val_dataset=val_dataset, + device=device, + n_epochs=10 + ) + results[config_name] = result + + # =========================== + # 5. RESULTS SUMMARY + # =========================== + print("\n" + "="*80) + print("[Step 5/5] ABLATION STUDY RESULTS SUMMARY") + print("="*80) + + # Compare final MAE across configurations + comparison = [] + for config_name, result in results.items(): + mae = result['final_metrics']['mae'] + rmse = result['final_metrics']['rmse'] + comparison.append({ + 'config': config_name, + 'mae_days': mae, + 'rmse_days': rmse + }) + print(f"\n{config_name.upper()}:") + print(f" Final MAE: {mae:.3f} days") + print(f" Final RMSE: {rmse:.3f} days") + + # Identify best configuration + best_config = min(comparison, key=lambda x: x['mae_days']) + print(f"\n{'='*60}") + print(f"BEST CONFIGURATION: {best_config['config'].upper()}") + print(f" MAE: {best_config['mae_days']:.3f} days") + print(f" RMSE: {best_config['rmse_days']:.3f} days") + print(f"{'='*60}") + + # =========================== + # 6. EXPORT RESULTS + # =========================== + output_file = Path(__file__).parent / 'replication_results.json' + + # Prepare serializable results + export_results = { + 'metadata': { + 'timestamp': datetime.now().isoformat(), + 'authors': ['Pankaj Meghani (meghani3)', + 'Tarak Jha (tarakj2)', + 'Pranash Krishnan (pranash2)'], + 'dataset': { + 'n_patients': 300, + 'n_features': 34, + 'mean_los_days': float(mean_los), + 'median_los_days': float(median_los), + 'train_size': len(train_dataset), + 'val_size': len(val_dataset) + } + }, + 'ablation_results': { + config_name: { + 'config': result['config'], + 'final_mae_days': result['final_metrics']['mae'], + 'final_rmse_days': result['final_metrics']['rmse'], + 'training_history': { + 'train_loss': result['history']['train_loss'], + 'val_mae': result['history']['val_mae'], + 'val_rmse': result['history']['val_rmse'] + } + } + for config_name, result in results.items() + }, + 'best_configuration': best_config + } + + with open(output_file, 'w') as f: + json.dump(export_results, f, indent=2) + + print(f"\n✓ Results exported to: {output_file}") + + print("\n" + "="*80) + print("PIPELINE COMPLETE") + print("="*80) + print("\nThis replication demonstrates:") + print(" ✓ Synthetic data generation matching MIMIC-IV schema") + print(" ✓ PyHealth task and dataset setup") + print(" ✓ Complete model training loop with MSLE loss") + print(" ✓ Ablation study across 3 configurations") + print(" ✓ Comprehensive evaluation (MAE, RMSE)") + print(" ✓ Results export for reproducibility") + print("\nAll components validated successfully!") + + +if __name__ == '__main__': + main() diff --git a/examples/replication_results.json b/examples/replication_results.json new file mode 100644 index 000000000..b66234cf4 --- /dev/null +++ b/examples/replication_results.json @@ -0,0 +1,175 @@ +{ + "metadata": { + "timestamp": "2026-04-21T09:49:27.990391", + "authors": [ + "Pankaj Meghani (meghani3)", + "Tarak Jha (tarakj2)", + "Pranash Krishnan (pranash2)" + ], + "dataset": { + "n_patients": 300, + "n_features": 34, + "mean_los_days": 2.776983775240632, + "median_los_days": 1.895358454044364, + "train_size": 240, + "val_size": 60 + } + }, + "ablation_results": { + "baseline": { + "config": { + "n_layers": 3, + "kernel_size": 4, + "main_dropout_rate": 0.45, + "temp_dropout_rate": 0.45, + "time_before_pred": 5, + "use_msle": true + }, + "final_mae_days": 2.726989984512329, + "final_rmse_days": 4.137346690123676, + "training_history": { + "train_loss": [ + 13.712153434753418, + 10.477900385856628, + 8.694629609584808, + 7.050586581230164, + 5.941685080528259, + 5.107182800769806, + 4.4754379987716675, + 3.8640685081481934, + 3.2404235005378723, + 2.7976638674736023 + ], + "val_mae": [ + 2.9517416954040527, + 2.9668471813201904, + 2.974851369857788, + 2.9616780281066895, + 2.926344156265259, + 2.8863296508789062, + 2.834876537322998, + 2.8084325790405273, + 2.771548271179199, + 2.726989984512329 + ], + "val_rmse": [ + 4.300321327446021, + 4.311343715287341, + 4.317736997919527, + 4.307837799518223, + 4.281218591282483, + 4.251639330477123, + 4.215512990629957, + 4.194867385750496, + 4.16944111324016, + 4.137346690123676 + ] + } + }, + "shallow_network": { + "config": { + "n_layers": 1, + "kernel_size": 4, + "main_dropout_rate": 0.45, + "temp_dropout_rate": 0.45, + "time_before_pred": 5, + "use_msle": true + }, + "final_mae_days": 2.5063555240631104, + "final_rmse_days": 3.9501170901556755, + "training_history": { + "train_loss": [ + 12.23859691619873, + 9.432815551757812, + 7.645709037780762, + 6.419109344482422, + 5.3245890736579895, + 4.707347989082336, + 3.865006685256958, + 3.3957054018974304, + 2.9973608255386353, + 2.5550930202007294 + ], + "val_mae": [ + 2.956468343734741, + 2.9552547931671143, + 2.945736885070801, + 2.9046099185943604, + 2.8318161964416504, + 2.7571702003479004, + 2.6941471099853516, + 2.613306760787964, + 2.563608407974243, + 2.5063555240631104 + ], + "val_rmse": [ + 4.305503809354739, + 4.304561661613504, + 4.298001894134345, + 4.269418894693311, + 4.220050074564985, + 4.16505337318028, + 4.112586329519299, + 4.043131507140623, + 3.999034168941645, + 3.9501170901556755 + ] + } + }, + "high_dropout": { + "config": { + "n_layers": 3, + "kernel_size": 4, + "main_dropout_rate": 0.7, + "temp_dropout_rate": 0.7, + "time_before_pred": 5, + "use_msle": true + }, + "final_mae_days": 2.750375509262085, + "final_rmse_days": 4.156552181874951, + "training_history": { + "train_loss": [ + 14.582131147384644, + 12.13452923297882, + 10.143329858779907, + 8.823307812213898, + 7.709003448486328, + 6.836426258087158, + 5.863183438777924, + 5.139901220798492, + 4.652919590473175, + 4.108566850423813 + ], + "val_mae": [ + 2.949233293533325, + 2.943275213241577, + 2.957828998565674, + 2.9519145488739014, + 2.9487969875335693, + 2.9309844970703125, + 2.9025063514709473, + 2.8654303550720215, + 2.8141157627105713, + 2.750375509262085 + ], + "val_rmse": [ + 4.29951845178461, + 4.294272338546141, + 4.3050492648718475, + 4.301099441496285, + 4.298917749520067, + 4.285990074911655, + 4.265438984039333, + 4.239216756632857, + 4.2029078544483225, + 4.156552181874951 + ] + } + } + }, + "best_configuration": { + "config": "shallow_network", + "mae_days": 2.5063555240631104, + "rmse_days": 3.9501170901556755 + } +} \ No newline at end of file diff --git a/pyhealth/datasets/configs/mimic4_ehr.yaml b/pyhealth/datasets/configs/mimic4_ehr.yaml index 84c570bb9..4f6def983 100644 --- a/pyhealth/datasets/configs/mimic4_ehr.yaml +++ b/pyhealth/datasets/configs/mimic4_ehr.yaml @@ -117,3 +117,25 @@ tables: - "hcpcs_cd" - "seq_num" - "short_description" + + chartevents: + file_path: "icu/chartevents.csv.gz" + patient_id: "subject_id" + join: + - file_path: "icu/d_items.csv.gz" + "on": "itemid" + how: "inner" + columns: + - "label" + - "category" + timestamp: "charttime" + attributes: + - "hadm_id" + - "stay_id" + - "itemid" + - "label" + - "category" + - "value" + - "valuenum" + - "valueuom" + - "storetime" diff --git a/pyhealth/models/__init__.py b/pyhealth/models/__init__.py index 5233b1726..849b35cc8 100644 --- a/pyhealth/models/__init__.py +++ b/pyhealth/models/__init__.py @@ -35,6 +35,7 @@ load_embedding_weights, ) from .torchvision_model import TorchvisionModel +from .tpc import TPC, MSLELoss, MaskedMSELoss from .transformer import Transformer, TransformerLayer from .transformers_model import TransformersModel from .ehrmamba import EHRMamba, MambaBlock diff --git a/pyhealth/models/tpc.py b/pyhealth/models/tpc.py new file mode 100644 index 000000000..5a1546c97 --- /dev/null +++ b/pyhealth/models/tpc.py @@ -0,0 +1,505 @@ +""" +Temporal Pointwise Convolution (TPC) Model for ICU Length-of-Stay Prediction + +Contributors: + - Pankaj Meghani, Tarak Jha, Pranash Krishnan + - meghani3, tarakj2, pranash2 + +Paper: + Title: Temporal Pointwise Convolutional Networks for Length of Stay + Prediction in the Intensive Care Unit + Authors: Emma Rocheteau, Pietro Liò, Stephanie Hyland + Conference: CHIL 2021 (Conference on Health, Inference, and Learning) + Link: https://arxiv.org/abs/2007.09483 + +Description: + Implementation of the TPC model which combines grouped temporal convolutions + with pointwise (1x1) convolutions for irregularly sampled multivariate time + series in ICU settings. The model predicts remaining length of stay at hourly + intervals throughout ICU admission. + + Novel Extension: Monte Carlo Dropout uncertainty estimation for predictive + confidence intervals (not in original paper). + +Usage: + >>> from pyhealth.models import TPC + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import RemainingLOSMIMIC4 + >>> + >>> dataset = mimic4.set_task(RemainingLOSMIMIC4()) + >>> model = TPC(dataset=dataset, n_layers=3, use_msle=True) +""" + +from __future__ import annotations + +from typing import Any, Dict, List, Optional, Sequence + +import torch +import torch.nn as nn +import torch.nn.functional as F + +from pyhealth.datasets import SampleDataset +from .base_model import BaseModel + + +class MSLELoss(nn.Module): + def __init__(self) -> None: + super().__init__() + self.squared_error = nn.MSELoss(reduction="none") + + def forward(self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor, seq_length: torch.Tensor, sum_losses: bool = False) -> torch.Tensor: + mask = mask.bool() + eps = 1e-8 + log_y_hat = torch.where(mask, torch.log(y_hat.clamp_min(eps)), torch.zeros_like(y_hat)) + log_y = torch.where(mask, torch.log(y.clamp_min(eps)), torch.zeros_like(y)) + loss = self.squared_error(log_y_hat, log_y) + loss = torch.sum(loss, dim=1) + if not sum_losses: + loss = loss / seq_length.clamp(min=1).float() + return loss.mean() + + +class MaskedMSELoss(nn.Module): + def __init__(self) -> None: + super().__init__() + self.squared_error = nn.MSELoss(reduction="none") + + def forward(self, y_hat: torch.Tensor, y: torch.Tensor, mask: torch.Tensor, seq_length: torch.Tensor, sum_losses: bool = False) -> torch.Tensor: + mask = mask.bool() + y_hat = torch.where(mask, y_hat, torch.zeros_like(y_hat)) + y = torch.where(mask, y, torch.zeros_like(y)) + loss = self.squared_error(y_hat, y) + loss = torch.sum(loss, dim=1) + if not sum_losses: + loss = loss / seq_length.clamp(min=1).float() + return loss.mean() + + +class TPC(BaseModel): + def __init__( + self, + dataset: SampleDataset, + timeseries_key: str = "timeseries", + static_key: Optional[str] = "static", + conditions_key: Optional[str] = "conditions", + n_layers: int = 3, + kernel_size: int = 4, + temp_kernels: Optional[Sequence[int]] = None, + point_sizes: Optional[Sequence[int]] = None, + diagnosis_size: int = 64, + last_linear_size: int = 64, + main_dropout_rate: float = 0.3, + temp_dropout_rate: float = 0.3, + time_before_pred: int = 5, + use_msle: bool = True, + sum_losses: bool = False, + apply_exp: bool = True, + ) -> None: + super().__init__(dataset=dataset) + self.mode = "regression" + + if len(self.label_keys) != 1: + raise ValueError("tpc supports exactly one label key") + if n_layers < 1: + raise ValueError("n_layers must be >= 1") + + self.label_key = self.label_keys[0] + self.timeseries_key = timeseries_key + self.static_key = static_key if static_key in self.feature_keys else None + self.conditions_key = ( + conditions_key if conditions_key in self.feature_keys else None + ) + + self.n_layers = n_layers + self.kernel_size = kernel_size + self.temp_kernels: List[int] = ( + list(temp_kernels) if temp_kernels is not None else [8] * n_layers + ) + self.point_sizes: List[int] = ( + list(point_sizes) if point_sizes is not None else [14] * n_layers + ) + if len(self.temp_kernels) != n_layers: + raise ValueError("temp_kernels must have exactly n_layers entries") + if len(self.point_sizes) != n_layers: + raise ValueError("point_sizes must have exactly n_layers entries") + + self.diagnosis_size = diagnosis_size + self.last_linear_size = last_linear_size + self.main_dropout_rate = main_dropout_rate + self.temp_dropout_rate = temp_dropout_rate + self.time_before_pred = time_before_pred + self.use_msle = use_msle + self.sum_losses = sum_losses + self.apply_exp = apply_exp + + self.relu = nn.ReLU() + + self.hardtanh = nn.Hardtanh(min_val=1.0 / 48.0, max_val=100.0) + self.main_dropout = nn.Dropout(p=self.main_dropout_rate) + self.temp_dropout = nn.Dropout(p=self.temp_dropout_rate) + + self.loss_fn: nn.Module = MSLELoss() if use_msle else MaskedMSELoss() + + sample = dataset[0] + if self.timeseries_key not in sample: + raise KeyError( + f"timeseries_key '{self.timeseries_key}' not found in dataset sample" + f"available keys: {list(sample.keys())}" + ) + ts_sample: torch.Tensor = sample[self.timeseries_key] + if ts_sample.dim() != 2: + raise ValueError( + f"Each timeseries sample must be 2-D (channels, time) or (time, channels), " + f"got shape {tuple(ts_sample.shape)}" + ) + + num_channels = min(ts_sample.shape) + if (num_channels - 2) % 2 != 0 or num_channels < 4: + raise ValueError( + "timeseries channel dimension must equal 2F+2 with F >= 1" + f"Detected smallest dim = {num_channels}." + ) + self.F: int = (num_channels - 2) // 2 + + self.no_flat_features: int = 0 + if self.static_key is not None: + static_sample: torch.Tensor = sample[self.static_key] + self.no_flat_features = ( + 1 if static_sample.dim() == 0 else int(static_sample.shape[-1]) + ) + + self.D: int = 0 + self.diagnosis_encoder: Optional[nn.Linear] = None + self.bn_diagnosis_encoder: Optional[nn.BatchNorm1d] = None + + if self.conditions_key is not None: + self.D = self.dataset.input_processors[self.conditions_key].size() + self.diagnosis_encoder = nn.Linear(self.D, self.diagnosis_size) + self.bn_diagnosis_encoder = nn.BatchNorm1d(self.diagnosis_size) + + + self.bn_point_last_los = nn.BatchNorm1d(self.last_linear_size) + self._init_tpc() + self.point_final_los = nn.Linear(self.last_linear_size, 1) + + def _init_tpc(self) -> None: + self._layer_info: List[Dict[str, Any]] = [] + for i in range(self.n_layers): + dilation = i * (self.kernel_size - 1) if i > 0 else 1 + padding = [(self.kernel_size - 1) * dilation, 0] + self._layer_info.append( + { + "temp_kernels": self.temp_kernels[i], + "point_size": self.point_sizes[i], + "dilation": dilation, + "padding": padding, + "stride": 1, + } + ) + + self._create_temp_pointwise_layers() + + input_size = ( + (self.F + self._Zt) * (1 + self._Y) + + self.diagnosis_size + + self.no_flat_features + ) + self.point_last_los = nn.Linear(input_size, self.last_linear_size) + + def _create_temp_pointwise_layers(self) -> None: + self.layer_modules = nn.ModuleDict() + Y = 0 + Z = 0 + Zt = 0 + + for i in range(self.n_layers): + temp_in = (self.F + Zt) * (1 + Y) if i > 0 else 2 * self.F + + temp_out = (self.F + Zt) * self.temp_kernels[i] + + point_in = ( + (self.F + Zt - Z) * Y + + Z + + 2 * self.F + + 2 + + self.no_flat_features + ) + point_out = self.point_sizes[i] + + self.layer_modules[str(i)] = nn.ModuleDict( + { + "temp": nn.Conv1d( + in_channels=temp_in, + out_channels=temp_out, + kernel_size=self.kernel_size, + stride=self._layer_info[i]["stride"], + dilation=self._layer_info[i]["dilation"], + groups=self.F + Zt, + ), + "bn_temp": nn.BatchNorm1d(temp_out), + "point": nn.Linear(point_in, point_out), + "bn_point": nn.BatchNorm1d(point_out), + } + ) + + Y = self.temp_kernels[i] + Z = point_out + Zt += Z + + self._Y = Y + self._Zt = Zt + + + def _normalize_timeseries(self, x: torch.Tensor) -> torch.Tensor: + x = x.to(self.device, dtype=torch.float32) + if x.dim() != 3: + raise ValueError( + f"expected a 3-D batched timeseries (B, C, T) or (B, T, C), " + f"got shape {tuple(x.shape)}." + ) + expected_c = 2 * self.F + 2 + if x.shape[1] == expected_c: + return x + if x.shape[2] == expected_c: + return x.transpose(1, 2) + raise ValueError( + f"cannot identify channel dimension of size {expected_c} in " + f"timeseries shape {tuple(x.shape)}." + ) + + def _prepare_static(self, batch_size: int, kwargs: Dict[str, Any]) -> torch.Tensor: + if self.static_key is None or self.no_flat_features == 0: + return torch.zeros(batch_size, 0, device=self.device) + flat = kwargs[self.static_key].to(self.device, dtype=torch.float32) + if flat.dim() == 1: + flat = flat.unsqueeze(-1) + return flat + + def _prepare_diagnoses(self, batch_size: int, kwargs: Dict[str, Any]) -> torch.Tensor: + if self.conditions_key is None or self.D == 0 or self.diagnosis_encoder is None: + return torch.zeros(batch_size, self.diagnosis_size, device=self.device) + + codes = kwargs[self.conditions_key].to(self.device) + if codes.dim() == 1: + codes = codes.unsqueeze(0) + + multi_hot = torch.zeros(batch_size, self.D, device=self.device) + valid = codes >= 0 + safe_codes = codes.masked_fill(~valid, 0) + multi_hot.scatter_add_(1, safe_codes, valid.float()) + multi_hot[:, 0] = 0.0 + + diag_enc = self.relu( + self.main_dropout( + self.bn_diagnosis_encoder(self.diagnosis_encoder((multi_hot > 0).float())) + ) + ) + return diag_enc + + def _temp_pointwise( + self, + B: int, + T: int, + X: torch.Tensor, + X_orig: torch.Tensor, + repeat_flat: torch.Tensor, + temp: nn.Conv1d, + bn_temp: nn.BatchNorm1d, + point: nn.Linear, + bn_point: nn.BatchNorm1d, + temp_kernels: int, + padding: List[int], + prev_temp: Optional[torch.Tensor], + prev_point: Optional[torch.Tensor], + point_skip: torch.Tensor, + ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: + X_padded = F.pad(X, padding, "constant", 0) + X_temp = self.temp_dropout(bn_temp(temp(X_padded))) + + C_feat_groups = X_temp.shape[1] // temp_kernels + + concat_parts: List[torch.Tensor] = [] + if prev_temp is not None: + concat_parts.append(prev_temp) + if prev_point is not None: + concat_parts.append(prev_point) + concat_parts.append(X_orig) + if repeat_flat.shape[1] > 0: + concat_parts.append(repeat_flat) + X_concat = torch.cat(concat_parts, dim=1) + point_out = self.main_dropout(bn_point(point(X_concat))) + + if prev_point is not None: + Z_prev = prev_point.shape[1] + point_skip = torch.cat( + [point_skip, prev_point.view(B, T, Z_prev).permute(0, 2, 1)], + dim=1, + ) + + temp_4d = X_temp.view(B, C_feat_groups, temp_kernels, T) + skip_4d = point_skip.unsqueeze(2) + temp_stack = torch.cat([skip_4d, temp_4d], dim=2) + + point_size = point_out.shape[1] + point_4d = ( + point_out.view(B, T, point_size) + .permute(0, 2, 1) + .unsqueeze(2) + .expand(-1, -1, 1 + temp_kernels, -1) + ) + + combined = self.relu( + torch.cat([temp_stack, point_4d], dim=1) + ) + + next_X = combined.view(B, (C_feat_groups + point_size) * (1 + temp_kernels), T) + + temp_out = ( + X_temp.permute(0, 2, 1) + .contiguous() + .view(B * T, C_feat_groups * temp_kernels) + ) + + return temp_out, point_out, next_X, point_skip + + + def forward(self, return_full_sequence: bool = False, **kwargs: Any) -> Dict[str, torch.Tensor]: + X = self._normalize_timeseries(kwargs[self.timeseries_key]) + B, _, T = X.shape + + if T <= self.time_before_pred: + raise ValueError( + f"Sequence length T={T} must be greater than " + f"time_before_pred={self.time_before_pred}." + ) + + flat = self._prepare_static(B, kwargs) + diagnoses_enc = self._prepare_diagnoses(B, kwargs) + + X_orig = X.permute(0, 2, 1).contiguous().view(B * T, 2 * self.F + 2) + repeat_flat = flat.repeat_interleave(T, dim=0) + + values = X[:, 1 : self.F + 1, :] + decay = X[:, self.F + 1 : 2 * self.F + 1, :] + + next_X = torch.stack([values, decay], dim=2).reshape(B, 2 * self.F, T) + + point_skip = values + + prev_temp: Optional[torch.Tensor] = None + prev_point: Optional[torch.Tensor] = None + + for i in range(self.n_layers): + mods = self.layer_modules[str(i)] + prev_temp, prev_point, next_X, point_skip = self._temp_pointwise( + B=B, + T=T, + X=next_X, + X_orig=X_orig, + repeat_flat=repeat_flat, + temp=mods["temp"], + bn_temp=mods["bn_temp"], + point=mods["point"], + bn_point=mods["bn_point"], + temp_kernels=self._layer_info[i]["temp_kernels"], + padding=self._layer_info[i]["padding"], + prev_temp=prev_temp, + prev_point=prev_point, + point_skip=point_skip, + ) + + post_hist = T - self.time_before_pred + + ts_features = ( + next_X[:, :, self.time_before_pred :] + .permute(0, 2, 1) + .contiguous() + .view(B * post_hist, -1) + ) + + combined_features = torch.cat( + [ + flat.repeat_interleave(post_hist, dim=0), + diagnoses_enc.repeat_interleave(post_hist, dim=0), + ts_features, + ], + dim=1, + ) + + last_hidden = self.relu( + self.main_dropout( + self.bn_point_last_los(self.point_last_los(combined_features)) + ) + ) + + raw_pred = self.point_final_los(last_hidden).view(B, post_hist) + if self.apply_exp: + raw_pred = torch.exp(raw_pred) + los_pred = self.hardtanh(raw_pred) + + output: Dict[str, torch.Tensor] = { + "logit": los_pred if return_full_sequence else los_pred.reshape(-1), + "y_prob": los_pred if return_full_sequence else los_pred.reshape(-1), + } + + if self.label_key in kwargs: + y_true = kwargs[self.label_key].to(self.device, dtype=torch.float32) + if y_true.dim() == 3 and y_true.shape[-1] == 1: + y_true = y_true.squeeze(-1) + y_true_post = y_true[:, self.time_before_pred :] + + mask = y_true_post > 0 + seq_lengths = mask.sum(dim=1) + loss = self.loss_fn(los_pred, y_true_post, mask, seq_lengths, self.sum_losses) + + output["loss"] = loss + if return_full_sequence: + output["y_true"] = y_true_post + output["mask"] = mask + else: + flat_mask = mask.reshape(-1) + output["y_true"] = los_pred.reshape(-1)[flat_mask] + output["y_prob"] = los_pred.reshape(-1)[flat_mask] + output["logit"] = los_pred.reshape(-1)[flat_mask] + output["y_true"] = y_true_post.reshape(-1)[flat_mask] + + if kwargs.get("embed", False): + output["embed"] = combined_features.view(B, post_hist, -1).mean(dim=1) + + return output + + def predict_with_uncertainty(self, mc_samples: int = 30, **kwargs: Any) -> Dict[str, torch.Tensor]: + + if mc_samples < 1: + raise ValueError("mc_samples must be >= 1.") + + was_training = self.training + self.train() + + samples: List[torch.Tensor] = [] + mask: Optional[torch.Tensor] = None + y_true: Optional[torch.Tensor] = None + + with torch.no_grad(): + for _ in range(mc_samples): + out = self.forward(return_full_sequence=True, **kwargs) + samples.append(out["y_prob"]) + if mask is None: + mask = out.get("mask") + if y_true is None: + y_true = out.get("y_true") + + if not was_training: + self.eval() + + stacked = torch.stack(samples, dim=0) + result: Dict[str, torch.Tensor] = { + "mean": stacked.mean(dim=0), + "std": stacked.std(dim=0), + "samples": stacked, + } + if mask is not None: + result["mask"] = mask + if y_true is not None: + result["y_true"] = y_true + return result \ No newline at end of file diff --git a/pyhealth/tasks/__init__.py b/pyhealth/tasks/__init__.py index 797988377..796fbdfab 100644 --- a/pyhealth/tasks/__init__.py +++ b/pyhealth/tasks/__init__.py @@ -29,6 +29,7 @@ LengthOfStayPredictionOMOP, ) from .length_of_stay_stagenet_mimic4 import LengthOfStayStageNetMIMIC4 +from .length_of_stay_tpc_mimic4 import RemainingLOSMIMIC4 from .medical_coding import MIMIC3ICD9Coding from .medical_transcriptions_classification import MedicalTranscriptionsClassification from .mortality_prediction import ( diff --git a/pyhealth/tasks/length_of_stay_tpc_mimic4.py b/pyhealth/tasks/length_of_stay_tpc_mimic4.py new file mode 100644 index 000000000..a783a9efb --- /dev/null +++ b/pyhealth/tasks/length_of_stay_tpc_mimic4.py @@ -0,0 +1,232 @@ +""" +Remaining ICU Length-of-Stay Prediction Task for MIMIC-IV (TPC Format) + +Contributors: + - Pankaj Meghani, Tarak Jha, Pranash Krishnan + - meghani3, tarakj2, pranash2 + +Paper: + Title: Temporal Pointwise Convolutional Networks for Length of Stay + Prediction in the Intensive Care Unit + Authors: Emma Rocheteau, Pietro Liò, Stephanie Hyland + Conference: CHIL 2021 (Conference on Health, Inference, and Learning) + Link: https://arxiv.org/abs/2007.09483 + +Description: + Task definition for remaining ICU length-of-stay prediction compatible with + the TPC model architecture. Unlike traditional LoS tasks that predict total + stay duration at admission, this task generates hourly predictions of remaining + time throughout the ICU stay. + + Features: + - Hourly timeseries from chartevents (17 vitals) and labevents (17 labs) + - Forward-filled values with decay indicators (time since last measurement) + - Static patient demographics (age, sex) + - ICD diagnosis codes + + Output: Remaining hours in ICU at each timestep + +Usage: + >>> from pyhealth.datasets import MIMIC4EHRDataset + >>> from pyhealth.tasks import RemainingLOSMIMIC4 + >>> + >>> mimic4 = MIMIC4EHRDataset(root="path/to/mimic-iv", + ... tables=["chartevents", "labevents", "diagnoses_icd"]) + >>> dataset = mimic4.set_task(RemainingLOSMIMIC4()) +""" + +from __future__ import annotations + +from pyhealth.tasks import BaseTask +from pyhealth.data import Event, Patient +from typing import List, Dict, Any, Type, Union, cast + +from dataclasses import dataclass, field +from datetime import datetime, timedelta +from typing import Dict, List, Optional, Sequence, Tuple + +import math +import numpy as np + +from pyhealth.data.data import Patient +from pyhealth.tasks.base_task import BaseTask + +from pyhealth.processors import TemporalTimeseriesProcessor, TensorProcessor, SequenceProcessor +from pyhealth.processors.base_processor import Processor +import polars as pl + +@dataclass +class RemainingLOSConfig: + prediction_step_size: int = 1 + min_history_hours: int = 5 + min_remaining_hours: int = 1 + max_history_hours: int = 366 + +# Formally, our task is to predict the remaining LoS at regular +# timepoints 𝑦1, . . . , 𝑦𝑇 ∈ R>0 in the patient’s ICU stay, up to the +# discharge time 𝑇 , using the diagnoses (d ∈ R𝐷×1), static features +# (s ∈ R𝑆×1), and time series (x1, . . . , x𝑇 ∈ R𝐹 ×2). Initially, for every +# timepoint 𝑡, there are two ‘channels’ per time series feature: 𝐹 fea- +# ture values (x′𝑡 ∈ R𝐹 ×1), and their corresponding decay indicators +# (x′′𝑡 ∈ R𝐹 ×1). The decay indicators tell the model how recently +# the observation x′𝑡 was recorded. + +class RemainingLOSMIMIC4(BaseTask): + """ + Custom remaining length-of-stay regression task for MIMIC-IV. + + Each sample corresponds to one prediction cutoff time within a stay. + Input: + - ts: (timestamps, values) where values is shape (T, F) + - optionally static / code features + Target: + - remaining_los_hours: float + """ + + task_name: str = "RemainingLOSMIMIC4" + + # Keep this conceptual unless you already know the exact schema names + # your installed PyHealth version expects. + input_schema: Dict[str, Any] = { + "timeseries": TensorProcessor(), + "static": TensorProcessor(), + "conditions": SequenceProcessor(), + } + + output_schema: Dict[str, Any] = {"los": "tensor"} + + def __init__(self, config: Optional[RemainingLOSConfig] = None): + self.config = config or RemainingLOSConfig() + + # TODO: These need to be fixed + # 17 vitals from chartevents + self.chart_itemids = [ + "220045", "220210", "220277", "220179", "220180", "220181", + "220050", "220051", "220052", "223761", "220739", "223900", + "223901", "226253", "220235", "224690", "220339" + ] + # 17 lab items from labevents + self.lab_itemids = [ + "51006", "50912", "50931", "50902", "50882", "50868", + "50960", "50970", "51265", "51301", "50811", "51222", + "50813", "50820", "50818", "50821", "50825" + ] + self.all_itemids = self.chart_itemids + self.lab_itemids + self.F = len(self.all_itemids) + + + def __call__(self, patient: Patient) -> List[Dict]: + samples: List[Dict] = [] + + admissions_result = patient.get_events(event_type="icustays") + # Handle both DataFrame and List[Event] return types + if isinstance(admissions_result, list): + admissions = admissions_result + else: + return [] + + if len(admissions) == 0: + return [] + + patient_static_attributes = patient.get_events("patients")[0] + + static = np.array([patient_static_attributes['anchor_age'], 1. if patient_static_attributes['gender'] == 'F' else 0.], dtype=np.float32) + + for admission in admissions: + + admit_time = admission.timestamp + # outtime is usually a string in attributes + outtime_raw = admission.outtime + discharge_time = datetime.strptime(outtime_raw, "%Y-%m-%d %H:%M:%S") + + los_hours = (discharge_time - admit_time).total_seconds() / 3600.0 + T = min(int(math.ceil(los_hours)), self.config.max_history_hours) + + if admit_time is None or discharge_time is None: + continue + if discharge_time <= admit_time: + continue + if los_hours < self.config.min_history_hours + self.config.min_remaining_hours: + continue + + labevents = patient.get_events( + event_type="labevents", + start=admission.timestamp, + end=discharge_time, + ) + chartevents = patient.get_events( + event_type="chartevents", + start=admission.timestamp, + end=discharge_time, + ) + diagnoses_events = patient.get_events( + event_type="diagnoses_icd", + filters=[("hadm_id", "==", admission.hadm_id)], + ) + conditions = [ + f"{getattr(event, 'icd_version', '10')}_{event.icd_code}" + for event in diagnoses_events if hasattr(event, "icd_code") + ] + + all_events = labevents + chartevents + + if len(all_events) == 0: + continue + + values_mat = np.zeros((self.F, T), dtype=np.float32) + masks_mat = np.zeros((self.F, T), dtype=np.float32) + + id_to_idx = {id: i for i, id in enumerate(self.all_itemids)} + # Pivot events into matrix + for row in all_events: + if row["itemid"] not in id_to_idx: + continue + f_idx = id_to_idx.get(row["itemid"]) + h_idx = int((row["timestamp"].timestamp() - admit_time.timestamp() ) // 3600) + if f_idx is not None and 0 <= h_idx < T: + values_mat[f_idx, h_idx] = row["valuenum"] + masks_mat[f_idx, h_idx] = 1.0 + + # Forward-fill and calculate decay indicators + # Decay = 0.75 ** (hours since last measurement) + decay_mat = np.zeros((self.F, T), dtype=np.float32) + for f in range(self.F): + last_val = 0.0 + hours_since = math.inf # Initial large value for decay + for t in range(T): + if masks_mat[f, t] > 0: + last_val = values_mat[f, t] + hours_since = 0.0 + else: + values_mat[f, t] = last_val + hours_since += 1.0 + decay_mat[f, t] = 0.75 ** hours_since + + # Elapsed time channel + elapsed = np.arange(T, dtype=np.float32).reshape(1, T) + + # hour_of_day channel + hour_of_day = np.array([ + (admit_time + timedelta(hours=t)).hour + for t in range(T) + ], dtype=np.float32).reshape(1, T) + + # Concatenate all channels: [elapsed (1), values (F), decays (F), hour_of_day (1)] -> (2F+2, T) + timeseries = np.concatenate([elapsed, values_mat, decay_mat, hour_of_day], axis=0) + + # Label sequence: remaining LoS in hours at each hour + labels = np.array([ + max(0.0, (discharge_time - (admit_time + timedelta(hours=t))).total_seconds() / (3600.0)) + for t in range(T) + ], dtype=np.float32) + + samples.append({ + "patient_id": patient.patient_id, + "visit_id": admission.stay_id , + "timeseries": timeseries, + "static": static, + "conditions": conditions, + "los": labels, + }) + + return samples diff --git a/test-resources/core/mimic4demo/icu/chartevents.csv b/test-resources/core/mimic4demo/icu/chartevents.csv new file mode 100644 index 000000000..cbca898ca --- /dev/null +++ b/test-resources/core/mimic4demo/icu/chartevents.csv @@ -0,0 +1,20 @@ +subject_id,hadm_id,stay_id,caregiver_id,charttime,storetime,itemid,value,valuenum,valueuom,warning +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:45:00,225054,On ,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:43:00,223769,100,100,%,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:47:00,223956,Atrial demand,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:47:00,224866,Yes,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:45:00,227341,No,0,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:47:00,224751,52,52,bpm,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:44:00,227969,"Quiet, calm space",,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:46:00,223935,Doppler,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:48:00,223782,Intermittent,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:47:00,224773,Cool,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:48:00,223784,Cough/Deep Breath,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:41:00,220047,55,55,bpm,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:42:00,220073,0,0,mmHg,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:44:00,227969,Bed locked in low position,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:47:00,223983,Ashen,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:47:00,224752,0.8,0.8,mV,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-16 00:02:00,220048,SR (Sinus Rhythm),,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-16 04:17:00,224093,Right Side,,,0 +10005817,20626031,32604416,6770,2132-12-16 00:00:00,2132-12-15 23:42:00,220066,0,0,mmHg,0 \ No newline at end of file diff --git a/test-resources/core/mimic4demo/icu/d_items.csv.gz b/test-resources/core/mimic4demo/icu/d_items.csv.gz new file mode 100644 index 000000000..bb8b39d96 Binary files /dev/null and b/test-resources/core/mimic4demo/icu/d_items.csv.gz differ diff --git a/tests/core/test_tpc.py b/tests/core/test_tpc.py new file mode 100644 index 000000000..f5d94b320 --- /dev/null +++ b/tests/core/test_tpc.py @@ -0,0 +1,417 @@ +""" +Unit tests for TPC (Temporal Pointwise Convolution) model. + +Tests include: +- Model initialization with various configurations +- Forward pass with synthetic data +- Output shape validation +- Gradient computation (backward pass) +- MC Dropout uncertainty estimation +- Custom hyperparameters +""" + +import unittest +import torch +import numpy as np + +from pyhealth.datasets import create_sample_dataset, get_dataloader +from pyhealth.models import TPC, MSLELoss, MaskedMSELoss + + +class TestTPC(unittest.TestCase): + """Test cases for the TPC model and its components.""" + + def setUp(self): + """Set up test data with MINIMAL synthetic samples.""" + # Create minimal synthetic data - 2 samples only! + # TPC expects specific input format from RemainingLOSMIMIC4 task + + F = 4 # Number of clinical features (keep SMALL for fast tests) + T = 10 # Sequence length (keep SHORT) + + # Sample 1: Normal case + self.samples = [ + { + "patient_id": "p0", + "visit_id": "v0", + # timeseries shape: (2F+2, T) = (10, 10) + # [elapsed(1), values(F), decay(F), hour_of_day(1)] + "timeseries": torch.randn(2 * F + 2, T), + # static features: [age, sex] + "static": torch.tensor([65.0, 1.0]), + # diagnosis codes (will be processed by SequenceProcessor) + "conditions": ["icd_A01", "icd_B02"], + # remaining LoS labels in hours + "los": torch.rand(T) * 48, # 0-48 hours + }, + # Sample 2: Another normal case + { + "patient_id": "p1", + "visit_id": "v1", + "timeseries": torch.randn(2 * F + 2, T), + "static": torch.tensor([72.0, 0.0]), + "conditions": ["icd_A01"], + "los": torch.rand(T) * 24, # 0-24 hours + }, + ] + + # Define schema matching TPC requirements + self.input_schema = { + "timeseries": "tensor", + "static": "tensor", + "conditions": "sequence", + } + self.output_schema = {"los": "tensor"} + + # Create synthetic dataset + self.dataset = create_sample_dataset( + samples=self.samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="tpc_test", + ) + + # Create model with SMALL configuration for fast testing + self.model = TPC( + dataset=self.dataset, + n_layers=2, # Minimal layers + temp_kernels=[4, 4], # Small kernel counts + point_sizes=[8, 8], # Small hidden sizes + diagnosis_size=16, # Small diagnosis encoding + last_linear_size=16, # Small final layer + time_before_pred=3, # Minimum history + ) + + def test_model_initialization(self): + """Test that TPC initializes correctly with all parameters.""" + self.assertIsInstance(self.model, TPC) + + # Check configuration parameters + self.assertEqual(self.model.n_layers, 2) + self.assertEqual(len(self.model.temp_kernels), 2) + self.assertEqual(len(self.model.point_sizes), 2) + self.assertEqual(self.model.diagnosis_size, 16) + self.assertEqual(self.model.last_linear_size, 16) + self.assertEqual(self.model.time_before_pred, 3) + + # Check feature keys + self.assertIn("timeseries", self.model.feature_keys) + self.assertIn("static", self.model.feature_keys) + self.assertIn("conditions", self.model.feature_keys) + + # Check label key + self.assertEqual(self.model.label_key, "los") + + # Check mode + self.assertEqual(self.model.mode, "regression") + + def test_model_forward_pass(self): + """Test TPC forward pass produces correct output structure.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + output = self.model(**data_batch) + + # Check required output keys + self.assertIn("loss", output) + self.assertIn("y_prob", output) + self.assertIn("y_true", output) + self.assertIn("logit", output) + + # Check output types + self.assertIsInstance(output["loss"], torch.Tensor) + self.assertIsInstance(output["y_prob"], torch.Tensor) + self.assertIsInstance(output["y_true"], torch.Tensor) + + # Check loss is scalar + self.assertEqual(output["loss"].dim(), 0, "Loss should be scalar") + + def test_output_shapes(self): + """Test that TPC produces correct output shapes.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + output = self.model(**data_batch) + + # With time_before_pred=3 and T=10, we get 7 prediction timesteps per sample + # Flattened predictions = batch_size * (T - time_before_pred) + B = 2 + T = 10 + expected_flat_predictions = B * (T - self.model.time_before_pred) + + # Check flattened output shapes (default behavior) + self.assertEqual( + output["y_prob"].shape[0], + expected_flat_predictions, + "y_prob should be flattened predictions" + ) + self.assertEqual( + output["y_true"].shape[0], + expected_flat_predictions, + "y_true should be flattened labels" + ) + + def test_output_shapes_full_sequence(self): + """Test that TPC can return full sequences when requested.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + output = self.model(return_full_sequence=True, **data_batch) + + B = 2 + T = 10 + post_hist = T - self.model.time_before_pred + + # Check full sequence output shapes + self.assertEqual( + output["y_prob"].shape, + (B, post_hist), + f"y_prob should be (batch, post_hist) = ({B}, {post_hist})" + ) + self.assertEqual( + output["y_true"].shape, + (B, post_hist), + f"y_true should be (batch, post_hist) = ({B}, {post_hist})" + ) + self.assertIn("mask", output, "Should include mask in full sequence mode") + + def test_backward_pass(self): + """Test that gradients are computed correctly.""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + # Forward pass + output = self.model(**data_batch) + loss = output["loss"] + + # Backward pass + loss.backward() + + # Check that at least some parameters have gradients + has_gradient = any( + param.requires_grad and param.grad is not None + for param in self.model.parameters() + ) + self.assertTrue( + has_gradient, + "Model parameters should have gradients after backward pass" + ) + + def test_custom_hyperparameters(self): + """Test TPC with different hyperparameter configurations.""" + # Test with MSLE loss + model_msle = TPC( + dataset=self.dataset, + n_layers=1, + temp_kernels=[8], + point_sizes=[12], + use_msle=True, + apply_exp=True, + ) + + self.assertEqual(model_msle.n_layers, 1) + self.assertIsInstance(model_msle.loss_fn, MSLELoss) + self.assertTrue(model_msle.apply_exp) + + # Test with MSE loss + model_mse = TPC( + dataset=self.dataset, + n_layers=1, + temp_kernels=[8], + point_sizes=[12], + use_msle=False, + apply_exp=False, + ) + + self.assertIsInstance(model_mse.loss_fn, MaskedMSELoss) + self.assertFalse(model_mse.apply_exp) + + # Verify both models can run forward pass + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + with torch.no_grad(): + out_msle = model_msle(**data_batch) + out_mse = model_mse(**data_batch) + + self.assertIn("loss", out_msle) + self.assertIn("loss", out_mse) + + def test_mc_dropout_uncertainty(self): + """Test Monte Carlo Dropout uncertainty estimation (ablation study feature).""" + train_loader = get_dataloader(self.dataset, batch_size=2, shuffle=False) + data_batch = next(iter(train_loader)) + + # Run MC Dropout with small sample count for speed + mc_samples = 5 + with torch.no_grad(): + uncertainty_output = self.model.predict_with_uncertainty( + mc_samples=mc_samples, + **data_batch + ) + + # Check required output keys + self.assertIn("mean", uncertainty_output) + self.assertIn("std", uncertainty_output) + self.assertIn("samples", uncertainty_output) + + # Check shapes + B = 2 + T = 10 + post_hist = T - self.model.time_before_pred + + self.assertEqual( + uncertainty_output["mean"].shape, + (B, post_hist), + "Mean should have shape (batch, post_hist)" + ) + self.assertEqual( + uncertainty_output["std"].shape, + (B, post_hist), + "Std should have shape (batch, post_hist)" + ) + self.assertEqual( + uncertainty_output["samples"].shape, + (mc_samples, B, post_hist), + f"Samples should have shape ({mc_samples}, {B}, {post_hist})" + ) + + # Verify uncertainty values are reasonable + self.assertTrue( + torch.all(uncertainty_output["std"] >= 0), + "Standard deviation should be non-negative" + ) + + def test_loss_functions(self): + """Test custom loss functions (MSLELoss and MaskedMSELoss).""" + B, T = 2, 10 + + # Create synthetic predictions and targets + y_hat = torch.rand(B, T) * 10 # Predictions 0-10 hours + y = torch.rand(B, T) * 10 # True values 0-10 hours + mask = torch.ones(B, T, dtype=torch.bool) + mask[:, :3] = False # Mask out first 3 timesteps + seq_length = mask.sum(dim=1) + + # Test MSLELoss + msle_loss_fn = MSLELoss() + msle_loss = msle_loss_fn(y_hat, y, mask, seq_length, sum_losses=False) + + self.assertIsInstance(msle_loss, torch.Tensor) + self.assertEqual(msle_loss.dim(), 0, "Loss should be scalar") + self.assertTrue(msle_loss >= 0, "MSLE should be non-negative") + + # Test MaskedMSELoss + mse_loss_fn = MaskedMSELoss() + mse_loss = mse_loss_fn(y_hat, y, mask, seq_length, sum_losses=False) + + self.assertIsInstance(mse_loss, torch.Tensor) + self.assertEqual(mse_loss.dim(), 0, "Loss should be scalar") + self.assertTrue(mse_loss >= 0, "MSE should be non-negative") + + def test_minimal_config(self): + """Test TPC with absolute minimum configuration.""" + # Single layer, smallest possible model + tiny_model = TPC( + dataset=self.dataset, + n_layers=1, + temp_kernels=[2], + point_sizes=[4], + diagnosis_size=8, + last_linear_size=8, + time_before_pred=2, + ) + + self.assertEqual(tiny_model.n_layers, 1) + + # Verify it can still run + train_loader = get_dataloader(self.dataset, batch_size=1, shuffle=False) + data_batch = next(iter(train_loader)) + + tiny_model.eval() # Set to eval mode for batch_size=1 + with torch.no_grad(): + output = tiny_model(**data_batch) + + self.assertIn("loss", output) + + def test_edge_case_short_sequence(self): + """Test TPC behavior with very short sequences (edge case).""" + # Create dataset with minimum viable sequence length + F = 4 + T = 10 # Enough for time_before_pred=3 and some predictions + + short_samples = [ + { + "patient_id": "p0", + "visit_id": "v0", + "timeseries": torch.randn(2 * F + 2, T), + "static": torch.tensor([65.0, 1.0]), + "conditions": ["icd_A01"], + "los": torch.rand(T) * 48, + } + ] + + short_dataset = create_sample_dataset( + samples=short_samples, + input_schema=self.input_schema, + output_schema=self.output_schema, + dataset_name="tpc_short", + ) + + short_model = TPC( + dataset=short_dataset, + n_layers=1, + temp_kernels=[4], + point_sizes=[8], + time_before_pred=3, + ) + + loader = get_dataloader(short_dataset, batch_size=1, shuffle=False) + data_batch = next(iter(loader)) + + short_model.eval() # Set to eval mode for batch_size=1 + with torch.no_grad(): + output = short_model(**data_batch) + + # Should still produce valid output + self.assertIn("loss", output) + self.assertIn("y_prob", output) + + +class TestTPCLossFunctions(unittest.TestCase): + """Separate test class for TPC loss functions.""" + + def test_msle_loss_properties(self): + """Test mathematical properties of MSLE loss.""" + msle = MSLELoss() + + # Test that identical predictions give ~zero loss + y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + y_hat = y.clone() + mask = torch.ones_like(y, dtype=torch.bool) + seq_length = torch.tensor([3, 3]) + + loss = msle(y_hat, y, mask, seq_length) + self.assertLess(loss.item(), 1e-5, "Loss should be near zero for identical predictions") + + def test_masked_mse_loss_masking(self): + """Test that MaskedMSELoss correctly ignores masked values.""" + mse = MaskedMSELoss() + + # Create data where masked values have large errors + y = torch.tensor([[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]) + y_hat = torch.tensor([[1.0, 2.0, 100.0], [4.0, 5.0, 200.0]]) # Large errors at end + mask = torch.tensor([[True, True, False], [True, True, False]]) # Mask the errors + seq_length = torch.tensor([2, 2]) + + loss = mse(y_hat, y, mask, seq_length) + + # Loss should be small because errors are masked + self.assertLess(loss.item(), 1.0, "Masked values should not contribute to loss") + + +if __name__ == "__main__": + unittest.main()