| | import gradio as gr |
| | import cv2 |
| | import numpy as np |
| | import os |
| | import tempfile |
| | import time |
| | import axengine as axe |
| | import common |
| | import imgproc |
| | import socket |
| |
|
| | rgb_range=255 |
| | scale=2 |
| | def from_numpy(x): |
| | return x if isinstance(x, np.ndarray) else np.array(x) |
| |
|
| | def quantize(img, rgb_range): |
| | pixel_range = 255 / rgb_range |
| | return np.round(np.clip(img * pixel_range, 0, 255)) / pixel_range |
| |
|
| | |
| | def init_SRmodel(EDSR_path="../model_convert/axmodel/edsr_baseline_x2_1.axmodel", |
| | ESPCN_path="../model_convert/axmodel/espcn_x2_T9.axmodel"): |
| | |
| | EDSR_session = axe.InferenceSession(EDSR_path) |
| | ESPCN_session = axe.InferenceSession(ESPCN_path) |
| |
|
| | return [EDSR_session, ESPCN_session] |
| |
|
| | SR_sessions=init_SRmodel() |
| |
|
| | def EDSR_infer(frame, EDSR_session=SR_sessions[0]): |
| | output_names = [x.name for x in EDSR_session.get_outputs()] |
| | input_name = EDSR_session.get_inputs()[0].name |
| | |
| | lr_y_image, = common.set_channel(frame, n_channels=3) |
| | lr_y_image, = common.np_prepare(lr_y_image, rgb_range=rgb_range) |
| | |
| | sr = EDSR_session.run(output_names, {input_name: lr_y_image}) |
| | |
| | if isinstance(sr, (list, tuple)): |
| | sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr] |
| | else: |
| | sr = from_numpy(sr) |
| |
|
| | sr = quantize(sr, rgb_range).squeeze(0) |
| | normalized = sr * 255 / rgb_range |
| | ndarr = normalized.transpose(1, 2, 0).astype(np.uint8) |
| |
|
| | return ndarr |
| |
|
| | def ESPCN_infer(frame, ESPCN_session=SR_sessions[1]): |
| | |
| | output_names = [x.name for x in ESPCN_session.get_outputs()] |
| | input_name = ESPCN_session.get_inputs()[0].name |
| |
|
| | lr_y_image, lr_cb_image, lr_cr_image = imgproc.preprocess_one_frame(frame) |
| | bic_cb_image = cv2.resize(lr_cb_image, |
| | (int(lr_cb_image.shape[1] * scale), |
| | int(lr_cb_image.shape[0] * scale)), |
| | interpolation=cv2.INTER_CUBIC) |
| | bic_cr_image = cv2.resize(lr_cr_image, |
| | (int(lr_cr_image.shape[1] * scale), |
| | int(lr_cr_image.shape[0] * scale)), |
| | interpolation=cv2.INTER_CUBIC) |
| | |
| | sr = ESPCN_session.run(output_names, {input_name: lr_y_image}) |
| | |
| | if isinstance(sr, (list, tuple)): |
| | sr = from_numpy(sr[0]) if len(sr) == 1 else [from_numpy(x) for x in sr] |
| | else: |
| | sr = from_numpy(sr) |
| |
|
| | ndarr = imgproc.array_to_image(sr) |
| | sr_y_image = ndarr.astype(np.float32) / 255.0 |
| | sr_ycbcr_image = cv2.merge([sr_y_image[:, :, 0], bic_cb_image, bic_cr_image]) |
| | sr_image = imgproc.ycbcr_to_bgr(sr_ycbcr_image) |
| | sr_image = np.clip(sr_image* 255.0, 0 , 255).astype(np.uint8) |
| |
|
| | return sr_image |
| |
|
| | |
| | |
| | |
| | def EDSR_MODEL(input_data, is_video=False): |
| |
|
| | if is_video: |
| | output_frames = [] |
| | for frame in input_data: |
| |
|
| | out = EDSR_infer(frame=frame) |
| | output_frames.append(out) |
| | return output_frames |
| | else: |
| | out = EDSR_infer(frame=input_data) |
| | return out |
| |
|
| | def ESPCN_MODEL(input_data, is_video=False): |
| | if is_video: |
| | output_frames = [] |
| | for frame in input_data: |
| | out = ESPCN_infer(frame=frame) |
| | output_frames.append(out) |
| | return output_frames |
| | else: |
| | out = ESPCN_infer(frame=input_data) |
| | return out |
| |
|
| | |
| | |
| | |
| | class AppState: |
| | def __init__(self): |
| | self.original_img = None |
| | self.sr_img = None |
| | self.is_video = False |
| |
|
| | app_state = AppState() |
| |
|
| | |
| | |
| | |
| | def process_super_resolution(input_file, model_choice): |
| | global app_state |
| | if input_file is None: |
| | raise gr.Error("请先上传图片或视频!") |
| |
|
| | file_path = input_file |
| | app_state = AppState() |
| | info_text = "" |
| |
|
| | is_video = any(ext in file_path.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']) |
| |
|
| | if is_video: |
| | |
| | cap = cv2.VideoCapture(file_path) |
| | total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT)) |
| | fps = cap.get(cv2.CAP_PROP_FPS) |
| | info_text += f"🎬 视频信息:\n- 总帧数: {total_frames}\n- 帧率: {fps:.2f} FPS\n" |
| | frames = [] |
| | while True: |
| | ret, frame = cap.read() |
| | if not ret: |
| | break |
| | frames.append(frame) |
| | cap.release() |
| |
|
| | model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL |
| | start_time = time.time() |
| | output_data = model_func(frames, is_video=True) |
| | infer_time = time.time() - start_time |
| | info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n" |
| |
|
| | full_video_path = os.path.join(tempfile.gettempdir(), f"sr_video_x2.mp4") |
| | h_out, w_out = output_data[0].shape[:2] |
| | info_text += f"- 超分后尺寸: {w_out} x {h_out}\n" |
| | fourcc = cv2.VideoWriter_fourcc(*'mp4v') |
| | out_video = cv2.VideoWriter(full_video_path, fourcc, fps, (w_out, h_out)) |
| | for frame in output_data: |
| | out_video.write(frame) |
| | out_video.release() |
| |
|
| | app_state.is_video = True |
| |
|
| | return ( |
| | gr.update(value=None, visible=False), |
| | gr.update(visible=False), |
| | gr.update(visible=False), |
| | gr.update(value="当前: 无", visible=False), |
| | gr.update(value=full_video_path, visible=True), |
| | gr.update(value=full_video_path, visible=True), |
| | gr.update(visible=False), |
| | info_text |
| | ) |
| |
|
| | else: |
| | |
| | img = cv2.imread(file_path) |
| | if img is None: |
| | raise gr.Error("无法读取图片!") |
| | h, w = img.shape[:2] |
| | info_text += f"🖼️ 图片信息:\n- 原始尺寸: {w} x {h}\n" |
| |
|
| | app_state.original_img = img.copy() |
| | model_func = EDSR_MODEL if model_choice == "EDSR_MODEL" else ESPCN_MODEL |
| | start_time = time.time() |
| | sr_img = model_func(img, is_video=False) |
| | infer_time = time.time() - start_time |
| | info_text += f"\n⏱️ 推理时间: {infer_time:.2f} 秒\n" |
| |
|
| | h_out, w_out = sr_img.shape[:2] |
| | info_text += f"- 超分后尺寸: {w_out} x {h_out}\n" |
| |
|
| | sr_img_path = os.path.join(tempfile.gettempdir(), f"sr_image_x2.png") |
| | cv2.imwrite(sr_img_path, sr_img) |
| | app_state.sr_img = sr_img |
| |
|
| | app_state.is_video = False |
| |
|
| | |
| | return ( |
| | gr.update(value=app_state.original_img[:, :, ::-1], visible=True), |
| | gr.update(visible=True), |
| | gr.update(visible=True), |
| | gr.update(value="当前: 原图", visible=True), |
| | gr.update(visible=False), |
| | gr.update(visible=False), |
| | gr.update(value=sr_img_path, visible=True), |
| | info_text |
| | ) |
| |
|
| | |
| | |
| | |
| | def show_original(): |
| | if app_state.original_img is None: |
| | return gr.update(), gr.update() |
| | |
| | rgb_img = app_state.original_img[:, :, ::-1] |
| | return gr.update(value=rgb_img), gr.update(value="当前: 原图") |
| |
|
| | def show_sr(): |
| | if app_state.sr_img is None: |
| | return gr.update(), gr.update() |
| | rgb_img = app_state.sr_img[:, :, ::-1] |
| | return gr.update(value=rgb_img), gr.update(value="当前: 超分图") |
| |
|
| | |
| | |
| | |
| | with gr.Blocks(title="超分辨率可视化工具") as demo: |
| | gr.Markdown("## 🚀 超分辨率模型效果可视化") |
| | gr.Markdown("上传图片或视频,选择模型,点击箭头切换原图/超分图!") |
| |
|
| | input_file = gr.File( |
| | label="📂 上传图片或视频", |
| | file_types=["image", "video"], |
| | file_count="single" |
| | ) |
| |
|
| | with gr.Row(): |
| | model_choice = gr.Radio( |
| | choices=["EDSR_MODEL", "ESPCN_MODEL"], |
| | value="EDSR_MODEL", |
| | label="🔍 选择超分辨率模型" |
| | ) |
| | run_btn = gr.Button("🚀 开始超分", variant="primary") |
| |
|
| | |
| | with gr.Column(visible=False) as image_section: |
| | image_label = gr.Textbox(value="当前: 原图", interactive=False, lines=1) |
| | image_display = gr.Image( |
| | label="🖼️ 图像显示", |
| | width=800, |
| | height=600 |
| | ) |
| | with gr.Row(): |
| | btn_original = gr.Button("◀ 原图") |
| | btn_sr = gr.Button("超分图 ▶") |
| |
|
| | |
| | output_video_player = gr.Video( |
| | label="▶️ 超分视频(高分辨率)", |
| | visible=False, |
| | height=450 |
| | ) |
| |
|
| | with gr.Row(): |
| | download_image = gr.File(label="📥 下载超分图片(原图)", visible=False) |
| | download_video = gr.File(label="📥 下载超分视频(完整分辨率)", visible=False) |
| |
|
| | info_box = gr.Textbox(label="📊 处理信息", lines=6, interactive=False) |
| |
|
| | run_btn.click( |
| | fn=process_super_resolution, |
| | inputs=[input_file, model_choice], |
| | outputs=[ |
| | image_display, |
| | btn_original, |
| | btn_sr, |
| | image_label, |
| | output_video_player, |
| | download_video, |
| | download_image, |
| | info_box |
| | ] |
| | ) |
| |
|
| | btn_original.click(show_original, outputs=[image_display, image_label]) |
| | btn_sr.click(show_sr, outputs=[image_display, image_label]) |
| |
|
| | def toggle_ui(file): |
| | if file is None: |
| | return ( |
| | gr.update(visible=False), |
| | gr.update(visible=False), |
| | gr.update(visible=False), |
| | gr.update(visible=False) |
| | ) |
| | if any(ext in file.lower() for ext in ['.mp4', '.avi', '.mov', '.mkv']): |
| | return ( |
| | gr.update(visible=False), |
| | gr.update(visible=False), |
| | gr.update(visible=True), |
| | gr.update(visible=True) |
| | ) |
| | else: |
| | return ( |
| | gr.update(visible=True), |
| | gr.update(visible=True), |
| | gr.update(visible=False), |
| | gr.update(visible=False) |
| | ) |
| |
|
| | input_file.change( |
| | fn=toggle_ui, |
| | inputs=input_file, |
| | outputs=[ |
| | image_section, |
| | download_image, |
| | output_video_player, |
| | download_video |
| | ] |
| | ) |
| | def get_local_ip(): |
| | """获取本机局域网IP地址""" |
| | try: |
| | |
| | with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as s: |
| | s.connect(("8.8.8.8", 80)) |
| | ip = s.getsockname()[0] |
| | return ip |
| | except Exception: |
| | |
| | return "127.0.0.1" |
| |
|
| |
|
| | if __name__ == "__main__": |
| | |
| |
|
| | server_port = 7860 |
| | server_name = "0.0.0.0" |
| | |
| | |
| | local_ip = get_local_ip() |
| | |
| | |
| | print("\n" + "="*50) |
| | print("🌐 SuperResolution 超分辨率 Web UI 已启动!") |
| | print(f"🔗 本地访问: http://127.0.0.1:{server_port}") |
| | if local_ip != "127.0.0.1": |
| | print(f"🔗 局域网访问: http://{local_ip}:{server_port}") |
| | print("="*50 + "\n") |
| |
|
| | |
| | demo.launch( |
| | server_name=server_name, |
| | server_port=server_port, |
| | theme=gr.themes.Soft() |
| | ) |