| # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates | |
| # SPDX-License-Identifier: MIT | |
| from .generation_utils import generate_block | |
| from transformers.models.llama.modeling_llama import LlamaForCausalLM | |
| from transformers.generation.utils import GenerationConfig | |
| import torch | |
| class SeedDiffcoderForCausalLM(LlamaForCausalLM): | |
| def generate( | |
| self, | |
| input_ids=None, | |
| generation_config: GenerationConfig = None, | |
| **kwargs, | |
| ): | |
| if input_ids is None: | |
| raise ValueError("input_ids must be provided") | |
| if generation_config is None: | |
| generation_config = self.generation_config | |
| prompt = input_ids | |
| output_ids, nfe = generate_block( | |
| model=self, | |
| prompt=prompt, | |
| **kwargs, | |
| ) | |
| return output_ids | |