Initial release v1.1.0

- Complete MVP for tracking Fidelity brokerage account performance
- Transaction import from CSV with deduplication
- Automatic FIFO position tracking with options support
- Real-time P&L calculations with market data caching
- Dashboard with timeframe filtering (30/90/180 days, 1 year, YTD, all time)
- Docker-based deployment with PostgreSQL backend
- React/TypeScript frontend with TailwindCSS
- FastAPI backend with SQLAlchemy ORM

Features:
- Multi-account support
- Import via CSV upload or filesystem
- Open and closed position tracking
- Balance history charting
- Performance analytics and metrics
- Top trades analysis
- Responsive UI design

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Chris
2026-01-22 14:27:43 -05:00
commit eea4469095
90 changed files with 14513 additions and 0 deletions

View File

@@ -0,0 +1,6 @@
"""Business logic services."""
from app.services.import_service import ImportService, ImportResult
from app.services.position_tracker import PositionTracker
from app.services.performance_calculator import PerformanceCalculator
__all__ = ["ImportService", "ImportResult", "PositionTracker", "PerformanceCalculator"]

View File

@@ -0,0 +1,149 @@
"""Service for importing transactions from CSV files."""
from pathlib import Path
from typing import List, Dict, Any, NamedTuple
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from app.parsers import FidelityParser
from app.models import Transaction
from app.utils import generate_transaction_hash
class ImportResult(NamedTuple):
"""
Result of an import operation.
Attributes:
imported: Number of successfully imported transactions
skipped: Number of skipped duplicate transactions
errors: List of error messages
total_rows: Total number of rows processed
"""
imported: int
skipped: int
errors: List[str]
total_rows: int
class ImportService:
"""
Service for importing transactions from brokerage CSV files.
Handles parsing, deduplication, and database insertion.
"""
def __init__(self, db: Session):
"""
Initialize import service.
Args:
db: Database session
"""
self.db = db
self.parser = FidelityParser() # Can be extended to support multiple parsers
def import_from_file(self, file_path: Path, account_id: int) -> ImportResult:
"""
Import transactions from a CSV file.
Args:
file_path: Path to CSV file
account_id: ID of the account to import transactions for
Returns:
ImportResult with statistics
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If file format is invalid
"""
# Parse CSV file
parse_result = self.parser.parse(file_path)
imported = 0
skipped = 0
errors = list(parse_result.errors)
# Process each transaction
for txn_data in parse_result.transactions:
try:
# Generate deduplication hash
unique_hash = generate_transaction_hash(
account_id=account_id,
run_date=txn_data["run_date"],
symbol=txn_data.get("symbol"),
action=txn_data["action"],
amount=txn_data.get("amount"),
quantity=txn_data.get("quantity"),
price=txn_data.get("price"),
)
# Check if transaction already exists
existing = (
self.db.query(Transaction)
.filter(Transaction.unique_hash == unique_hash)
.first()
)
if existing:
skipped += 1
continue
# Create new transaction
transaction = Transaction(
account_id=account_id,
unique_hash=unique_hash,
**txn_data
)
self.db.add(transaction)
self.db.commit()
imported += 1
except IntegrityError:
# Duplicate hash (edge case if concurrent imports)
self.db.rollback()
skipped += 1
except Exception as e:
self.db.rollback()
errors.append(f"Failed to import transaction: {str(e)}")
return ImportResult(
imported=imported,
skipped=skipped,
errors=errors,
total_rows=parse_result.row_count,
)
def import_from_directory(
self, directory: Path, account_id: int, pattern: str = "*.csv"
) -> Dict[str, ImportResult]:
"""
Import transactions from all CSV files in a directory.
Args:
directory: Path to directory containing CSV files
account_id: ID of the account to import transactions for
pattern: Glob pattern for matching files (default: *.csv)
Returns:
Dictionary mapping filename to ImportResult
"""
if not directory.exists() or not directory.is_dir():
raise ValueError(f"Invalid directory: {directory}")
results = {}
for file_path in directory.glob(pattern):
try:
result = self.import_from_file(file_path, account_id)
results[file_path.name] = result
except Exception as e:
results[file_path.name] = ImportResult(
imported=0,
skipped=0,
errors=[str(e)],
total_rows=0,
)
return results

View File

