| using System; |
| using TorchSharp; |
|
|
| public class DDIMSampler |
| { |
| private readonly DDPM _model; |
| private const int TIME_STEPS = 1000; |
| private readonly torch.Device _device; |
|
|
| public DDIMSampler(DDPM model, float scale = 9.0f) |
| { |
| _model = model; |
| _device = model.Device; |
| } |
|
|
| public torch.Tensor Sample(torch.Tensor img, torch.Tensor condition, torch.Tensor unconditional_condition, int steps = 50, float scale = 9.0f) |
| { |
| var gap = DDIMSampler.TIME_STEPS / steps; |
| var batch = img.shape[0]; |
| |
| using(var context = torch.enable_grad(false)) |
| { |
| for(var i = DDIMSampler.TIME_STEPS-1; i >=0; i -= gap) |
| { |
| var t_cur = torch.full(batch, i, dtype: torch.ScalarType.Int64, device: _device); |
| var t_prev = torch.full(batch, i - gap >= 0? i - gap: 0, dtype: torch.ScalarType.Int64, device: _device); |
| (var e_t_uncond, var e_t) = _model.DiffusionModel(img, condition, unconditional_condition, t_cur); |
| var model_output = e_t_uncond + scale * (e_t - e_t_uncond); |
| e_t = _model.PredictEPSFromZANDV(img, t_cur, model_output); |
| var pred_x0 = _model.PredictStartFromZANDV(img, t_cur, model_output); |
| img = _model.QSample(pred_x0, t_prev, e_t); |
| Console.WriteLine(img); |
| } |
|
|
| return img; |
| } |
| } |
| } |