# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates # SPDX-License-Identifier: MIT from .generation_utils import generate_block from transformers.models.llama.modeling_llama import LlamaForCausalLM from transformers.generation.utils import GenerationConfig import torch class SeedDiffcoderForCausalLM(LlamaForCausalLM): @torch.no_grad() def generate( self, input_ids=None, generation_config: GenerationConfig = None, **kwargs, ): if input_ids is None: raise ValueError("input_ids must be provided") if generation_config is None: generation_config = self.generation_config prompt = input_ids output_ids, nfe = generate_block( model=self, prompt=prompt, **kwargs, ) return output_ids