"""Import API endpoints for CSV file uploads.""" from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status from sqlalchemy.orm import Session from pathlib import Path import tempfile import shutil from app.api.deps import get_db from app.services import ImportService from app.services.position_tracker import PositionTracker from app.config import settings router = APIRouter() @router.post("/upload/{account_id}") def upload_csv( account_id: int, file: UploadFile = File(...), db: Session = Depends(get_db) ): """ Upload and import a CSV file for an account. Args: account_id: Account ID to import transactions for file: CSV file to upload db: Database session Returns: Import statistics Raises: HTTPException: If import fails """ if not file.filename.endswith(".csv"): raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="File must be a CSV" ) # Save uploaded file to temporary location try: with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file: shutil.copyfileobj(file.file, tmp_file) tmp_path = Path(tmp_file.name) # Import transactions import_service = ImportService(db) result = import_service.import_from_file(tmp_path, account_id) # Rebuild positions after import if result.imported > 0: position_tracker = PositionTracker(db) positions_created = position_tracker.rebuild_positions(account_id) else: positions_created = 0 # Clean up temporary file tmp_path.unlink() return { "filename": file.filename, "imported": result.imported, "skipped": result.skipped, "errors": result.errors, "total_rows": result.total_rows, "positions_created": positions_created, } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Import failed: {str(e)}", ) @router.post("/filesystem/{account_id}") def import_from_filesystem(account_id: int, db: Session = Depends(get_db)): """ Import all CSV files from the filesystem import directory. Args: account_id: Account ID to import transactions for db: Database session Returns: Import statistics for all files Raises: HTTPException: If import directory doesn't exist """ import_dir = Path(settings.IMPORT_DIR) if not import_dir.exists(): raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Import directory not found: {import_dir}", ) try: import_service = ImportService(db) results = import_service.import_from_directory(import_dir, account_id) # Rebuild positions if any transactions were imported total_imported = sum(r.imported for r in results.values()) if total_imported > 0: position_tracker = PositionTracker(db) positions_created = position_tracker.rebuild_positions(account_id) else: positions_created = 0 return { "files": { filename: { "imported": result.imported, "skipped": result.skipped, "errors": result.errors, "total_rows": result.total_rows, } for filename, result in results.items() }, "total_imported": total_imported, "positions_created": positions_created, } except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"Import failed: {str(e)}", )