Initial release v1.1.0
- Complete MVP for tracking Fidelity brokerage account performance - Transaction import from CSV with deduplication - Automatic FIFO position tracking with options support - Real-time P&L calculations with market data caching - Dashboard with timeframe filtering (30/90/180 days, 1 year, YTD, all time) - Docker-based deployment with PostgreSQL backend - React/TypeScript frontend with TailwindCSS - FastAPI backend with SQLAlchemy ORM Features: - Multi-account support - Import via CSV upload or filesystem - Open and closed position tracking - Balance history charting - Performance analytics and metrics - Top trades analysis - Responsive UI design Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
6
backend/app/services/__init__.py
Normal file
6
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Business logic services."""
|
||||
from app.services.import_service import ImportService, ImportResult
|
||||
from app.services.position_tracker import PositionTracker
|
||||
from app.services.performance_calculator import PerformanceCalculator
|
||||
|
||||
__all__ = ["ImportService", "ImportResult", "PositionTracker", "PerformanceCalculator"]
|
||||
149
backend/app/services/import_service.py
Normal file
149
backend/app/services/import_service.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Service for importing transactions from CSV files."""
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, NamedTuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.parsers import FidelityParser
|
||||
from app.models import Transaction
|
||||
from app.utils import generate_transaction_hash
|
||||
|
||||
|
||||
class ImportResult(NamedTuple):
|
||||
"""
|
||||
Result of an import operation.
|
||||
|
||||
Attributes:
|
||||
imported: Number of successfully imported transactions
|
||||
skipped: Number of skipped duplicate transactions
|
||||
errors: List of error messages
|
||||
total_rows: Total number of rows processed
|
||||
"""
|
||||
imported: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
total_rows: int
|
||||
|
||||
|
||||
class ImportService:
|
||||
"""
|
||||
Service for importing transactions from brokerage CSV files.
|
||||
|
||||
Handles parsing, deduplication, and database insertion.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize import service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
self.db = db
|
||||
self.parser = FidelityParser() # Can be extended to support multiple parsers
|
||||
|
||||
def import_from_file(self, file_path: Path, account_id: int) -> ImportResult:
|
||||
"""
|
||||
Import transactions from a CSV file.
|
||||
|
||||
Args:
|
||||
file_path: Path to CSV file
|
||||
account_id: ID of the account to import transactions for
|
||||
|
||||
Returns:
|
||||
ImportResult with statistics
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If file format is invalid
|
||||
"""
|
||||
# Parse CSV file
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = list(parse_result.errors)
|
||||
|
||||
# Process each transaction
|
||||
for txn_data in parse_result.transactions:
|
||||
try:
|
||||
# Generate deduplication hash
|
||||
unique_hash = generate_transaction_hash(
|
||||
account_id=account_id,
|
||||
run_date=txn_data["run_date"],
|
||||
symbol=txn_data.get("symbol"),
|
||||
action=txn_data["action"],
|
||||
amount=txn_data.get("amount"),
|
||||
quantity=txn_data.get("quantity"),
|
||||
price=txn_data.get("price"),
|
||||
)
|
||||
|
||||
# Check if transaction already exists
|
||||
existing = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.unique_hash == unique_hash)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Create new transaction
|
||||
transaction = Transaction(
|
||||
account_id=account_id,
|
||||
unique_hash=unique_hash,
|
||||
**txn_data
|
||||
)
|
||||
|
||||
self.db.add(transaction)
|
||||
self.db.commit()
|
||||
imported += 1
|
||||
|
||||
except IntegrityError:
|
||||
# Duplicate hash (edge case if concurrent imports)
|
||||
self.db.rollback()
|
||||
skipped += 1
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
errors.append(f"Failed to import transaction: {str(e)}")
|
||||
|
||||
return ImportResult(
|
||||
imported=imported,
|
||||
skipped=skipped,
|
||||
errors=errors,
|
||||
total_rows=parse_result.row_count,
|
||||
)
|
||||
|
||||
def import_from_directory(
|
||||
self, directory: Path, account_id: int, pattern: str = "*.csv"
|
||||
) -> Dict[str, ImportResult]:
|
||||
"""
|
||||
Import transactions from all CSV files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Path to directory containing CSV files
|
||||
account_id: ID of the account to import transactions for
|
||||
pattern: Glob pattern for matching files (default: *.csv)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping filename to ImportResult
|
||||
"""
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
raise ValueError(f"Invalid directory: {directory}")
|
||||
|
||||
results = {}
|
||||
|
||||
for file_path in directory.glob(pattern):
|
||||
try:
|
||||
result = self.import_from_file(file_path, account_id)
|
||||
results[file_path.name] = result
|
||||
except Exception as e:
|
||||
results[file_path.name] = ImportResult(
|
||||
imported=0,
|
||||
skipped=0,
|
||||
errors=[str(e)],
|
||||
total_rows=0,
|
||||
)
|
||||
|
||||
return results
|
||||
330
backend/app/services/market_data_service.py
Normal file
330
backend/app/services/market_data_service.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Market data service with rate limiting, caching, and batch processing.
|
||||
|
||||
This service handles fetching market prices from Yahoo Finance with:
|
||||
- Database-backed caching to survive restarts
|
||||
- Rate limiting with exponential backoff
|
||||
- Batch processing to reduce API calls
|
||||
- Stale-while-revalidate pattern for better UX
|
||||
"""
|
||||
import time
|
||||
import yfinance as yf
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import Dict, List, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from app.models.market_price import MarketPrice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketDataService:
|
||||
"""Service for fetching and caching market prices with rate limiting."""
|
||||
|
||||
def __init__(self, db: Session, cache_ttl_seconds: int = 300):
|
||||
"""
|
||||
Initialize market data service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
cache_ttl_seconds: How long cached prices are considered fresh (default: 5 minutes)
|
||||
"""
|
||||
self.db = db
|
||||
self.cache_ttl = cache_ttl_seconds
|
||||
self._rate_limit_delay = 0.5 # Start with 500ms between requests
|
||||
self._last_request_time = 0.0
|
||||
self._consecutive_errors = 0
|
||||
self._max_retries = 3
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_stock_symbol(symbol: str) -> bool:
|
||||
"""
|
||||
Check if a symbol is a valid stock ticker (not an option symbol or CUSIP).
|
||||
|
||||
Args:
|
||||
symbol: Symbol to check
|
||||
|
||||
Returns:
|
||||
True if it looks like a valid stock ticker
|
||||
"""
|
||||
if not symbol or len(symbol) > 5:
|
||||
return False
|
||||
|
||||
# Stock symbols should start with a letter, not a number
|
||||
# Numbers indicate CUSIP codes or option symbols
|
||||
if symbol[0].isdigit():
|
||||
return False
|
||||
|
||||
# Should be mostly uppercase letters
|
||||
# Allow $ for preferred shares (e.g., BRK.B becomes BRK-B)
|
||||
return symbol.replace('-', '').replace('.', '').isalpha()
|
||||
|
||||
def get_price(self, symbol: str, allow_stale: bool = True) -> Optional[Decimal]:
|
||||
"""
|
||||
Get current price for a symbol with caching.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
allow_stale: If True, return stale cache data instead of None
|
||||
|
||||
Returns:
|
||||
Price or None if unavailable
|
||||
"""
|
||||
# Skip invalid symbols (option symbols, CUSIPs, etc.)
|
||||
if not self._is_valid_stock_symbol(symbol):
|
||||
logger.debug(f"Skipping invalid symbol: {symbol} (not a stock ticker)")
|
||||
return None
|
||||
|
||||
# Check database cache first
|
||||
cached = self._get_cached_price(symbol)
|
||||
|
||||
if cached:
|
||||
price, age_seconds = cached
|
||||
if age_seconds < self.cache_ttl:
|
||||
# Fresh cache hit
|
||||
logger.debug(f"Cache HIT (fresh): {symbol} = ${price} (age: {age_seconds}s)")
|
||||
return price
|
||||
elif allow_stale:
|
||||
# Stale cache hit, but we'll return it
|
||||
logger.debug(f"Cache HIT (stale): {symbol} = ${price} (age: {age_seconds}s)")
|
||||
return price
|
||||
|
||||
# Cache miss or expired - fetch from Yahoo Finance
|
||||
logger.info(f"Cache MISS: {symbol}, fetching from Yahoo Finance...")
|
||||
fresh_price = self._fetch_from_yahoo(symbol)
|
||||
|
||||
if fresh_price is not None:
|
||||
self._update_cache(symbol, fresh_price)
|
||||
return fresh_price
|
||||
|
||||
# If fetch failed and we have stale data, return it
|
||||
if cached and allow_stale:
|
||||
price, age_seconds = cached
|
||||
logger.warning(f"Yahoo fetch failed, using stale cache: {symbol} = ${price} (age: {age_seconds}s)")
|
||||
return price
|
||||
|
||||
return None
|
||||
|
||||
def get_prices_batch(
|
||||
self,
|
||||
symbols: List[str],
|
||||
allow_stale: bool = True,
|
||||
max_fetches: int = 10
|
||||
) -> Dict[str, Optional[Decimal]]:
|
||||
"""
|
||||
Get prices for multiple symbols with rate limiting.
|
||||
|
||||
Args:
|
||||
symbols: List of ticker symbols
|
||||
allow_stale: Return stale cache data if available
|
||||
max_fetches: Maximum number of API calls to make (remaining use cache)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to price (or None if unavailable)
|
||||
"""
|
||||
results = {}
|
||||
symbols_to_fetch = []
|
||||
|
||||
# First pass: Check cache for all symbols
|
||||
for symbol in symbols:
|
||||
# Skip invalid symbols
|
||||
if not self._is_valid_stock_symbol(symbol):
|
||||
logger.debug(f"Skipping invalid symbol in batch: {symbol}")
|
||||
results[symbol] = None
|
||||
continue
|
||||
cached = self._get_cached_price(symbol)
|
||||
|
||||
if cached:
|
||||
price, age_seconds = cached
|
||||
if age_seconds < self.cache_ttl:
|
||||
# Fresh cache - use it
|
||||
results[symbol] = price
|
||||
elif allow_stale:
|
||||
# Stale but usable
|
||||
results[symbol] = price
|
||||
if age_seconds < self.cache_ttl * 2: # Not TOO stale
|
||||
symbols_to_fetch.append(symbol)
|
||||
else:
|
||||
# Stale and not allowing stale - need to fetch
|
||||
symbols_to_fetch.append(symbol)
|
||||
else:
|
||||
# No cache at all
|
||||
symbols_to_fetch.append(symbol)
|
||||
|
||||
# Second pass: Fetch missing/stale symbols (with limit)
|
||||
if symbols_to_fetch:
|
||||
logger.info(f"Batch fetching {len(symbols_to_fetch)} symbols (max: {max_fetches})")
|
||||
|
||||
for i, symbol in enumerate(symbols_to_fetch[:max_fetches]):
|
||||
if i > 0:
|
||||
# Rate limiting delay
|
||||
time.sleep(self._rate_limit_delay)
|
||||
|
||||
price = self._fetch_from_yahoo(symbol)
|
||||
if price is not None:
|
||||
results[symbol] = price
|
||||
self._update_cache(symbol, price)
|
||||
elif symbol not in results:
|
||||
# No cached value and fetch failed
|
||||
results[symbol] = None
|
||||
|
||||
return results
|
||||
|
||||
def refresh_stale_prices(self, min_age_seconds: int = 300, limit: int = 20) -> int:
|
||||
"""
|
||||
Background task to refresh stale prices.
|
||||
|
||||
Args:
|
||||
min_age_seconds: Only refresh prices older than this
|
||||
limit: Maximum number of prices to refresh
|
||||
|
||||
Returns:
|
||||
Number of prices refreshed
|
||||
"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(seconds=min_age_seconds)
|
||||
|
||||
# Get stale prices ordered by oldest first
|
||||
stale_prices = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.fetched_at < cutoff_time)
|
||||
.order_by(MarketPrice.fetched_at.asc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
refreshed = 0
|
||||
for cached_price in stale_prices:
|
||||
time.sleep(self._rate_limit_delay)
|
||||
|
||||
fresh_price = self._fetch_from_yahoo(cached_price.symbol)
|
||||
if fresh_price is not None:
|
||||
self._update_cache(cached_price.symbol, fresh_price)
|
||||
refreshed += 1
|
||||
|
||||
logger.info(f"Refreshed {refreshed}/{len(stale_prices)} stale prices")
|
||||
return refreshed
|
||||
|
||||
def _get_cached_price(self, symbol: str) -> Optional[tuple[Decimal, float]]:
|
||||
"""
|
||||
Get cached price from database.
|
||||
|
||||
Returns:
|
||||
Tuple of (price, age_in_seconds) or None if not cached
|
||||
"""
|
||||
cached = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.symbol == symbol)
|
||||
.first()
|
||||
)
|
||||
|
||||
if cached:
|
||||
age = (datetime.utcnow() - cached.fetched_at).total_seconds()
|
||||
return (cached.price, age)
|
||||
|
||||
return None
|
||||
|
||||
def _update_cache(self, symbol: str, price: Decimal) -> None:
|
||||
"""Update or insert price in database cache."""
|
||||
cached = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.symbol == symbol)
|
||||
.first()
|
||||
)
|
||||
|
||||
if cached:
|
||||
cached.price = price
|
||||
cached.fetched_at = datetime.utcnow()
|
||||
else:
|
||||
new_price = MarketPrice(
|
||||
symbol=symbol,
|
||||
price=price,
|
||||
fetched_at=datetime.utcnow()
|
||||
)
|
||||
self.db.add(new_price)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def _fetch_from_yahoo(self, symbol: str) -> Optional[Decimal]:
|
||||
"""
|
||||
Fetch price from Yahoo Finance with rate limiting and retries.
|
||||
|
||||
Returns:
|
||||
Price or None if fetch failed
|
||||
"""
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
# Rate limiting
|
||||
elapsed = time.time() - self._last_request_time
|
||||
if elapsed < self._rate_limit_delay:
|
||||
time.sleep(self._rate_limit_delay - elapsed)
|
||||
|
||||
self._last_request_time = time.time()
|
||||
|
||||
# Fetch from Yahoo
|
||||
ticker = yf.Ticker(symbol)
|
||||
info = ticker.info
|
||||
|
||||
# Try different price fields
|
||||
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
|
||||
if field in info and info[field]:
|
||||
price = Decimal(str(info[field]))
|
||||
|
||||
# Success - reset error tracking
|
||||
self._consecutive_errors = 0
|
||||
self._rate_limit_delay = max(0.5, self._rate_limit_delay * 0.9) # Gradually decrease delay
|
||||
|
||||
logger.debug(f"Fetched {symbol} = ${price}")
|
||||
return price
|
||||
|
||||
# No price found in response
|
||||
logger.warning(f"No price data in Yahoo response for {symbol}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
if "429" in error_str or "too many requests" in error_str:
|
||||
# Rate limit hit - back off exponentially
|
||||
self._consecutive_errors += 1
|
||||
self._rate_limit_delay = min(10.0, self._rate_limit_delay * 2) # Double delay, max 10s
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit hit for {symbol} (attempt {attempt + 1}/{self._max_retries}), "
|
||||
f"backing off to {self._rate_limit_delay}s delay"
|
||||
)
|
||||
|
||||
if attempt < self._max_retries - 1:
|
||||
time.sleep(self._rate_limit_delay * (attempt + 1)) # Longer wait for retries
|
||||
continue
|
||||
else:
|
||||
# Other error
|
||||
logger.error(f"Error fetching {symbol}: {e}")
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to fetch {symbol} after {self._max_retries} attempts")
|
||||
return None
|
||||
|
||||
def clear_cache(self, older_than_days: int = 30) -> int:
|
||||
"""
|
||||
Clear old cached prices.
|
||||
|
||||
Args:
|
||||
older_than_days: Delete prices older than this many days
|
||||
|
||||
Returns:
|
||||
Number of records deleted
|
||||
"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=older_than_days)
|
||||
|
||||
deleted = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.fetched_at < cutoff)
|
||||
.delete()
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Cleared {deleted} cached prices older than {older_than_days} days")
|
||||
return deleted
|
||||
364
backend/app/services/performance_calculator.py
Normal file
364
backend/app/services/performance_calculator.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Service for calculating performance metrics and unrealized P&L."""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func
|
||||
from typing import Dict, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
import yfinance as yf
|
||||
from functools import lru_cache
|
||||
|
||||
from app.models import Position, Transaction
|
||||
from app.models.position import PositionStatus
|
||||
|
||||
|
||||
class PerformanceCalculator:
|
||||
"""
|
||||
Service for calculating performance metrics and market data.
|
||||
|
||||
Integrates with Yahoo Finance API for real-time pricing of open positions.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, cache_ttl: int = 60):
|
||||
"""
|
||||
Initialize performance calculator.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (default: 60)
|
||||
"""
|
||||
self.db = db
|
||||
self.cache_ttl = cache_ttl
|
||||
self._price_cache: Dict[str, tuple[Decimal, datetime]] = {}
|
||||
|
||||
def calculate_unrealized_pnl(self, position: Position) -> Optional[Decimal]:
|
||||
"""
|
||||
Calculate unrealized P&L for an open position.
|
||||
|
||||
Args:
|
||||
position: Open position to calculate P&L for
|
||||
|
||||
Returns:
|
||||
Unrealized P&L or None if market data unavailable
|
||||
"""
|
||||
if position.status != PositionStatus.OPEN:
|
||||
return None
|
||||
|
||||
# Get current market price
|
||||
current_price = self.get_current_price(position.symbol)
|
||||
if current_price is None:
|
||||
return None
|
||||
|
||||
if position.avg_entry_price is None:
|
||||
return None
|
||||
|
||||
# Calculate P&L based on position direction
|
||||
quantity = abs(position.total_quantity)
|
||||
is_short = position.total_quantity < 0
|
||||
|
||||
if is_short:
|
||||
# Short position: profit when price decreases
|
||||
pnl = (position.avg_entry_price - current_price) * quantity * 100
|
||||
else:
|
||||
# Long position: profit when price increases
|
||||
pnl = (current_price - position.avg_entry_price) * quantity * 100
|
||||
|
||||
# Subtract fees and commissions from opening transactions
|
||||
total_fees = Decimal("0")
|
||||
for link in position.transaction_links:
|
||||
txn = link.transaction
|
||||
if txn.commission:
|
||||
total_fees += txn.commission
|
||||
if txn.fees:
|
||||
total_fees += txn.fees
|
||||
|
||||
pnl -= total_fees
|
||||
|
||||
return pnl
|
||||
|
||||
def update_open_positions_pnl(self, account_id: int) -> int:
|
||||
"""
|
||||
Update unrealized P&L for all open positions in an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to update
|
||||
|
||||
Returns:
|
||||
Number of positions updated
|
||||
"""
|
||||
open_positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.OPEN,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
updated = 0
|
||||
for position in open_positions:
|
||||
unrealized_pnl = self.calculate_unrealized_pnl(position)
|
||||
if unrealized_pnl is not None:
|
||||
position.unrealized_pnl = unrealized_pnl
|
||||
updated += 1
|
||||
|
||||
self.db.commit()
|
||||
return updated
|
||||
|
||||
def get_current_price(self, symbol: str) -> Optional[Decimal]:
|
||||
"""
|
||||
Get current market price for a symbol.
|
||||
|
||||
Uses Yahoo Finance API with caching to reduce API calls.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Current price or None if unavailable
|
||||
"""
|
||||
# Check cache
|
||||
if symbol in self._price_cache:
|
||||
price, timestamp = self._price_cache[symbol]
|
||||
if datetime.now() - timestamp < timedelta(seconds=self.cache_ttl):
|
||||
return price
|
||||
|
||||
# Fetch from Yahoo Finance
|
||||
try:
|
||||
ticker = yf.Ticker(symbol)
|
||||
info = ticker.info
|
||||
|
||||
# Try different price fields
|
||||
current_price = None
|
||||
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
|
||||
if field in info and info[field]:
|
||||
current_price = Decimal(str(info[field]))
|
||||
break
|
||||
|
||||
if current_price is not None:
|
||||
# Cache the price
|
||||
self._price_cache[symbol] = (current_price, datetime.now())
|
||||
return current_price
|
||||
|
||||
except Exception:
|
||||
# Failed to fetch price
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def calculate_account_stats(self, account_id: int) -> Dict:
|
||||
"""
|
||||
Calculate aggregate statistics for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
# Get all positions
|
||||
positions = (
|
||||
self.db.query(Position)
|
||||
.filter(Position.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_positions = len(positions)
|
||||
open_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.OPEN
|
||||
)
|
||||
closed_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Calculate P&L
|
||||
total_realized_pnl = sum(
|
||||
(p.realized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Update unrealized P&L for open positions
|
||||
self.update_open_positions_pnl(account_id)
|
||||
|
||||
total_unrealized_pnl = sum(
|
||||
(p.unrealized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.OPEN
|
||||
)
|
||||
|
||||
# Calculate win rate and average win/loss
|
||||
closed_with_pnl = [
|
||||
p for p in positions
|
||||
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
|
||||
]
|
||||
|
||||
if closed_with_pnl:
|
||||
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
|
||||
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
|
||||
|
||||
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
|
||||
|
||||
avg_win = (
|
||||
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
|
||||
if winning_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
|
||||
avg_loss = (
|
||||
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
|
||||
if losing_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
else:
|
||||
win_rate = 0.0
|
||||
avg_win = Decimal("0")
|
||||
avg_loss = Decimal("0")
|
||||
|
||||
# Get current account balance from latest transaction
|
||||
latest_txn = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.account_id == account_id)
|
||||
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
current_balance = (
|
||||
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
|
||||
)
|
||||
|
||||
return {
|
||||
"total_positions": total_positions,
|
||||
"open_positions": open_positions_count,
|
||||
"closed_positions": closed_positions_count,
|
||||
"total_realized_pnl": float(total_realized_pnl),
|
||||
"total_unrealized_pnl": float(total_unrealized_pnl),
|
||||
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
|
||||
"win_rate": float(win_rate),
|
||||
"avg_win": float(avg_win),
|
||||
"avg_loss": float(avg_loss),
|
||||
"current_balance": float(current_balance),
|
||||
}
|
||||
|
||||
def get_balance_history(
|
||||
self, account_id: int, days: int = 30
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
cutoff_date = datetime.now().date() - timedelta(days=days)
|
||||
|
||||
transactions = (
|
||||
self.db.query(Transaction.run_date, Transaction.cash_balance)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.account_id == account_id,
|
||||
Transaction.run_date >= cutoff_date,
|
||||
Transaction.cash_balance.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Transaction.run_date)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get one balance per day (use last transaction of the day)
|
||||
daily_balances = {}
|
||||
for txn in transactions:
|
||||
daily_balances[txn.run_date] = float(txn.cash_balance)
|
||||
|
||||
return [
|
||||
{"date": date.isoformat(), "balance": balance}
|
||||
for date, balance in sorted(daily_balances.items())
|
||||
]
|
||||
|
||||
def get_top_trades(
|
||||
self, account_id: int, limit: int = 10
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get top performing trades (by realized P&L).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Position.realized_pnl.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
|
||||
def get_worst_trades(
|
||||
self, account_id: int, limit: int = 20
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get worst performing trades (biggest losses by realized P&L).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Position.realized_pnl.asc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
433
backend/app/services/performance_calculator_v2.py
Normal file
433
backend/app/services/performance_calculator_v2.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Improved performance calculator with rate-limited market data fetching.
|
||||
|
||||
This version uses the MarketDataService for efficient, cached price lookups.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import Dict, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from app.models import Position, Transaction
|
||||
from app.models.position import PositionStatus
|
||||
from app.services.market_data_service import MarketDataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerformanceCalculatorV2:
|
||||
"""
|
||||
Enhanced performance calculator with efficient market data handling.
|
||||
|
||||
Features:
|
||||
- Database-backed price caching
|
||||
- Rate-limited API calls
|
||||
- Batch price fetching
|
||||
- Stale-while-revalidate pattern
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, cache_ttl: int = 300):
|
||||
"""
|
||||
Initialize performance calculator.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (default: 5 minutes)
|
||||
"""
|
||||
self.db = db
|
||||
self.market_data = MarketDataService(db, cache_ttl_seconds=cache_ttl)
|
||||
|
||||
def calculate_unrealized_pnl(self, position: Position, current_price: Optional[Decimal] = None) -> Optional[Decimal]:
|
||||
"""
|
||||
Calculate unrealized P&L for an open position.
|
||||
|
||||
Args:
|
||||
position: Open position to calculate P&L for
|
||||
current_price: Optional pre-fetched current price (avoids API call)
|
||||
|
||||
Returns:
|
||||
Unrealized P&L or None if market data unavailable
|
||||
"""
|
||||
if position.status != PositionStatus.OPEN:
|
||||
return None
|
||||
|
||||
# Use provided price or fetch it
|
||||
if current_price is None:
|
||||
current_price = self.market_data.get_price(position.symbol, allow_stale=True)
|
||||
|
||||
if current_price is None or position.avg_entry_price is None:
|
||||
return None
|
||||
|
||||
# Calculate P&L based on position direction
|
||||
quantity = abs(position.total_quantity)
|
||||
is_short = position.total_quantity < 0
|
||||
|
||||
if is_short:
|
||||
# Short position: profit when price decreases
|
||||
pnl = (position.avg_entry_price - current_price) * quantity * 100
|
||||
else:
|
||||
# Long position: profit when price increases
|
||||
pnl = (current_price - position.avg_entry_price) * quantity * 100
|
||||
|
||||
# Subtract fees and commissions from opening transactions
|
||||
total_fees = Decimal("0")
|
||||
for link in position.transaction_links:
|
||||
txn = link.transaction
|
||||
if txn.commission:
|
||||
total_fees += txn.commission
|
||||
if txn.fees:
|
||||
total_fees += txn.fees
|
||||
|
||||
pnl -= total_fees
|
||||
|
||||
return pnl
|
||||
|
||||
def update_open_positions_pnl(
|
||||
self,
|
||||
account_id: int,
|
||||
max_api_calls: int = 10,
|
||||
allow_stale: bool = True
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Update unrealized P&L for all open positions in an account.
|
||||
|
||||
Uses batch fetching with rate limiting to avoid overwhelming Yahoo Finance API.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to update
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls to make
|
||||
allow_stale: Allow using stale cached prices
|
||||
|
||||
Returns:
|
||||
Dictionary with update statistics
|
||||
"""
|
||||
open_positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.OPEN,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not open_positions:
|
||||
return {
|
||||
"total": 0,
|
||||
"updated": 0,
|
||||
"cached": 0,
|
||||
"failed": 0
|
||||
}
|
||||
|
||||
# Get unique symbols
|
||||
symbols = list(set(p.symbol for p in open_positions))
|
||||
|
||||
logger.info(f"Updating P&L for {len(open_positions)} positions across {len(symbols)} symbols")
|
||||
|
||||
# Fetch prices in batch
|
||||
prices = self.market_data.get_prices_batch(
|
||||
symbols,
|
||||
allow_stale=allow_stale,
|
||||
max_fetches=max_api_calls
|
||||
)
|
||||
|
||||
# Update P&L for each position
|
||||
updated = 0
|
||||
cached = 0
|
||||
failed = 0
|
||||
|
||||
for position in open_positions:
|
||||
price = prices.get(position.symbol)
|
||||
|
||||
if price is not None:
|
||||
unrealized_pnl = self.calculate_unrealized_pnl(position, current_price=price)
|
||||
|
||||
if unrealized_pnl is not None:
|
||||
position.unrealized_pnl = unrealized_pnl
|
||||
updated += 1
|
||||
|
||||
# Check if price was from cache (age > 0) or fresh fetch
|
||||
cached_info = self.market_data._get_cached_price(position.symbol)
|
||||
if cached_info:
|
||||
_, age = cached_info
|
||||
if age < self.market_data.cache_ttl:
|
||||
cached += 1
|
||||
else:
|
||||
failed += 1
|
||||
else:
|
||||
failed += 1
|
||||
logger.warning(f"Could not get price for {position.symbol}")
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Updated {updated}/{len(open_positions)} positions "
|
||||
f"(cached: {cached}, failed: {failed})"
|
||||
)
|
||||
|
||||
return {
|
||||
"total": len(open_positions),
|
||||
"updated": updated,
|
||||
"cached": cached,
|
||||
"failed": failed
|
||||
}
|
||||
|
||||
def calculate_account_stats(
|
||||
self,
|
||||
account_id: int,
|
||||
update_prices: bool = True,
|
||||
max_api_calls: int = 10,
|
||||
start_date = None,
|
||||
end_date = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate aggregate statistics for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
update_prices: Whether to fetch fresh prices (if False, uses cached only)
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls
|
||||
start_date: Filter positions opened on or after this date
|
||||
end_date: Filter positions opened on or before this date
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
# Get all positions with optional date filtering
|
||||
query = self.db.query(Position).filter(Position.account_id == account_id)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(Position.open_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Position.open_date <= end_date)
|
||||
|
||||
positions = query.all()
|
||||
|
||||
total_positions = len(positions)
|
||||
open_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.OPEN
|
||||
)
|
||||
closed_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Calculate realized P&L (doesn't need market data)
|
||||
total_realized_pnl = sum(
|
||||
(p.realized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Update unrealized P&L for open positions
|
||||
update_stats = None
|
||||
if update_prices and open_positions_count > 0:
|
||||
update_stats = self.update_open_positions_pnl(
|
||||
account_id,
|
||||
max_api_calls=max_api_calls,
|
||||
allow_stale=True
|
||||
)
|
||||
|
||||
# Calculate total unrealized P&L
|
||||
total_unrealized_pnl = sum(
|
||||
(p.unrealized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.OPEN
|
||||
)
|
||||
|
||||
# Calculate win rate and average win/loss
|
||||
closed_with_pnl = [
|
||||
p for p in positions
|
||||
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
|
||||
]
|
||||
|
||||
if closed_with_pnl:
|
||||
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
|
||||
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
|
||||
|
||||
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
|
||||
|
||||
avg_win = (
|
||||
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
|
||||
if winning_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
|
||||
avg_loss = (
|
||||
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
|
||||
if losing_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
else:
|
||||
win_rate = 0.0
|
||||
avg_win = Decimal("0")
|
||||
avg_loss = Decimal("0")
|
||||
|
||||
# Get current account balance from latest transaction
|
||||
latest_txn = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.account_id == account_id)
|
||||
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
current_balance = (
|
||||
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
|
||||
)
|
||||
|
||||
result = {
|
||||
"total_positions": total_positions,
|
||||
"open_positions": open_positions_count,
|
||||
"closed_positions": closed_positions_count,
|
||||
"total_realized_pnl": float(total_realized_pnl),
|
||||
"total_unrealized_pnl": float(total_unrealized_pnl),
|
||||
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
|
||||
"win_rate": float(win_rate),
|
||||
"avg_win": float(avg_win),
|
||||
"avg_loss": float(avg_loss),
|
||||
"current_balance": float(current_balance),
|
||||
}
|
||||
|
||||
# Add update stats if prices were fetched
|
||||
if update_stats:
|
||||
result["price_update_stats"] = update_stats
|
||||
|
||||
return result
|
||||
|
||||
def get_balance_history(
|
||||
self, account_id: int, days: int = 30
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
This doesn't need market data, just transaction history.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
cutoff_date = datetime.now().date() - timedelta(days=days)
|
||||
|
||||
transactions = (
|
||||
self.db.query(Transaction.run_date, Transaction.cash_balance)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.account_id == account_id,
|
||||
Transaction.run_date >= cutoff_date,
|
||||
Transaction.cash_balance.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Transaction.run_date)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get one balance per day (use last transaction of the day)
|
||||
daily_balances = {}
|
||||
for txn in transactions:
|
||||
daily_balances[txn.run_date] = float(txn.cash_balance)
|
||||
|
||||
return [
|
||||
{"date": date.isoformat(), "balance": balance}
|
||||
for date, balance in sorted(daily_balances.items())
|
||||
]
|
||||
|
||||
def get_top_trades(
|
||||
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get top performing trades (by realized P&L).
|
||||
|
||||
This doesn't need market data, just closed positions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
query = self.db.query(Position).filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply date filters if provided
|
||||
if start_date:
|
||||
query = query.filter(Position.close_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Position.close_date <= end_date)
|
||||
|
||||
positions = query.order_by(Position.realized_pnl.desc()).limit(limit).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
|
||||
def get_worst_trades(
|
||||
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get worst performing trades (by realized P&L).
|
||||
|
||||
This doesn't need market data, just closed positions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
query = self.db.query(Position).filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply date filters if provided
|
||||
if start_date:
|
||||
query = query.filter(Position.close_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Position.close_date <= end_date)
|
||||
|
||||
positions = query.order_by(Position.realized_pnl.asc()).limit(limit).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
465
backend/app/services/position_tracker.py
Normal file
465
backend/app/services/position_tracker.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""Service for tracking and calculating trading positions."""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import List, Optional, Dict
|
||||
from decimal import Decimal
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from app.models import Transaction, Position, PositionTransaction
|
||||
from app.models.position import PositionType, PositionStatus
|
||||
from app.utils import parse_option_symbol
|
||||
|
||||
|
||||
class PositionTracker:
|
||||
"""
|
||||
Service for tracking trading positions from transactions.
|
||||
|
||||
Matches opening and closing transactions using FIFO (First-In-First-Out) method.
|
||||
Handles stocks, calls, and puts including complex scenarios like assignments and expirations.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize position tracker.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def rebuild_positions(self, account_id: int) -> int:
|
||||
"""
|
||||
Rebuild all positions for an account from transactions.
|
||||
|
||||
Deletes existing positions and recalculates from scratch.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to rebuild positions for
|
||||
|
||||
Returns:
|
||||
Number of positions created
|
||||
"""
|
||||
# Delete existing positions
|
||||
self.db.query(Position).filter(Position.account_id == account_id).delete()
|
||||
self.db.commit()
|
||||
|
||||
# Get all transactions ordered by date
|
||||
transactions = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.account_id == account_id)
|
||||
.order_by(Transaction.run_date, Transaction.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group transactions by symbol and option details
|
||||
# For options, we need to group by the full option contract (symbol + strike + expiration)
|
||||
# For stocks, we group by symbol only
|
||||
symbol_txns = defaultdict(list)
|
||||
for txn in transactions:
|
||||
if txn.symbol:
|
||||
# Create a unique grouping key
|
||||
grouping_key = self._get_grouping_key(txn)
|
||||
symbol_txns[grouping_key].append(txn)
|
||||
|
||||
# Process each symbol/contract group
|
||||
position_count = 0
|
||||
for grouping_key, txns in symbol_txns.items():
|
||||
positions = self._process_symbol_transactions(account_id, grouping_key, txns)
|
||||
position_count += len(positions)
|
||||
|
||||
self.db.commit()
|
||||
return position_count
|
||||
|
||||
def _process_symbol_transactions(
|
||||
self, account_id: int, symbol: str, transactions: List[Transaction]
|
||||
) -> List[Position]:
|
||||
"""
|
||||
Process all transactions for a single symbol to create positions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
symbol: Trading symbol
|
||||
transactions: List of transactions for this symbol
|
||||
|
||||
Returns:
|
||||
List of created Position objects
|
||||
"""
|
||||
positions = []
|
||||
|
||||
# Determine position type from first transaction
|
||||
position_type = self._determine_position_type_from_txn(transactions[0]) if transactions else PositionType.STOCK
|
||||
|
||||
# Track open positions using FIFO
|
||||
open_positions: List[Dict] = []
|
||||
|
||||
for txn in transactions:
|
||||
action = txn.action.upper()
|
||||
|
||||
# Determine if this is an opening or closing transaction
|
||||
if self._is_opening_transaction(action):
|
||||
# Create new open position
|
||||
open_pos = {
|
||||
"transactions": [txn],
|
||||
"quantity": abs(txn.quantity) if txn.quantity else Decimal("0"),
|
||||
"entry_price": txn.price,
|
||||
"open_date": txn.run_date,
|
||||
"is_short": "SELL" in action or "SOLD" in action,
|
||||
}
|
||||
open_positions.append(open_pos)
|
||||
|
||||
elif self._is_closing_transaction(action):
|
||||
# Close positions using FIFO
|
||||
close_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
|
||||
remaining_to_close = close_quantity
|
||||
|
||||
while remaining_to_close > 0 and open_positions:
|
||||
open_pos = open_positions[0]
|
||||
open_qty = open_pos["quantity"]
|
||||
|
||||
if open_qty <= remaining_to_close:
|
||||
# Close entire position
|
||||
open_pos["transactions"].append(txn)
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
open_pos,
|
||||
close_date=txn.run_date,
|
||||
exit_price=txn.price,
|
||||
close_quantity=open_qty,
|
||||
)
|
||||
positions.append(position)
|
||||
open_positions.pop(0)
|
||||
remaining_to_close -= open_qty
|
||||
else:
|
||||
# Partially close position
|
||||
# Split into closed portion
|
||||
closed_portion = {
|
||||
"transactions": open_pos["transactions"] + [txn],
|
||||
"quantity": remaining_to_close,
|
||||
"entry_price": open_pos["entry_price"],
|
||||
"open_date": open_pos["open_date"],
|
||||
"is_short": open_pos["is_short"],
|
||||
}
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
closed_portion,
|
||||
close_date=txn.run_date,
|
||||
exit_price=txn.price,
|
||||
close_quantity=remaining_to_close,
|
||||
)
|
||||
positions.append(position)
|
||||
|
||||
# Update open position with remaining quantity
|
||||
open_pos["quantity"] -= remaining_to_close
|
||||
remaining_to_close = Decimal("0")
|
||||
|
||||
elif self._is_expiration(action):
|
||||
# Handle option expirations
|
||||
expire_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
|
||||
remaining_to_expire = expire_quantity
|
||||
|
||||
while remaining_to_expire > 0 and open_positions:
|
||||
open_pos = open_positions[0]
|
||||
open_qty = open_pos["quantity"]
|
||||
|
||||
if open_qty <= remaining_to_expire:
|
||||
# Expire entire position
|
||||
open_pos["transactions"].append(txn)
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
open_pos,
|
||||
close_date=txn.run_date,
|
||||
exit_price=Decimal("0"), # Expired worthless
|
||||
close_quantity=open_qty,
|
||||
)
|
||||
positions.append(position)
|
||||
open_positions.pop(0)
|
||||
remaining_to_expire -= open_qty
|
||||
else:
|
||||
# Partially expire
|
||||
closed_portion = {
|
||||
"transactions": open_pos["transactions"] + [txn],
|
||||
"quantity": remaining_to_expire,
|
||||
"entry_price": open_pos["entry_price"],
|
||||
"open_date": open_pos["open_date"],
|
||||
"is_short": open_pos["is_short"],
|
||||
}
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
closed_portion,
|
||||
close_date=txn.run_date,
|
||||
exit_price=Decimal("0"),
|
||||
close_quantity=remaining_to_expire,
|
||||
)
|
||||
positions.append(position)
|
||||
open_pos["quantity"] -= remaining_to_expire
|
||||
remaining_to_expire = Decimal("0")
|
||||
|
||||
# Create positions for any remaining open positions
|
||||
for open_pos in open_positions:
|
||||
position = self._create_position(
|
||||
account_id, symbol, position_type, open_pos
|
||||
)
|
||||
positions.append(position)
|
||||
|
||||
return positions
|
||||
|
||||
def _create_position(
|
||||
self,
|
||||
account_id: int,
|
||||
symbol: str,
|
||||
position_type: PositionType,
|
||||
position_data: Dict,
|
||||
close_date: Optional[datetime] = None,
|
||||
exit_price: Optional[Decimal] = None,
|
||||
close_quantity: Optional[Decimal] = None,
|
||||
) -> Position:
|
||||
"""
|
||||
Create a Position database object.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
symbol: Trading symbol
|
||||
position_type: Type of position
|
||||
position_data: Dictionary with position information
|
||||
close_date: Close date (if closed)
|
||||
exit_price: Exit price (if closed)
|
||||
close_quantity: Quantity closed (if closed)
|
||||
|
||||
Returns:
|
||||
Created Position object
|
||||
"""
|
||||
is_closed = close_date is not None
|
||||
quantity = close_quantity if close_quantity else position_data["quantity"]
|
||||
|
||||
# Calculate P&L if closed
|
||||
realized_pnl = None
|
||||
if is_closed and position_data["entry_price"] and exit_price is not None:
|
||||
if position_data["is_short"]:
|
||||
# Short position: profit when price decreases
|
||||
realized_pnl = (
|
||||
position_data["entry_price"] - exit_price
|
||||
) * quantity * 100
|
||||
else:
|
||||
# Long position: profit when price increases
|
||||
realized_pnl = (
|
||||
exit_price - position_data["entry_price"]
|
||||
) * quantity * 100
|
||||
|
||||
# Subtract fees and commissions
|
||||
for txn in position_data["transactions"]:
|
||||
if txn.commission:
|
||||
realized_pnl -= txn.commission
|
||||
if txn.fees:
|
||||
realized_pnl -= txn.fees
|
||||
|
||||
# Extract option symbol from first transaction if this is an option
|
||||
option_symbol = None
|
||||
if position_type != PositionType.STOCK and position_data["transactions"]:
|
||||
first_txn = position_data["transactions"][0]
|
||||
# Try to extract option details from description
|
||||
option_symbol = self._extract_option_symbol_from_description(
|
||||
first_txn.description, first_txn.action, symbol
|
||||
)
|
||||
|
||||
# Create position
|
||||
position = Position(
|
||||
account_id=account_id,
|
||||
symbol=symbol,
|
||||
option_symbol=option_symbol,
|
||||
position_type=position_type,
|
||||
status=PositionStatus.CLOSED if is_closed else PositionStatus.OPEN,
|
||||
open_date=position_data["open_date"],
|
||||
close_date=close_date,
|
||||
total_quantity=quantity if not position_data["is_short"] else -quantity,
|
||||
avg_entry_price=position_data["entry_price"],
|
||||
avg_exit_price=exit_price,
|
||||
realized_pnl=realized_pnl,
|
||||
)
|
||||
|
||||
self.db.add(position)
|
||||
self.db.flush() # Get position ID
|
||||
|
||||
# Link transactions to position
|
||||
for txn in position_data["transactions"]:
|
||||
link = PositionTransaction(
|
||||
position_id=position.id, transaction_id=txn.id
|
||||
)
|
||||
self.db.add(link)
|
||||
|
||||
return position
|
||||
|
||||
def _extract_option_symbol_from_description(
|
||||
self, description: str, action: str, base_symbol: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract option symbol from transaction description.
|
||||
|
||||
Example: "CALL (TGT) TARGET CORP JAN 16 26 $95 (100 SHS)"
|
||||
Returns: "-TGT260116C95"
|
||||
|
||||
Args:
|
||||
description: Transaction description
|
||||
action: Transaction action
|
||||
base_symbol: Underlying symbol
|
||||
|
||||
Returns:
|
||||
Option symbol in standard format, or None if can't parse
|
||||
"""
|
||||
if not description:
|
||||
return None
|
||||
|
||||
# Determine if CALL or PUT
|
||||
call_or_put = None
|
||||
if "CALL" in description.upper():
|
||||
call_or_put = "C"
|
||||
elif "PUT" in description.upper():
|
||||
call_or_put = "P"
|
||||
else:
|
||||
return None
|
||||
|
||||
# Extract date and strike: "JAN 16 26 $95"
|
||||
# Pattern: MONTH DAY YY $STRIKE
|
||||
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
|
||||
match = re.search(date_strike_pattern, description)
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
month_abbr, day, year, strike = match.groups()
|
||||
|
||||
# Convert month abbreviation to number
|
||||
month_map = {
|
||||
'JAN': '01', 'FEB': '02', 'MAR': '03', 'APR': '04',
|
||||
'MAY': '05', 'JUN': '06', 'JUL': '07', 'AUG': '08',
|
||||
'SEP': '09', 'OCT': '10', 'NOV': '11', 'DEC': '12'
|
||||
}
|
||||
|
||||
month = month_map.get(month_abbr.upper())
|
||||
if not month:
|
||||
return None
|
||||
|
||||
# Format: -SYMBOL + YYMMDD + C/P + STRIKE
|
||||
# Remove decimal point from strike if it's a whole number
|
||||
strike_num = float(strike)
|
||||
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike.replace('.', '')
|
||||
|
||||
option_symbol = f"-{base_symbol}{year}{month}{day.zfill(2)}{call_or_put}{strike_str}"
|
||||
return option_symbol
|
||||
|
||||
def _determine_position_type_from_txn(self, txn: Transaction) -> PositionType:
|
||||
"""
|
||||
Determine position type from transaction action/description.
|
||||
|
||||
Args:
|
||||
txn: Transaction to analyze
|
||||
|
||||
Returns:
|
||||
PositionType (STOCK, CALL, or PUT)
|
||||
"""
|
||||
# Check action and description for option indicators
|
||||
action_upper = txn.action.upper() if txn.action else ""
|
||||
desc_upper = txn.description.upper() if txn.description else ""
|
||||
|
||||
# Look for CALL or PUT keywords
|
||||
if "CALL" in action_upper or "CALL" in desc_upper:
|
||||
return PositionType.CALL
|
||||
elif "PUT" in action_upper or "PUT" in desc_upper:
|
||||
return PositionType.PUT
|
||||
|
||||
# Fall back to checking symbol format (for backwards compatibility)
|
||||
if txn.symbol and txn.symbol.startswith("-"):
|
||||
option_info = parse_option_symbol(txn.symbol)
|
||||
if option_info:
|
||||
return (
|
||||
PositionType.CALL
|
||||
if option_info.option_type == "CALL"
|
||||
else PositionType.PUT
|
||||
)
|
||||
|
||||
return PositionType.STOCK
|
||||
|
||||
def _get_base_symbol(self, symbol: str) -> str:
|
||||
"""Extract base symbol from option symbol."""
|
||||
if symbol.startswith("-"):
|
||||
option_info = parse_option_symbol(symbol)
|
||||
if option_info:
|
||||
return option_info.underlying_symbol
|
||||
return symbol
|
||||
|
||||
def _is_opening_transaction(self, action: str) -> bool:
|
||||
"""Check if action represents opening a position."""
|
||||
opening_keywords = [
|
||||
"OPENING TRANSACTION",
|
||||
"YOU BOUGHT OPENING",
|
||||
"YOU SOLD OPENING",
|
||||
]
|
||||
return any(keyword in action for keyword in opening_keywords)
|
||||
|
||||
def _is_closing_transaction(self, action: str) -> bool:
|
||||
"""Check if action represents closing a position."""
|
||||
closing_keywords = [
|
||||
"CLOSING TRANSACTION",
|
||||
"YOU BOUGHT CLOSING",
|
||||
"YOU SOLD CLOSING",
|
||||
"ASSIGNED",
|
||||
]
|
||||
return any(keyword in action for keyword in closing_keywords)
|
||||
|
||||
def _is_expiration(self, action: str) -> bool:
|
||||
"""Check if action represents an expiration."""
|
||||
return "EXPIRED" in action
|
||||
|
||||
def _get_grouping_key(self, txn: Transaction) -> str:
|
||||
"""
|
||||
Create a unique grouping key for transactions.
|
||||
|
||||
For options, returns: symbol + option details (e.g., "TGT-JAN16-100C")
|
||||
For stocks, returns: just the symbol (e.g., "TGT")
|
||||
|
||||
Args:
|
||||
txn: Transaction to create key for
|
||||
|
||||
Returns:
|
||||
Grouping key string
|
||||
"""
|
||||
# Determine if this is an option transaction
|
||||
action_upper = txn.action.upper() if txn.action else ""
|
||||
desc_upper = txn.description.upper() if txn.description else ""
|
||||
|
||||
is_option = "CALL" in action_upper or "CALL" in desc_upper or "PUT" in action_upper or "PUT" in desc_upper
|
||||
|
||||
if not is_option or not txn.description:
|
||||
# Stock transaction - group by symbol only
|
||||
return txn.symbol
|
||||
|
||||
# Option transaction - extract strike and expiration to create unique key
|
||||
# Pattern: "CALL (TGT) TARGET CORP JAN 16 26 $100 (100 SHS)"
|
||||
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
|
||||
match = re.search(date_strike_pattern, txn.description)
|
||||
|
||||
if not match:
|
||||
# Can't parse option details, fall back to symbol only
|
||||
return txn.symbol
|
||||
|
||||
month_abbr, day, year, strike = match.groups()
|
||||
|
||||
# Determine call or put
|
||||
call_or_put = "C" if "CALL" in desc_upper else "P"
|
||||
|
||||
# Create key: SYMBOL-MONTHDAY-STRIKEC/P
|
||||
# e.g., "TGT-JAN16-100C"
|
||||
strike_num = float(strike)
|
||||
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike
|
||||
|
||||
grouping_key = f"{txn.symbol}-{month_abbr}{day}-{strike_str}{call_or_put}"
|
||||
return grouping_key
|
||||
Reference in New Issue
Block a user