File size: 847 Bytes
8511ba7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
 # Copyright (c) 2026 ByteDance Ltd. and/or its affiliates
 # SPDX-License-Identifier: MIT

from .generation_utils import generate_block
from transformers.models.llama.modeling_llama import LlamaForCausalLM
from transformers.generation.utils import GenerationConfig
import torch

class SeedDiffcoderForCausalLM(LlamaForCausalLM):
    @torch.no_grad()
    def generate(
        self,
        input_ids=None,
        generation_config: GenerationConfig = None,
        **kwargs,
    ):
        if input_ids is None:
            raise ValueError("input_ids must be provided")

        if generation_config is None:
            generation_config = self.generation_config

        prompt = input_ids
        output_ids, nfe = generate_block(
            model=self,
            prompt=prompt,
            **kwargs,
        )

        return output_ids