Login
main >   tillslips >  


import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from transformers import AutoModelForVision2Seq, AutoTokenizer, AutoProcessor, AutoModelForCausalLM
import wandb
from PIL import Image
import json
import os
from tqdm import tqdm
import numpy as np
import torchvision.ops as ops
from torchvision.transforms import functional as F
import pytesseract
from typing import Dict, List, Tuple
from dataclasses import dataclass
import cv2

@dataclass
class ExtractedField:
    text: str
    bbox: List[int]
    field_type: str
    confidence: float

class QwenVLTillSlipDataset(Dataset):
    def __init__(self, 
                 image_dir: str, 
                 processor,
                 tokenizer,
                 max_length: int = 512):
        self.image_dir = image_dir
        self.processor = processor
        self.tokenizer = tokenizer
        self.max_length = max_length
        self.image_files = [f for f in os.listdir(image_dir) 
                          if f.endswith(('.jpg', '.png', '.jpeg'))]
        
        # Initialize detection model
        self.detector = self.initialize_detector()
        
    def initialize_detector(self):
        """Initialize the detection model for finding regions"""
        model = fasterrcnn_resnet50_fpn(pretrained=True)
        in_features = model.roi_heads.box_predictor.cls_score.in_features
        model.roi_heads.box_predictor = ops.FastRCNNPredictor(in_features, 5)  # 4 field types + background
        
        # Load pre-trained weights if available
        if os.path.exists('till_slip_detector.pth'):
            model.load_state_dict(torch.load('till_slip_detector.pth'))
        
        model.eval()
        return model
    
    def detect_and_extract_fields(self, image: Image.Image) -> List[ExtractedField]:
        """Detect regions and extract text using OCR"""
        # Convert image for detection
        image_tensor = F.to_tensor(image).unsqueeze(0)
        
        # Get detections
        with torch.no_grad():
            predictions = self.detector(image_tensor)
        
        # Filter predictions
        boxes = predictions[0]['boxes'][predictions[0]['scores'] > 0.5]
        scores = predictions[0]['scores'][predictions[0]['scores'] > 0.5]
        labels = predictions[0]['labels'][predictions[0]['scores'] > 0.5]
        
        field_types = ['time', 'price', 'quantity', 'item_name']
        extracted_fields = []
        
        # Extract text from each region
        for box, score, label in zip(boxes, scores, labels):
            x1, y1, x2, y2 = map(int, box.tolist())
            region = image.crop((x1, y1, x2, y2))
            
            # OCR the region
            text = pytesseract.image_to_string(region, config='--psm 7').strip()
            
            if text:  # Only add if text was found
                extracted_fields.append(ExtractedField(
                    text=text,
                    bbox=[x1, y1, x2, y2],
                    field_type=field_types[label.item() - 1],
                    confidence=score.item()
                ))
        
        return extracted_fields
    
    def create_instruction_text(self, fields: List[ExtractedField]) -> Tuple[str, str]:
        """Create instruction and response text for Qwen-VL training"""
        instruction = "Extract the following information from this till slip: time, items with quantities and prices."
        
        # Group fields by type
        grouped_fields = {}
        for field in fields:
            if field.field_type not in grouped_fields:
                grouped_fields[field.field_type] = []
            grouped_fields[field.field_type].append(field)
        
        # Create structured response
        response_parts = []
        
        # Add time if found
        if 'time' in grouped_fields:
            time_field = max(grouped_fields['time'], key=lambda x: x.confidence)
            response_parts.append(f"Time: {time_field.text}")
        
        # Process items
        items = []
        if 'item_name' in grouped_fields:
            for item_field in grouped_fields['item_name']:
                item_info = [f"Item: {item_field.text}"]
                
                # Find quantity and price near this item
                item_y = (item_field.bbox[1] + item_field.bbox[3]) / 2
                
                # Match quantity
                if 'quantity' in grouped_fields:
                    quantities = [q for q in grouped_fields['quantity'] 
                                if abs((q.bbox[1] + q.bbox[3])/2 - item_y) < 20]
                    if quantities:
                        quantity = max(quantities, key=lambda x: x.confidence)
                        item_info.append(f"Quantity: {quantity.text}")
                
                # Match price
                if 'price' in grouped_fields:
                    prices = [p for p in grouped_fields['price'] 
                            if abs((p.bbox[1] + p.bbox[3])/2 - item_y) < 20]
                    if prices:
                        price = max(prices, key=lambda x: x.confidence)
                        item_info.append(f"Price: {price.text}")
                
                items.append(" | ".join(item_info))
        
        if items:
            response_parts.append("Items:\n" + "\n".join(items))
        
        response = "\n".join(response_parts)
        return instruction, response
    
    def __len__(self):
        return len(self.image_files)
    
    def __getitem__(self, idx):
        image_file = self.image_files[idx]
        image_path = os.path.join(self.image_dir, image_file)
        image = Image.open(image_path).convert('RGB')
        
        # Detect and extract fields
        extracted_fields = self.detect_and_extract_fields(image)
        
        # Create instruction and response
        instruction, response = self.create_instruction_text(extracted_fields)
        
        # Process for Qwen-VL
        inputs = self.processor(
            images=image,
            text=instruction,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        )
        
        # Add response as labels
        response_ids = self.tokenizer(
            response,
            return_tensors="pt",
            padding=True,
            truncation=True,
            max_length=self.max_length
        )["input_ids"]
        
        inputs["labels"] = response_ids
        
        # Remove batch dimension
        return {k: v.squeeze(0) for k, v in inputs.items()}

