| | import os |
| | import random |
| | import io |
| | import av |
| | import cv2 |
| | import decord |
| | import imageio |
| | from decord import VideoReader |
| | import torch |
| | import numpy as np |
| | import math |
| | import torch.nn.functional as F |
| | decord.bridge.set_bridge("torch") |
| |
|
| | from transformers import AutoConfig, AutoModel |
| | config = AutoConfig.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True) |
| | model = AutoModel.from_pretrained("/fs-computility/video/heyinan/iv2hf/", trust_remote_code=True).to(config.device) |
| |
|
| |
|
| | def get_frame_indices(num_frames, vlen, sample='rand', fix_start=None, input_fps=1, max_num_frames=-1, start=None, end=None): |
| | start_frame, end_frame = 0, vlen |
| | if start is not None: |
| | start_frame = max(start_frame,int(start * input_fps)) |
| | if end is not None: |
| | end_frame = min(end_frame,int(end * input_fps)) |
| |
|
| | |
| | if start_frame >= end_frame: |
| | raise ValueError("Start frame index must be less than end frame index") |
| |
|
| | |
| | clip_length = end_frame - start_frame |
| |
|
| | if sample in ["rand", "middle"]: |
| | acc_samples = min(num_frames, clip_length) |
| | |
| | intervals = np.linspace(start=start_frame, stop=end_frame, num=acc_samples + 1).astype(int) |
| | ranges = [] |
| | for idx, interv in enumerate(intervals[:-1]): |
| | ranges.append((interv, intervals[idx + 1] - 1)) |
| | if sample == 'rand': |
| | try: |
| | frame_indices = [random.choice(range(x[0], x[1] + 1)) for x in ranges] |
| | except: |
| | frame_indices = np.random.permutation(clip_length)[:acc_samples] + start_frame |
| | frame_indices.sort() |
| | frame_indices = list(frame_indices) |
| | elif fix_start is not None: |
| | frame_indices = [x[0] + fix_start for x in ranges] |
| | elif sample == 'middle': |
| | frame_indices = [(x[0] + x[1]) // 2 for x in ranges] |
| | else: |
| | raise NotImplementedError |
| |
|
| | if len(frame_indices) < num_frames: |
| | padded_frame_indices = [frame_indices[-1]] * num_frames |
| | padded_frame_indices[:len(frame_indices)] = frame_indices |
| | frame_indices = padded_frame_indices |
| | elif "fps" in sample: |
| | output_fps = float(sample[3:]) |
| | duration = float(clip_length) / input_fps |
| | delta = 1 / output_fps |
| | frame_seconds = np.arange(0 + delta / 2, duration + delta / 2, delta) |
| | frame_indices = np.around(frame_seconds * input_fps).astype(int) + start_frame |
| | frame_indices = [e for e in frame_indices if e < end_frame] |
| | if max_num_frames > 0 and len(frame_indices) > max_num_frames: |
| | frame_indices = frame_indices[:max_num_frames] |
| | |
| | else: |
| | raise ValueError |
| | return frame_indices |
| |
|
| | def read_frames_decord( |
| | video_path, num_frames, sample='middle', fix_start=None, |
| | max_num_frames=-1, client=None, trimmed30=False, start=None, end=None |
| | ): |
| | num_threads = 1 if video_path.endswith('.webm') else 0 |
| |
|
| | video_reader = VideoReader(video_path, num_threads=num_threads) |
| | vlen = len(video_reader) |
| | |
| | fps = video_reader.get_avg_fps() |
| | duration = vlen / float(fps) |
| |
|
| | frame_indices = get_frame_indices( |
| | num_frames, vlen, sample=sample, fix_start=fix_start, |
| | input_fps=fps, max_num_frames=max_num_frames, start=start, end=end |
| | ) |
| |
|
| | frames = video_reader.get_batch(frame_indices) |
| | frames = frames.permute(0, 3, 1, 2) |
| | return frames, frame_indices, duration |
| |
|
| | def get_text_feature(model, texts): |
| | text_input = model.tokenizer(texts).to(model.device) |
| | text_features = model.encode_text(text_input) |
| | return text_features |
| | |
| | def get_similarity(video_feature, text_feature): |
| | video_feature = F.normalize(video_feature, dim=-1) |
| | text_feature = F.normalize(text_feature, dim=-1) |
| | sim_matrix = text_feature @ video_feature.T |
| | return sim_matrix |
| |
|
| | def get_top_videos(model, text_features, video_features, video_paths, texts): |
| | |
| |
|
| | video_features = F.normalize(video_features, dim=-1) |
| | text_features = F.normalize(text_features, dim=-1) |
| |
|
| | |
| | sim_matrix = text_features @ video_features.T |
| | |
| |
|
| | top_k = 5 |
| | sim_matrix_top_k = torch.topk(sim_matrix, top_k, dim=1)[1] |
| | softmax_sim_matrix = F.softmax(sim_matrix, dim=1) |
| |
|
| | retrieval_infos = {} |
| | for i in range(len(sim_matrix_top_k)): |
| | print("\n",texts[i]) |
| | retrieval_infos[texts[i]] = [] |
| | for j in range(top_k): |
| | print("top", j+1, ":", video_paths[sim_matrix_top_k[i][j]], "~prob:", sim_matrix[i][sim_matrix_top_k[i][j]].item()) |
| | retrieval_infos[texts[i]].append({"video": video_paths[sim_matrix_top_k[i][j]], "prob": sim_matrix[i][sim_matrix_top_k[i][j]].item(), "rank": j+1}) |
| | return retrieval_infos |
| |
|
| | if __name__=="__main__": |
| | video_features = [] |
| | demo_videos = ["video1.mp4","video2.mp4"] |
| | texts = ['a person talking', 'a logo', 'a building'] |
| | for video_path in demo_videos: |
| | frames, frame_indices, video_duration = read_frames_decord(video_path,8) |
| | frames = model.transform(frames).unsqueeze(0).to(model.device) |
| | with torch.no_grad(): |
| | video_feature = model.encode_vision(frames, test=True) |
| | video_features.append(video_feature) |
| | |
| | text_features = get_text_feature(model, texts) |
| | video_features = torch.cat(video_features, dim=0).to(text_features.dtype).to(config.device) |
| | results = get_top_videos(model, text_features, video_features, demo_videos, texts) |
| |
|
| |
|
| |
|