"""Service for calculating performance metrics and unrealized P&L.""" from sqlalchemy.orm import Session from sqlalchemy import and_, func from typing import Dict, Optional from decimal import Decimal from datetime import datetime, timedelta import yfinance as yf from functools import lru_cache from app.models import Position, Transaction from app.models.position import PositionStatus class PerformanceCalculator: """ Service for calculating performance metrics and market data. Integrates with Yahoo Finance API for real-time pricing of open positions. """ def __init__(self, db: Session, cache_ttl: int = 60): """ Initialize performance calculator. Args: db: Database session cache_ttl: Cache time-to-live in seconds (default: 60) """ self.db = db self.cache_ttl = cache_ttl self._price_cache: Dict[str, tuple[Decimal, datetime]] = {} def calculate_unrealized_pnl(self, position: Position) -> Optional[Decimal]: """ Calculate unrealized P&L for an open position. Args: position: Open position to calculate P&L for Returns: Unrealized P&L or None if market data unavailable """ if position.status != PositionStatus.OPEN: return None # Get current market price current_price = self.get_current_price(position.symbol) if current_price is None: return None if position.avg_entry_price is None: return None # Calculate P&L based on position direction quantity = abs(position.total_quantity) is_short = position.total_quantity < 0 if is_short: # Short position: profit when price decreases pnl = (position.avg_entry_price - current_price) * quantity * 100 else: # Long position: profit when price increases pnl = (current_price - position.avg_entry_price) * quantity * 100 # Subtract fees and commissions from opening transactions total_fees = Decimal("0") for link in position.transaction_links: txn = link.transaction if txn.commission: total_fees += txn.commission if txn.fees: total_fees += txn.fees pnl -= total_fees return pnl def update_open_positions_pnl(self, account_id: int) -> int: """ Update unrealized P&L for all open positions in an account. Args: account_id: Account ID to update Returns: Number of positions updated """ open_positions = ( self.db.query(Position) .filter( and_( Position.account_id == account_id, Position.status == PositionStatus.OPEN, ) ) .all() ) updated = 0 for position in open_positions: unrealized_pnl = self.calculate_unrealized_pnl(position) if unrealized_pnl is not None: position.unrealized_pnl = unrealized_pnl updated += 1 self.db.commit() return updated def get_current_price(self, symbol: str) -> Optional[Decimal]: """ Get current market price for a symbol. Uses Yahoo Finance API with caching to reduce API calls. Args: symbol: Stock ticker symbol Returns: Current price or None if unavailable """ # Check cache if symbol in self._price_cache: price, timestamp = self._price_cache[symbol] if datetime.now() - timestamp < timedelta(seconds=self.cache_ttl): return price # Fetch from Yahoo Finance try: ticker = yf.Ticker(symbol) info = ticker.info # Try different price fields current_price = None for field in ["currentPrice", "regularMarketPrice", "previousClose"]: if field in info and info[field]: current_price = Decimal(str(info[field])) break if current_price is not None: # Cache the price self._price_cache[symbol] = (current_price, datetime.now()) return current_price except Exception: # Failed to fetch price pass return None def calculate_account_stats(self, account_id: int) -> Dict: """ Calculate aggregate statistics for an account. Args: account_id: Account ID Returns: Dictionary with performance metrics """ # Get all positions positions = ( self.db.query(Position) .filter(Position.account_id == account_id) .all() ) total_positions = len(positions) open_positions_count = sum( 1 for p in positions if p.status == PositionStatus.OPEN ) closed_positions_count = sum( 1 for p in positions if p.status == PositionStatus.CLOSED ) # Calculate P&L total_realized_pnl = sum( (p.realized_pnl or Decimal("0")) for p in positions if p.status == PositionStatus.CLOSED ) # Update unrealized P&L for open positions self.update_open_positions_pnl(account_id) total_unrealized_pnl = sum( (p.unrealized_pnl or Decimal("0")) for p in positions if p.status == PositionStatus.OPEN ) # Calculate win rate and average win/loss closed_with_pnl = [ p for p in positions if p.status == PositionStatus.CLOSED and p.realized_pnl is not None ] if closed_with_pnl: winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0] losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0] win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100 avg_win = ( sum(p.realized_pnl for p in winning_trades) / len(winning_trades) if winning_trades else Decimal("0") ) avg_loss = ( sum(p.realized_pnl for p in losing_trades) / len(losing_trades) if losing_trades else Decimal("0") ) else: win_rate = 0.0 avg_win = Decimal("0") avg_loss = Decimal("0") # Get current account balance from latest transaction latest_txn = ( self.db.query(Transaction) .filter(Transaction.account_id == account_id) .order_by(Transaction.run_date.desc(), Transaction.id.desc()) .first() ) current_balance = ( latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0") ) return { "total_positions": total_positions, "open_positions": open_positions_count, "closed_positions": closed_positions_count, "total_realized_pnl": float(total_realized_pnl), "total_unrealized_pnl": float(total_unrealized_pnl), "total_pnl": float(total_realized_pnl + total_unrealized_pnl), "win_rate": float(win_rate), "avg_win": float(avg_win), "avg_loss": float(avg_loss), "current_balance": float(current_balance), } def get_balance_history( self, account_id: int, days: int = 30 ) -> list[Dict]: """ Get account balance history for charting. Args: account_id: Account ID days: Number of days to retrieve Returns: List of {date, balance} dictionaries """ cutoff_date = datetime.now().date() - timedelta(days=days) transactions = ( self.db.query(Transaction.run_date, Transaction.cash_balance) .filter( and_( Transaction.account_id == account_id, Transaction.run_date >= cutoff_date, Transaction.cash_balance.isnot(None), ) ) .order_by(Transaction.run_date) .all() ) # Get one balance per day (use last transaction of the day) daily_balances = {} for txn in transactions: daily_balances[txn.run_date] = float(txn.cash_balance) return [ {"date": date.isoformat(), "balance": balance} for date, balance in sorted(daily_balances.items()) ] def get_top_trades( self, account_id: int, limit: int = 10 ) -> list[Dict]: """ Get top performing trades (by realized P&L). Args: account_id: Account ID limit: Maximum number of trades to return Returns: List of trade dictionaries """ positions = ( self.db.query(Position) .filter( and_( Position.account_id == account_id, Position.status == PositionStatus.CLOSED, Position.realized_pnl.isnot(None), ) ) .order_by(Position.realized_pnl.desc()) .limit(limit) .all() ) return [ { "symbol": p.symbol, "option_symbol": p.option_symbol, "position_type": p.position_type.value, "open_date": p.open_date.isoformat(), "close_date": p.close_date.isoformat() if p.close_date else None, "quantity": float(p.total_quantity), "entry_price": float(p.avg_entry_price) if p.avg_entry_price else None, "exit_price": float(p.avg_exit_price) if p.avg_exit_price else None, "realized_pnl": float(p.realized_pnl), } for p in positions ] def get_worst_trades( self, account_id: int, limit: int = 20 ) -> list[Dict]: """ Get worst performing trades (biggest losses by realized P&L). Args: account_id: Account ID limit: Maximum number of trades to return Returns: List of trade dictionaries """ positions = ( self.db.query(Position) .filter( and_( Position.account_id == account_id, Position.status == PositionStatus.CLOSED, Position.realized_pnl.isnot(None), ) ) .order_by(Position.realized_pnl.asc()) .limit(limit) .all() ) return [ { "symbol": p.symbol, "option_symbol": p.option_symbol, "position_type": p.position_type.value, "open_date": p.open_date.isoformat(), "close_date": p.close_date.isoformat() if p.close_date else None, "quantity": float(p.total_quantity), "entry_price": float(p.avg_entry_price) if p.avg_entry_price else None, "exit_price": float(p.avg_exit_price) if p.avg_exit_price else None, "realized_pnl": float(p.realized_pnl), } for p in positions ]