| |
| import csv |
| import os |
|
|
| import torch |
| from transformers import VisionEncoderDecoderModel, AutoTokenizer, AutoFeatureExtractor, Seq2SeqTrainer, training_args |
|
|
| from datasets import load_dataset, Image |
| from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments |
| import evaluate |
| import numpy as np |
|
|
|
|
| import nltk |
| from transformers import default_data_collator |
|
|
| import PIL |
|
|
| import wandb |
| import nltk |
| nltk.download('punkt') |
| import os |
| os.environ["WANDB_DISABLED"] = "true" |
|
|
| import torch |
| import torch_xla.core.xla_model as xm |
|
|
| dev = xm.xla_device() |
|
|
| |
| def tokenization_fn(captions, max_target_length): |
| """Run tokenization on captions.""" |
| labels = tokenizer(captions, |
| padding="max_length", |
| max_length=max_target_length).input_ids |
|
|
| return labels |
|
|
|
|
| |
| def feature_extraction_fn(image_paths, check_image=True): |
| """ |
| Run feature extraction on images |
| If `check_image` is `True`, the examples that fails during `Image.open()` will be caught and discarded. |
| Otherwise, an exception will be thrown. |
| """ |
|
|
| model_inputs = {} |
|
|
| if check_image: |
| images = [] |
| to_keep = [] |
| for image_file in image_paths: |
| try: |
| img = PIL.Image.open(image_file) |
| images.append(img) |
| to_keep.append(True) |
| except Exception: |
| to_keep.append(False) |
| else: |
| images = [PIL.Image.open(image_file) for image_file in image_paths] |
|
|
| encoder_inputs = feature_extractor(images=images, return_tensors="np") |
|
|
| return encoder_inputs.pixel_values |
|
|
| def transform(example_batch): |
| |
| inputs = feature_extractor([x for x in example_batch['image']], return_tensors='pt') |
|
|
| |
| inputs['labels'] = example_batch['labels'] |
| return inputs |
|
|
| def preprocess_fn(example_batch): |
| """Run tokenization + image feature extraction""" |
| model_inputs = {} |
| model_inputs['pixel_values'] = feature_extraction_fn([x for x in example_batch['image_path']]) |
| model_inputs['labels'] = tokenization_fn([x for x in example_batch['tags']], 128) |
| return model_inputs |
|
|
| def postprocess_text(preds, labels): |
| preds = [pred.strip() for pred in preds] |
| labels = [label.strip() for label in labels] |
|
|
| |
| preds = ["\n".join(nltk.sent_tokenize(pred)) for pred in preds] |
| labels = ["\n".join(nltk.sent_tokenize(label)) for label in labels] |
|
|
| return preds, labels |
|
|
|
|
| def compute_metrics(eval_preds): |
| preds, labels = eval_preds |
| if isinstance(preds, tuple): |
| preds = preds[0] |
| decoded_preds = tokenizer.batch_decode(preds, skip_special_tokens=True) |
| if ignore_pad_token_for_loss: |
| |
| labels = np.where(labels != -100, labels, tokenizer.pad_token_id) |
| decoded_labels = tokenizer.batch_decode(labels, skip_special_tokens=True) |
|
|
| |
| decoded_preds, decoded_labels = postprocess_text(decoded_preds, |
| decoded_labels) |
|
|
| result = metric.compute(predictions=decoded_preds, |
| references=decoded_labels, |
| use_stemmer=True) |
| result = {k: round(v * 100, 4) for k, v in result.items()} |
| prediction_lens = [ |
| np.count_nonzero(pred != tokenizer.pad_token_id) for pred in preds |
| ] |
| result["gen_len"] = np.mean(prediction_lens) |
| return result |
|
|
| def load_csv_as_dict(file_path): |
| with open(file_path, mode='r') as csv_file: |
| reader = csv.reader(csv_file) |
| result = {rows[0]: rows[1] for rows in reader} |
| return result |
|
|
| image_encoder_model = "google/vit-base-patch16-224" |
| text_decode_model = "Thouph/GPT-E6-small" |
|
|
| model = VisionEncoderDecoderModel.from_encoder_decoder_pretrained( |
| image_encoder_model, text_decode_model) |
|
|
| model.eval() |
| for p in model.parameters(): |
| p.requires_grad = False |
|
|
| |
| for layer in model.decoder.transformer.h: |
| layer.crossattention.train() |
| for p in layer.crossattention.parameters(): |
| p.requires_grad = True |
| layer.ln_cross_attn.train() |
| for p in layer.ln_cross_attn.parameters(): |
| p.requires_grad = True |
|
|
| |
| feature_extractor = AutoFeatureExtractor.from_pretrained(image_encoder_model) |
| |
| tokenizer = AutoTokenizer.from_pretrained("Thouph/six_tokenizer_filtered_space_merge") |
|
|
| |
| tokenizer.pad_token = tokenizer.eos_token |
|
|
| |
| model.config.eos_token_id = tokenizer.eos_token_id |
| model.config.decoder_start_token_id = tokenizer.bos_token_id |
| model.config.pad_token_id = tokenizer.pad_token_id |
| output_dir = "vit-gpt-model" |
| model.save_pretrained(output_dir) |
| for name, param in model.named_parameters(): |
| if "crossattention" not in name: |
| param.requires_grad = False |
| feature_extractor.save_pretrained(output_dir) |
| tokenizer.save_pretrained(output_dir) |
|
|
|
|
|
|
| dataset = load_dataset('csv', data_files=r"posts-2023-04-17_MD5_caption_sifted_no_symbol_purged_folder.csv") |
| print(dataset) |
| def add_image_path(example): |
| image_name = [i + '.jpg' for i in example["image_id"]] |
| folder_name=example["folder_name"] |
| image_path = [os.path.join(rf"/home/user/dump_small/{folder_name[i]}", image_name[i]) for i in range(len(image_name))] |
| example['image_path'] = image_path |
| return example |
|
|
| ds = dataset.map(add_image_path, batched=True, batch_size=8192)["train"] |
| print(ds) |
|
|
| ds = ds.train_test_split(test_size=0.02) |
| print(ds['train'][0:2]) |
| ds.set_transform(preprocess_fn) |
| print(ds['train'][0:2]) |
|
|
|
|
| training_args = Seq2SeqTrainingArguments( |
| predict_with_generate=True, |
| evaluation_strategy="steps", |
| eval_steps=100, |
| gradient_accumulation_steps=4, |
| per_device_train_batch_size=128, |
| weight_decay=0.1, |
| max_steps=10000, |
| warmup_steps=1000, |
| logging_strategy="steps", |
| save_steps=5000, |
| fp16=True, |
| tpu_num_cores=8, |
| per_device_eval_batch_size=128, |
| output_dir="image-captioning-output", |
| learning_rate=5e-4, |
| lr_scheduler_type="cosine", |
| ) |
|
|
| def collate_fn(batch): |
| return { |
| 'pixel_values': torch.stack([x['pixel_values'] for x in batch]), |
| 'labels': torch.tensor([x['labels'] for x in batch]) |
| } |
|
|
| metric = evaluate.load("rouge") |
| ignore_pad_token_for_loss = True |
|
|
| |
| trainer = Seq2SeqTrainer( |
| model=model, |
| tokenizer=feature_extractor, |
| args=training_args, |
| compute_metrics=compute_metrics, |
| train_dataset=ds['train'], |
| eval_dataset=ds['test'], |
| data_collator=collate_fn, |
| ) |
|
|
|
|
| trainer.train() |
|
|
|
|
| trainer.save_model("image-captioning-output1") |
| tokenizer.save_pretrained("image-captioning-output1") |
|
|