| | """ |
| | This script ports models from VQ-diffusion (https://github.com/microsoft/VQ-Diffusion) to diffusers. |
| | |
| | It currently only supports porting the ITHQ dataset. |
| | |
| | ITHQ dataset: |
| | ```sh |
| | # From the root directory of diffusers. |
| | |
| | # Download the VQVAE checkpoint |
| | $ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_vqvae.pth?sv=2020-10-02&st=2022-05-30T15%3A17%3A18Z&se=2030-05-31T15%3A17%3A00Z&sr=b&sp=r&sig=1jVavHFPpUjDs%2FTO1V3PTezaNbPp2Nx8MxiWI7y6fEY%3D -O ithq_vqvae.pth |
| | |
| | # Download the VQVAE config |
| | # NOTE that in VQ-diffusion the documented file is `configs/ithq.yaml` but the target class |
| | # `image_synthesis.modeling.codecs.image_codec.ema_vqvae.PatchVQVAE` |
| | # loads `OUTPUT/pretrained_model/taming_dvae/config.yaml` |
| | $ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/OUTPUT/pretrained_model/taming_dvae/config.yaml -O ithq_vqvae.yaml |
| | |
| | # Download the main model checkpoint |
| | $ wget https://facevcstandard.blob.core.windows.net/v-zhictang/Improved-VQ-Diffusion_model_release/ithq_learnable.pth?sv=2020-10-02&st=2022-05-30T10%3A22%3A06Z&se=2030-05-31T10%3A22%3A00Z&sr=b&sp=r&sig=GOE%2Bza02%2FPnGxYVOOPtwrTR4RA3%2F5NVgMxdW4kjaEZ8%3D -O ithq_learnable.pth |
| | |
| | # Download the main model config |
| | $ wget https://raw.githubusercontent.com/microsoft/VQ-Diffusion/main/configs/ithq.yaml -O ithq.yaml |
| | |
| | # run the convert script |
| | $ python ./scripts/convert_vq_diffusion_to_diffusers.py \ |
| | --checkpoint_path ./ithq_learnable.pth \ |
| | --original_config_file ./ithq.yaml \ |
| | --vqvae_checkpoint_path ./ithq_vqvae.pth \ |
| | --vqvae_original_config_file ./ithq_vqvae.yaml \ |
| | --dump_path <path to save pre-trained `VQDiffusionPipeline`> |
| | ``` |
| | """ |
| |
|
| | import argparse |
| | import tempfile |
| |
|
| | import torch |
| | import yaml |
| | from accelerate import init_empty_weights, load_checkpoint_and_dispatch |
| | from transformers import CLIPTextModel, CLIPTokenizer |
| | from yaml.loader import FullLoader |
| |
|
| | from diffusers import Transformer2DModel, VQDiffusionPipeline, VQDiffusionScheduler, VQModel |
| | from diffusers.pipelines.vq_diffusion.pipeline_vq_diffusion import LearnedClassifierFreeSamplingEmbeddings |
| |
|
| |
|
| | try: |
| | from omegaconf import OmegaConf |
| | except ImportError: |
| | raise ImportError( |
| | "OmegaConf is required to convert the VQ Diffusion checkpoints. Please install it with `pip install" |
| | " OmegaConf`." |
| | ) |
| |
|
| | |
| |
|
| | PORTED_VQVAES = ["image_synthesis.modeling.codecs.image_codec.patch_vqgan.PatchVQGAN"] |
| |
|
| |
|
| | def vqvae_model_from_original_config(original_config): |
| | assert original_config.target in PORTED_VQVAES, f"{original_config.target} has not yet been ported to diffusers." |
| |
|
| | original_config = original_config.params |
| |
|
| | original_encoder_config = original_config.encoder_config.params |
| | original_decoder_config = original_config.decoder_config.params |
| |
|
| | in_channels = original_encoder_config.in_channels |
| | out_channels = original_decoder_config.out_ch |
| |
|
| | down_block_types = get_down_block_types(original_encoder_config) |
| | up_block_types = get_up_block_types(original_decoder_config) |
| |
|
| | assert original_encoder_config.ch == original_decoder_config.ch |
| | assert original_encoder_config.ch_mult == original_decoder_config.ch_mult |
| | block_out_channels = tuple( |
| | [original_encoder_config.ch * a_ch_mult for a_ch_mult in original_encoder_config.ch_mult] |
| | ) |
| |
|
| | assert original_encoder_config.num_res_blocks == original_decoder_config.num_res_blocks |
| | layers_per_block = original_encoder_config.num_res_blocks |
| |
|
| | assert original_encoder_config.z_channels == original_decoder_config.z_channels |
| | latent_channels = original_encoder_config.z_channels |
| |
|
| | num_vq_embeddings = original_config.n_embed |
| |
|
| | |
| | norm_num_groups = 32 |
| |
|
| | e_dim = original_config.embed_dim |
| |
|
| | model = VQModel( |
| | in_channels=in_channels, |
| | out_channels=out_channels, |
| | down_block_types=down_block_types, |
| | up_block_types=up_block_types, |
| | block_out_channels=block_out_channels, |
| | layers_per_block=layers_per_block, |
| | latent_channels=latent_channels, |
| | num_vq_embeddings=num_vq_embeddings, |
| | norm_num_groups=norm_num_groups, |
| | vq_embed_dim=e_dim, |
| | ) |
| |
|
| | return model |
| |
|
| |
|
| | def get_down_block_types(original_encoder_config): |
| | attn_resolutions = coerce_attn_resolutions(original_encoder_config.attn_resolutions) |
| | num_resolutions = len(original_encoder_config.ch_mult) |
| | resolution = coerce_resolution(original_encoder_config.resolution) |
| |
|
| | curr_res = resolution |
| | down_block_types = [] |
| |
|
| | for _ in range(num_resolutions): |
| | if curr_res in attn_resolutions: |
| | down_block_type = "AttnDownEncoderBlock2D" |
| | else: |
| | down_block_type = "DownEncoderBlock2D" |
| |
|
| | down_block_types.append(down_block_type) |
| |
|
| | curr_res = [r // 2 for r in curr_res] |
| |
|
| | return down_block_types |
| |
|
| |
|
| | def get_up_block_types(original_decoder_config): |
| | attn_resolutions = coerce_attn_resolutions(original_decoder_config.attn_resolutions) |
| | num_resolutions = len(original_decoder_config.ch_mult) |
| | resolution = coerce_resolution(original_decoder_config.resolution) |
| |
|
| | curr_res = [r // 2 ** (num_resolutions - 1) for r in resolution] |
| | up_block_types = [] |
| |
|
| | for _ in reversed(range(num_resolutions)): |
| | if curr_res in attn_resolutions: |
| | up_block_type = "AttnUpDecoderBlock2D" |
| | else: |
| | up_block_type = "UpDecoderBlock2D" |
| |
|
| | up_block_types.append(up_block_type) |
| |
|
| | curr_res = [r * 2 for r in curr_res] |
| |
|
| | return up_block_types |
| |
|
| |
|
| | def coerce_attn_resolutions(attn_resolutions): |
| | attn_resolutions = OmegaConf.to_object(attn_resolutions) |
| | attn_resolutions_ = [] |
| | for ar in attn_resolutions: |
| | if isinstance(ar, (list, tuple)): |
| | attn_resolutions_.append(list(ar)) |
| | else: |
| | attn_resolutions_.append([ar, ar]) |
| | return attn_resolutions_ |
| |
|
| |
|
| | def coerce_resolution(resolution): |
| | resolution = OmegaConf.to_object(resolution) |
| | if isinstance(resolution, int): |
| | resolution = [resolution, resolution] |
| | elif isinstance(resolution, (tuple, list)): |
| | resolution = list(resolution) |
| | else: |
| | raise ValueError("Unknown type of resolution:", resolution) |
| | return resolution |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def vqvae_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | diffusers_checkpoint.update(vqvae_encoder_to_diffusers_checkpoint(model, checkpoint)) |
| |
|
| | |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | "quant_conv.weight": checkpoint["quant_conv.weight"], |
| | "quant_conv.bias": checkpoint["quant_conv.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update({"quantize.embedding.weight": checkpoint["quantize.embedding"]}) |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "post_quant_conv.weight": checkpoint["post_quant_conv.weight"], |
| | "post_quant_conv.bias": checkpoint["post_quant_conv.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_checkpoint.update(vqvae_decoder_to_diffusers_checkpoint(model, checkpoint)) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def vqvae_encoder_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "encoder.conv_in.weight": checkpoint["encoder.conv_in.weight"], |
| | "encoder.conv_in.bias": checkpoint["encoder.conv_in.bias"], |
| | } |
| | ) |
| |
|
| | |
| | for down_block_idx, down_block in enumerate(model.encoder.down_blocks): |
| | diffusers_down_block_prefix = f"encoder.down_blocks.{down_block_idx}" |
| | down_block_prefix = f"encoder.down.{down_block_idx}" |
| |
|
| | |
| | for resnet_idx, resnet in enumerate(down_block.resnets): |
| | diffusers_resnet_prefix = f"{diffusers_down_block_prefix}.resnets.{resnet_idx}" |
| | resnet_prefix = f"{down_block_prefix}.block.{resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | vqvae_resnet_to_diffusers_checkpoint( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | if down_block_idx != len(model.encoder.down_blocks) - 1: |
| | |
| | |
| | diffusers_downsample_prefix = f"{diffusers_down_block_prefix}.downsamplers.0.conv" |
| | downsample_prefix = f"{down_block_prefix}.downsample.conv" |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], |
| | f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | |
| |
|
| | if hasattr(down_block, "attentions"): |
| | for attention_idx, _ in enumerate(down_block.attentions): |
| | diffusers_attention_prefix = f"{diffusers_down_block_prefix}.attentions.{attention_idx}" |
| | attention_prefix = f"{down_block_prefix}.attn.{attention_idx}" |
| | diffusers_checkpoint.update( |
| | vqvae_attention_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_attention_prefix=diffusers_attention_prefix, |
| | attention_prefix=attention_prefix, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | diffusers_attention_prefix = "encoder.mid_block.attentions.0" |
| | attention_prefix = "encoder.mid.attn_1" |
| | diffusers_checkpoint.update( |
| | vqvae_attention_to_diffusers_checkpoint( |
| | checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): |
| | diffusers_resnet_prefix = f"encoder.mid_block.resnets.{diffusers_resnet_idx}" |
| |
|
| | |
| | orig_resnet_idx = diffusers_resnet_idx + 1 |
| | |
| | resnet_prefix = f"encoder.mid.block_{orig_resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | vqvae_resnet_to_diffusers_checkpoint( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | |
| | "encoder.conv_norm_out.weight": checkpoint["encoder.norm_out.weight"], |
| | "encoder.conv_norm_out.bias": checkpoint["encoder.norm_out.bias"], |
| | |
| | "encoder.conv_out.weight": checkpoint["encoder.conv_out.weight"], |
| | "encoder.conv_out.bias": checkpoint["encoder.conv_out.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def vqvae_decoder_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | "decoder.conv_in.weight": checkpoint["decoder.conv_in.weight"], |
| | "decoder.conv_in.bias": checkpoint["decoder.conv_in.bias"], |
| | } |
| | ) |
| |
|
| | |
| |
|
| | for diffusers_up_block_idx, up_block in enumerate(model.decoder.up_blocks): |
| | |
| | orig_up_block_idx = len(model.decoder.up_blocks) - 1 - diffusers_up_block_idx |
| |
|
| | diffusers_up_block_prefix = f"decoder.up_blocks.{diffusers_up_block_idx}" |
| | up_block_prefix = f"decoder.up.{orig_up_block_idx}" |
| |
|
| | |
| | for resnet_idx, resnet in enumerate(up_block.resnets): |
| | diffusers_resnet_prefix = f"{diffusers_up_block_prefix}.resnets.{resnet_idx}" |
| | resnet_prefix = f"{up_block_prefix}.block.{resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | vqvae_resnet_to_diffusers_checkpoint( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| | if diffusers_up_block_idx != len(model.decoder.up_blocks) - 1: |
| | |
| | |
| | diffusers_downsample_prefix = f"{diffusers_up_block_prefix}.upsamplers.0.conv" |
| | downsample_prefix = f"{up_block_prefix}.upsample.conv" |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_downsample_prefix}.weight": checkpoint[f"{downsample_prefix}.weight"], |
| | f"{diffusers_downsample_prefix}.bias": checkpoint[f"{downsample_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | |
| |
|
| | if hasattr(up_block, "attentions"): |
| | for attention_idx, _ in enumerate(up_block.attentions): |
| | diffusers_attention_prefix = f"{diffusers_up_block_prefix}.attentions.{attention_idx}" |
| | attention_prefix = f"{up_block_prefix}.attn.{attention_idx}" |
| | diffusers_checkpoint.update( |
| | vqvae_attention_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_attention_prefix=diffusers_attention_prefix, |
| | attention_prefix=attention_prefix, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | |
| | diffusers_attention_prefix = "decoder.mid_block.attentions.0" |
| | attention_prefix = "decoder.mid.attn_1" |
| | diffusers_checkpoint.update( |
| | vqvae_attention_to_diffusers_checkpoint( |
| | checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | for diffusers_resnet_idx, resnet in enumerate(model.encoder.mid_block.resnets): |
| | diffusers_resnet_prefix = f"decoder.mid_block.resnets.{diffusers_resnet_idx}" |
| |
|
| | |
| | orig_resnet_idx = diffusers_resnet_idx + 1 |
| | |
| | resnet_prefix = f"decoder.mid.block_{orig_resnet_idx}" |
| |
|
| | diffusers_checkpoint.update( |
| | vqvae_resnet_to_diffusers_checkpoint( |
| | resnet, checkpoint, diffusers_resnet_prefix=diffusers_resnet_prefix, resnet_prefix=resnet_prefix |
| | ) |
| | ) |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | |
| | "decoder.conv_norm_out.weight": checkpoint["decoder.norm_out.weight"], |
| | "decoder.conv_norm_out.bias": checkpoint["decoder.norm_out.bias"], |
| | |
| | "decoder.conv_out.weight": checkpoint["decoder.conv_out.weight"], |
| | "decoder.conv_out.bias": checkpoint["decoder.conv_out.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def vqvae_resnet_to_diffusers_checkpoint(resnet, checkpoint, *, diffusers_resnet_prefix, resnet_prefix): |
| | rv = { |
| | |
| | f"{diffusers_resnet_prefix}.norm1.weight": checkpoint[f"{resnet_prefix}.norm1.weight"], |
| | f"{diffusers_resnet_prefix}.norm1.bias": checkpoint[f"{resnet_prefix}.norm1.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.conv1.weight": checkpoint[f"{resnet_prefix}.conv1.weight"], |
| | f"{diffusers_resnet_prefix}.conv1.bias": checkpoint[f"{resnet_prefix}.conv1.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.norm2.weight": checkpoint[f"{resnet_prefix}.norm2.weight"], |
| | f"{diffusers_resnet_prefix}.norm2.bias": checkpoint[f"{resnet_prefix}.norm2.bias"], |
| | |
| | f"{diffusers_resnet_prefix}.conv2.weight": checkpoint[f"{resnet_prefix}.conv2.weight"], |
| | f"{diffusers_resnet_prefix}.conv2.bias": checkpoint[f"{resnet_prefix}.conv2.bias"], |
| | } |
| |
|
| | if resnet.conv_shortcut is not None: |
| | rv.update( |
| | { |
| | f"{diffusers_resnet_prefix}.conv_shortcut.weight": checkpoint[f"{resnet_prefix}.nin_shortcut.weight"], |
| | f"{diffusers_resnet_prefix}.conv_shortcut.bias": checkpoint[f"{resnet_prefix}.nin_shortcut.bias"], |
| | } |
| | ) |
| |
|
| | return rv |
| |
|
| |
|
| | def vqvae_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): |
| | return { |
| | |
| | f"{diffusers_attention_prefix}.group_norm.weight": checkpoint[f"{attention_prefix}.norm.weight"], |
| | f"{diffusers_attention_prefix}.group_norm.bias": checkpoint[f"{attention_prefix}.norm.bias"], |
| | |
| | f"{diffusers_attention_prefix}.query.weight": checkpoint[f"{attention_prefix}.q.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.query.bias": checkpoint[f"{attention_prefix}.q.bias"], |
| | |
| | f"{diffusers_attention_prefix}.key.weight": checkpoint[f"{attention_prefix}.k.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.key.bias": checkpoint[f"{attention_prefix}.k.bias"], |
| | |
| | f"{diffusers_attention_prefix}.value.weight": checkpoint[f"{attention_prefix}.v.weight"][:, :, 0, 0], |
| | f"{diffusers_attention_prefix}.value.bias": checkpoint[f"{attention_prefix}.v.bias"], |
| | |
| | f"{diffusers_attention_prefix}.proj_attn.weight": checkpoint[f"{attention_prefix}.proj_out.weight"][ |
| | :, :, 0, 0 |
| | ], |
| | f"{diffusers_attention_prefix}.proj_attn.bias": checkpoint[f"{attention_prefix}.proj_out.bias"], |
| | } |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| | PORTED_DIFFUSIONS = ["image_synthesis.modeling.transformers.diffusion_transformer.DiffusionTransformer"] |
| | PORTED_TRANSFORMERS = ["image_synthesis.modeling.transformers.transformer_utils.Text2ImageTransformer"] |
| | PORTED_CONTENT_EMBEDDINGS = ["image_synthesis.modeling.embeddings.dalle_mask_image_embedding.DalleMaskImageEmbedding"] |
| |
|
| |
|
| | def transformer_model_from_original_config( |
| | original_diffusion_config, original_transformer_config, original_content_embedding_config |
| | ): |
| | assert ( |
| | original_diffusion_config.target in PORTED_DIFFUSIONS |
| | ), f"{original_diffusion_config.target} has not yet been ported to diffusers." |
| | assert ( |
| | original_transformer_config.target in PORTED_TRANSFORMERS |
| | ), f"{original_transformer_config.target} has not yet been ported to diffusers." |
| | assert ( |
| | original_content_embedding_config.target in PORTED_CONTENT_EMBEDDINGS |
| | ), f"{original_content_embedding_config.target} has not yet been ported to diffusers." |
| |
|
| | original_diffusion_config = original_diffusion_config.params |
| | original_transformer_config = original_transformer_config.params |
| | original_content_embedding_config = original_content_embedding_config.params |
| |
|
| | inner_dim = original_transformer_config["n_embd"] |
| |
|
| | n_heads = original_transformer_config["n_head"] |
| |
|
| | |
| | |
| | |
| | |
| | assert inner_dim % n_heads == 0 |
| | d_head = inner_dim // n_heads |
| |
|
| | depth = original_transformer_config["n_layer"] |
| | context_dim = original_transformer_config["condition_dim"] |
| |
|
| | num_embed = original_content_embedding_config["num_embed"] |
| | |
| | |
| | num_embed = num_embed + 1 |
| |
|
| | height = original_transformer_config["content_spatial_size"][0] |
| | width = original_transformer_config["content_spatial_size"][1] |
| |
|
| | assert width == height, "width has to be equal to height" |
| | dropout = original_transformer_config["resid_pdrop"] |
| | num_embeds_ada_norm = original_diffusion_config["diffusion_step"] |
| |
|
| | model_kwargs = { |
| | "attention_bias": True, |
| | "cross_attention_dim": context_dim, |
| | "attention_head_dim": d_head, |
| | "num_layers": depth, |
| | "dropout": dropout, |
| | "num_attention_heads": n_heads, |
| | "num_vector_embeds": num_embed, |
| | "num_embeds_ada_norm": num_embeds_ada_norm, |
| | "norm_num_groups": 32, |
| | "sample_size": width, |
| | "activation_fn": "geglu-approximate", |
| | } |
| |
|
| | model = Transformer2DModel(**model_kwargs) |
| | return model |
| |
|
| |
|
| | |
| |
|
| | |
| |
|
| |
|
| | def transformer_original_checkpoint_to_diffusers_checkpoint(model, checkpoint): |
| | diffusers_checkpoint = {} |
| |
|
| | transformer_prefix = "transformer.transformer" |
| |
|
| | diffusers_latent_image_embedding_prefix = "latent_image_embedding" |
| | latent_image_embedding_prefix = f"{transformer_prefix}.content_emb" |
| |
|
| | |
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_latent_image_embedding_prefix}.emb.weight": checkpoint[ |
| | f"{latent_image_embedding_prefix}.emb.weight" |
| | ], |
| | f"{diffusers_latent_image_embedding_prefix}.height_emb.weight": checkpoint[ |
| | f"{latent_image_embedding_prefix}.height_emb.weight" |
| | ], |
| | f"{diffusers_latent_image_embedding_prefix}.width_emb.weight": checkpoint[ |
| | f"{latent_image_embedding_prefix}.width_emb.weight" |
| | ], |
| | } |
| | ) |
| |
|
| | |
| | for transformer_block_idx, transformer_block in enumerate(model.transformer_blocks): |
| | diffusers_transformer_block_prefix = f"transformer_blocks.{transformer_block_idx}" |
| | transformer_block_prefix = f"{transformer_prefix}.blocks.{transformer_block_idx}" |
| |
|
| | |
| | diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm1" |
| | ada_norm_prefix = f"{transformer_block_prefix}.ln1" |
| |
|
| | diffusers_checkpoint.update( |
| | transformer_ada_norm_to_diffusers_checkpoint( |
| | checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix |
| | ) |
| | ) |
| |
|
| | |
| | diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn1" |
| | attention_prefix = f"{transformer_block_prefix}.attn1" |
| |
|
| | diffusers_checkpoint.update( |
| | transformer_attention_to_diffusers_checkpoint( |
| | checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix |
| | ) |
| | ) |
| |
|
| | |
| | diffusers_ada_norm_prefix = f"{diffusers_transformer_block_prefix}.norm2" |
| | ada_norm_prefix = f"{transformer_block_prefix}.ln1_1" |
| |
|
| | diffusers_checkpoint.update( |
| | transformer_ada_norm_to_diffusers_checkpoint( |
| | checkpoint, diffusers_ada_norm_prefix=diffusers_ada_norm_prefix, ada_norm_prefix=ada_norm_prefix |
| | ) |
| | ) |
| |
|
| | |
| | diffusers_attention_prefix = f"{diffusers_transformer_block_prefix}.attn2" |
| | attention_prefix = f"{transformer_block_prefix}.attn2" |
| |
|
| | diffusers_checkpoint.update( |
| | transformer_attention_to_diffusers_checkpoint( |
| | checkpoint, diffusers_attention_prefix=diffusers_attention_prefix, attention_prefix=attention_prefix |
| | ) |
| | ) |
| |
|
| | |
| | diffusers_norm_block_prefix = f"{diffusers_transformer_block_prefix}.norm3" |
| | norm_block_prefix = f"{transformer_block_prefix}.ln2" |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_norm_block_prefix}.weight": checkpoint[f"{norm_block_prefix}.weight"], |
| | f"{diffusers_norm_block_prefix}.bias": checkpoint[f"{norm_block_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | |
| | diffusers_feedforward_prefix = f"{diffusers_transformer_block_prefix}.ff" |
| | feedforward_prefix = f"{transformer_block_prefix}.mlp" |
| |
|
| | diffusers_checkpoint.update( |
| | transformer_feedforward_to_diffusers_checkpoint( |
| | checkpoint, |
| | diffusers_feedforward_prefix=diffusers_feedforward_prefix, |
| | feedforward_prefix=feedforward_prefix, |
| | ) |
| | ) |
| |
|
| | |
| |
|
| | diffusers_norm_out_prefix = "norm_out" |
| | norm_out_prefix = f"{transformer_prefix}.to_logits.0" |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_norm_out_prefix}.weight": checkpoint[f"{norm_out_prefix}.weight"], |
| | f"{diffusers_norm_out_prefix}.bias": checkpoint[f"{norm_out_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | diffusers_out_prefix = "out" |
| | out_prefix = f"{transformer_prefix}.to_logits.1" |
| |
|
| | diffusers_checkpoint.update( |
| | { |
| | f"{diffusers_out_prefix}.weight": checkpoint[f"{out_prefix}.weight"], |
| | f"{diffusers_out_prefix}.bias": checkpoint[f"{out_prefix}.bias"], |
| | } |
| | ) |
| |
|
| | return diffusers_checkpoint |
| |
|
| |
|
| | def transformer_ada_norm_to_diffusers_checkpoint(checkpoint, *, diffusers_ada_norm_prefix, ada_norm_prefix): |
| | return { |
| | f"{diffusers_ada_norm_prefix}.emb.weight": checkpoint[f"{ada_norm_prefix}.emb.weight"], |
| | f"{diffusers_ada_norm_prefix}.linear.weight": checkpoint[f"{ada_norm_prefix}.linear.weight"], |
| | f"{diffusers_ada_norm_prefix}.linear.bias": checkpoint[f"{ada_norm_prefix}.linear.bias"], |
| | } |
| |
|
| |
|
| | def transformer_attention_to_diffusers_checkpoint(checkpoint, *, diffusers_attention_prefix, attention_prefix): |
| | return { |
| | |
| | f"{diffusers_attention_prefix}.to_k.weight": checkpoint[f"{attention_prefix}.key.weight"], |
| | f"{diffusers_attention_prefix}.to_k.bias": checkpoint[f"{attention_prefix}.key.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_q.weight": checkpoint[f"{attention_prefix}.query.weight"], |
| | f"{diffusers_attention_prefix}.to_q.bias": checkpoint[f"{attention_prefix}.query.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_v.weight": checkpoint[f"{attention_prefix}.value.weight"], |
| | f"{diffusers_attention_prefix}.to_v.bias": checkpoint[f"{attention_prefix}.value.bias"], |
| | |
| | f"{diffusers_attention_prefix}.to_out.0.weight": checkpoint[f"{attention_prefix}.proj.weight"], |
| | f"{diffusers_attention_prefix}.to_out.0.bias": checkpoint[f"{attention_prefix}.proj.bias"], |
| | } |
| |
|
| |
|
| | def transformer_feedforward_to_diffusers_checkpoint(checkpoint, *, diffusers_feedforward_prefix, feedforward_prefix): |
| | return { |
| | f"{diffusers_feedforward_prefix}.net.0.proj.weight": checkpoint[f"{feedforward_prefix}.0.weight"], |
| | f"{diffusers_feedforward_prefix}.net.0.proj.bias": checkpoint[f"{feedforward_prefix}.0.bias"], |
| | f"{diffusers_feedforward_prefix}.net.2.weight": checkpoint[f"{feedforward_prefix}.2.weight"], |
| | f"{diffusers_feedforward_prefix}.net.2.bias": checkpoint[f"{feedforward_prefix}.2.bias"], |
| | } |
| |
|
| |
|
| | |
| |
|
| |
|
| | def read_config_file(filename): |
| | |
| | |
| | |
| | |
| | with open(filename) as f: |
| | original_config = yaml.load(f, FullLoader) |
| |
|
| | return OmegaConf.create(original_config) |
| |
|
| |
|
| | |
| | |
| | if __name__ == "__main__": |
| | parser = argparse.ArgumentParser() |
| |
|
| | parser.add_argument( |
| | "--vqvae_checkpoint_path", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="Path to the vqvae checkpoint to convert.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--vqvae_original_config_file", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="The YAML config file corresponding to the original architecture for the vqvae.", |
| | ) |
| |
|
| | parser.add_argument( |
| | "--checkpoint_path", default=None, type=str, required=True, help="Path to the checkpoint to convert." |
| | ) |
| |
|
| | parser.add_argument( |
| | "--original_config_file", |
| | default=None, |
| | type=str, |
| | required=True, |
| | help="The YAML config file corresponding to the original architecture.", |
| | ) |
| |
|
| | parser.add_argument("--dump_path", default=None, type=str, required=True, help="Path to the output model.") |
| |
|
| | parser.add_argument( |
| | "--checkpoint_load_device", |
| | default="cpu", |
| | type=str, |
| | required=False, |
| | help="The device passed to `map_location` when loading checkpoints.", |
| | ) |
| |
|
| | |
| | |
| | parser.add_argument( |
| | "--no_use_ema", |
| | action="store_true", |
| | required=False, |
| | help=( |
| | "Set to not use the ema weights from the original VQ-Diffusion checkpoint. You probably do not want to set" |
| | " it as the original VQ-Diffusion always uses the ema weights when loading models." |
| | ), |
| | ) |
| |
|
| | args = parser.parse_args() |
| |
|
| | use_ema = not args.no_use_ema |
| |
|
| | print(f"loading checkpoints to {args.checkpoint_load_device}") |
| |
|
| | checkpoint_map_location = torch.device(args.checkpoint_load_device) |
| |
|
| | |
| |
|
| | print(f"loading vqvae, config: {args.vqvae_original_config_file}, checkpoint: {args.vqvae_checkpoint_path}") |
| |
|
| | vqvae_original_config = read_config_file(args.vqvae_original_config_file).model |
| | vqvae_checkpoint = torch.load(args.vqvae_checkpoint_path, map_location=checkpoint_map_location)["model"] |
| |
|
| | with init_empty_weights(): |
| | vqvae_model = vqvae_model_from_original_config(vqvae_original_config) |
| |
|
| | vqvae_diffusers_checkpoint = vqvae_original_checkpoint_to_diffusers_checkpoint(vqvae_model, vqvae_checkpoint) |
| |
|
| | with tempfile.NamedTemporaryFile() as vqvae_diffusers_checkpoint_file: |
| | torch.save(vqvae_diffusers_checkpoint, vqvae_diffusers_checkpoint_file.name) |
| | del vqvae_diffusers_checkpoint |
| | del vqvae_checkpoint |
| | load_checkpoint_and_dispatch(vqvae_model, vqvae_diffusers_checkpoint_file.name, device_map="auto") |
| |
|
| | print("done loading vqvae") |
| |
|
| | |
| |
|
| | |
| |
|
| | print( |
| | f"loading transformer, config: {args.original_config_file}, checkpoint: {args.checkpoint_path}, use ema:" |
| | f" {use_ema}" |
| | ) |
| |
|
| | original_config = read_config_file(args.original_config_file).model |
| |
|
| | diffusion_config = original_config.params.diffusion_config |
| | transformer_config = original_config.params.diffusion_config.params.transformer_config |
| | content_embedding_config = original_config.params.diffusion_config.params.content_emb_config |
| |
|
| | pre_checkpoint = torch.load(args.checkpoint_path, map_location=checkpoint_map_location) |
| |
|
| | if use_ema: |
| | if "ema" in pre_checkpoint: |
| | checkpoint = {} |
| | for k, v in pre_checkpoint["model"].items(): |
| | checkpoint[k] = v |
| |
|
| | for k, v in pre_checkpoint["ema"].items(): |
| | |
| | |
| | |
| | checkpoint[f"transformer.{k}"] = v |
| | else: |
| | print("attempted to load ema weights but no ema weights are specified in the loaded checkpoint.") |
| | checkpoint = pre_checkpoint["model"] |
| | else: |
| | checkpoint = pre_checkpoint["model"] |
| |
|
| | del pre_checkpoint |
| |
|
| | with init_empty_weights(): |
| | transformer_model = transformer_model_from_original_config( |
| | diffusion_config, transformer_config, content_embedding_config |
| | ) |
| |
|
| | diffusers_transformer_checkpoint = transformer_original_checkpoint_to_diffusers_checkpoint( |
| | transformer_model, checkpoint |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| |
|
| | learnable_classifier_free_sampling_embeddings = diffusion_config.params.learnable_cf |
| |
|
| | if learnable_classifier_free_sampling_embeddings: |
| | learned_classifier_free_sampling_embeddings_embeddings = checkpoint["transformer.empty_text_embed"] |
| | else: |
| | learned_classifier_free_sampling_embeddings_embeddings = None |
| |
|
| | |
| |
|
| | with tempfile.NamedTemporaryFile() as diffusers_transformer_checkpoint_file: |
| | torch.save(diffusers_transformer_checkpoint, diffusers_transformer_checkpoint_file.name) |
| | del diffusers_transformer_checkpoint |
| | del checkpoint |
| | load_checkpoint_and_dispatch(transformer_model, diffusers_transformer_checkpoint_file.name, device_map="auto") |
| |
|
| | print("done loading transformer") |
| |
|
| | |
| |
|
| | |
| |
|
| | print("loading CLIP text encoder") |
| |
|
| | clip_name = "openai/clip-vit-base-patch32" |
| |
|
| | |
| | |
| | |
| | |
| | pad_token = "!" |
| |
|
| | tokenizer_model = CLIPTokenizer.from_pretrained(clip_name, pad_token=pad_token, device_map="auto") |
| |
|
| | assert tokenizer_model.convert_tokens_to_ids(pad_token) == 0 |
| |
|
| | text_encoder_model = CLIPTextModel.from_pretrained( |
| | clip_name, |
| | |
| | |
| | ) |
| |
|
| | print("done loading CLIP text encoder") |
| |
|
| | |
| |
|
| | |
| |
|
| | scheduler_model = VQDiffusionScheduler( |
| | |
| | num_vec_classes=transformer_model.num_vector_embeds |
| | ) |
| |
|
| | |
| |
|
| | |
| |
|
| | with init_empty_weights(): |
| | learned_classifier_free_sampling_embeddings_model = LearnedClassifierFreeSamplingEmbeddings( |
| | learnable_classifier_free_sampling_embeddings, |
| | hidden_size=text_encoder_model.config.hidden_size, |
| | length=tokenizer_model.model_max_length, |
| | ) |
| |
|
| | learned_classifier_free_sampling_checkpoint = { |
| | "embeddings": learned_classifier_free_sampling_embeddings_embeddings.float() |
| | } |
| |
|
| | with tempfile.NamedTemporaryFile() as learned_classifier_free_sampling_checkpoint_file: |
| | torch.save(learned_classifier_free_sampling_checkpoint, learned_classifier_free_sampling_checkpoint_file.name) |
| | del learned_classifier_free_sampling_checkpoint |
| | del learned_classifier_free_sampling_embeddings_embeddings |
| | load_checkpoint_and_dispatch( |
| | learned_classifier_free_sampling_embeddings_model, |
| | learned_classifier_free_sampling_checkpoint_file.name, |
| | device_map="auto", |
| | ) |
| |
|
| | |
| |
|
| | print(f"saving VQ diffusion model, path: {args.dump_path}") |
| |
|
| | pipe = VQDiffusionPipeline( |
| | vqvae=vqvae_model, |
| | transformer=transformer_model, |
| | tokenizer=tokenizer_model, |
| | text_encoder=text_encoder_model, |
| | learned_classifier_free_sampling_embeddings=learned_classifier_free_sampling_embeddings_model, |
| | scheduler=scheduler_model, |
| | ) |
| | pipe.save_pretrained(args.dump_path) |
| |
|
| | print("done writing VQ diffusion model") |
| |
|