| import torch |
|
|
| from diffusers.pipelines import FluxPipeline |
| from omini.pipeline.flux_omini import Condition, generate, seed_everything, convert_to_condition |
| from omini.rotation import RotationConfig, RotationTuner |
| from PIL import Image |
|
|
|
|
| def load_rotation(transformer, path: str, adapter_name: str = "default", strict: bool = False): |
| """ |
| Load rotation adapter weights. |
| |
| Args: |
| path: Directory containing the saved adapter weights |
| adapter_name: Name of the adapter to load |
| strict: Whether to strictly match all keys |
| """ |
| from safetensors.torch import load_file |
| import os |
| import yaml |
| |
| device = transformer.device |
| print(f"device for loading: {device}") |
| |
| |
| safetensors_path = os.path.join(path, f"{adapter_name}.safetensors") |
| pth_path = os.path.join(path, f"{adapter_name}.pth") |
| |
| if os.path.exists(safetensors_path): |
| state_dict = load_file(safetensors_path) |
| print(f"Loaded rotation adapter from {safetensors_path}") |
| elif os.path.exists(pth_path): |
| state_dict = torch.load(pth_path, map_location=device) |
| print(f"Loaded rotation adapter from {pth_path}") |
| else: |
| raise FileNotFoundError( |
| f"No adapter weights found for '{adapter_name}' in {path}\n" |
| f"Looking for: {safetensors_path} or {pth_path}" |
| ) |
| |
| |
| transformer_device = next(transformer.parameters()).device |
| transformer_dtype = next(transformer.parameters()).dtype |
| |
| |
| |
| state_dict_with_adapter = {} |
| for k, v in state_dict.items(): |
| |
| new_key = k.replace(".rotation.", f".rotation.{adapter_name}.") |
| if "_adapter_config" in new_key: |
| print(f"adapter_config key: {new_key}") |
| |
| |
| |
| |
| if v.dtype in [torch.long, torch.int, torch.int32, torch.int64, torch.bool]: |
| |
| state_dict_with_adapter[new_key] = v.to(device=transformer_device) |
| else: |
| |
| state_dict_with_adapter[new_key] = v.to(device=transformer_device, dtype=transformer_dtype) |
| |
| |
| state_dict_with_adapter = { |
| k.replace(".rotation.", f".rotation.{adapter_name}."): v |
| for k, v in state_dict.items() |
| } |
| |
| |
| |
| missing, unexpected = transformer.load_state_dict( |
| state_dict_with_adapter, |
| strict=strict |
| ) |
| |
| if missing: |
| print(f"Missing keys: {missing[:5]}{'...' if len(missing) > 5 else ''}") |
| if unexpected: |
| print(f"Unexpected keys: {unexpected[:5]}{'...' if len(unexpected) > 5 else ''}") |
| |
| |
| config_path = os.path.join(path, f"{adapter_name}_config.yaml") |
| if os.path.exists(config_path): |
| with open(config_path, 'r') as f: |
| config = yaml.safe_load(f) |
| print(f"Loaded config: {config}") |
| |
| total_params = sum(p.numel() for p in state_dict.values()) |
| print(f"Loaded {len(state_dict)} tensors ({total_params:,} parameters)") |
| |
| return state_dict |
|
|
|
|
| |
| image = Image.open("assets/coffee.png").convert("RGB") |
|
|
| w, h, min_dim = image.size + (min(image.size),) |
| image = image.crop( |
| ((w - min_dim) // 2, (h - min_dim) // 2, (w + min_dim) // 2, (h + min_dim) // 2) |
| ).resize((512, 512)) |
|
|
| prompt = "In a bright room. A cup of a coffee with some beans on the side. They are placed on a dark wooden table." |
|
|
| canny_image = convert_to_condition("canny", image) |
| condition = Condition(canny_image, "canny") |
|
|
| seed_everything() |
|
|
|
|
|
|
| for i in range(40, 60): |
| pipe = FluxPipeline.from_pretrained( |
| "black-forest-labs/FLUX.1-dev", torch_dtype=torch.bfloat16 |
| ) |
|
|
|
|
| |
| transformer = pipe.transformer |
|
|
| adapter_name = "default" |
| transformer._hf_peft_config_loaded = True |
|
|
| rotation_adapter_config = { |
| "r": 4, |
| "num_rotations": 4, |
| "target_modules": "(.*x_embedder|.*(?<!single_)transformer_blocks\\.[0-9]+\\.norm1\\.linear|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_k|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_q|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_v|.*(?<!single_)transformer_blocks\\.[0-9]+\\.attn\\.to_out\\.0|.*(?<!single_)transformer_blocks\\.[0-9]+\\.ff\\.net\\.2|.*single_transformer_blocks\\.[0-9]+\\.norm\\.linear|.*single_transformer_blocks\\.[0-9]+\\.proj_mlp|.*single_transformer_blocks\\.[0-9]+\\.proj_out|.*single_transformer_blocks\\.[0-9]+\\.attn.to_k|.*single_transformer_blocks\\.[0-9]+\\.attn.to_q|.*single_transformer_blocks\\.[0-9]+\\.attn.to_v|.*single_transformer_blocks\\.[0-9]+\\.attn.to_out)", |
| } |
|
|
| config = RotationConfig(**rotation_adapter_config) |
| config.T = float(i + 1) / 20 |
| rotation_tuner = RotationTuner( |
| transformer, |
| config, |
| adapter_name=adapter_name, |
| ) |
| |
| transformer = transformer.to(torch.bfloat16) |
| transformer.set_adapter(adapter_name) |
|
|
| |
| load_rotation( |
| transformer, |
| path="runs/20251110-191859/ckpt/4000", |
| adapter_name=adapter_name, |
| strict=False, |
| ) |
|
|
| pipe = pipe.to("cuda") |
|
|
|
|
|
|
|
|
|
|
| result_img = generate( |
| pipe, |
| prompt=prompt, |
| conditions=[condition], |
| ).images[0] |
|
|
| concat_image = Image.new("RGB", (1536, 512)) |
| concat_image.paste(image, (0, 0)) |
| concat_image.paste(condition.condition, (512, 0)) |
| concat_image.paste(result_img, (1024, 0)) |
|
|
| |
| result_img.save(f"result_{i+1}.png") |
| concat_image.save(f"result_concat_{i+1}.png") |
| print(f"Saved result_{i+1}.png and result_concat_{i+1}.png") |