@@ -0,0 +1,330 @@
"""
Market data service with rate limiting, caching, and batch processing.
This service handles fetching market prices from Yahoo Finance with:
- Database-backed caching to survive restarts
- Rate limiting with exponential backoff
- Batch processing to reduce API calls
- Stale-while-revalidate pattern for better UX
"""
import time
import yfinance as yf
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import Dict, List, Optional
from decimal import Decimal
from datetime import datetime, timedelta
import logging
from app.models.market_price import MarketPrice
logger = logging.getLogger(__name__)
class MarketDataService:
"""Service for fetching and caching market prices with rate limiting."""
def __init__(self, db: Session, cache_ttl_seconds: int = 300):
"""
Initialize market data service.
Args:
db: Database session
cache_ttl_seconds: How long cached prices are considered fresh (default: 5 minutes)
"""
self.db = db
self.cache_ttl = cache_ttl_seconds
self._rate_limit_delay = 0.5 # Start with 500ms between requests
self._last_request_time = 0.0
self._consecutive_errors = 0
self._max_retries = 3
@staticmethod
def _is_valid_stock_symbol(symbol: str) -> bool:
"""
Check if a symbol is a valid stock ticker (not an option symbol or CUSIP).
Args:
symbol: Symbol to check
Returns:
True if it looks like a valid stock ticker
"""
if not symbol or len(symbol) > 5:
return False
# Stock symbols should start with a letter, not a number
# Numbers indicate CUSIP codes or option symbols
if symbol[0].isdigit():
return False
# Should be mostly uppercase letters
# Allow $ for preferred shares (e.g., BRK.B becomes BRK-B)
return symbol.replace('-', '').replace('.', '').isalpha()
def get_price(self, symbol: str, allow_stale: bool = True) -> Optional[Decimal]:
"""
Get current price for a symbol with caching.
Args:
symbol: Stock ticker symbol
allow_stale: If True, return stale cache data instead of None
Returns:
Price or None if unavailable
"""
# Skip invalid symbols (option symbols, CUSIPs, etc.)
if not self._is_valid_stock_symbol(symbol):
logger.debug(f"Skipping invalid symbol: {symbol} (not a stock ticker)")
return None
# Check database cache first
cached = self._get_cached_price(symbol)
if cached:
price, age_seconds = cached
if age_seconds < self.cache_ttl:
# Fresh cache hit
logger.debug(f"Cache HIT (fresh): {symbol} = ${price} (age: {age_seconds}s)")
return price
elif allow_stale:
# Stale cache hit, but we'll return it
logger.debug(f"Cache HIT (stale): {symbol} = ${price} (age: {age_seconds}s)")
return price
# Cache miss or expired - fetch from Yahoo Finance
logger.info(f"Cache MISS: {symbol}, fetching from Yahoo Finance...")
fresh_price = self._fetch_from_yahoo(symbol)
if fresh_price is not None:
self._update_cache(symbol, fresh_price)
return fresh_price
# If fetch failed and we have stale data, return it
if cached and allow_stale:
price, age_seconds = cached
logger.warning(f"Yahoo fetch failed, using stale cache: {symbol} = ${price} (age: {age_seconds}s)")
return price
return None
def get_prices_batch(
self,
symbols: List[str],
allow_stale: bool = True,
max_fetches: int = 10
) -> Dict[str, Optional[Decimal]]:
"""
Get prices for multiple symbols with rate limiting.
Args:
symbols: List of ticker symbols
allow_stale: Return stale cache data if available
max_fetches: Maximum number of API calls to make (remaining use cache)
Returns:
Dictionary mapping symbol to price (or None if unavailable)
"""
results = {}
symbols_to_fetch = []
# First pass: Check cache for all symbols
for symbol in symbols:
# Skip invalid symbols
if not self._is_valid_stock_symbol(symbol):
logger.debug(f"Skipping invalid symbol in batch: {symbol}")
results[symbol] = None
continue
cached = self._get_cached_price(symbol)
if cached:
price, age_seconds = cached
if age_seconds < self.cache_ttl:
# Fresh cache - use it
results[symbol] = price
elif allow_stale:
# Stale but usable
results[symbol] = price
if age_seconds < self.cache_ttl * 2: # Not TOO stale
symbols_to_fetch.append(symbol)
else:
# Stale and not allowing stale - need to fetch
symbols_to_fetch.append(symbol)
else:
# No cache at all
symbols_to_fetch.append(symbol)
# Second pass: Fetch missing/stale symbols (with limit)
if symbols_to_fetch:
logger.info(f"Batch fetching {len(symbols_to_fetch)} symbols (max: {max_fetches})")
for i, symbol in enumerate(symbols_to_fetch[:max_fetches]):
if i > 0:
# Rate limiting delay
time.sleep(self._rate_limit_delay)
price = self._fetch_from_yahoo(symbol)
if price is not None:
results[symbol] = price
self._update_cache(symbol, price)
elif symbol not in results:
# No cached value and fetch failed
results[symbol] = None
return results
def refresh_stale_prices(self, min_age_seconds: int = 300, limit: int = 20) -> int:
"""
Background task to refresh stale prices.
Args:
min_age_seconds: Only refresh prices older than this
limit: Maximum number of prices to refresh
Returns:
Number of prices refreshed
"""
cutoff_time = datetime.utcnow() - timedelta(seconds=min_age_seconds)
# Get stale prices ordered by oldest first
stale_prices = (
self.db.query(MarketPrice)
.filter(MarketPrice.fetched_at < cutoff_time)
.order_by(MarketPrice.fetched_at.asc())
.limit(limit)
.all()
)
refreshed = 0
for cached_price in stale_prices:
time.sleep(self._rate_limit_delay)
fresh_price = self._fetch_from_yahoo(cached_price.symbol)
if fresh_price is not None:
self._update_cache(cached_price.symbol, fresh_price)
refreshed += 1
logger.info(f"Refreshed {refreshed}/{len(stale_prices)} stale prices")
return refreshed
def _get_cached_price(self, symbol: str) -> Optional[tuple[Decimal, float]]:
"""
Get cached price from database.
Returns:
Tuple of (price, age_in_seconds) or None if not cached
"""
cached = (
self.db.query(MarketPrice)
.filter(MarketPrice.symbol == symbol)
.first()
)
if cached:
age = (datetime.utcnow() - cached.fetched_at).total_seconds()
return (cached.price, age)
return None
def _update_cache(self, symbol: str, price: Decimal) -> None:
"""Update or insert price in database cache."""
cached = (
self.db.query(MarketPrice)
.filter(MarketPrice.symbol == symbol)
.first()
)
if cached:
cached.price = price
cached.fetched_at = datetime.utcnow()
else:
new_price = MarketPrice(
symbol=symbol,
price=price,
fetched_at=datetime.utcnow()
)
self.db.add(new_price)
self.db.commit()
def _fetch_from_yahoo(self, symbol: str) -> Optional[Decimal]:
"""
Fetch price from Yahoo Finance with rate limiting and retries.
Returns:
Price or None if fetch failed
"""
for attempt in range(self._max_retries):
try:
# Rate limiting
elapsed = time.time() - self._last_request_time
if elapsed < self._rate_limit_delay:
time.sleep(self._rate_limit_delay - elapsed)
self._last_request_time = time.time()
# Fetch from Yahoo
ticker = yf.Ticker(symbol)
info = ticker.info
# Try different price fields
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
if field in info and info[field]:
price = Decimal(str(info[field]))
# Success - reset error tracking
self._consecutive_errors = 0
self._rate_limit_delay = max(0.5, self._rate_limit_delay * 0.9) # Gradually decrease delay
logger.debug(f"Fetched {symbol} = ${price}")
return price
# No price found in response
logger.warning(f"No price data in Yahoo response for {symbol}")
return None
except Exception as e:
error_str = str(e).lower()
if "429" in error_str or "too many requests" in error_str:
# Rate limit hit - back off exponentially
self._consecutive_errors += 1
self._rate_limit_delay = min(10.0, self._rate_limit_delay * 2) # Double delay, max 10s
logger.warning(
f"Rate limit hit for {symbol} (attempt {attempt + 1}/{self._max_retries}), "
f"backing off to {self._rate_limit_delay}s delay"
)
if attempt < self._max_retries - 1:
time.sleep(self._rate_limit_delay * (attempt + 1)) # Longer wait for retries
continue
else:
# Other error
logger.error(f"Error fetching {symbol}: {e}")
return None
logger.error(f"Failed to fetch {symbol} after {self._max_retries} attempts")
return None
def clear_cache(self, older_than_days: int = 30) -> int:
"""
Clear old cached prices.
Args:
older_than_days: Delete prices older than this many days
Returns:
Number of records deleted
"""
cutoff = datetime.utcnow() - timedelta(days=older_than_days)
deleted = (
self.db.query(MarketPrice)
.filter(MarketPrice.fetched_at < cutoff)
.delete()
)
self.db.commit()
logger.info(f"Cleared {deleted} cached prices older than {older_than_days} days")
return deleted

