import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor, AutoModelForCausalLM
import wandb
from PIL import Image
import json
import os
from tqdm import tqdm
import numpy as np
import torchvision.ops as ops
from torchvision.transforms import functional as F
import pytesseract
from typing import Dict, List, Tuple
from dataclasses import dataclass
import cv2
@dataclass
class ExtractedField:
text: str
bbox: List[int]
field_type: str
confidence: float
class QwenVLTillSlipDataset(Dataset):
def __init__(self,
image_dir: str,
processor,
tokenizer,
max_length: int = 512):
self.image_dir = image_dir
self.processor = processor
self.tokenizer = tokenizer
self.max_length = max_length
self.image_files = [f for f in os.listdir(image_dir)
if f.endswith(('.jpg', '.png', '.jpeg'))]
# Initialize detection model
self.detector = self.initialize_detector()
def initialize_detector(self):
"""Initialize the detection model for finding regions"""
model = fasterrcnn_resnet50_fpn(pretrained=True)
in_features = model.roi_heads.box_predictor.cls_score.in_features
model.roi_heads.box_predictor = ops.FastRCNNPredictor(in_features, 5) # 4 field types + background
# Load pre-trained weights if available
if os.path.exists('till_slip_detector.pth'):
model.load_state_dict(torch.load('till_slip_detector.pth'))
model.eval()
return model
def detect_and_extract_fields(self, image: Image.Image) -> List[ExtractedField]:
"""Detect regions and extract text using OCR"""
# Convert image for detection
image_tensor = F.to_tensor(image).unsqueeze(0)
# Get detections
with torch.no_grad():
predictions = self.detector(image_tensor)
# Filter predictions
boxes = predictions[0]['boxes'][predictions[0]['scores'] > 0.5]
scores = predictions[0]['scores'][predictions[0]['scores'] > 0.5]
labels = predictions[0]['labels'][predictions[0]['scores'] > 0.5]
field_types = ['time', 'price', 'quantity', 'item_name']
extracted_fields = []
# Extract text from each region
for box, score, label in zip(boxes, scores, labels):
x1, y1, x2, y2 = map(int, box.tolist())
region = image.crop((x1, y1, x2, y2))
# OCR the region
text = pytesseract.image_to_string(region, config='--psm 7').strip()
if text: # Only add if text was found
extracted_fields.append(ExtractedField(
text=text,
bbox=[x1, y1, x2, y2],
field_type=field_types[label.item() - 1],
confidence=score.item()
))
return extracted_fields
def create_instruction_text(self, fields: List[ExtractedField]) -> Tuple[str, str]:
"""Create instruction and response text for Qwen-VL training"""
instruction = "Extract the following information from this till slip: time, items with quantities and prices."
# Group fields by type
grouped_fields = {}
for field in fields:
if field.field_type not in grouped_fields:
grouped_fields[field.field_type] = []
grouped_fields[field.field_type].append(field)
# Create structured response
response_parts = []
# Add time if found
if 'time' in grouped_fields:
time_field = max(grouped_fields['time'], key=lambda x: x.confidence)
response_parts.append(f"Time: {time_field.text}")
# Process items
items = []
if 'item_name' in grouped_fields:
for item_field in grouped_fields['item_name']:
item_info = [f"Item: {item_field.text}"]
# Find quantity and price near this item
item_y = (item_field.bbox[1] + item_field.bbox[3]) / 2
# Match quantity
if 'quantity' in grouped_fields:
quantities = [q for q in grouped_fields['quantity']
if abs((q.bbox[1] + q.bbox[3])/2 - item_y) < 20]
if quantities:
quantity = max(quantities, key=lambda x: x.confidence)
item_info.append(f"Quantity: {quantity.text}")
# Match price
if 'price' in grouped_fields:
prices = [p for p in grouped_fields['price']
if abs((p.bbox[1] + p.bbox[3])/2 - item_y) < 20]
if prices:
price = max(prices, key=lambda x: x.confidence)
item_info.append(f"Price: {price.text}")
items.append(" | ".join(item_info))
if items:
response_parts.append("Items:\n" + "\n".join(items))
response = "\n".join(response_parts)
return instruction, response
def __len__(self):
return len(self.image_files)
def __getitem__(self, idx):
image_file = self.image_files[idx]
image_path = os.path.join(self.image_dir, image_file)
image = Image.open(image_path).convert('RGB')
# Detect and extract fields
extracted_fields = self.detect_and_extract_fields(image)
# Create instruction and response
instruction, response = self.create_instruction_text(extracted_fields)
# Process for Qwen-VL
inputs = self.processor(
images=image,
text=instruction,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)
# Add response as labels
response_ids = self.tokenizer(
response,
return_tensors="pt",
padding=True,
truncation=True,
max_length=self.max_length
)["input_ids"]
inputs["labels"] = response_ids
# Remove batch dimension
return {k: v.squeeze(0) for k, v in inputs.items()}
def train_qwen_vl(
model,
train_loader,
val_loader,
optimizer,
num_epochs,
device,
processor,
wandb_project="qwen-vl-till-slip"
):
wandb.init(project=wandb_project)
model.train()
model = model.to(device)
for epoch in range(num_epochs):
train_loss = 0
val_loss = 0
# Training loop
model.train()
for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
batch = {k: v.to(device) for k, v in batch.items()}
optimizer.zero_grad()
outputs = model(**batch)
loss = outputs.loss
loss.backward()
optimizer.step()
train_loss += loss.item()
wandb.log({
"batch_loss": loss.item(),
"learning_rate": optimizer.param_groups[0]['lr']
})
# Validation loop
model.eval()
with torch.no_grad():
for batch in val_loader:
batch = {k: v.to(device) for k, v in batch.items()}
outputs = model(**batch)
val_loss += outputs.loss.item()
# Generate sample predictions
if epoch % 5 == 0: # Log examples every 5 epochs
predictions = model.generate(
input_ids=batch["input_ids"],
attention_mask=batch["attention_mask"],
max_length=200
)
pred_texts = processor.batch_decode(predictions, skip_special_tokens=True)
actual_texts = processor.batch_decode(batch["labels"], skip_special_tokens=True)
# Log example predictions
for pred, actual in zip(pred_texts[:2], actual_texts[:2]):
wandb.log({
"predictions": wandb.Table(
columns=["Predicted", "Actual"],
data=[[pred, actual]]
)
})
avg_train_loss = train_loss / len(train_loader)
avg_val_loss = val_loss / len(val_loader)
wandb.log({
"epoch": epoch,
"avg_train_loss": avg_train_loss,
"avg_val_loss": avg_val_loss
})
print(f'Epoch {epoch+1}/{num_epochs}:')
print(f'Average Training Loss: {avg_train_loss:.4f}')
print(f'Average Validation Loss: {avg_val_loss:.4f}')
def main():
wandb.login()
# Initialize Qwen-VL 2.5
model_name = "Qwen/Qwen-VL-Chat" # Update to 2.5 when available
processor = AutoProcessor.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
model = AutoModelForCausalLM.from_pretrained(model_name)
# Create datasets
train_dataset = QwenVLTillSlipDataset(
image_dir="path/to/train/images",
processor=processor,
tokenizer=tokenizer
)
val_dataset = QwenVLTillSlipDataset(
image_dir="path/to/val/images",
processor=processor,
tokenizer=tokenizer
)
# Create dataloaders
train_loader = DataLoader(
train_dataset,
batch_size=4,
shuffle=True
)
val_loader = DataLoader(
val_dataset,
batch_size=4
)
# Initialize optimizer
optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
# Set device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Train the model
train_qwen_vl(
model=model,
train_loader=train_loader,
val_loader=val_loader,
optimizer=optimizer,
num_epochs=10,
device=device,
processor=processor
)
# Save the fine-tuned model
model.save_pretrained("fine_tuned_qwen_vl_till_slip")
processor.save_pretrained("fine_tuned_qwen_vl_till_slip")
wandb.finish()
# Inference function
def process_till_slip_with_qwen(model, processor, image_path: str) -> str:
"""
Process a till slip image with fine-tuned Qwen-VL
"""
image = Image.open(image_path).convert('RGB')
# Create instruction
instruction = "Extract the following information from this till slip: time, items with quantities and prices."
# Process image and instruction
inputs = processor(
images=image,
text=instruction,
return_tensors="pt",
padding=True,
truncation=True
)
# Generate response
model.eval()
with torch.no_grad():
outputs = model.generate(
input_ids=inputs["input_ids"],
attention_mask=inputs["attention_mask"],
max_length=200
)
# Decode response
response = processor.decode(outputs[0], skip_special_tokens=True)
return response
if __name__ == "__main__":
main()