def train_qwen_vl(
    model,
    train_loader,
    val_loader,
    optimizer,
    num_epochs,
    device,
    processor,
    wandb_project="qwen-vl-till-slip"
):
    wandb.init(project=wandb_project)
    
    model.train()
    model = model.to(device)
    
    for epoch in range(num_epochs):
        train_loss = 0
        val_loss = 0
        
        # Training loop
        model.train()
        for batch in tqdm(train_loader, desc=f'Epoch {epoch+1}/{num_epochs}'):
            batch = {k: v.to(device) for k, v in batch.items()}
            
            optimizer.zero_grad()
            outputs = model(**batch)
            loss = outputs.loss
            loss.backward()
            optimizer.step()
            
            train_loss += loss.item()
            
            wandb.log({
                "batch_loss": loss.item(),
                "learning_rate": optimizer.param_groups[0]['lr']
            })
        
        # Validation loop
        model.eval()
        with torch.no_grad():
            for batch in val_loader:
                batch = {k: v.to(device) for k, v in batch.items()}
                outputs = model(**batch)
                val_loss += outputs.loss.item()
                
                # Generate sample predictions
                if epoch % 5 == 0:  # Log examples every 5 epochs
                    predictions = model.generate(
                        input_ids=batch["input_ids"],
                        attention_mask=batch["attention_mask"],
                        max_length=200
                    )
                    
                    pred_texts = processor.batch_decode(predictions, skip_special_tokens=True)
                    actual_texts = processor.batch_decode(batch["labels"], skip_special_tokens=True)
                    
                    # Log example predictions
                    for pred, actual in zip(pred_texts[:2], actual_texts[:2]):
                        wandb.log({
                            "predictions": wandb.Table(
                                columns=["Predicted", "Actual"],
                                data=[[pred, actual]]
                            )
                        })
        
        avg_train_loss = train_loss / len(train_loader)
        avg_val_loss = val_loss / len(val_loader)
        
        wandb.log({
            "epoch": epoch,
            "avg_train_loss": avg_train_loss,
            "avg_val_loss": avg_val_loss
        })
        
        print(f'Epoch {epoch+1}/{num_epochs}:')
        print(f'Average Training Loss: {avg_train_loss:.4f}')
        print(f'Average Validation Loss: {avg_val_loss:.4f}')

def main():
    wandb.login()
    
    # Initialize Qwen-VL 2.5
    model_name = "Qwen/Qwen-VL-Chat"  # Update to 2.5 when available
    processor = AutoProcessor.from_pretrained(model_name)
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)
    
    # Create datasets
    train_dataset = QwenVLTillSlipDataset(
        image_dir="path/to/train/images",
        processor=processor,
        tokenizer=tokenizer
    )
    
    val_dataset = QwenVLTillSlipDataset(
        image_dir="path/to/val/images",
        processor=processor,
        tokenizer=tokenizer
    )
    
    # Create dataloaders
    train_loader = DataLoader(
        train_dataset,
        batch_size=4,
        shuffle=True
    )
    
    val_loader = DataLoader(
        val_dataset,
        batch_size=4
    )
    
    # Initialize optimizer
    optimizer = torch.optim.AdamW(model.parameters(), lr=2e-5)
    
    # Set device
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    
    # Train the model
    train_qwen_vl(
        model=model,
        train_loader=train_loader,
        val_loader=val_loader,
        optimizer=optimizer,
        num_epochs=10,
        device=device,
        processor=processor
    )
    
    # Save the fine-tuned model
    model.save_pretrained("fine_tuned_qwen_vl_till_slip")
    processor.save_pretrained("fine_tuned_qwen_vl_till_slip")
    
    wandb.finish()

# Inference function
def process_till_slip_with_qwen(model, processor, image_path: str) -> str:
    """
    Process a till slip image with fine-tuned Qwen-VL
    """
    image = Image.open(image_path).convert('RGB')
    
    # Create instruction
    instruction = "Extract the following information from this till slip: time, items with quantities and prices."
    
    # Process image and instruction
    inputs = processor(
        images=image,
        text=instruction,
        return_tensors="pt",
        padding=True,
        truncation=True
    )
    
    # Generate response
    model.eval()
    with torch.no_grad():
        outputs = model.generate(
            input_ids=inputs["input_ids"],
            attention_mask=inputs["attention_mask"],
            max_length=200
        )
    
    # Decode response
    response = processor.decode(outputs[0], skip_special_tokens=True)
    return response

if __name__ == "__main__":
    main()
hidden1

hidden2