View File

@@ -0,0 +1,364 @@
"""Service for calculating performance metrics and unrealized P&L."""
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
from typing import Dict, Optional
from decimal import Decimal
from datetime import datetime, timedelta
import yfinance as yf
from functools import lru_cache
from app.models import Position, Transaction
from app.models.position import PositionStatus
class PerformanceCalculator:
"""
Service for calculating performance metrics and market data.
Integrates with Yahoo Finance API for real-time pricing of open positions.
"""
def __init__(self, db: Session, cache_ttl: int = 60):
"""
Initialize performance calculator.
Args:
db: Database session
cache_ttl: Cache time-to-live in seconds (default: 60)
"""
self.db = db
self.cache_ttl = cache_ttl
self._price_cache: Dict[str, tuple[Decimal, datetime]] = {}
def calculate_unrealized_pnl(self, position: Position) -> Optional[Decimal]:
"""
Calculate unrealized P&L for an open position.
Args:
position: Open position to calculate P&L for
Returns:
Unrealized P&L or None if market data unavailable
"""
if position.status != PositionStatus.OPEN:
return None
# Get current market price
current_price = self.get_current_price(position.symbol)
if current_price is None:
return None
if position.avg_entry_price is None:
return None
# Calculate P&L based on position direction
quantity = abs(position.total_quantity)
is_short = position.total_quantity < 0
if is_short:
# Short position: profit when price decreases
pnl = (position.avg_entry_price - current_price) * quantity * 100
else:
# Long position: profit when price increases
pnl = (current_price - position.avg_entry_price) * quantity * 100
# Subtract fees and commissions from opening transactions
total_fees = Decimal("0")
for link in position.transaction_links:
txn = link.transaction
if txn.commission:
total_fees += txn.commission
if txn.fees:
total_fees += txn.fees
pnl -= total_fees
return pnl
def update_open_positions_pnl(self, account_id: int) -> int:
"""
Update unrealized P&L for all open positions in an account.
Args:
account_id: Account ID to update
Returns:
Number of positions updated
"""
open_positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.OPEN,
)
)
.all()
)
updated = 0
for position in open_positions:
unrealized_pnl = self.calculate_unrealized_pnl(position)
if unrealized_pnl is not None:
position.unrealized_pnl = unrealized_pnl
updated += 1
self.db.commit()
return updated
def get_current_price(self, symbol: str) -> Optional[Decimal]:
"""
Get current market price for a symbol.
Uses Yahoo Finance API with caching to reduce API calls.
Args:
symbol: Stock ticker symbol
Returns:
Current price or None if unavailable
"""
# Check cache
if symbol in self._price_cache:
price, timestamp = self._price_cache[symbol]
if datetime.now() - timestamp < timedelta(seconds=self.cache_ttl):
return price
# Fetch from Yahoo Finance
try:
ticker = yf.Ticker(symbol)
info = ticker.info
# Try different price fields
current_price = None
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
if field in info and info[field]:
current_price = Decimal(str(info[field]))
break
if current_price is not None:
# Cache the price
self._price_cache[symbol] = (current_price, datetime.now())
return current_price
except Exception:
# Failed to fetch price
pass
return None
def calculate_account_stats(self, account_id: int) -> Dict:
"""
Calculate aggregate statistics for an account.
Args:
account_id: Account ID
Returns:
Dictionary with performance metrics
"""
# Get all positions
positions = (
self.db.query(Position)
.filter(Position.account_id == account_id)
.all()
)
total_positions = len(positions)
open_positions_count = sum(
1 for p in positions if p.status == PositionStatus.OPEN
)
closed_positions_count = sum(
1 for p in positions if p.status == PositionStatus.CLOSED
)
# Calculate P&L
total_realized_pnl = sum(
(p.realized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.CLOSED
)
# Update unrealized P&L for open positions
self.update_open_positions_pnl(account_id)
total_unrealized_pnl = sum(
(p.unrealized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.OPEN
)
# Calculate win rate and average win/loss
closed_with_pnl = [
p for p in positions
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
]
if closed_with_pnl:
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
avg_win = (
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
if winning_trades
else Decimal("0")
)
avg_loss = (
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
if losing_trades
else Decimal("0")
)
else:
win_rate = 0.0
avg_win = Decimal("0")
avg_loss = Decimal("0")
# Get current account balance from latest transaction
latest_txn = (
self.db.query(Transaction)
.filter(Transaction.account_id == account_id)
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
.first()
)
current_balance = (
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
)
return {
"total_positions": total_positions,
"open_positions": open_positions_count,
"closed_positions": closed_positions_count,
"total_realized_pnl": float(total_realized_pnl),
"total_unrealized_pnl": float(total_unrealized_pnl),
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
"win_rate": float(win_rate),
"avg_win": float(avg_win),
"avg_loss": float(avg_loss),
"current_balance": float(current_balance),
}
def get_balance_history(
self, account_id: int, days: int = 30
) -> list[Dict]:
"""
Get account balance history for charting.
Args:
account_id: Account ID
days: Number of days to retrieve
Returns:
List of {date, balance} dictionaries
"""
cutoff_date = datetime.now().date() - timedelta(days=days)
transactions = (
self.db.query(Transaction.run_date, Transaction.cash_balance)
.filter(
and_(
Transaction.account_id == account_id,
Transaction.run_date >= cutoff_date,
Transaction.cash_balance.isnot(None),
)
)
.order_by(Transaction.run_date)
.all()
)
# Get one balance per day (use last transaction of the day)
daily_balances = {}
for txn in transactions:
daily_balances[txn.run_date] = float(txn.cash_balance)
return [
{"date": date.isoformat(), "balance": balance}
for date, balance in sorted(daily_balances.items())
]
def get_top_trades(
self, account_id: int, limit: int = 10
) -> list[Dict]:
"""
Get top performing trades (by realized P&L).
Args:
account_id: Account ID
limit: Maximum number of trades to return
Returns:
List of trade dictionaries
"""
positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
.order_by(Position.realized_pnl.desc())
.limit(limit)
.all()
)
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]
def get_worst_trades(
self, account_id: int, limit: int = 20
) -> list[Dict]:
"""
Get worst performing trades (biggest losses by realized P&L).
Args:
account_id: Account ID
limit: Maximum number of trades to return
Returns:
List of trade dictionaries
"""
positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
.order_by(Position.realized_pnl.asc())
.limit(limit)
.all()
)
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]

