"""Transaction API endpoints.""" from fastapi import APIRouter, Depends, HTTPException, Query, status from sqlalchemy.orm import Session from sqlalchemy import and_, or_ from typing import List, Optional, Dict from datetime import date from app.api.deps import get_db from app.models import Transaction, Position, PositionTransaction from app.schemas import TransactionResponse router = APIRouter() @router.get("", response_model=List[TransactionResponse]) def list_transactions( account_id: Optional[int] = None, symbol: Optional[str] = None, start_date: Optional[date] = None, end_date: Optional[date] = None, skip: int = 0, limit: int = Query(default=50, le=500), db: Session = Depends(get_db), ): """ List transactions with optional filtering. Args: account_id: Filter by account ID symbol: Filter by symbol start_date: Filter by start date (inclusive) end_date: Filter by end date (inclusive) skip: Number of records to skip (pagination) limit: Maximum number of records to return db: Database session Returns: List of transactions """ query = db.query(Transaction) # Apply filters if account_id: query = query.filter(Transaction.account_id == account_id) if symbol: query = query.filter(Transaction.symbol == symbol) if start_date: query = query.filter(Transaction.run_date >= start_date) if end_date: query = query.filter(Transaction.run_date <= end_date) # Order by date descending query = query.order_by(Transaction.run_date.desc(), Transaction.id.desc()) # Pagination transactions = query.offset(skip).limit(limit).all() return transactions @router.get("/{transaction_id}", response_model=TransactionResponse) def get_transaction(transaction_id: int, db: Session = Depends(get_db)): """ Get transaction by ID. Args: transaction_id: Transaction ID db: Database session Returns: Transaction details Raises: HTTPException: If transaction not found """ transaction = ( db.query(Transaction).filter(Transaction.id == transaction_id).first() ) if not transaction: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Transaction {transaction_id} not found", ) return transaction @router.get("/{transaction_id}/position-details") def get_transaction_position_details( transaction_id: int, db: Session = Depends(get_db) ) -> Dict: """ Get full position details for a transaction, including all related transactions. This endpoint finds the position associated with a transaction and returns: - All transactions that are part of the same position - Position metadata (type, status, P&L, etc.) - Strategy classification for options (covered call, cash-secured put, etc.) Args: transaction_id: Transaction ID db: Database session Returns: Dictionary with position details and all related transactions Raises: HTTPException: If transaction not found or not part of a position """ # Find the transaction transaction = ( db.query(Transaction).filter(Transaction.id == transaction_id).first() ) if not transaction: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Transaction {transaction_id} not found", ) # Find the position this transaction belongs to position_link = ( db.query(PositionTransaction) .filter(PositionTransaction.transaction_id == transaction_id) .first() ) if not position_link: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Transaction {transaction_id} is not part of any position", ) # Get the position with all its transactions position = ( db.query(Position) .filter(Position.id == position_link.position_id) .first() ) if not position: raise HTTPException( status_code=status.HTTP_404_NOT_FOUND, detail=f"Position not found", ) # Get all transactions for this position all_transactions = [] for link in position.transaction_links: txn = link.transaction all_transactions.append({ "id": txn.id, "run_date": txn.run_date.isoformat(), "action": txn.action, "symbol": txn.symbol, "description": txn.description, "quantity": float(txn.quantity) if txn.quantity else None, "price": float(txn.price) if txn.price else None, "amount": float(txn.amount) if txn.amount else None, "commission": float(txn.commission) if txn.commission else None, "fees": float(txn.fees) if txn.fees else None, }) # Sort transactions by date all_transactions.sort(key=lambda t: t["run_date"]) # Determine strategy type for options strategy = _classify_option_strategy(position, all_transactions) return { "position": { "id": position.id, "symbol": position.symbol, "option_symbol": position.option_symbol, "position_type": position.position_type.value, "status": position.status.value, "open_date": position.open_date.isoformat(), "close_date": position.close_date.isoformat() if position.close_date else None, "total_quantity": float(position.total_quantity), "avg_entry_price": float(position.avg_entry_price) if position.avg_entry_price is not None else None, "avg_exit_price": float(position.avg_exit_price) if position.avg_exit_price is not None else None, "realized_pnl": float(position.realized_pnl) if position.realized_pnl is not None else None, "unrealized_pnl": float(position.unrealized_pnl) if position.unrealized_pnl is not None else None, "strategy": strategy, }, "transactions": all_transactions, } def _classify_option_strategy(position: Position, transactions: List[Dict]) -> str: """ Classify the option strategy based on position type and transactions. Args: position: Position object transactions: List of transaction dictionaries Returns: Strategy name (e.g., "Long Call", "Covered Call", "Cash-Secured Put") """ if position.position_type.value == "stock": return "Stock" # Check if this is a short or long position is_short = position.total_quantity < 0 # For options if position.position_type.value == "call": if is_short: # Short call - could be covered or naked # We'd need to check if there's a corresponding stock position to determine # For now, just return "Short Call" (could enhance later) return "Short Call (Covered Call)" else: return "Long Call" elif position.position_type.value == "put": if is_short: # Short put - could be cash-secured or naked return "Short Put (Cash-Secured Put)" else: return "Long Put" return "Unknown"