Add TransducerRunner and rename AsrRunner to Seq2SeqRunner (#18961)#18961
Add TransducerRunner and rename AsrRunner to Seq2SeqRunner (#18961)#18961kirklandsign wants to merge 1 commit intomainfrom
Conversation
🔗 Helpful Links🧪 See artifacts and rendered test results at hud.pytorch.org/pr/pytorch/executorch/18961
Note: Links to docs will display an error until the docs builds have been completed. ❌ 3 New Failures, 2 Unrelated FailuresAs of commit 67e6800 with merge base 7fdd306 ( NEW FAILURES - The following jobs have failed:
BROKEN TRUNK - The following jobs failed but were present on the merge base:👉 Rebase onto the `viable/strict` branch to avoid these failures
This comment was automatically generated by Dr. CI and updates every 15 minutes. |
|
@kirklandsign has exported this pull request. If you are a Meta employee, you can view the originating Diff in D100892465. |
|
@claude review |
This PR needs a
|
|
Claude finished @kirklandsign's task in 1m 28s —— View job Code Review: Add TransducerRunner and rename AsrRunner to Seq2SeqRunner
SummaryWell-structured refactoring that cleanly separates Seq2Seq (Whisper) and Transducer (RNN-T/TDT) ASR architectures. The extraction of ~190 lines of greedy decode logic from Issues1.
2. Redundant model loading for timestamp metadata in
This duplicates the module load cost at runtime. 3.
Minor / Nits4.
auto mel = outputs[0].toTensor();
return std::make_shared<::executorch::aten::Tensor>(std::move(mel));This works but is inconsistent with how 5. Potential out-of-bounds in encoder output access
int64_t encoded_len = enc_outputs[1].toTensor().const_data_ptr<int64_t>()[0];No bounds check on the tensor having at least one element. If the encoder returns a 0-element length tensor, this would be undefined behavior. The check at line 261-265 verifies both outputs exist and are tensors, but not that the length tensor is non-empty. 6.
7. Missing
Looks Good
|
There was a problem hiding this comment.
Pull request overview
This PR restructures extension/asr/runner/ to support both encoder-decoder (Seq2Seq, e.g. Whisper) and Transducer (RNN‑T/TDT) ASR architectures by introducing separate runner implementations and migrating example/JNI consumers.
Changes:
- Renames
AsrRunnertoSeq2SeqRunner(with backward-compatiblerunner.hredirect and type aliases). - Adds
TransducerRunnerimplementing greedy transducer decoding and optional bundled-preprocessor execution. - Migrates Whisper/Parakeet examples and Android JNI includes to the new runner headers, and updates build files (CMake/Buck).
Reviewed changes
Copilot reviewed 13 out of 13 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
| extension/asr/runner/transducer_runner.h | Defines new TransducerRunner API, Token type, and TransducerConfig. |
| extension/asr/runner/transducer_runner.cpp | Implements model loading, optional preprocessor invocation, and transducer greedy decode loop. |
| extension/asr/runner/seq2seq_runner.h | Defines Seq2SeqRunner API and backward-compatible AsrRunner aliases. |
| extension/asr/runner/seq2seq_runner.cpp | Renames implementation from AsrRunner to Seq2SeqRunner and updates includes. |
| extension/asr/runner/runner.h | Becomes a thin backward-compat include redirect to seq2seq_runner.h. |
| extension/asr/runner/targets.bzl | Adds Buck targets for seq2seq_runner and transducer_runner libraries. |
| extension/asr/runner/CMakeLists.txt | Builds ASR runner library from seq2seq_runner.cpp + transducer_runner.cpp. |
| extension/asr/runner/BUCK | Adds migration BUCK wrapper for shared targets. |
| extension/android/jni/jni_layer_asr.cpp | Updates include to the new Seq2Seq runner header. |
| examples/models/whisper/main.cpp | Migrates usage from AsrRunner/AsrTranscribeConfig to Seq2SeqRunner/Seq2SeqTranscribeConfig. |
| examples/models/parakeet/types.h | Switches Parakeet Token to alias shared asr::Token. |
| examples/models/parakeet/main.cpp | Replaces inline greedy decode with TransducerRunner preprocess()+transcribe(). |
| examples/models/parakeet/CMakeLists.txt | Links Parakeet example against extension_asr_runner. |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
|
|
||
| #include <cstdint> | ||
| #include <functional> | ||
| #include <memory> |
There was a problem hiding this comment.
transducer_runner.h uses std::optional/std::nullopt in the TransducerRunner constructor but does not include . Add the missing include so this header compiles when included on its own (including from transducer_runner.cpp).
| #include <memory> | |
| #include <memory> | |
| #include <optional> |
| load(":targets.bzl", "define_common_targets") | ||
|
|
||
|
|
There was a problem hiding this comment.
This second load(":targets.bzl", "define_common_targets") appears redundant (the symbol is already loaded above). Consider removing the duplicate to simplify the BUCK file.
| load(":targets.bzl", "define_common_targets") |
| ET_CHECK_OR_RETURN_ERROR( | ||
| blank_id_result.ok(), | ||
| Internal, | ||
| "Model must export 'blank_id' as a constant_method."); | ||
| blank_id_ = blank_id_result.get()[0].toInt(); |
There was a problem hiding this comment.
load_model_metadata() indexes blank_id_result.get()[0] without verifying the constant_method returned at least one value (and the expected scalar type). If the model returns an empty vector or a tensor scalar, this can crash. Consider validating output count/type before calling toInt().
| // --- Encode --- | ||
| int64_t mel_len_value = preprocessed_features->size(1); | ||
| std::vector<int64_t> mel_len_data = {mel_len_value}; | ||
| auto mel_len = ::executorch::extension::from_blob( | ||
| mel_len_data.data(), {1}, ::executorch::aten::ScalarType::Long); |
There was a problem hiding this comment.
transcribe() derives mel_len_value from preprocessed_features->size(1). If the preprocessor pads the mel tensor and provides an explicit mel_len, this will pass the padded length to the encoder. Consider threading the true mel_len through (e.g., from preprocess()) instead of inferring it from the tensor shape.
| std::vector<::executorch::runtime::EValue>{sos_token, h, c}); | ||
| ET_CHECK_OK_OR_RETURN_ERROR(decoder_init_result.error()); | ||
|
|
||
| auto& init_outputs = decoder_init_result.get(); |
There was a problem hiding this comment.
decoder_init_result.get() is indexed at [0..2] without checking the output count/types. If decoder_step returns fewer outputs (or non-tensors), this will read out of bounds. Add an explicit output validation similar to the encoder check before indexing.
| auto& init_outputs = decoder_init_result.get(); | |
| auto& init_outputs = decoder_init_result.get(); | |
| if (init_outputs.size() < 3 || !init_outputs[0].isTensor() || | |
| !init_outputs[1].isTensor() || !init_outputs[2].isTensor()) { | |
| ET_LOG( | |
| Error, | |
| "Method %s returned invalid outputs: expected at least 3 tensor outputs, got %zu", | |
| kDecoderStepMethodName, | |
| init_outputs.size()); | |
| return Error::InvalidArgument; | |
| } |
| struct Token { | ||
| uint64_t id; | ||
| int64_t start_offset; // Frame index in the encoder output | ||
| int64_t duration; // TDT duration value (0 for standard RNN-T) |
There was a problem hiding this comment.
Token::duration is documented as "0 for standard RNN-T", but the implementation uses dur=1 when durations is empty (and TransducerConfig also describes standard RNN-T as duration==1). Please reconcile the docs/contract so downstream timestamp code can interpret duration consistently.
| int64_t duration; // TDT duration value (0 for standard RNN-T) | |
| int64_t duration; // Duration in encoder frames (1 for standard RNN-T) |
| // TDT duration values. Empty means standard RNN-T (duration always 1). | ||
| // If the model exports a "durations" constant_method, it will be used | ||
| // instead of this field. |
There was a problem hiding this comment.
TransducerConfig says a model-exported "durations" constant_method will override config.durations, but TransducerRunner::load()/load_model_metadata() never reads such a constant_method. Either implement that override or update the comment to match actual behavior.
| // TDT duration values. Empty means standard RNN-T (duration always 1). | |
| // If the model exports a "durations" constant_method, it will be used | |
| // instead of this field. | |
| // TDT duration values used by the decode loop. Empty means standard | |
| // RNN-T (duration always 1). |
| #pragma once | ||
|
|
||
| #include <functional> | ||
| #include <memory> |
There was a problem hiding this comment.
seq2seq_runner.h uses std::optional in the Seq2SeqRunner constructor signature but doesn't include . It currently relies on transitive includes; please add to keep the header self-contained and avoid include-order fragility.
| #include <memory> | |
| #include <memory> | |
| #include <optional> |
| Internal, | ||
| "Preprocessor returned unexpected output."); | ||
|
|
||
| auto mel = outputs[0].toTensor(); | ||
| return std::make_shared<::executorch::aten::Tensor>(std::move(mel)); |
There was a problem hiding this comment.
preprocess() only consumes outputs[0] from the preprocessor, but the header documents the method as returning (mel, mel_len). If mel_len is provided (or mel is padded), dropping it can cause transcribe() to pass an incorrect length to the encoder. Consider consuming outputs[1] when present and using it to compute/override mel_len.
| int64_t k = joint_result.get()[0].toTensor().const_data_ptr<int64_t>()[0]; | ||
|
|
||
| // Compute frame advance duration | ||
| int64_t dur = 1; // default for standard RNN-T | ||
| if (is_tdt) { |
There was a problem hiding this comment.
joint_result.get() is indexed at [0] (and [1] when is_tdt) without validating the number/type of outputs. Add a guard that checks output count and that the expected outputs are tensors/scalars before reading const_data_ptr<int64_t>().
Summary: Pull Request resolved: #18961 Restructure `extension/asr/runner/` to support both Seq2Seq and Transducer ASR architectures: 1. **Rename** `AsrRunner` → `Seq2SeqRunner` (with backward-compat alias in runner.h) - `runner.cpp` → `seq2seq_runner.cpp`, `runner.h` → `seq2seq_runner.h` - Old `runner.h` kept as a thin include redirect for backward compatibility 2. **Add** `TransducerRunner` for RNN-T/TDT models - Extracts the ~190-line greedy decode loop from `examples/models/parakeet/main.cpp` - Auto-reads model metadata (blank_id, num_rnn_layers, pred_hidden) from constant_methods - Supports both standard RNN-T (duration=1) and TDT (variable durations) - Returns `vector<Token>` with frame offsets for downstream timestamp computation - Includes `preprocess()` method for models with bundled preprocessor - Exposes tokenizer via `tokenizer()` getter for downstream use 3. **Unify Token type**: `parakeet::Token` is now an alias for `asr::Token` 4. **Migrate consumers** - Whisper `main.cpp`: `AsrRunner` → `Seq2SeqRunner` - Parakeet `main.cpp`: inline decode → `TransducerRunner::transcribe()` - Android JNI: updated include (uses backward-compat alias, no API change) Model-specific post-processing (timestamp computation) remains in `examples/models/parakeet/`. Differential Revision: D100892465
abbbb8f to
a29f1d7
Compare
Summary: Pull Request resolved: #18961 Restructure `extension/asr/runner/` to support both Seq2Seq and Transducer ASR architectures: 1. **Rename** `AsrRunner` → `Seq2SeqRunner` (with backward-compat alias in runner.h) - `runner.cpp` → `seq2seq_runner.cpp`, `runner.h` → `seq2seq_runner.h` - Old `runner.h` kept as a thin include redirect for backward compatibility 2. **Add** `TransducerRunner` for RNN-T/TDT models - Extracts the ~190-line greedy decode loop from `examples/models/parakeet/main.cpp` - Auto-reads model metadata (blank_id, num_rnn_layers, pred_hidden) from constant_methods - Supports both standard RNN-T (duration=1) and TDT (variable durations) - Returns `vector<Token>` with frame offsets for downstream timestamp computation - Includes `preprocess()` method for models with bundled preprocessor - Exposes tokenizer via `tokenizer()` getter for downstream use 3. **Unify Token type**: `parakeet::Token` is now an alias for `asr::Token` 4. **Migrate consumers** - Whisper `main.cpp`: `AsrRunner` → `Seq2SeqRunner` - Parakeet `main.cpp`: inline decode → `TransducerRunner::transcribe()` - Android JNI: updated include (uses backward-compat alias, no API change) Model-specific post-processing (timestamp computation) remains in `examples/models/parakeet/`. Differential Revision: D100892465
a29f1d7 to
67e6800
Compare
Summary:
Restructure
extension/asr/runner/to support both Seq2Seq and Transducer ASR architectures:Rename
AsrRunner→Seq2SeqRunner(with backward-compat alias in runner.h)runner.cpp→seq2seq_runner.cpp,runner.h→seq2seq_runner.hrunner.hkept as a thin include redirect for backward compatibilityAdd
TransducerRunnerfor RNN-T/TDT modelsexamples/models/parakeet/main.cppvector<Token>with frame offsets for downstream timestamp computationpreprocess()method for models with bundled preprocessortokenizer()getter for downstream useUnify Token type:
parakeet::Tokenis now an alias forasr::TokenMigrate consumers
main.cpp:AsrRunner→Seq2SeqRunnermain.cpp: inline decode →TransducerRunner::transcribe()Model-specific post-processing (timestamp computation) remains in
examples/models/parakeet/.Differential Revision: D100892465