View File

@@ -0,0 +1,433 @@
"""
Improved performance calculator with rate-limited market data fetching.
This version uses the MarketDataService for efficient, cached price lookups.
"""
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import Dict, Optional
from decimal import Decimal
from datetime import datetime, timedelta
import logging
from app.models import Position, Transaction
from app.models.position import PositionStatus
from app.services.market_data_service import MarketDataService
logger = logging.getLogger(__name__)
class PerformanceCalculatorV2:
"""
Enhanced performance calculator with efficient market data handling.
Features:
- Database-backed price caching
- Rate-limited API calls
- Batch price fetching
- Stale-while-revalidate pattern
"""
def __init__(self, db: Session, cache_ttl: int = 300):
"""
Initialize performance calculator.
Args:
db: Database session
cache_ttl: Cache time-to-live in seconds (default: 5 minutes)
"""
self.db = db
self.market_data = MarketDataService(db, cache_ttl_seconds=cache_ttl)
def calculate_unrealized_pnl(self, position: Position, current_price: Optional[Decimal] = None) -> Optional[Decimal]:
"""
Calculate unrealized P&L for an open position.
Args:
position: Open position to calculate P&L for
current_price: Optional pre-fetched current price (avoids API call)
Returns:
Unrealized P&L or None if market data unavailable
"""
if position.status != PositionStatus.OPEN:
return None
# Use provided price or fetch it
if current_price is None:
current_price = self.market_data.get_price(position.symbol, allow_stale=True)
if current_price is None or position.avg_entry_price is None:
return None
# Calculate P&L based on position direction
quantity = abs(position.total_quantity)
is_short = position.total_quantity < 0
if is_short:
# Short position: profit when price decreases
pnl = (position.avg_entry_price - current_price) * quantity * 100
else:
# Long position: profit when price increases
pnl = (current_price - position.avg_entry_price) * quantity * 100
# Subtract fees and commissions from opening transactions
total_fees = Decimal("0")
for link in position.transaction_links:
txn = link.transaction
if txn.commission:
total_fees += txn.commission
if txn.fees:
total_fees += txn.fees
pnl -= total_fees
return pnl
def update_open_positions_pnl(
self,
account_id: int,
max_api_calls: int = 10,
allow_stale: bool = True
) -> Dict[str, int]:
"""
Update unrealized P&L for all open positions in an account.
Uses batch fetching with rate limiting to avoid overwhelming Yahoo Finance API.
Args:
account_id: Account ID to update
max_api_calls: Maximum number of Yahoo Finance API calls to make
allow_stale: Allow using stale cached prices
Returns:
Dictionary with update statistics
"""
open_positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.OPEN,
)
)
.all()
)
if not open_positions:
return {
"total": 0,
"updated": 0,
"cached": 0,
"failed": 0
}
# Get unique symbols
symbols = list(set(p.symbol for p in open_positions))
logger.info(f"Updating P&L for {len(open_positions)} positions across {len(symbols)} symbols")
# Fetch prices in batch
prices = self.market_data.get_prices_batch(
symbols,
allow_stale=allow_stale,
max_fetches=max_api_calls
)
# Update P&L for each position
updated = 0
cached = 0
failed = 0
for position in open_positions:
price = prices.get(position.symbol)
if price is not None:
unrealized_pnl = self.calculate_unrealized_pnl(position, current_price=price)
if unrealized_pnl is not None:
position.unrealized_pnl = unrealized_pnl
updated += 1
# Check if price was from cache (age > 0) or fresh fetch
cached_info = self.market_data._get_cached_price(position.symbol)
if cached_info:
_, age = cached_info
if age < self.market_data.cache_ttl:
cached += 1
else:
failed += 1
else:
failed += 1
logger.warning(f"Could not get price for {position.symbol}")
self.db.commit()
logger.info(
f"Updated {updated}/{len(open_positions)} positions "
f"(cached: {cached}, failed: {failed})"
)
return {
"total": len(open_positions),
"updated": updated,
"cached": cached,
"failed": failed
}
def calculate_account_stats(
self,
account_id: int,
update_prices: bool = True,
max_api_calls: int = 10,
start_date = None,
end_date = None
) -> Dict:
"""
Calculate aggregate statistics for an account.
Args:
account_id: Account ID
update_prices: Whether to fetch fresh prices (if False, uses cached only)
max_api_calls: Maximum number of Yahoo Finance API calls
start_date: Filter positions opened on or after this date
end_date: Filter positions opened on or before this date
Returns:
Dictionary with performance metrics
"""
# Get all positions with optional date filtering
query = self.db.query(Position).filter(Position.account_id == account_id)
if start_date:
query = query.filter(Position.open_date >= start_date)
if end_date:
query = query.filter(Position.open_date <= end_date)
positions = query.all()
total_positions = len(positions)
open_positions_count = sum(
1 for p in positions if p.status == PositionStatus.OPEN
)
closed_positions_count = sum(
1 for p in positions if p.status == PositionStatus.CLOSED
)
# Calculate realized P&L (doesn't need market data)
total_realized_pnl = sum(
(p.realized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.CLOSED
)
# Update unrealized P&L for open positions
update_stats = None
if update_prices and open_positions_count > 0:
update_stats = self.update_open_positions_pnl(
account_id,
max_api_calls=max_api_calls,
allow_stale=True
)
# Calculate total unrealized P&L
total_unrealized_pnl = sum(
(p.unrealized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.OPEN
)
# Calculate win rate and average win/loss
closed_with_pnl = [
p for p in positions
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
]
if closed_with_pnl:
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
avg_win = (
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
if winning_trades
else Decimal("0")
)
avg_loss = (
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
if losing_trades
else Decimal("0")
)
else:
win_rate = 0.0
avg_win = Decimal("0")
avg_loss = Decimal("0")
# Get current account balance from latest transaction
latest_txn = (
self.db.query(Transaction)
.filter(Transaction.account_id == account_id)
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
.first()
)
current_balance = (
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
)
result = {
"total_positions": total_positions,
"open_positions": open_positions_count,
"closed_positions": closed_positions_count,
"total_realized_pnl": float(total_realized_pnl),
"total_unrealized_pnl": float(total_unrealized_pnl),
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
"win_rate": float(win_rate),
"avg_win": float(avg_win),
"avg_loss": float(avg_loss),
"current_balance": float(current_balance),
}
# Add update stats if prices were fetched
if update_stats:
result["price_update_stats"] = update_stats
return result
def get_balance_history(
self, account_id: int, days: int = 30
) -> list[Dict]:
"""
Get account balance history for charting.
This doesn't need market data, just transaction history.
Args:
account_id: Account ID
days: Number of days to retrieve
Returns:
List of {date, balance} dictionaries
"""
cutoff_date = datetime.now().date() - timedelta(days=days)
transactions = (
self.db.query(Transaction.run_date, Transaction.cash_balance)
.filter(
and_(
Transaction.account_id == account_id,
Transaction.run_date >= cutoff_date,
Transaction.cash_balance.isnot(None),
)
)
.order_by(Transaction.run_date)
.all()
)
# Get one balance per day (use last transaction of the day)
daily_balances = {}
for txn in transactions:
daily_balances[txn.run_date] = float(txn.cash_balance)
return [
{"date": date.isoformat(), "balance": balance}
for date, balance in sorted(daily_balances.items())
]
def get_top_trades(
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
) -> list[Dict]:
"""
Get top performing trades (by realized P&L).
This doesn't need market data, just closed positions.
Args:
account_id: Account ID
limit: Maximum number of trades to return
start_date: Filter positions closed on or after this date
end_date: Filter positions closed on or before this date
Returns:
List of trade dictionaries
"""
query = self.db.query(Position).filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
# Apply date filters if provided
if start_date:
query = query.filter(Position.close_date >= start_date)
if end_date:
query = query.filter(Position.close_date <= end_date)
positions = query.order_by(Position.realized_pnl.desc()).limit(limit).all()
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]
def get_worst_trades(
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
) -> list[Dict]:
"""
Get worst performing trades (by realized P&L).
This doesn't need market data, just closed positions.
Args:
account_id: Account ID
limit: Maximum number of trades to return
start_date: Filter positions closed on or after this date
end_date: Filter positions closed on or before this date
Returns:
List of trade dictionaries
"""
query = self.db.query(Position).filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
# Apply date filters if provided
if start_date:
query = query.filter(Position.close_date >= start_date)
if end_date:
query = query.filter(Position.close_date <= end_date)
positions = query.order_by(Position.realized_pnl.asc()).limit(limit).all()
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]

