File size: 847 Bytes
8511ba7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 |
# Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
# SPDX-License-Identifier: MIT
from .generation_utils import generate_block
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.generation.utils import GenerationConfig
import torch
class SeedDiffcoderForCausalLM(LlamaForCausalLM):
@torch.no_grad()
def generate(
self,
input_ids=None,
generation_config: GenerationConfig = None,
**kwargs,
):
if input_ids is None:
raise ValueError("input_ids must be provided")
if generation_config is None:
generation_config = self.generation_config
prompt = input_ids
output_ids, nfe = generate_block(
model=self,
prompt=prompt,
**kwargs,
)
return output_ids
|