#!/usr/bin/env python3
import json
from typing import Dict, List
from hazm import Normalizer, sent_tokenize, word_tokenize, Chunker
import arabic_reshaper
from bidi.algorithm import get_display
from pymongo import MongoClient
import os
from datetime import datetime
import tabula
import pandas as pd

class PersianTextProcessor:
    def __init__(self):
        # Initialize Hazm components
        self.normalizer = Normalizer()
        self.chunker = Chunker(model='resources/chunker.model')
        
        # MongoDB connection
        mongodb_uri = os.getenv('MONGODB_URI', 'mongodb://localhost:27017')
        self.client = MongoClient(mongodb_uri)
        self.db = self.client.vr_expo

    def extract_tables(self, file_path: str, mime_type: str) -> List[Dict]:
        """Extract tables from document"""
        tables = []
        
        try:
            if mime_type == 'application/pdf':
                # Extract tables from PDF
                pdf_tables = tabula.read_pdf(file_path, pages='all')
                for i, table in enumerate(pdf_tables):
                    if not table.empty:
                        tables.append({
                            "table_index": i,
                            "headers": table.columns.tolist(),
                            "data": table.values.tolist()
                        })
            
            elif mime_type in ['text/csv', 'application/vnd.ms-excel']:
                # Handle CSV files
                table = pd.read_csv(file_path)
                if not table.empty:
                    tables.append({
                        "table_index": 0,
                        "headers": table.columns.tolist(),
                        "data": table.values.tolist()
                    })
            
            elif mime_type == 'application/vnd.openxmlformats-officedocument.spreadsheetml.sheet':
                # Handle Excel files
                excel_tables = pd.read_excel(file_path, sheet_name=None)
                for sheet_name, table in excel_tables.items():
                    if not table.empty:
                        tables.append({
                            "sheet_name": sheet_name,
                            "headers": table.columns.tolist(),
                            "data": table.values.tolist()
                        })
        
        except Exception as e:
            print(f"Error extracting tables from {file_path}: {str(e)}")
        
        return tables

    def process_text(self, text: str) -> Dict:
        """Process Persian text using Hazm"""
        # Normalize text
        normalized = self.normalizer.normalize(text)
        
        # Extract sentences
        sentences = sent_tokenize(normalized)
        
        # Get phrases using chunker
        # Note: Chunker requires word tokens but we don't include them in output
        words = word_tokenize(normalized)
        chunks = self.chunker.parse([(word, '') for word in words])
        phrases = [' '.join(chunk.leaves()) for chunk in chunks]
        
        # Text reshaping for proper display
        reshaped_text = arabic_reshaper.reshape(normalized)
        bidi_text = get_display(reshaped_text)
        
        return {
            "normalized": normalized,
            "sentences": sentences,
            "phrases": phrases,
            "display_text": bidi_text
        }

    def process_batch(self, batch_id: str) -> Dict:
        """Process all files in a batch"""
        # Get all files in batch
        raw_files = self.db.rawfiles.find({"batchId": batch_id})
        
        batch_data = {
            "batchId": batch_id,
            "processed_files": [],
            "timestamp": datetime.utcnow()
        }

        for file in raw_files:
            try:
                # Skip if file doesn't have Persian content
                if 'fa' not in file.get('metadata', {}).get('languages', []):
                    continue

                # Read file content
                with open(file['path'], 'r', encoding=file.get('encoding', 'utf-8')) as f:
                    content = f.read()

                # Process Persian text
                processed = self.process_text(content)
                
                # Extract tables if applicable
                tables = self.extract_tables(file['path'], file['mimeType'])
                
                # Create file data structure
                file_data = {
                    "fileId": str(file['_id']),
                    "originalName": file['originalName'],
                    "processed_content": {
                        "normalized_text": processed["normalized"],
                        "sentences": processed["sentences"],
                        "phrases": processed["phrases"],
                        "display_text": processed["display_text"],
                        "tables": tables
                    },
                    "metadata": {
                        "processedAt": datetime.utcnow(),
                        "status": "success",
                        "table_count": len(tables)
                    }
                }
                
                # Add to batch results
                batch_data["processed_files"].append(file_data)
                
                # Update file processing history
                self.db.rawfiles.update_one(
                    {"_id": file['_id']},
                    {
                        "$push": {
                            "processingHistory": {
                                "stage": "persian_nlp",
                                "timestamp": datetime.utcnow(),
                                "status": "success",
                                "details": {
                                    "tables_extracted": len(tables)
                                }
                            }
                        }
                    }
                )

            except Exception as e:
                print(f"Error processing file {file['originalName']}: {str(e)}")
                # Update file processing history with error
                self.db.rawfiles.update_one(
                    {"_id": file['_id']},
                    {
                        "$push": {
                            "processingHistory": {
                                "stage": "persian_nlp",
                                "timestamp": datetime.utcnow(),
                                "status": "error",
                                "error": str(e)
                            }
                        }
                    }
                )

        # Store processed batch data
        self.db.processed_persian.insert_one(batch_data)
        return batch_data

if __name__ == "__main__":
    import sys
    
    if len(sys.argv) != 2:
        print("Usage: python persian_processor.py <batch_id>")
        sys.exit(1)
        
    processor = PersianTextProcessor()
    result = processor.process_batch(sys.argv[1])
    print(json.dumps(result, ensure_ascii=False, indent=2))