| import google.generativeai as genai |
| from google.generativeai.types import HarmBlockThreshold, HarmCategory |
| import gradio as gr |
| from PIL import Image, ImageDraw, ImageFont |
| import json |
|
|
| |
| async def get_bounding_boxes(prompt: str, image: str, api_key: str): |
| system_prompt = """ |
| You are a helpful assistant, who always responds with the bounding box and label with the explanation JSON based on the user input, and nothing else. |
| Your response can also include multiple bounding boxes and their labels in the list. |
| The values in the list should be integers. |
| Here are some example responses: |
| { |
| "explanation": "User asked for the bounding box of the dragon, so I will provide the bounding box of the dragon.", |
| "bounding_boxes": [ |
| {"label": "dragon", "box": [ymin, xmin, ymax, xmax]} |
| ] |
| } |
| { |
| "explanation": "User asked for the bounding box of the fruits which are red in color, so I will provide the bounding box of the Apple and the Tomato.", |
| "bounding_boxes": [ |
| {"label": "apple", "box": [ymin, xmin, ymax, xmax]}, |
| {"label": "tomato", "box": [ymin, xmin, ymax, xmax]} |
| ] |
| } |
| """.strip() |
| |
| prompt = f"Return the bounding boxes and labels of: {prompt}" |
|
|
| messages = [ |
| {"role": "user", "parts": [prompt, image]}, |
| ] |
|
|
| genai.configure(api_key=api_key) |
|
|
| generation_config = { |
| "temperature": 1, |
| "max_output_tokens": 8192, |
| "response_mime_type": "application/json", |
| } |
|
|
| model = genai.GenerativeModel( |
| model_name="gemini-1.5-flash", |
| generation_config=generation_config, |
| safety_settings={ |
| HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE, |
| HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE, |
| HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE, |
| HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE |
| }, |
| system_instruction=system_prompt |
| ) |
|
|
| try: |
| response = await model.generate_content_async(messages) |
| except Exception as e: |
| if "API key not valid" in str(e): |
| raise gr.Error( |
| "Invalid API key. Please provide a valid Gemini API key.") |
| elif "rate limit" in str(e).lower(): |
| raise gr.Error("Rate limit exceeded for the API key.") |
| else: |
| raise gr.Error(f"Failed to generate content: {str(e)}") |
|
|
| response_json = json.loads(response.text) |
|
|
| explanation = response_json["explanation"] |
| bounding_boxes = response_json["bounding_boxes"] |
|
|
| return bounding_boxes, explanation |
|
|
| |
| async def adjust_bounding_box(bounding_boxes, image): |
| width, height = image.size |
| adjusted_boxes = [] |
| for item in bounding_boxes: |
| label = item["label"] |
| ymin, xmin, ymax, xmax = [coord / 1000 for coord in item["box"]] |
| xmin *= width |
| xmax *= width |
| ymin *= height |
| ymax *= height |
| adjusted_boxes.append({"label": label, "box": [xmin, ymin, xmax, ymax]}) |
| return adjusted_boxes |
|
|
| |
| async def process_image(image, text, api_key): |
| if not api_key: |
| raise gr.Error("Please provide a Gemini API key.") |
|
|
| |
| image = Image.open(image) |
|
|
| |
| bounding_boxes, explanation = await get_bounding_boxes(text, image, api_key) |
|
|
| |
| adjusted_boxes = await adjust_bounding_box(bounding_boxes, image) |
|
|
| |
| draw = ImageDraw.Draw(image) |
| font = ImageFont.load_default(size=20) |
| |
| for item in adjusted_boxes: |
| box = item["box"] |
| label = item["label"] |
| draw.rectangle(box, outline="red", width=3) |
| |
| draw.text((box[0], box[1] - 25), label, fill="red", font=font) |
|
|
| |
| adjusted_boxes_str = "\n".join(f"{item['label']}: {item['box']}" for item in adjusted_boxes) |
|
|
| return explanation, image, adjusted_boxes_str |
|
|
| |
| async def gradio_app(image, text, api_key): |
| return await process_image(image, text, api_key) |
|
|
| |
| iface = gr.Interface( |
| fn=gradio_app, |
| inputs=[ |
| gr.Image(type="filepath"), |
| gr.Textbox(label="Object(s) to detect", value="person"), |
| gr.Textbox(label="Your Gemini API Key", type="password") |
| ], |
| outputs=[ |
| gr.Textbox(label="Explanation"), |
| gr.Image(type="pil", label="Output Image"), |
| gr.Textbox(label="Coordinates of the detected objects") |
| ], |
| title="OBJECT DETECTOR ✨", |
| description="Detect objects in images using the Gemini 1.5 Flash model.", |
| allow_flagging="never" |
| ) |
|
|
| iface.launch() |
|
|