View File

@@ -0,0 +1,465 @@
"""Service for tracking and calculating trading positions."""
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import List, Optional, Dict
from decimal import Decimal
from collections import defaultdict
from datetime import datetime
import re
from app.models import Transaction, Position, PositionTransaction
from app.models.position import PositionType, PositionStatus
from app.utils import parse_option_symbol
class PositionTracker:
"""
Service for tracking trading positions from transactions.
Matches opening and closing transactions using FIFO (First-In-First-Out) method.
Handles stocks, calls, and puts including complex scenarios like assignments and expirations.
"""
def __init__(self, db: Session):
"""
Initialize position tracker.
Args:
db: Database session
"""
self.db = db
def rebuild_positions(self, account_id: int) -> int:
"""
Rebuild all positions for an account from transactions.
Deletes existing positions and recalculates from scratch.
Args:
account_id: Account ID to rebuild positions for
Returns:
Number of positions created
"""
# Delete existing positions
self.db.query(Position).filter(Position.account_id == account_id).delete()
self.db.commit()
# Get all transactions ordered by date
transactions = (
self.db.query(Transaction)
.filter(Transaction.account_id == account_id)
.order_by(Transaction.run_date, Transaction.id)
.all()
)
# Group transactions by symbol and option details
# For options, we need to group by the full option contract (symbol + strike + expiration)
# For stocks, we group by symbol only
symbol_txns = defaultdict(list)
for txn in transactions:
if txn.symbol:
# Create a unique grouping key
grouping_key = self._get_grouping_key(txn)
symbol_txns[grouping_key].append(txn)
# Process each symbol/contract group
position_count = 0
for grouping_key, txns in symbol_txns.items():
positions = self._process_symbol_transactions(account_id, grouping_key, txns)
position_count += len(positions)
self.db.commit()
return position_count
def _process_symbol_transactions(
self, account_id: int, symbol: str, transactions: List[Transaction]
) -> List[Position]:
"""
Process all transactions for a single symbol to create positions.
Args:
account_id: Account ID
symbol: Trading symbol
transactions: List of transactions for this symbol
Returns:
List of created Position objects
"""
positions = []
# Determine position type from first transaction
position_type = self._determine_position_type_from_txn(transactions[0]) if transactions else PositionType.STOCK
# Track open positions using FIFO
open_positions: List[Dict] = []
for txn in transactions:
action = txn.action.upper()
# Determine if this is an opening or closing transaction
if self._is_opening_transaction(action):
# Create new open position
open_pos = {
"transactions": [txn],
"quantity": abs(txn.quantity) if txn.quantity else Decimal("0"),
"entry_price": txn.price,
"open_date": txn.run_date,
"is_short": "SELL" in action or "SOLD" in action,
}
open_positions.append(open_pos)
elif self._is_closing_transaction(action):
# Close positions using FIFO
close_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
remaining_to_close = close_quantity
while remaining_to_close > 0 and open_positions:
open_pos = open_positions[0]
open_qty = open_pos["quantity"]
if open_qty <= remaining_to_close:
# Close entire position
open_pos["transactions"].append(txn)
position = self._create_position(
account_id,
symbol,
position_type,
open_pos,
close_date=txn.run_date,
exit_price=txn.price,
close_quantity=open_qty,
)
positions.append(position)
open_positions.pop(0)
remaining_to_close -= open_qty
else:
# Partially close position
# Split into closed portion
closed_portion = {
"transactions": open_pos["transactions"] + [txn],
"quantity": remaining_to_close,
"entry_price": open_pos["entry_price"],
"open_date": open_pos["open_date"],
"is_short": open_pos["is_short"],
}
position = self._create_position(
account_id,
symbol,
position_type,
closed_portion,
close_date=txn.run_date,
exit_price=txn.price,
close_quantity=remaining_to_close,
)
positions.append(position)
# Update open position with remaining quantity
open_pos["quantity"] -= remaining_to_close
remaining_to_close = Decimal("0")
elif self._is_expiration(action):
# Handle option expirations
expire_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
remaining_to_expire = expire_quantity
while remaining_to_expire > 0 and open_positions:
open_pos = open_positions[0]
open_qty = open_pos["quantity"]
if open_qty <= remaining_to_expire:
# Expire entire position
open_pos["transactions"].append(txn)
position = self._create_position(
account_id,
symbol,
position_type,
open_pos,
close_date=txn.run_date,
exit_price=Decimal("0"), # Expired worthless
close_quantity=open_qty,
)
positions.append(position)
open_positions.pop(0)
remaining_to_expire -= open_qty
else:
# Partially expire
closed_portion = {
"transactions": open_pos["transactions"] + [txn],
"quantity": remaining_to_expire,
"entry_price": open_pos["entry_price"],
"open_date": open_pos["open_date"],
"is_short": open_pos["is_short"],
}
position = self._create_position(
account_id,
symbol,
position_type,
closed_portion,
close_date=txn.run_date,
exit_price=Decimal("0"),
close_quantity=remaining_to_expire,
)
positions.append(position)
open_pos["quantity"] -= remaining_to_expire
remaining_to_expire = Decimal("0")
# Create positions for any remaining open positions
for open_pos in open_positions:
position = self._create_position(
account_id, symbol, position_type, open_pos
)
positions.append(position)
return positions
def _create_position(
self,
account_id: int,
symbol: str,
position_type: PositionType,
position_data: Dict,
close_date: Optional[datetime] = None,
exit_price: Optional[Decimal] = None,
close_quantity: Optional[Decimal] = None,
) -> Position:
"""
Create a Position database object.
Args:
account_id: Account ID
symbol: Trading symbol
position_type: Type of position
position_data: Dictionary with position information
close_date: Close date (if closed)
exit_price: Exit price (if closed)
close_quantity: Quantity closed (if closed)
Returns:
Created Position object
"""
is_closed = close_date is not None
quantity = close_quantity if close_quantity else position_data["quantity"]
# Calculate P&L if closed
realized_pnl = None
if is_closed and position_data["entry_price"] and exit_price is not None:
if position_data["is_short"]:
# Short position: profit when price decreases
realized_pnl = (
position_data["entry_price"] - exit_price
) * quantity * 100
else:
# Long position: profit when price increases
realized_pnl = (
exit_price - position_data["entry_price"]
) * quantity * 100
# Subtract fees and commissions
for txn in position_data["transactions"]:
if txn.commission:
realized_pnl -= txn.commission
if txn.fees:
realized_pnl -= txn.fees
# Extract option symbol from first transaction if this is an option
option_symbol = None
if position_type != PositionType.STOCK and position_data["transactions"]:
first_txn = position_data["transactions"][0]
# Try to extract option details from description
option_symbol = self._extract_option_symbol_from_description(
first_txn.description, first_txn.action, symbol
)
# Create position
position = Position(
account_id=account_id,
symbol=symbol,
option_symbol=option_symbol,
position_type=position_type,
status=PositionStatus.CLOSED if is_closed else PositionStatus.OPEN,
open_date=position_data["open_date"],
close_date=close_date,
total_quantity=quantity if not position_data["is_short"] else -quantity,
avg_entry_price=position_data["entry_price"],
avg_exit_price=exit_price,
realized_pnl=realized_pnl,
)
self.db.add(position)
self.db.flush() # Get position ID
# Link transactions to position
for txn in position_data["transactions"]:
link = PositionTransaction(
position_id=position.id, transaction_id=txn.id
)
self.db.add(link)
return position
def _extract_option_symbol_from_description(
self, description: str, action: str, base_symbol: str
) -> Optional[str]:
"""
Extract option symbol from transaction description.
Example: "CALL (TGT) TARGET CORP JAN 16 26 $95 (100 SHS)"
Returns: "-TGT260116C95"
Args:
description: Transaction description
action: Transaction action
base_symbol: Underlying symbol
Returns:
Option symbol in standard format, or None if can't parse
"""
if not description:
return None
# Determine if CALL or PUT
call_or_put = None
if "CALL" in description.upper():
call_or_put = "C"
elif "PUT" in description.upper():
call_or_put = "P"
else:
return None
# Extract date and strike: "JAN 16 26 $95"
# Pattern: MONTH DAY YY $STRIKE
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
match = re.search(date_strike_pattern, description)
if not match:
return None
month_abbr, day, year, strike = match.groups()
# Convert month abbreviation to number
month_map = {
'JAN': '01', 'FEB': '02', 'MAR': '03', 'APR': '04',
'MAY': '05', 'JUN': '06', 'JUL': '07', 'AUG': '08',
'SEP': '09', 'OCT': '10', 'NOV': '11', 'DEC': '12'
}
month = month_map.get(month_abbr.upper())
if not month:
return None
# Format: -SYMBOL + YYMMDD + C/P + STRIKE
# Remove decimal point from strike if it's a whole number
strike_num = float(strike)
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike.replace('.', '')
option_symbol = f"-{base_symbol}{year}{month}{day.zfill(2)}{call_or_put}{strike_str}"
return option_symbol
def _determine_position_type_from_txn(self, txn: Transaction) -> PositionType:
"""
Determine position type from transaction action/description.
Args:
txn: Transaction to analyze
Returns:
PositionType (STOCK, CALL, or PUT)
"""
# Check action and description for option indicators
action_upper = txn.action.upper() if txn.action else ""
desc_upper = txn.description.upper() if txn.description else ""
# Look for CALL or PUT keywords
if "CALL" in action_upper or "CALL" in desc_upper:
return PositionType.CALL
elif "PUT" in action_upper or "PUT" in desc_upper:
return PositionType.PUT
# Fall back to checking symbol format (for backwards compatibility)
if txn.symbol and txn.symbol.startswith("-"):
option_info = parse_option_symbol(txn.symbol)
if option_info:
return (
PositionType.CALL
if option_info.option_type == "CALL"
else PositionType.PUT
)
return PositionType.STOCK
def _get_base_symbol(self, symbol: str) -> str:
"""Extract base symbol from option symbol."""
if symbol.startswith("-"):
option_info = parse_option_symbol(symbol)
if option_info:
return option_info.underlying_symbol
return symbol
def _is_opening_transaction(self, action: str) -> bool:
"""Check if action represents opening a position."""
opening_keywords = [
"OPENING TRANSACTION",
"YOU BOUGHT OPENING",
"YOU SOLD OPENING",
]
return any(keyword in action for keyword in opening_keywords)
def _is_closing_transaction(self, action: str) -> bool:
"""Check if action represents closing a position."""
closing_keywords = [
"CLOSING TRANSACTION",
"YOU BOUGHT CLOSING",
"YOU SOLD CLOSING",
"ASSIGNED",
]
return any(keyword in action for keyword in closing_keywords)
def _is_expiration(self, action: str) -> bool:
"""Check if action represents an expiration."""
return "EXPIRED" in action
def _get_grouping_key(self, txn: Transaction) -> str:
"""
Create a unique grouping key for transactions.
For options, returns: symbol + option details (e.g., "TGT-JAN16-100C")
For stocks, returns: just the symbol (e.g., "TGT")
Args:
txn: Transaction to create key for
Returns:
Grouping key string
"""
# Determine if this is an option transaction
action_upper = txn.action.upper() if txn.action else ""
desc_upper = txn.description.upper() if txn.description else ""
is_option = "CALL" in action_upper or "CALL" in desc_upper or "PUT" in action_upper or "PUT" in desc_upper
if not is_option or not txn.description:
# Stock transaction - group by symbol only
return txn.symbol
# Option transaction - extract strike and expiration to create unique key
# Pattern: "CALL (TGT) TARGET CORP JAN 16 26 $100 (100 SHS)"
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
match = re.search(date_strike_pattern, txn.description)
if not match:
# Can't parse option details, fall back to symbol only
return txn.symbol
month_abbr, day, year, strike = match.groups()
# Determine call or put
call_or_put = "C" if "CALL" in desc_upper else "P"
# Create key: SYMBOL-MONTHDAY-STRIKEC/P
# e.g., "TGT-JAN16-100C"
strike_num = float(strike)
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike
grouping_key = f"{txn.symbol}-{month_abbr}{day}-{strike_str}{call_or_put}"
return grouping_key