FastCDM / scripts /app.py
BinyangQiu
Impl Dockerfile
c5b5c64
import gradio as gr
import cv2
from pathlib import Path
from fastcdm import FastCDM
CHROMEDRIVER_PATH = Path("driver/chromedriver")
def _wrap_latex(s: str) -> str:
s = s or ""
return s if s.strip().startswith("$$") else f"$$ {s} $$"
def preview_latex_gt(gt: str) -> str:
return _wrap_latex(gt)
def preview_latex_pred(pred: str) -> str:
return _wrap_latex(pred)
def compute_fastcdm(gt: str, pred: str):
print("-" * 20)
print(" gt:", gt)
print("pred:", pred)
print("-" * 20)
driver_path = str(CHROMEDRIVER_PATH) if CHROMEDRIVER_PATH.exists() else None
fastcdm = FastCDM(chromedriver=driver_path)
try:
f1, recall, precision, vis_img = fastcdm.compute(gt, pred, visualize=True)
if vis_img is not None:
vis_rgb = cv2.cvtColor(vis_img, cv2.COLOR_BGR2RGB)
else:
vis_rgb = None
metrics_md = f"**CDM得分(F1)**: {f1:.4f} \n**召回率**: {recall:.4f} \n**准确率**: {precision:.4f}"
return metrics_md, vis_rgb
finally:
fastcdm.close()
with gr.Blocks(title="FastCDM 可视化") as demo:
gr.Markdown("# FastCDM 可视化")
with gr.Row():
with gr.Column():
gt_input = gr.Textbox(
label="GT (LaTeX)",
lines=4,
placeholder="输入GT公式,例如: \\frac{1}{2}",
)
gt_md = gr.Markdown(
value="",
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
],
)
pred_input = gr.Textbox(
label="Pred (LaTeX)",
lines=4,
placeholder="输入Pred公式,例如: \\frac{1}{2}",
)
pred_md = gr.Markdown(
value="",
latex_delimiters=[
{"left": "$$", "right": "$$", "display": True},
{"left": "$", "right": "$", "display": False},
{"left": "\\(", "right": "\\)", "display": False},
{"left": "\\[", "right": "\\]", "display": True},
],
)
submit_btn = gr.Button("提交并评估")
with gr.Column():
metrics_out = gr.Markdown(label="评估指标")
vis_out = gr.Image(type="numpy", label="匹配可视化", format="png")
gt_input.change(fn=preview_latex_gt, inputs=gt_input, outputs=gt_md)
pred_input.change(fn=preview_latex_pred, inputs=pred_input, outputs=pred_md)
submit_btn.click(
fn=compute_fastcdm,
inputs=[gt_input, pred_input],
outputs=[metrics_out, vis_out],
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)