Initial release v1.1.0

- Complete MVP for tracking Fidelity brokerage account performance
- Transaction import from CSV with deduplication
- Automatic FIFO position tracking with options support
- Real-time P&L calculations with market data caching
- Dashboard with timeframe filtering (30/90/180 days, 1 year, YTD, all time)
- Docker-based deployment with PostgreSQL backend
- React/TypeScript frontend with TailwindCSS
- FastAPI backend with SQLAlchemy ORM

Features:
- Multi-account support
- Import via CSV upload or filesystem
- Open and closed position tracking
- Balance history charting
- Performance analytics and metrics
- Top trades analysis
- Responsive UI design

Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
Chris
2026-01-22 14:27:43 -05:00
commit eea4469095
90 changed files with 14513 additions and 0 deletions

42
backend/Dockerfile Normal file
View File

@@ -0,0 +1,42 @@
# Multi-stage build for Python FastAPI backend
FROM python:3.11-slim as builder
WORKDIR /app
# Install system dependencies
RUN apt-get update && apt-get install -y \
gcc \
postgresql-client \
&& rm -rf /var/lib/apt/lists/*
# Copy requirements and install Python dependencies
COPY requirements.txt .
RUN pip install --no-cache-dir --user -r requirements.txt
# Final stage
FROM python:3.11-slim
WORKDIR /app
# Install runtime dependencies
RUN apt-get update && apt-get install -y \
postgresql-client \
&& rm -rf /var/lib/apt/lists/*
# Copy Python dependencies from builder
COPY --from=builder /root/.local /root/.local
# Copy application code
COPY . .
# Make sure scripts in .local are usable
ENV PATH=/root/.local/bin:$PATH
# Create imports directory
RUN mkdir -p /app/imports
# Expose port
EXPOSE 8000
# Run migrations and start server
CMD alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000

52
backend/alembic.ini Normal file
View File

@@ -0,0 +1,52 @@
# Alembic configuration file
[alembic]
# Path to migration scripts
script_location = alembic
# Template used to generate migration files
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
# Timezone for migration timestamps
timezone = UTC
# Prepend migration scripts with proper encoding
prepend_sys_path = .
# Version location specification
version_path_separator = os
# Logging configuration
[loggers]
keys = root,sqlalchemy,alembic
[handlers]
keys = console
[formatters]
keys = generic
[logger_root]
level = WARN
handlers = console
qualname =
[logger_sqlalchemy]
level = WARN
handlers =
qualname = sqlalchemy.engine
[logger_alembic]
level = INFO
handlers =
qualname = alembic
[handler_console]
class = StreamHandler
args = (sys.stderr,)
level = NOTSET
formatter = generic
[formatter_generic]
format = %(levelname)-5.5s [%(name)s] %(message)s
datefmt = %H:%M:%S

72
backend/alembic/env.py Normal file
View File

@@ -0,0 +1,72 @@
"""Alembic environment configuration for database migrations."""
from logging.config import fileConfig
from sqlalchemy import engine_from_config, pool
from alembic import context
import sys
from pathlib import Path
# Add parent directory to path to import app modules
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
from app.config import settings
from app.database import Base
from app.models import Account, Transaction, Position, PositionTransaction
# Alembic Config object
config = context.config
# Override sqlalchemy.url with our settings
config.set_main_option("sqlalchemy.url", settings.database_url)
# Interpret the config file for Python logging
if config.config_file_name is not None:
fileConfig(config.config_file_name)
# Target metadata for autogenerate support
target_metadata = Base.metadata
def run_migrations_offline() -> None:
"""
Run migrations in 'offline' mode.
This configures the context with just a URL and not an Engine,
though an Engine is acceptable here as well. By skipping the Engine
creation we don't even need a DBAPI to be available.
"""
url = config.get_main_option("sqlalchemy.url")
context.configure(
url=url,
target_metadata=target_metadata,
literal_binds=True,
dialect_opts={"paramstyle": "named"},
)
with context.begin_transaction():
context.run_migrations()
def run_migrations_online() -> None:
"""
Run migrations in 'online' mode.
In this scenario we need to create an Engine and associate a
connection with the context.
"""
connectable = engine_from_config(
config.get_section(config.config_ini_section),
prefix="sqlalchemy.",
poolclass=pool.NullPool,
)
with connectable.connect() as connection:
context.configure(connection=connection, target_metadata=target_metadata)
with context.begin_transaction():
context.run_migrations()
if context.is_offline_mode():
run_migrations_offline()
else:
run_migrations_online()

View File

@@ -0,0 +1,25 @@
"""${message}
Revision ID: ${up_revision}
Revises: ${down_revision | comma,n}
Create Date: ${create_date}
"""
from typing import Sequence, Union
from alembic import op
import sqlalchemy as sa
${imports if imports else ""}
# revision identifiers, used by Alembic.
revision: str = ${repr(up_revision)}
down_revision: Union[str, None] = ${repr(down_revision)}
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
def upgrade() -> None:
${upgrades if upgrades else "pass"}
def downgrade() -> None:
${downgrades if downgrades else "pass"}

View File

@@ -0,0 +1,83 @@
"""Initial schema
Revision ID: 001_initial_schema
Revises:
Create Date: 2026-01-20 10:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '001_initial_schema'
down_revision = None
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create accounts table
op.create_table(
'accounts',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('account_number', sa.String(length=50), nullable=False),
sa.Column('account_name', sa.String(length=200), nullable=False),
sa.Column('account_type', sa.Enum('CASH', 'MARGIN', name='accounttype'), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_accounts_id'), 'accounts', ['id'], unique=False)
op.create_index(op.f('ix_accounts_account_number'), 'accounts', ['account_number'], unique=True)
# Create transactions table
op.create_table(
'transactions',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('account_id', sa.Integer(), nullable=False),
sa.Column('run_date', sa.Date(), nullable=False),
sa.Column('action', sa.String(length=500), nullable=False),
sa.Column('symbol', sa.String(length=50), nullable=True),
sa.Column('description', sa.String(length=500), nullable=True),
sa.Column('transaction_type', sa.String(length=20), nullable=True),
sa.Column('exchange_quantity', sa.Numeric(precision=20, scale=8), nullable=True),
sa.Column('exchange_currency', sa.String(length=10), nullable=True),
sa.Column('currency', sa.String(length=10), nullable=True),
sa.Column('price', sa.Numeric(precision=20, scale=8), nullable=True),
sa.Column('quantity', sa.Numeric(precision=20, scale=8), nullable=True),
sa.Column('exchange_rate', sa.Numeric(precision=20, scale=8), nullable=True),
sa.Column('commission', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('fees', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('accrued_interest', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('amount', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('cash_balance', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('settlement_date', sa.Date(), nullable=True),
sa.Column('unique_hash', sa.String(length=64), nullable=False),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_transactions_id'), 'transactions', ['id'], unique=False)
op.create_index(op.f('ix_transactions_account_id'), 'transactions', ['account_id'], unique=False)
op.create_index(op.f('ix_transactions_run_date'), 'transactions', ['run_date'], unique=False)
op.create_index(op.f('ix_transactions_symbol'), 'transactions', ['symbol'], unique=False)
op.create_index(op.f('ix_transactions_unique_hash'), 'transactions', ['unique_hash'], unique=True)
op.create_index('idx_account_date', 'transactions', ['account_id', 'run_date'], unique=False)
op.create_index('idx_account_symbol', 'transactions', ['account_id', 'symbol'], unique=False)
def downgrade() -> None:
op.drop_index('idx_account_symbol', table_name='transactions')
op.drop_index('idx_account_date', table_name='transactions')
op.drop_index(op.f('ix_transactions_unique_hash'), table_name='transactions')
op.drop_index(op.f('ix_transactions_symbol'), table_name='transactions')
op.drop_index(op.f('ix_transactions_run_date'), table_name='transactions')
op.drop_index(op.f('ix_transactions_account_id'), table_name='transactions')
op.drop_index(op.f('ix_transactions_id'), table_name='transactions')
op.drop_table('transactions')
op.drop_index(op.f('ix_accounts_account_number'), table_name='accounts')
op.drop_index(op.f('ix_accounts_id'), table_name='accounts')
op.drop_table('accounts')
op.execute('DROP TYPE accounttype')

View File

@@ -0,0 +1,70 @@
"""Add positions tables
Revision ID: 002_add_positions
Revises: 001_initial_schema
Create Date: 2026-01-20 15:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '002_add_positions'
down_revision = '001_initial_schema'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create positions table
op.create_table(
'positions',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('account_id', sa.Integer(), nullable=False),
sa.Column('symbol', sa.String(length=50), nullable=False),
sa.Column('option_symbol', sa.String(length=100), nullable=True),
sa.Column('position_type', sa.Enum('STOCK', 'CALL', 'PUT', name='positiontype'), nullable=False),
sa.Column('status', sa.Enum('OPEN', 'CLOSED', name='positionstatus'), nullable=False),
sa.Column('open_date', sa.Date(), nullable=False),
sa.Column('close_date', sa.Date(), nullable=True),
sa.Column('total_quantity', sa.Numeric(precision=20, scale=8), nullable=False),
sa.Column('avg_entry_price', sa.Numeric(precision=20, scale=8), nullable=True),
sa.Column('avg_exit_price', sa.Numeric(precision=20, scale=8), nullable=True),
sa.Column('realized_pnl', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('unrealized_pnl', sa.Numeric(precision=20, scale=2), nullable=True),
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_positions_id'), 'positions', ['id'], unique=False)
op.create_index(op.f('ix_positions_account_id'), 'positions', ['account_id'], unique=False)
op.create_index(op.f('ix_positions_symbol'), 'positions', ['symbol'], unique=False)
op.create_index(op.f('ix_positions_option_symbol'), 'positions', ['option_symbol'], unique=False)
op.create_index(op.f('ix_positions_status'), 'positions', ['status'], unique=False)
op.create_index('idx_account_status', 'positions', ['account_id', 'status'], unique=False)
op.create_index('idx_account_symbol_status', 'positions', ['account_id', 'symbol', 'status'], unique=False)
# Create position_transactions junction table
op.create_table(
'position_transactions',
sa.Column('position_id', sa.Integer(), nullable=False),
sa.Column('transaction_id', sa.Integer(), nullable=False),
sa.ForeignKeyConstraint(['position_id'], ['positions.id'], ondelete='CASCADE'),
sa.ForeignKeyConstraint(['transaction_id'], ['transactions.id'], ondelete='CASCADE'),
sa.PrimaryKeyConstraint('position_id', 'transaction_id')
)
def downgrade() -> None:
op.drop_table('position_transactions')
op.drop_index('idx_account_symbol_status', table_name='positions')
op.drop_index('idx_account_status', table_name='positions')
op.drop_index(op.f('ix_positions_status'), table_name='positions')
op.drop_index(op.f('ix_positions_option_symbol'), table_name='positions')
op.drop_index(op.f('ix_positions_symbol'), table_name='positions')
op.drop_index(op.f('ix_positions_account_id'), table_name='positions')
op.drop_index(op.f('ix_positions_id'), table_name='positions')
op.drop_table('positions')
op.execute('DROP TYPE positionstatus')
op.execute('DROP TYPE positiontype')

View File

@@ -0,0 +1,40 @@
"""Add market_prices table for price caching
Revision ID: 003_market_prices
Revises: 002_add_positions
Create Date: 2026-01-20 16:00:00.000000
"""
from alembic import op
import sqlalchemy as sa
from datetime import datetime
# revision identifiers, used by Alembic.
revision = '003_market_prices'
down_revision = '002_add_positions'
branch_labels = None
depends_on = None
def upgrade() -> None:
# Create market_prices table
op.create_table(
'market_prices',
sa.Column('id', sa.Integer(), nullable=False),
sa.Column('symbol', sa.String(length=20), nullable=False),
sa.Column('price', sa.Numeric(precision=20, scale=6), nullable=False),
sa.Column('fetched_at', sa.DateTime(), nullable=False, default=datetime.utcnow),
sa.Column('source', sa.String(length=50), default='yahoo_finance'),
sa.PrimaryKeyConstraint('id')
)
# Create indexes
op.create_index('idx_market_prices_symbol', 'market_prices', ['symbol'], unique=True)
op.create_index('idx_symbol_fetched', 'market_prices', ['symbol', 'fetched_at'])
def downgrade() -> None:
op.drop_index('idx_symbol_fetched', table_name='market_prices')
op.drop_index('idx_market_prices_symbol', table_name='market_prices')
op.drop_table('market_prices')

2
backend/app/__init__.py Normal file
View File

@@ -0,0 +1,2 @@
"""myFidelityTracker backend application."""
__version__ = "1.0.0"

View File

@@ -0,0 +1 @@
"""API routes and endpoints."""

19
backend/app/api/deps.py Normal file
View File

@@ -0,0 +1,19 @@
"""API dependencies."""
from typing import Generator
from sqlalchemy.orm import Session
from app.database import SessionLocal
def get_db() -> Generator[Session, None, None]:
"""
Dependency that provides a database session.
Yields:
Database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()

View File

@@ -0,0 +1 @@
"""API endpoint modules."""

View File

@@ -0,0 +1,151 @@
"""Account management API endpoints."""
from fastapi import APIRouter, Depends, HTTPException, status
from sqlalchemy.orm import Session
from typing import List
from app.api.deps import get_db
from app.models import Account
from app.schemas import AccountCreate, AccountUpdate, AccountResponse
router = APIRouter()
@router.post("", response_model=AccountResponse, status_code=status.HTTP_201_CREATED)
def create_account(account: AccountCreate, db: Session = Depends(get_db)):
"""
Create a new brokerage account.
Args:
account: Account creation data
db: Database session
Returns:
Created account
Raises:
HTTPException: If account number already exists
"""
# Check if account number already exists
existing = (
db.query(Account)
.filter(Account.account_number == account.account_number)
.first()
)
if existing:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail=f"Account with number {account.account_number} already exists",
)
# Create new account
db_account = Account(**account.model_dump())
db.add(db_account)
db.commit()
db.refresh(db_account)
return db_account
@router.get("", response_model=List[AccountResponse])
def list_accounts(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
"""
List all accounts.
Args:
skip: Number of records to skip
limit: Maximum number of records to return
db: Database session
Returns:
List of accounts
"""
accounts = db.query(Account).offset(skip).limit(limit).all()
return accounts
@router.get("/{account_id}", response_model=AccountResponse)
def get_account(account_id: int, db: Session = Depends(get_db)):
"""
Get account by ID.
Args:
account_id: Account ID
db: Database session
Returns:
Account details
Raises:
HTTPException: If account not found
"""
account = db.query(Account).filter(Account.id == account_id).first()
if not account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Account {account_id} not found",
)
return account
@router.put("/{account_id}", response_model=AccountResponse)
def update_account(
account_id: int, account_update: AccountUpdate, db: Session = Depends(get_db)
):
"""
Update account details.
Args:
account_id: Account ID
account_update: Updated account data
db: Database session
Returns:
Updated account
Raises:
HTTPException: If account not found
"""
db_account = db.query(Account).filter(Account.id == account_id).first()
if not db_account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Account {account_id} not found",
)
# Update fields
update_data = account_update.model_dump(exclude_unset=True)
for field, value in update_data.items():
setattr(db_account, field, value)
db.commit()
db.refresh(db_account)
return db_account
@router.delete("/{account_id}", status_code=status.HTTP_204_NO_CONTENT)
def delete_account(account_id: int, db: Session = Depends(get_db)):
"""
Delete an account and all associated data.
Args:
account_id: Account ID
db: Database session
Raises:
HTTPException: If account not found
"""
db_account = db.query(Account).filter(Account.id == account_id).first()
if not db_account:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Account {account_id} not found",
)
db.delete(db_account)
db.commit()

View File

@@ -0,0 +1,111 @@
"""Analytics API endpoints."""
from fastapi import APIRouter, Depends, Query
from sqlalchemy.orm import Session
from typing import Optional
from app.api.deps import get_db
from app.services.performance_calculator import PerformanceCalculator
router = APIRouter()
@router.get("/overview/{account_id}")
def get_overview(account_id: int, db: Session = Depends(get_db)):
"""
Get overview statistics for an account.
Args:
account_id: Account ID
db: Database session
Returns:
Dictionary with performance metrics
"""
calculator = PerformanceCalculator(db)
stats = calculator.calculate_account_stats(account_id)
return stats
@router.get("/balance-history/{account_id}")
def get_balance_history(
account_id: int,
days: int = Query(default=30, ge=1, le=3650),
db: Session = Depends(get_db),
):
"""
Get account balance history for charting.
Args:
account_id: Account ID
days: Number of days to retrieve (default: 30)
db: Database session
Returns:
List of {date, balance} dictionaries
"""
calculator = PerformanceCalculator(db)
history = calculator.get_balance_history(account_id, days)
return {"data": history}
@router.get("/top-trades/{account_id}")
def get_top_trades(
account_id: int,
limit: int = Query(default=20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Get top performing trades.
Args:
account_id: Account ID
limit: Maximum number of trades to return (default: 20)
db: Database session
Returns:
List of trade dictionaries
"""
calculator = PerformanceCalculator(db)
trades = calculator.get_top_trades(account_id, limit)
return {"data": trades}
@router.get("/worst-trades/{account_id}")
def get_worst_trades(
account_id: int,
limit: int = Query(default=20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Get worst performing trades (biggest losses).
Args:
account_id: Account ID
limit: Maximum number of trades to return (default: 20)
db: Database session
Returns:
List of trade dictionaries
"""
calculator = PerformanceCalculator(db)
trades = calculator.get_worst_trades(account_id, limit)
return {"data": trades}
@router.post("/update-pnl/{account_id}")
def update_unrealized_pnl(account_id: int, db: Session = Depends(get_db)):
"""
Update unrealized P&L for all open positions in an account.
Fetches current market prices and recalculates P&L.
Args:
account_id: Account ID
db: Database session
Returns:
Number of positions updated
"""
calculator = PerformanceCalculator(db)
updated = calculator.update_open_positions_pnl(account_id)
return {"positions_updated": updated}

View File

@@ -0,0 +1,273 @@
"""
Enhanced analytics API endpoints with efficient market data handling.
This version uses PerformanceCalculatorV2 with:
- Database-backed price caching
- Rate-limited API calls
- Stale-while-revalidate pattern for better UX
"""
from fastapi import APIRouter, Depends, Query, BackgroundTasks
from sqlalchemy.orm import Session
from typing import Optional
from datetime import date
from app.api.deps import get_db
from app.services.performance_calculator_v2 import PerformanceCalculatorV2
from app.services.market_data_service import MarketDataService
router = APIRouter()
@router.get("/overview/{account_id}")
def get_overview(
account_id: int,
refresh_prices: bool = Query(default=False, description="Force fresh price fetch"),
max_api_calls: int = Query(default=5, ge=0, le=50, description="Max Yahoo Finance API calls"),
start_date: Optional[date] = None,
end_date: Optional[date] = None,
db: Session = Depends(get_db)
):
"""
Get overview statistics for an account.
By default, uses cached prices (stale-while-revalidate pattern).
Set refresh_prices=true to force fresh data (may be slow).
Args:
account_id: Account ID
refresh_prices: Whether to fetch fresh prices from Yahoo Finance
max_api_calls: Maximum number of API calls to make
start_date: Filter positions opened on or after this date
end_date: Filter positions opened on or before this date
db: Database session
Returns:
Dictionary with performance metrics and cache stats
"""
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
# If not refreshing, use cached only (fast)
if not refresh_prices:
max_api_calls = 0
stats = calculator.calculate_account_stats(
account_id,
update_prices=True,
max_api_calls=max_api_calls,
start_date=start_date,
end_date=end_date
)
return stats
@router.get("/balance-history/{account_id}")
def get_balance_history(
account_id: int,
days: int = Query(default=30, ge=1, le=3650),
db: Session = Depends(get_db),
):
"""
Get account balance history for charting.
This endpoint doesn't need market data, so it's always fast.
Args:
account_id: Account ID
days: Number of days to retrieve (default: 30)
db: Database session
Returns:
List of {date, balance} dictionaries
"""
calculator = PerformanceCalculatorV2(db)
history = calculator.get_balance_history(account_id, days)
return {"data": history}
@router.get("/top-trades/{account_id}")
def get_top_trades(
account_id: int,
limit: int = Query(default=10, ge=1, le=100),
start_date: Optional[date] = None,
end_date: Optional[date] = None,
db: Session = Depends(get_db),
):
"""
Get top performing trades.
This endpoint only uses closed positions, so no market data needed.
Args:
account_id: Account ID
limit: Maximum number of trades to return (default: 10)
start_date: Filter positions closed on or after this date
end_date: Filter positions closed on or before this date
db: Database session
Returns:
List of trade dictionaries
"""
calculator = PerformanceCalculatorV2(db)
trades = calculator.get_top_trades(account_id, limit, start_date, end_date)
return {"data": trades}
@router.get("/worst-trades/{account_id}")
def get_worst_trades(
account_id: int,
limit: int = Query(default=10, ge=1, le=100),
start_date: Optional[date] = None,
end_date: Optional[date] = None,
db: Session = Depends(get_db),
):
"""
Get worst performing trades.
This endpoint only uses closed positions, so no market data needed.
Args:
account_id: Account ID
limit: Maximum number of trades to return (default: 10)
start_date: Filter positions closed on or after this date
end_date: Filter positions closed on or before this date
db: Database session
Returns:
List of trade dictionaries
"""
calculator = PerformanceCalculatorV2(db)
trades = calculator.get_worst_trades(account_id, limit, start_date, end_date)
return {"data": trades}
@router.post("/refresh-prices/{account_id}")
def refresh_prices(
account_id: int,
max_api_calls: int = Query(default=10, ge=1, le=50),
db: Session = Depends(get_db),
):
"""
Manually trigger a price refresh for open positions.
This is useful when you want fresh data but don't want to wait
on the dashboard load.
Args:
account_id: Account ID
max_api_calls: Maximum number of Yahoo Finance API calls
db: Database session
Returns:
Update statistics
"""
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
stats = calculator.update_open_positions_pnl(
account_id,
max_api_calls=max_api_calls,
allow_stale=False # Force fresh fetches
)
return {
"message": "Price refresh completed",
"stats": stats
}
@router.post("/refresh-prices-background/{account_id}")
def refresh_prices_background(
account_id: int,
background_tasks: BackgroundTasks,
max_api_calls: int = Query(default=20, ge=1, le=50),
db: Session = Depends(get_db),
):
"""
Trigger a background price refresh.
This returns immediately while prices are fetched in the background.
Client can poll /overview endpoint to see updated data.
Args:
account_id: Account ID
background_tasks: FastAPI background tasks
max_api_calls: Maximum number of Yahoo Finance API calls
db: Database session
Returns:
Acknowledgment that background task was started
"""
def refresh_task():
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
calculator.update_open_positions_pnl(
account_id,
max_api_calls=max_api_calls,
allow_stale=False
)
background_tasks.add_task(refresh_task)
return {
"message": "Price refresh started in background",
"account_id": account_id,
"max_api_calls": max_api_calls
}
@router.post("/refresh-stale-cache")
def refresh_stale_cache(
min_age_minutes: int = Query(default=10, ge=1, le=1440),
limit: int = Query(default=20, ge=1, le=100),
db: Session = Depends(get_db),
):
"""
Background maintenance endpoint to refresh stale cached prices.
This can be called periodically (e.g., via cron) to keep cache fresh.
Args:
min_age_minutes: Only refresh prices older than this many minutes
limit: Maximum number of prices to refresh
db: Database session
Returns:
Number of prices refreshed
"""
market_data = MarketDataService(db, cache_ttl_seconds=300)
refreshed = market_data.refresh_stale_prices(
min_age_seconds=min_age_minutes * 60,
limit=limit
)
return {
"message": "Stale price refresh completed",
"refreshed": refreshed,
"min_age_minutes": min_age_minutes
}
@router.delete("/clear-old-cache")
def clear_old_cache(
older_than_days: int = Query(default=30, ge=1, le=365),
db: Session = Depends(get_db),
):
"""
Clear old cached prices from database.
Args:
older_than_days: Delete prices older than this many days
db: Database session
Returns:
Number of records deleted
"""
market_data = MarketDataService(db)
deleted = market_data.clear_cache(older_than_days=older_than_days)
return {
"message": "Old cache cleared",
"deleted": deleted,
"older_than_days": older_than_days
}

View File

@@ -0,0 +1,128 @@
"""Import API endpoints for CSV file uploads."""
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
from sqlalchemy.orm import Session
from pathlib import Path
import tempfile
import shutil
from app.api.deps import get_db
from app.services import ImportService
from app.services.position_tracker import PositionTracker
from app.config import settings
router = APIRouter()
@router.post("/upload/{account_id}")
def upload_csv(
account_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)
):
"""
Upload and import a CSV file for an account.
Args:
account_id: Account ID to import transactions for
file: CSV file to upload
db: Database session
Returns:
Import statistics
Raises:
HTTPException: If import fails
"""
if not file.filename.endswith(".csv"):
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST, detail="File must be a CSV"
)
# Save uploaded file to temporary location
try:
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
shutil.copyfileobj(file.file, tmp_file)
tmp_path = Path(tmp_file.name)
# Import transactions
import_service = ImportService(db)
result = import_service.import_from_file(tmp_path, account_id)
# Rebuild positions after import
if result.imported > 0:
position_tracker = PositionTracker(db)
positions_created = position_tracker.rebuild_positions(account_id)
else:
positions_created = 0
# Clean up temporary file
tmp_path.unlink()
return {
"filename": file.filename,
"imported": result.imported,
"skipped": result.skipped,
"errors": result.errors,
"total_rows": result.total_rows,
"positions_created": positions_created,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Import failed: {str(e)}",
)
@router.post("/filesystem/{account_id}")
def import_from_filesystem(account_id: int, db: Session = Depends(get_db)):
"""
Import all CSV files from the filesystem import directory.
Args:
account_id: Account ID to import transactions for
db: Database session
Returns:
Import statistics for all files
Raises:
HTTPException: If import directory doesn't exist
"""
import_dir = Path(settings.IMPORT_DIR)
if not import_dir.exists():
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Import directory not found: {import_dir}",
)
try:
import_service = ImportService(db)
results = import_service.import_from_directory(import_dir, account_id)
# Rebuild positions if any transactions were imported
total_imported = sum(r.imported for r in results.values())
if total_imported > 0:
position_tracker = PositionTracker(db)
positions_created = position_tracker.rebuild_positions(account_id)
else:
positions_created = 0
return {
"files": {
filename: {
"imported": result.imported,
"skipped": result.skipped,
"errors": result.errors,
"total_rows": result.total_rows,
}
for filename, result in results.items()
},
"total_imported": total_imported,
"positions_created": positions_created,
}
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"Import failed: {str(e)}",
)

View File

@@ -0,0 +1,104 @@
"""Position API endpoints."""
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import List, Optional
from app.api.deps import get_db
from app.models import Position
from app.models.position import PositionStatus
from app.schemas import PositionResponse
router = APIRouter()
@router.get("", response_model=List[PositionResponse])
def list_positions(
account_id: Optional[int] = None,
status_filter: Optional[PositionStatus] = Query(
default=None, alias="status", description="Filter by position status"
),
symbol: Optional[str] = None,
skip: int = 0,
limit: int = Query(default=100, le=500),
db: Session = Depends(get_db),
):
"""
List positions with optional filtering.
Args:
account_id: Filter by account ID
status_filter: Filter by status (open/closed)
symbol: Filter by symbol
skip: Number of records to skip (pagination)
limit: Maximum number of records to return
db: Database session
Returns:
List of positions
"""
query = db.query(Position)
# Apply filters
if account_id:
query = query.filter(Position.account_id == account_id)
if status_filter:
query = query.filter(Position.status == status_filter)
if symbol:
query = query.filter(Position.symbol == symbol)
# Order by most recent first
query = query.order_by(Position.open_date.desc(), Position.id.desc())
# Pagination
positions = query.offset(skip).limit(limit).all()
return positions
@router.get("/{position_id}", response_model=PositionResponse)
def get_position(position_id: int, db: Session = Depends(get_db)):
"""
Get position by ID.
Args:
position_id: Position ID
db: Database session
Returns:
Position details
Raises:
HTTPException: If position not found
"""
position = db.query(Position).filter(Position.id == position_id).first()
if not position:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Position {position_id} not found",
)
return position
@router.post("/{account_id}/rebuild")
def rebuild_positions(account_id: int, db: Session = Depends(get_db)):
"""
Rebuild all positions for an account from transactions.
Args:
account_id: Account ID
db: Database session
Returns:
Number of positions created
"""
from app.services.position_tracker import PositionTracker
position_tracker = PositionTracker(db)
positions_created = position_tracker.rebuild_positions(account_id)
return {"positions_created": positions_created}

View File

@@ -0,0 +1,227 @@
"""Transaction API endpoints."""
from fastapi import APIRouter, Depends, HTTPException, Query, status
from sqlalchemy.orm import Session
from sqlalchemy import and_, or_
from typing import List, Optional, Dict
from datetime import date
from app.api.deps import get_db
from app.models import Transaction, Position, PositionTransaction
from app.schemas import TransactionResponse
router = APIRouter()
@router.get("", response_model=List[TransactionResponse])
def list_transactions(
account_id: Optional[int] = None,
symbol: Optional[str] = None,
start_date: Optional[date] = None,
end_date: Optional[date] = None,
skip: int = 0,
limit: int = Query(default=50, le=500),
db: Session = Depends(get_db),
):
"""
List transactions with optional filtering.
Args:
account_id: Filter by account ID
symbol: Filter by symbol
start_date: Filter by start date (inclusive)
end_date: Filter by end date (inclusive)
skip: Number of records to skip (pagination)
limit: Maximum number of records to return
db: Database session
Returns:
List of transactions
"""
query = db.query(Transaction)
# Apply filters
if account_id:
query = query.filter(Transaction.account_id == account_id)
if symbol:
query = query.filter(Transaction.symbol == symbol)
if start_date:
query = query.filter(Transaction.run_date >= start_date)
if end_date:
query = query.filter(Transaction.run_date <= end_date)
# Order by date descending
query = query.order_by(Transaction.run_date.desc(), Transaction.id.desc())
# Pagination
transactions = query.offset(skip).limit(limit).all()
return transactions
@router.get("/{transaction_id}", response_model=TransactionResponse)
def get_transaction(transaction_id: int, db: Session = Depends(get_db)):
"""
Get transaction by ID.
Args:
transaction_id: Transaction ID
db: Database session
Returns:
Transaction details
Raises:
HTTPException: If transaction not found
"""
transaction = (
db.query(Transaction).filter(Transaction.id == transaction_id).first()
)
if not transaction:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Transaction {transaction_id} not found",
)
return transaction
@router.get("/{transaction_id}/position-details")
def get_transaction_position_details(
transaction_id: int, db: Session = Depends(get_db)
) -> Dict:
"""
Get full position details for a transaction, including all related transactions.
This endpoint finds the position associated with a transaction and returns:
- All transactions that are part of the same position
- Position metadata (type, status, P&L, etc.)
- Strategy classification for options (covered call, cash-secured put, etc.)
Args:
transaction_id: Transaction ID
db: Database session
Returns:
Dictionary with position details and all related transactions
Raises:
HTTPException: If transaction not found or not part of a position
"""
# Find the transaction
transaction = (
db.query(Transaction).filter(Transaction.id == transaction_id).first()
)
if not transaction:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Transaction {transaction_id} not found",
)
# Find the position this transaction belongs to
position_link = (
db.query(PositionTransaction)
.filter(PositionTransaction.transaction_id == transaction_id)
.first()
)
if not position_link:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Transaction {transaction_id} is not part of any position",
)
# Get the position with all its transactions
position = (
db.query(Position)
.filter(Position.id == position_link.position_id)
.first()
)
if not position:
raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND,
detail=f"Position not found",
)
# Get all transactions for this position
all_transactions = []
for link in position.transaction_links:
txn = link.transaction
all_transactions.append({
"id": txn.id,
"run_date": txn.run_date.isoformat(),
"action": txn.action,
"symbol": txn.symbol,
"description": txn.description,
"quantity": float(txn.quantity) if txn.quantity else None,
"price": float(txn.price) if txn.price else None,
"amount": float(txn.amount) if txn.amount else None,
"commission": float(txn.commission) if txn.commission else None,
"fees": float(txn.fees) if txn.fees else None,
})
# Sort transactions by date
all_transactions.sort(key=lambda t: t["run_date"])
# Determine strategy type for options
strategy = _classify_option_strategy(position, all_transactions)
return {
"position": {
"id": position.id,
"symbol": position.symbol,
"option_symbol": position.option_symbol,
"position_type": position.position_type.value,
"status": position.status.value,
"open_date": position.open_date.isoformat(),
"close_date": position.close_date.isoformat() if position.close_date else None,
"total_quantity": float(position.total_quantity),
"avg_entry_price": float(position.avg_entry_price) if position.avg_entry_price is not None else None,
"avg_exit_price": float(position.avg_exit_price) if position.avg_exit_price is not None else None,
"realized_pnl": float(position.realized_pnl) if position.realized_pnl is not None else None,
"unrealized_pnl": float(position.unrealized_pnl) if position.unrealized_pnl is not None else None,
"strategy": strategy,
},
"transactions": all_transactions,
}
def _classify_option_strategy(position: Position, transactions: List[Dict]) -> str:
"""
Classify the option strategy based on position type and transactions.
Args:
position: Position object
transactions: List of transaction dictionaries
Returns:
Strategy name (e.g., "Long Call", "Covered Call", "Cash-Secured Put")
"""
if position.position_type.value == "stock":
return "Stock"
# Check if this is a short or long position
is_short = position.total_quantity < 0
# For options
if position.position_type.value == "call":
if is_short:
# Short call - could be covered or naked
# We'd need to check if there's a corresponding stock position to determine
# For now, just return "Short Call" (could enhance later)
return "Short Call (Covered Call)"
else:
return "Long Call"
elif position.position_type.value == "put":
if is_short:
# Short put - could be cash-secured or naked
return "Short Put (Cash-Secured Put)"
else:
return "Long Put"
return "Unknown"

53
backend/app/config.py Normal file
View File

@@ -0,0 +1,53 @@
"""
Application configuration settings.
Loads configuration from environment variables with sensible defaults.
"""
from pydantic_settings import BaseSettings
from typing import Optional
class Settings(BaseSettings):
"""Application settings loaded from environment variables."""
# Database configuration
POSTGRES_HOST: str = "postgres"
POSTGRES_PORT: int = 5432
POSTGRES_DB: str = "fidelitytracker"
POSTGRES_USER: str = "fidelity"
POSTGRES_PASSWORD: str = "fidelity123"
# API configuration
API_V1_PREFIX: str = "/api"
PROJECT_NAME: str = "myFidelityTracker"
# CORS configuration - allow all origins for local development
CORS_ORIGINS: str = "*"
@property
def cors_origins_list(self) -> list[str]:
"""Parse CORS origins from comma-separated string."""
if self.CORS_ORIGINS == "*":
return ["*"]
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
# File import configuration
IMPORT_DIR: str = "/app/imports"
# Market data cache TTL (seconds)
MARKET_DATA_CACHE_TTL: int = 60
@property
def database_url(self) -> str:
"""Construct PostgreSQL database URL."""
return (
f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}"
f"@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}"
)
class Config:
env_file = ".env"
case_sensitive = True
# Global settings instance
settings = Settings()

38
backend/app/database.py Normal file
View File

@@ -0,0 +1,38 @@
"""
Database configuration and session management.
Provides SQLAlchemy engine and session factory.
"""
from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker
from app.config import settings
# Create SQLAlchemy engine
engine = create_engine(
settings.database_url,
pool_pre_ping=True, # Enable connection health checks
pool_size=10,
max_overflow=20
)
# Create session factory
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
# Base class for SQLAlchemy models
Base = declarative_base()
def get_db():
"""
Dependency function that provides a database session.
Automatically closes the session after the request is completed.
Yields:
Session: SQLAlchemy database session
"""
db = SessionLocal()
try:
yield db
finally:
db.close()

66
backend/app/main.py Normal file
View File

@@ -0,0 +1,66 @@
"""
FastAPI application entry point for myFidelityTracker.
This module initializes the FastAPI application, configures CORS,
and registers all API routers.
"""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from app.config import settings
from app.api.endpoints import accounts, transactions, positions, import_endpoint
from app.api.endpoints import analytics_v2 as analytics
# Create FastAPI application
app = FastAPI(
title=settings.PROJECT_NAME,
description="Track and analyze your Fidelity brokerage account performance",
version="1.0.0",
)
# Configure CORS middleware - allow all origins for local network access
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # Allow all origins for local development
allow_credentials=False, # Must be False when using allow_origins=["*"]
allow_methods=["*"],
allow_headers=["*"],
)
# Register API routers
app.include_router(
accounts.router, prefix=f"{settings.API_V1_PREFIX}/accounts", tags=["accounts"]
)
app.include_router(
transactions.router,
prefix=f"{settings.API_V1_PREFIX}/transactions",
tags=["transactions"],
)
app.include_router(
positions.router, prefix=f"{settings.API_V1_PREFIX}/positions", tags=["positions"]
)
app.include_router(
analytics.router, prefix=f"{settings.API_V1_PREFIX}/analytics", tags=["analytics"]
)
app.include_router(
import_endpoint.router,
prefix=f"{settings.API_V1_PREFIX}/import",
tags=["import"],
)
@app.get("/")
def root():
"""Root endpoint returning API information."""
return {
"name": settings.PROJECT_NAME,
"version": "1.0.0",
"message": "Welcome to myFidelityTracker API",
}
@app.get("/health")
def health_check():
"""Health check endpoint."""
return {"status": "healthy"}

View File

@@ -0,0 +1,7 @@
"""SQLAlchemy models for the application."""
from app.models.account import Account
from app.models.transaction import Transaction
from app.models.position import Position, PositionTransaction
from app.models.market_price import MarketPrice
__all__ = ["Account", "Transaction", "Position", "PositionTransaction", "MarketPrice"]

View File

@@ -0,0 +1,41 @@
"""Account model representing a brokerage account."""
from sqlalchemy import Column, Integer, String, DateTime, Enum
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
import enum
from app.database import Base
class AccountType(str, enum.Enum):
"""Enumeration of account types."""
CASH = "cash"
MARGIN = "margin"
class Account(Base):
"""
Represents a brokerage account.
Attributes:
id: Primary key
account_number: Unique account identifier
account_name: Human-readable account name
account_type: Type of account (cash or margin)
created_at: Timestamp of account creation
updated_at: Timestamp of last update
transactions: Related transactions
positions: Related positions
"""
__tablename__ = "accounts"
id = Column(Integer, primary_key=True, index=True)
account_number = Column(String(50), unique=True, nullable=False, index=True)
account_name = Column(String(200), nullable=False)
account_type = Column(Enum(AccountType), nullable=False, default=AccountType.CASH)
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), server_default=func.now(), nullable=False)
# Relationships
transactions = relationship("Transaction", back_populates="account", cascade="all, delete-orphan")
positions = relationship("Position", back_populates="account", cascade="all, delete-orphan")

View File

@@ -0,0 +1,29 @@
"""Market price cache model for storing Yahoo Finance data."""
from sqlalchemy import Column, Integer, String, Numeric, DateTime, Index
from datetime import datetime
from app.database import Base
class MarketPrice(Base):
"""
Cache table for market prices from Yahoo Finance.
Stores the last fetched price for each symbol to reduce API calls.
"""
__tablename__ = "market_prices"
id = Column(Integer, primary_key=True, index=True)
symbol = Column(String(20), unique=True, nullable=False, index=True)
price = Column(Numeric(precision=20, scale=6), nullable=False)
fetched_at = Column(DateTime, nullable=False, default=datetime.utcnow)
source = Column(String(50), default="yahoo_finance")
# Index for quick lookups by symbol and freshness checks
__table_args__ = (
Index('idx_symbol_fetched', 'symbol', 'fetched_at'),
)
def __repr__(self):
return f"<MarketPrice(symbol={self.symbol}, price={self.price}, fetched_at={self.fetched_at})>"

View File

@@ -0,0 +1,104 @@
"""Position model representing a trading position."""
from sqlalchemy import Column, Integer, String, DateTime, Numeric, ForeignKey, Date, Enum, Index
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
import enum
from app.database import Base
class PositionType(str, enum.Enum):
"""Enumeration of position types."""
STOCK = "stock"
CALL = "call"
PUT = "put"
class PositionStatus(str, enum.Enum):
"""Enumeration of position statuses."""
OPEN = "open"
CLOSED = "closed"
class Position(Base):
"""
Represents a trading position (open or closed).
A position aggregates related transactions (entries and exits) for a specific security.
For options, tracks strikes, expirations, and option-specific details.
Attributes:
id: Primary key
account_id: Foreign key to account
symbol: Base trading symbol (e.g., AAPL)
option_symbol: Full option symbol if applicable (e.g., -AAPL260116C150)
position_type: Type (stock, call, put)
status: Status (open, closed)
open_date: Date position was opened
close_date: Date position was closed (if closed)
total_quantity: Net quantity (can be negative for short positions)
avg_entry_price: Average entry price
avg_exit_price: Average exit price (if closed)
realized_pnl: Realized profit/loss for closed positions
unrealized_pnl: Unrealized profit/loss for open positions
created_at: Timestamp of record creation
updated_at: Timestamp of last update
"""
__tablename__ = "positions"
id = Column(Integer, primary_key=True, index=True)
account_id = Column(Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
# Symbol information
symbol = Column(String(50), nullable=False, index=True)
option_symbol = Column(String(100), index=True) # Full option symbol for options
position_type = Column(Enum(PositionType), nullable=False, default=PositionType.STOCK)
# Status and dates
status = Column(Enum(PositionStatus), nullable=False, default=PositionStatus.OPEN, index=True)
open_date = Column(Date, nullable=False)
close_date = Column(Date)
# Position metrics
total_quantity = Column(Numeric(20, 8), nullable=False) # Can be negative for short
avg_entry_price = Column(Numeric(20, 8))
avg_exit_price = Column(Numeric(20, 8))
# P&L tracking
realized_pnl = Column(Numeric(20, 2)) # For closed positions
unrealized_pnl = Column(Numeric(20, 2)) # For open positions
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), server_default=func.now(), nullable=False)
# Relationships
account = relationship("Account", back_populates="positions")
transaction_links = relationship("PositionTransaction", back_populates="position", cascade="all, delete-orphan")
# Composite indexes for common queries
__table_args__ = (
Index('idx_account_status', 'account_id', 'status'),
Index('idx_account_symbol_status', 'account_id', 'symbol', 'status'),
)
class PositionTransaction(Base):
"""
Junction table linking positions to transactions.
A position can have multiple transactions (entries, exits, adjustments).
A transaction can be part of multiple positions (e.g., closing multiple lots).
Attributes:
position_id: Foreign key to position
transaction_id: Foreign key to transaction
"""
__tablename__ = "position_transactions"
position_id = Column(Integer, ForeignKey("positions.id", ondelete="CASCADE"), primary_key=True)
transaction_id = Column(Integer, ForeignKey("transactions.id", ondelete="CASCADE"), primary_key=True)
# Relationships
position = relationship("Position", back_populates="transaction_links")
transaction = relationship("Transaction", back_populates="position_links")

View File

@@ -0,0 +1,81 @@
"""Transaction model representing a brokerage transaction."""
from sqlalchemy import Column, Integer, String, DateTime, Numeric, ForeignKey, Date, Index
from sqlalchemy.orm import relationship
from sqlalchemy.sql import func
from app.database import Base
class Transaction(Base):
"""
Represents a single brokerage transaction.
Attributes:
id: Primary key
account_id: Foreign key to account
run_date: Date the transaction was recorded
action: Description of the transaction action
symbol: Trading symbol
description: Full transaction description
transaction_type: Type (Cash/Margin)
exchange_quantity: Quantity in exchange currency
exchange_currency: Exchange currency code
currency: Transaction currency
price: Transaction price per unit
quantity: Number of shares/contracts
exchange_rate: Currency exchange rate
commission: Commission fees
fees: Additional fees
accrued_interest: Interest accrued
amount: Total transaction amount
cash_balance: Account balance after transaction
settlement_date: Date transaction settles
unique_hash: SHA-256 hash for deduplication
created_at: Timestamp of record creation
updated_at: Timestamp of last update
"""
__tablename__ = "transactions"
id = Column(Integer, primary_key=True, index=True)
account_id = Column(Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
# Transaction details from CSV
run_date = Column(Date, nullable=False, index=True)
action = Column(String(500), nullable=False)
symbol = Column(String(50), index=True)
description = Column(String(500))
transaction_type = Column(String(20)) # Cash, Margin
# Quantities and currencies
exchange_quantity = Column(Numeric(20, 8))
exchange_currency = Column(String(10))
currency = Column(String(10))
# Financial details
price = Column(Numeric(20, 8))
quantity = Column(Numeric(20, 8))
exchange_rate = Column(Numeric(20, 8))
commission = Column(Numeric(20, 2))
fees = Column(Numeric(20, 2))
accrued_interest = Column(Numeric(20, 2))
amount = Column(Numeric(20, 2))
cash_balance = Column(Numeric(20, 2))
settlement_date = Column(Date)
# Deduplication hash
unique_hash = Column(String(64), unique=True, nullable=False, index=True)
# Timestamps
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), server_default=func.now(), nullable=False)
# Relationships
account = relationship("Account", back_populates="transactions")
position_links = relationship("PositionTransaction", back_populates="transaction", cascade="all, delete-orphan")
# Composite index for common queries
__table_args__ = (
Index('idx_account_date', 'account_id', 'run_date'),
Index('idx_account_symbol', 'account_id', 'symbol'),
)

View File

@@ -0,0 +1,5 @@
"""CSV parser modules for various brokerage formats."""
from app.parsers.base_parser import BaseParser, ParseResult
from app.parsers.fidelity_parser import FidelityParser
__all__ = ["BaseParser", "ParseResult", "FidelityParser"]

View File

@@ -0,0 +1,99 @@
"""Base parser interface for brokerage CSV files."""
from abc import ABC, abstractmethod
from typing import List, Dict, Any, NamedTuple
from pathlib import Path
import pandas as pd
class ParseResult(NamedTuple):
"""
Result of parsing a brokerage CSV file.
Attributes:
transactions: List of parsed transaction dictionaries
errors: List of error messages encountered during parsing
row_count: Total number of rows processed
"""
transactions: List[Dict[str, Any]]
errors: List[str]
row_count: int
class BaseParser(ABC):
"""
Abstract base class for brokerage CSV parsers.
Provides a standard interface for parsing CSV files from different brokerages.
Subclasses must implement the parse() method for their specific format.
"""
@abstractmethod
def parse(self, file_path: Path) -> ParseResult:
"""
Parse a brokerage CSV file into standardized transaction dictionaries.
Args:
file_path: Path to the CSV file to parse
Returns:
ParseResult containing transactions, errors, and row count
Raises:
FileNotFoundError: If the file does not exist
ValueError: If the file format is invalid
"""
pass
def _read_csv(self, file_path: Path, **kwargs) -> pd.DataFrame:
"""
Read CSV file into a pandas DataFrame with error handling.
Args:
file_path: Path to CSV file
**kwargs: Additional arguments passed to pd.read_csv()
Returns:
DataFrame containing CSV data
Raises:
FileNotFoundError: If file does not exist
pd.errors.EmptyDataError: If file is empty
"""
if not file_path.exists():
raise FileNotFoundError(f"CSV file not found: {file_path}")
return pd.read_csv(file_path, **kwargs)
@staticmethod
def _safe_decimal(value: Any) -> Any:
"""
Safely convert value to decimal-compatible format, handling NaN and None.
Args:
value: Value to convert
Returns:
Converted value or None if invalid
"""
if pd.isna(value):
return None
if value == "":
return None
return value
@staticmethod
def _safe_date(value: Any) -> Any:
"""
Safely convert value to date, handling NaN and None.
Args:
value: Value to convert
Returns:
Converted date or None if invalid
"""
if pd.isna(value):
return None
if value == "":
return None
return value

View File

@@ -0,0 +1,257 @@
"""Fidelity brokerage CSV parser."""
from pathlib import Path
from typing import List, Dict, Any
import pandas as pd
from datetime import datetime
import re
from app.parsers.base_parser import BaseParser, ParseResult
class FidelityParser(BaseParser):
"""
Parser for Fidelity brokerage account history CSV files.
Expected CSV columns:
- Run Date
- Action
- Symbol
- Description
- Type
- Exchange Quantity
- Exchange Currency
- Currency
- Price
- Quantity
- Exchange Rate
- Commission
- Fees
- Accrued Interest
- Amount
- Cash Balance
- Settlement Date
"""
# Expected column names in Fidelity CSV
EXPECTED_COLUMNS = [
"Run Date",
"Action",
"Symbol",
"Description",
"Type",
"Exchange Quantity",
"Exchange Currency",
"Currency",
"Price",
"Quantity",
"Exchange Rate",
"Commission",
"Fees",
"Accrued Interest",
"Amount",
"Cash Balance",
"Settlement Date",
]
def parse(self, file_path: Path) -> ParseResult:
"""
Parse a Fidelity CSV file into standardized transaction dictionaries.
Args:
file_path: Path to the Fidelity CSV file
Returns:
ParseResult containing parsed transactions, errors, and row count
Raises:
FileNotFoundError: If the file does not exist
ValueError: If the CSV format is invalid
"""
errors = []
transactions = []
try:
# Read CSV, skipping empty rows at the beginning
df = self._read_csv(file_path, skiprows=self._find_header_row(file_path))
# Validate columns
missing_cols = set(self.EXPECTED_COLUMNS) - set(df.columns)
if missing_cols:
raise ValueError(f"Missing required columns: {missing_cols}")
# Parse each row
for idx, row in df.iterrows():
try:
transaction = self._parse_row(row)
if transaction:
transactions.append(transaction)
except Exception as e:
errors.append(f"Row {idx + 1}: {str(e)}")
return ParseResult(
transactions=transactions, errors=errors, row_count=len(df)
)
except FileNotFoundError as e:
raise e
except Exception as e:
raise ValueError(f"Failed to parse Fidelity CSV: {str(e)}")
def _find_header_row(self, file_path: Path) -> int:
"""
Find the row number where the header starts in Fidelity CSV.
Fidelity CSVs may have empty rows or metadata at the beginning.
Args:
file_path: Path to CSV file
Returns:
Row number (0-indexed) where the header is located
"""
with open(file_path, "r", encoding="utf-8-sig") as f:
for i, line in enumerate(f):
if "Run Date" in line:
return i
return 0 # Default to first row if not found
def _extract_real_ticker(self, symbol: str, description: str, action: str) -> str:
"""
Extract the real underlying ticker from option descriptions.
Fidelity uses internal reference numbers (like 6736999MM) in the Symbol column
for options, but the real ticker is in the Description/Action in parentheses.
Examples:
- Description: "CALL (OPEN) OPENDOOR JAN 16 26 (100 SHS)"
- Action: "YOU SOLD CLOSING TRANSACTION CALL (OPEN) OPENDOOR..."
Args:
symbol: Symbol from CSV (might be Fidelity internal reference)
description: Description field
action: Action field
Returns:
Real ticker symbol, or original symbol if not found
"""
# If symbol looks normal (letters only, not Fidelity's numeric codes), return it
if symbol and re.match(r'^[A-Z]{1,5}$', symbol):
return symbol
# Try to extract from description first (more reliable)
# Pattern: (TICKER) or CALL (TICKER) or PUT (TICKER)
if description:
# Look for pattern like "CALL (OPEN)" or "PUT (AAPL)"
match = re.search(r'(?:CALL|PUT)\s*\(([A-Z]+)\)', description, re.IGNORECASE)
if match:
return match.group(1)
# Look for standalone (TICKER) pattern
match = re.search(r'\(([A-Z]{1,5})\)', description)
if match:
ticker = match.group(1)
# Make sure it's not something like (100 or (Margin)
if not ticker.isdigit() and ticker not in ['MARGIN', 'CASH', 'SHS']:
return ticker
# Fall back to action field
if action:
match = re.search(r'(?:CALL|PUT)\s*\(([A-Z]+)\)', action, re.IGNORECASE)
if match:
return match.group(1)
# Return original symbol if we couldn't extract anything better
return symbol if symbol else None
def _parse_row(self, row: pd.Series) -> Dict[str, Any]:
"""
Parse a single row from Fidelity CSV into a transaction dictionary.
Args:
row: Pandas Series representing one CSV row
Returns:
Dictionary with transaction data, or None if row should be skipped
Raises:
ValueError: If required fields are missing or invalid
"""
# Parse dates
run_date = self._parse_date(row["Run Date"])
settlement_date = self._parse_date(row["Settlement Date"])
# Extract raw values
raw_symbol = self._safe_string(row["Symbol"])
description = self._safe_string(row["Description"])
action = str(row["Action"]).strip() if pd.notna(row["Action"]) else ""
# Extract the real ticker (especially important for options)
actual_symbol = self._extract_real_ticker(raw_symbol, description, action)
# Extract and clean values
transaction = {
"run_date": run_date,
"action": action,
"symbol": actual_symbol,
"description": description,
"transaction_type": self._safe_string(row["Type"]),
"exchange_quantity": self._safe_decimal(row["Exchange Quantity"]),
"exchange_currency": self._safe_string(row["Exchange Currency"]),
"currency": self._safe_string(row["Currency"]),
"price": self._safe_decimal(row["Price"]),
"quantity": self._safe_decimal(row["Quantity"]),
"exchange_rate": self._safe_decimal(row["Exchange Rate"]),
"commission": self._safe_decimal(row["Commission"]),
"fees": self._safe_decimal(row["Fees"]),
"accrued_interest": self._safe_decimal(row["Accrued Interest"]),
"amount": self._safe_decimal(row["Amount"]),
"cash_balance": self._safe_decimal(row["Cash Balance"]),
"settlement_date": settlement_date,
}
return transaction
def _parse_date(self, date_value: Any) -> Any:
"""
Parse date value from CSV, handling various formats.
Args:
date_value: Date value from CSV (string or datetime)
Returns:
datetime.date object or None if empty/invalid
"""
if pd.isna(date_value) or date_value == "":
return None
# If already a datetime object
if isinstance(date_value, datetime):
return date_value.date()
# Try parsing common date formats
date_str = str(date_value).strip()
if not date_str:
return None
# Try common formats
for fmt in ["%m/%d/%Y", "%Y-%m-%d", "%m-%d-%Y"]:
try:
return datetime.strptime(date_str, fmt).date()
except ValueError:
continue
return None
def _safe_string(self, value: Any) -> str:
"""
Safely convert value to string, handling NaN and empty values.
Args:
value: Value to convert
Returns:
String value or None if empty
"""
if pd.isna(value) or value == "":
return None
return str(value).strip()

View File

@@ -0,0 +1,14 @@
"""Pydantic schemas for API request/response validation."""
from app.schemas.account import AccountCreate, AccountUpdate, AccountResponse
from app.schemas.transaction import TransactionCreate, TransactionResponse
from app.schemas.position import PositionResponse, PositionStats
__all__ = [
"AccountCreate",
"AccountUpdate",
"AccountResponse",
"TransactionCreate",
"TransactionResponse",
"PositionResponse",
"PositionStats",
]

View File

@@ -0,0 +1,34 @@
"""Pydantic schemas for account-related API operations."""
from pydantic import BaseModel, Field
from datetime import datetime
from typing import Optional
from app.models.account import AccountType
class AccountBase(BaseModel):
"""Base schema for account data."""
account_number: str = Field(..., description="Unique account identifier")
account_name: str = Field(..., description="Human-readable account name")
account_type: AccountType = Field(default=AccountType.CASH, description="Account type")
class AccountCreate(AccountBase):
"""Schema for creating a new account."""
pass
class AccountUpdate(BaseModel):
"""Schema for updating an existing account."""
account_name: Optional[str] = Field(None, description="Updated account name")
account_type: Optional[AccountType] = Field(None, description="Updated account type")
class AccountResponse(AccountBase):
"""Schema for account API responses."""
id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,45 @@
"""Pydantic schemas for position-related API operations."""
from pydantic import BaseModel, Field
from datetime import date, datetime
from typing import Optional
from decimal import Decimal
from app.models.position import PositionType, PositionStatus
class PositionBase(BaseModel):
"""Base schema for position data."""
symbol: str
option_symbol: Optional[str] = None
position_type: PositionType
status: PositionStatus
open_date: date
close_date: Optional[date] = None
total_quantity: Decimal
avg_entry_price: Optional[Decimal] = None
avg_exit_price: Optional[Decimal] = None
realized_pnl: Optional[Decimal] = None
unrealized_pnl: Optional[Decimal] = None
class PositionResponse(PositionBase):
"""Schema for position API responses."""
id: int
account_id: int
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True
class PositionStats(BaseModel):
"""Schema for aggregate position statistics."""
total_positions: int = Field(..., description="Total number of positions")
open_positions: int = Field(..., description="Number of open positions")
closed_positions: int = Field(..., description="Number of closed positions")
total_realized_pnl: Decimal = Field(..., description="Total realized P&L")
total_unrealized_pnl: Decimal = Field(..., description="Total unrealized P&L")
win_rate: float = Field(..., description="Percentage of profitable trades")
avg_win: Decimal = Field(..., description="Average profit on winning trades")
avg_loss: Decimal = Field(..., description="Average loss on losing trades")

View File

@@ -0,0 +1,44 @@
"""Pydantic schemas for transaction-related API operations."""
from pydantic import BaseModel, Field
from datetime import date, datetime
from typing import Optional
from decimal import Decimal
class TransactionBase(BaseModel):
"""Base schema for transaction data."""
run_date: date
action: str
symbol: Optional[str] = None
description: Optional[str] = None
transaction_type: Optional[str] = None
exchange_quantity: Optional[Decimal] = None
exchange_currency: Optional[str] = None
currency: Optional[str] = None
price: Optional[Decimal] = None
quantity: Optional[Decimal] = None
exchange_rate: Optional[Decimal] = None
commission: Optional[Decimal] = None
fees: Optional[Decimal] = None
accrued_interest: Optional[Decimal] = None
amount: Optional[Decimal] = None
cash_balance: Optional[Decimal] = None
settlement_date: Optional[date] = None
class TransactionCreate(TransactionBase):
"""Schema for creating a new transaction."""
account_id: int
unique_hash: str
class TransactionResponse(TransactionBase):
"""Schema for transaction API responses."""
id: int
account_id: int
unique_hash: str
created_at: datetime
updated_at: datetime
class Config:
from_attributes = True

View File

@@ -0,0 +1,6 @@
"""Business logic services."""
from app.services.import_service import ImportService, ImportResult
from app.services.position_tracker import PositionTracker
from app.services.performance_calculator import PerformanceCalculator
__all__ = ["ImportService", "ImportResult", "PositionTracker", "PerformanceCalculator"]

View File

@@ -0,0 +1,149 @@
"""Service for importing transactions from CSV files."""
from pathlib import Path
from typing import List, Dict, Any, NamedTuple
from sqlalchemy.orm import Session
from sqlalchemy.exc import IntegrityError
from app.parsers import FidelityParser
from app.models import Transaction
from app.utils import generate_transaction_hash
class ImportResult(NamedTuple):
"""
Result of an import operation.
Attributes:
imported: Number of successfully imported transactions
skipped: Number of skipped duplicate transactions
errors: List of error messages
total_rows: Total number of rows processed
"""
imported: int
skipped: int
errors: List[str]
total_rows: int
class ImportService:
"""
Service for importing transactions from brokerage CSV files.
Handles parsing, deduplication, and database insertion.
"""
def __init__(self, db: Session):
"""
Initialize import service.
Args:
db: Database session
"""
self.db = db
self.parser = FidelityParser() # Can be extended to support multiple parsers
def import_from_file(self, file_path: Path, account_id: int) -> ImportResult:
"""
Import transactions from a CSV file.
Args:
file_path: Path to CSV file
account_id: ID of the account to import transactions for
Returns:
ImportResult with statistics
Raises:
FileNotFoundError: If file doesn't exist
ValueError: If file format is invalid
"""
# Parse CSV file
parse_result = self.parser.parse(file_path)
imported = 0
skipped = 0
errors = list(parse_result.errors)
# Process each transaction
for txn_data in parse_result.transactions:
try:
# Generate deduplication hash
unique_hash = generate_transaction_hash(
account_id=account_id,
run_date=txn_data["run_date"],
symbol=txn_data.get("symbol"),
action=txn_data["action"],
amount=txn_data.get("amount"),
quantity=txn_data.get("quantity"),
price=txn_data.get("price"),
)
# Check if transaction already exists
existing = (
self.db.query(Transaction)
.filter(Transaction.unique_hash == unique_hash)
.first()
)
if existing:
skipped += 1
continue
# Create new transaction
transaction = Transaction(
account_id=account_id,
unique_hash=unique_hash,
**txn_data
)
self.db.add(transaction)
self.db.commit()
imported += 1
except IntegrityError:
# Duplicate hash (edge case if concurrent imports)
self.db.rollback()
skipped += 1
except Exception as e:
self.db.rollback()
errors.append(f"Failed to import transaction: {str(e)}")
return ImportResult(
imported=imported,
skipped=skipped,
errors=errors,
total_rows=parse_result.row_count,
)
def import_from_directory(
self, directory: Path, account_id: int, pattern: str = "*.csv"
) -> Dict[str, ImportResult]:
"""
Import transactions from all CSV files in a directory.
Args:
directory: Path to directory containing CSV files
account_id: ID of the account to import transactions for
pattern: Glob pattern for matching files (default: *.csv)
Returns:
Dictionary mapping filename to ImportResult
"""
if not directory.exists() or not directory.is_dir():
raise ValueError(f"Invalid directory: {directory}")
results = {}
for file_path in directory.glob(pattern):
try:
result = self.import_from_file(file_path, account_id)
results[file_path.name] = result
except Exception as e:
results[file_path.name] = ImportResult(
imported=0,
skipped=0,
errors=[str(e)],
total_rows=0,
)
return results

View File

@@ -0,0 +1,330 @@
"""
Market data service with rate limiting, caching, and batch processing.
This service handles fetching market prices from Yahoo Finance with:
- Database-backed caching to survive restarts
- Rate limiting with exponential backoff
- Batch processing to reduce API calls
- Stale-while-revalidate pattern for better UX
"""
import time
import yfinance as yf
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import Dict, List, Optional
from decimal import Decimal
from datetime import datetime, timedelta
import logging
from app.models.market_price import MarketPrice
logger = logging.getLogger(__name__)
class MarketDataService:
"""Service for fetching and caching market prices with rate limiting."""
def __init__(self, db: Session, cache_ttl_seconds: int = 300):
"""
Initialize market data service.
Args:
db: Database session
cache_ttl_seconds: How long cached prices are considered fresh (default: 5 minutes)
"""
self.db = db
self.cache_ttl = cache_ttl_seconds
self._rate_limit_delay = 0.5 # Start with 500ms between requests
self._last_request_time = 0.0
self._consecutive_errors = 0
self._max_retries = 3
@staticmethod
def _is_valid_stock_symbol(symbol: str) -> bool:
"""
Check if a symbol is a valid stock ticker (not an option symbol or CUSIP).
Args:
symbol: Symbol to check
Returns:
True if it looks like a valid stock ticker
"""
if not symbol or len(symbol) > 5:
return False
# Stock symbols should start with a letter, not a number
# Numbers indicate CUSIP codes or option symbols
if symbol[0].isdigit():
return False
# Should be mostly uppercase letters
# Allow $ for preferred shares (e.g., BRK.B becomes BRK-B)
return symbol.replace('-', '').replace('.', '').isalpha()
def get_price(self, symbol: str, allow_stale: bool = True) -> Optional[Decimal]:
"""
Get current price for a symbol with caching.
Args:
symbol: Stock ticker symbol
allow_stale: If True, return stale cache data instead of None
Returns:
Price or None if unavailable
"""
# Skip invalid symbols (option symbols, CUSIPs, etc.)
if not self._is_valid_stock_symbol(symbol):
logger.debug(f"Skipping invalid symbol: {symbol} (not a stock ticker)")
return None
# Check database cache first
cached = self._get_cached_price(symbol)
if cached:
price, age_seconds = cached
if age_seconds < self.cache_ttl:
# Fresh cache hit
logger.debug(f"Cache HIT (fresh): {symbol} = ${price} (age: {age_seconds}s)")
return price
elif allow_stale:
# Stale cache hit, but we'll return it
logger.debug(f"Cache HIT (stale): {symbol} = ${price} (age: {age_seconds}s)")
return price
# Cache miss or expired - fetch from Yahoo Finance
logger.info(f"Cache MISS: {symbol}, fetching from Yahoo Finance...")
fresh_price = self._fetch_from_yahoo(symbol)
if fresh_price is not None:
self._update_cache(symbol, fresh_price)
return fresh_price
# If fetch failed and we have stale data, return it
if cached and allow_stale:
price, age_seconds = cached
logger.warning(f"Yahoo fetch failed, using stale cache: {symbol} = ${price} (age: {age_seconds}s)")
return price
return None
def get_prices_batch(
self,
symbols: List[str],
allow_stale: bool = True,
max_fetches: int = 10
) -> Dict[str, Optional[Decimal]]:
"""
Get prices for multiple symbols with rate limiting.
Args:
symbols: List of ticker symbols
allow_stale: Return stale cache data if available
max_fetches: Maximum number of API calls to make (remaining use cache)
Returns:
Dictionary mapping symbol to price (or None if unavailable)
"""
results = {}
symbols_to_fetch = []
# First pass: Check cache for all symbols
for symbol in symbols:
# Skip invalid symbols
if not self._is_valid_stock_symbol(symbol):
logger.debug(f"Skipping invalid symbol in batch: {symbol}")
results[symbol] = None
continue
cached = self._get_cached_price(symbol)
if cached:
price, age_seconds = cached
if age_seconds < self.cache_ttl:
# Fresh cache - use it
results[symbol] = price
elif allow_stale:
# Stale but usable
results[symbol] = price
if age_seconds < self.cache_ttl * 2: # Not TOO stale
symbols_to_fetch.append(symbol)
else:
# Stale and not allowing stale - need to fetch
symbols_to_fetch.append(symbol)
else:
# No cache at all
symbols_to_fetch.append(symbol)
# Second pass: Fetch missing/stale symbols (with limit)
if symbols_to_fetch:
logger.info(f"Batch fetching {len(symbols_to_fetch)} symbols (max: {max_fetches})")
for i, symbol in enumerate(symbols_to_fetch[:max_fetches]):
if i > 0:
# Rate limiting delay
time.sleep(self._rate_limit_delay)
price = self._fetch_from_yahoo(symbol)
if price is not None:
results[symbol] = price
self._update_cache(symbol, price)
elif symbol not in results:
# No cached value and fetch failed
results[symbol] = None
return results
def refresh_stale_prices(self, min_age_seconds: int = 300, limit: int = 20) -> int:
"""
Background task to refresh stale prices.
Args:
min_age_seconds: Only refresh prices older than this
limit: Maximum number of prices to refresh
Returns:
Number of prices refreshed
"""
cutoff_time = datetime.utcnow() - timedelta(seconds=min_age_seconds)
# Get stale prices ordered by oldest first
stale_prices = (
self.db.query(MarketPrice)
.filter(MarketPrice.fetched_at < cutoff_time)
.order_by(MarketPrice.fetched_at.asc())
.limit(limit)
.all()
)
refreshed = 0
for cached_price in stale_prices:
time.sleep(self._rate_limit_delay)
fresh_price = self._fetch_from_yahoo(cached_price.symbol)
if fresh_price is not None:
self._update_cache(cached_price.symbol, fresh_price)
refreshed += 1
logger.info(f"Refreshed {refreshed}/{len(stale_prices)} stale prices")
return refreshed
def _get_cached_price(self, symbol: str) -> Optional[tuple[Decimal, float]]:
"""
Get cached price from database.
Returns:
Tuple of (price, age_in_seconds) or None if not cached
"""
cached = (
self.db.query(MarketPrice)
.filter(MarketPrice.symbol == symbol)
.first()
)
if cached:
age = (datetime.utcnow() - cached.fetched_at).total_seconds()
return (cached.price, age)
return None
def _update_cache(self, symbol: str, price: Decimal) -> None:
"""Update or insert price in database cache."""
cached = (
self.db.query(MarketPrice)
.filter(MarketPrice.symbol == symbol)
.first()
)
if cached:
cached.price = price
cached.fetched_at = datetime.utcnow()
else:
new_price = MarketPrice(
symbol=symbol,
price=price,
fetched_at=datetime.utcnow()
)
self.db.add(new_price)
self.db.commit()
def _fetch_from_yahoo(self, symbol: str) -> Optional[Decimal]:
"""
Fetch price from Yahoo Finance with rate limiting and retries.
Returns:
Price or None if fetch failed
"""
for attempt in range(self._max_retries):
try:
# Rate limiting
elapsed = time.time() - self._last_request_time
if elapsed < self._rate_limit_delay:
time.sleep(self._rate_limit_delay - elapsed)
self._last_request_time = time.time()
# Fetch from Yahoo
ticker = yf.Ticker(symbol)
info = ticker.info
# Try different price fields
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
if field in info and info[field]:
price = Decimal(str(info[field]))
# Success - reset error tracking
self._consecutive_errors = 0
self._rate_limit_delay = max(0.5, self._rate_limit_delay * 0.9) # Gradually decrease delay
logger.debug(f"Fetched {symbol} = ${price}")
return price
# No price found in response
logger.warning(f"No price data in Yahoo response for {symbol}")
return None
except Exception as e:
error_str = str(e).lower()
if "429" in error_str or "too many requests" in error_str:
# Rate limit hit - back off exponentially
self._consecutive_errors += 1
self._rate_limit_delay = min(10.0, self._rate_limit_delay * 2) # Double delay, max 10s
logger.warning(
f"Rate limit hit for {symbol} (attempt {attempt + 1}/{self._max_retries}), "
f"backing off to {self._rate_limit_delay}s delay"
)
if attempt < self._max_retries - 1:
time.sleep(self._rate_limit_delay * (attempt + 1)) # Longer wait for retries
continue
else:
# Other error
logger.error(f"Error fetching {symbol}: {e}")
return None
logger.error(f"Failed to fetch {symbol} after {self._max_retries} attempts")
return None
def clear_cache(self, older_than_days: int = 30) -> int:
"""
Clear old cached prices.
Args:
older_than_days: Delete prices older than this many days
Returns:
Number of records deleted
"""
cutoff = datetime.utcnow() - timedelta(days=older_than_days)
deleted = (
self.db.query(MarketPrice)
.filter(MarketPrice.fetched_at < cutoff)
.delete()
)
self.db.commit()
logger.info(f"Cleared {deleted} cached prices older than {older_than_days} days")
return deleted

View File

@@ -0,0 +1,364 @@
"""Service for calculating performance metrics and unrealized P&L."""
from sqlalchemy.orm import Session
from sqlalchemy import and_, func
from typing import Dict, Optional
from decimal import Decimal
from datetime import datetime, timedelta
import yfinance as yf
from functools import lru_cache
from app.models import Position, Transaction
from app.models.position import PositionStatus
class PerformanceCalculator:
"""
Service for calculating performance metrics and market data.
Integrates with Yahoo Finance API for real-time pricing of open positions.
"""
def __init__(self, db: Session, cache_ttl: int = 60):
"""
Initialize performance calculator.
Args:
db: Database session
cache_ttl: Cache time-to-live in seconds (default: 60)
"""
self.db = db
self.cache_ttl = cache_ttl
self._price_cache: Dict[str, tuple[Decimal, datetime]] = {}
def calculate_unrealized_pnl(self, position: Position) -> Optional[Decimal]:
"""
Calculate unrealized P&L for an open position.
Args:
position: Open position to calculate P&L for
Returns:
Unrealized P&L or None if market data unavailable
"""
if position.status != PositionStatus.OPEN:
return None
# Get current market price
current_price = self.get_current_price(position.symbol)
if current_price is None:
return None
if position.avg_entry_price is None:
return None
# Calculate P&L based on position direction
quantity = abs(position.total_quantity)
is_short = position.total_quantity < 0
if is_short:
# Short position: profit when price decreases
pnl = (position.avg_entry_price - current_price) * quantity * 100
else:
# Long position: profit when price increases
pnl = (current_price - position.avg_entry_price) * quantity * 100
# Subtract fees and commissions from opening transactions
total_fees = Decimal("0")
for link in position.transaction_links:
txn = link.transaction
if txn.commission:
total_fees += txn.commission
if txn.fees:
total_fees += txn.fees
pnl -= total_fees
return pnl
def update_open_positions_pnl(self, account_id: int) -> int:
"""
Update unrealized P&L for all open positions in an account.
Args:
account_id: Account ID to update
Returns:
Number of positions updated
"""
open_positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.OPEN,
)
)
.all()
)
updated = 0
for position in open_positions:
unrealized_pnl = self.calculate_unrealized_pnl(position)
if unrealized_pnl is not None:
position.unrealized_pnl = unrealized_pnl
updated += 1
self.db.commit()
return updated
def get_current_price(self, symbol: str) -> Optional[Decimal]:
"""
Get current market price for a symbol.
Uses Yahoo Finance API with caching to reduce API calls.
Args:
symbol: Stock ticker symbol
Returns:
Current price or None if unavailable
"""
# Check cache
if symbol in self._price_cache:
price, timestamp = self._price_cache[symbol]
if datetime.now() - timestamp < timedelta(seconds=self.cache_ttl):
return price
# Fetch from Yahoo Finance
try:
ticker = yf.Ticker(symbol)
info = ticker.info
# Try different price fields
current_price = None
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
if field in info and info[field]:
current_price = Decimal(str(info[field]))
break
if current_price is not None:
# Cache the price
self._price_cache[symbol] = (current_price, datetime.now())
return current_price
except Exception:
# Failed to fetch price
pass
return None
def calculate_account_stats(self, account_id: int) -> Dict:
"""
Calculate aggregate statistics for an account.
Args:
account_id: Account ID
Returns:
Dictionary with performance metrics
"""
# Get all positions
positions = (
self.db.query(Position)
.filter(Position.account_id == account_id)
.all()
)
total_positions = len(positions)
open_positions_count = sum(
1 for p in positions if p.status == PositionStatus.OPEN
)
closed_positions_count = sum(
1 for p in positions if p.status == PositionStatus.CLOSED
)
# Calculate P&L
total_realized_pnl = sum(
(p.realized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.CLOSED
)
# Update unrealized P&L for open positions
self.update_open_positions_pnl(account_id)
total_unrealized_pnl = sum(
(p.unrealized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.OPEN
)
# Calculate win rate and average win/loss
closed_with_pnl = [
p for p in positions
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
]
if closed_with_pnl:
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
avg_win = (
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
if winning_trades
else Decimal("0")
)
avg_loss = (
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
if losing_trades
else Decimal("0")
)
else:
win_rate = 0.0
avg_win = Decimal("0")
avg_loss = Decimal("0")
# Get current account balance from latest transaction
latest_txn = (
self.db.query(Transaction)
.filter(Transaction.account_id == account_id)
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
.first()
)
current_balance = (
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
)
return {
"total_positions": total_positions,
"open_positions": open_positions_count,
"closed_positions": closed_positions_count,
"total_realized_pnl": float(total_realized_pnl),
"total_unrealized_pnl": float(total_unrealized_pnl),
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
"win_rate": float(win_rate),
"avg_win": float(avg_win),
"avg_loss": float(avg_loss),
"current_balance": float(current_balance),
}
def get_balance_history(
self, account_id: int, days: int = 30
) -> list[Dict]:
"""
Get account balance history for charting.
Args:
account_id: Account ID
days: Number of days to retrieve
Returns:
List of {date, balance} dictionaries
"""
cutoff_date = datetime.now().date() - timedelta(days=days)
transactions = (
self.db.query(Transaction.run_date, Transaction.cash_balance)
.filter(
and_(
Transaction.account_id == account_id,
Transaction.run_date >= cutoff_date,
Transaction.cash_balance.isnot(None),
)
)
.order_by(Transaction.run_date)
.all()
)
# Get one balance per day (use last transaction of the day)
daily_balances = {}
for txn in transactions:
daily_balances[txn.run_date] = float(txn.cash_balance)
return [
{"date": date.isoformat(), "balance": balance}
for date, balance in sorted(daily_balances.items())
]
def get_top_trades(
self, account_id: int, limit: int = 10
) -> list[Dict]:
"""
Get top performing trades (by realized P&L).
Args:
account_id: Account ID
limit: Maximum number of trades to return
Returns:
List of trade dictionaries
"""
positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
.order_by(Position.realized_pnl.desc())
.limit(limit)
.all()
)
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]
def get_worst_trades(
self, account_id: int, limit: int = 20
) -> list[Dict]:
"""
Get worst performing trades (biggest losses by realized P&L).
Args:
account_id: Account ID
limit: Maximum number of trades to return
Returns:
List of trade dictionaries
"""
positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
.order_by(Position.realized_pnl.asc())
.limit(limit)
.all()
)
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]

View File

@@ -0,0 +1,433 @@
"""
Improved performance calculator with rate-limited market data fetching.
This version uses the MarketDataService for efficient, cached price lookups.
"""
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import Dict, Optional
from decimal import Decimal
from datetime import datetime, timedelta
import logging
from app.models import Position, Transaction
from app.models.position import PositionStatus
from app.services.market_data_service import MarketDataService
logger = logging.getLogger(__name__)
class PerformanceCalculatorV2:
"""
Enhanced performance calculator with efficient market data handling.
Features:
- Database-backed price caching
- Rate-limited API calls
- Batch price fetching
- Stale-while-revalidate pattern
"""
def __init__(self, db: Session, cache_ttl: int = 300):
"""
Initialize performance calculator.
Args:
db: Database session
cache_ttl: Cache time-to-live in seconds (default: 5 minutes)
"""
self.db = db
self.market_data = MarketDataService(db, cache_ttl_seconds=cache_ttl)
def calculate_unrealized_pnl(self, position: Position, current_price: Optional[Decimal] = None) -> Optional[Decimal]:
"""
Calculate unrealized P&L for an open position.
Args:
position: Open position to calculate P&L for
current_price: Optional pre-fetched current price (avoids API call)
Returns:
Unrealized P&L or None if market data unavailable
"""
if position.status != PositionStatus.OPEN:
return None
# Use provided price or fetch it
if current_price is None:
current_price = self.market_data.get_price(position.symbol, allow_stale=True)
if current_price is None or position.avg_entry_price is None:
return None
# Calculate P&L based on position direction
quantity = abs(position.total_quantity)
is_short = position.total_quantity < 0
if is_short:
# Short position: profit when price decreases
pnl = (position.avg_entry_price - current_price) * quantity * 100
else:
# Long position: profit when price increases
pnl = (current_price - position.avg_entry_price) * quantity * 100
# Subtract fees and commissions from opening transactions
total_fees = Decimal("0")
for link in position.transaction_links:
txn = link.transaction
if txn.commission:
total_fees += txn.commission
if txn.fees:
total_fees += txn.fees
pnl -= total_fees
return pnl
def update_open_positions_pnl(
self,
account_id: int,
max_api_calls: int = 10,
allow_stale: bool = True
) -> Dict[str, int]:
"""
Update unrealized P&L for all open positions in an account.
Uses batch fetching with rate limiting to avoid overwhelming Yahoo Finance API.
Args:
account_id: Account ID to update
max_api_calls: Maximum number of Yahoo Finance API calls to make
allow_stale: Allow using stale cached prices
Returns:
Dictionary with update statistics
"""
open_positions = (
self.db.query(Position)
.filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.OPEN,
)
)
.all()
)
if not open_positions:
return {
"total": 0,
"updated": 0,
"cached": 0,
"failed": 0
}
# Get unique symbols
symbols = list(set(p.symbol for p in open_positions))
logger.info(f"Updating P&L for {len(open_positions)} positions across {len(symbols)} symbols")
# Fetch prices in batch
prices = self.market_data.get_prices_batch(
symbols,
allow_stale=allow_stale,
max_fetches=max_api_calls
)
# Update P&L for each position
updated = 0
cached = 0
failed = 0
for position in open_positions:
price = prices.get(position.symbol)
if price is not None:
unrealized_pnl = self.calculate_unrealized_pnl(position, current_price=price)
if unrealized_pnl is not None:
position.unrealized_pnl = unrealized_pnl
updated += 1
# Check if price was from cache (age > 0) or fresh fetch
cached_info = self.market_data._get_cached_price(position.symbol)
if cached_info:
_, age = cached_info
if age < self.market_data.cache_ttl:
cached += 1
else:
failed += 1
else:
failed += 1
logger.warning(f"Could not get price for {position.symbol}")
self.db.commit()
logger.info(
f"Updated {updated}/{len(open_positions)} positions "
f"(cached: {cached}, failed: {failed})"
)
return {
"total": len(open_positions),
"updated": updated,
"cached": cached,
"failed": failed
}
def calculate_account_stats(
self,
account_id: int,
update_prices: bool = True,
max_api_calls: int = 10,
start_date = None,
end_date = None
) -> Dict:
"""
Calculate aggregate statistics for an account.
Args:
account_id: Account ID
update_prices: Whether to fetch fresh prices (if False, uses cached only)
max_api_calls: Maximum number of Yahoo Finance API calls
start_date: Filter positions opened on or after this date
end_date: Filter positions opened on or before this date
Returns:
Dictionary with performance metrics
"""
# Get all positions with optional date filtering
query = self.db.query(Position).filter(Position.account_id == account_id)
if start_date:
query = query.filter(Position.open_date >= start_date)
if end_date:
query = query.filter(Position.open_date <= end_date)
positions = query.all()
total_positions = len(positions)
open_positions_count = sum(
1 for p in positions if p.status == PositionStatus.OPEN
)
closed_positions_count = sum(
1 for p in positions if p.status == PositionStatus.CLOSED
)
# Calculate realized P&L (doesn't need market data)
total_realized_pnl = sum(
(p.realized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.CLOSED
)
# Update unrealized P&L for open positions
update_stats = None
if update_prices and open_positions_count > 0:
update_stats = self.update_open_positions_pnl(
account_id,
max_api_calls=max_api_calls,
allow_stale=True
)
# Calculate total unrealized P&L
total_unrealized_pnl = sum(
(p.unrealized_pnl or Decimal("0"))
for p in positions
if p.status == PositionStatus.OPEN
)
# Calculate win rate and average win/loss
closed_with_pnl = [
p for p in positions
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
]
if closed_with_pnl:
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
avg_win = (
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
if winning_trades
else Decimal("0")
)
avg_loss = (
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
if losing_trades
else Decimal("0")
)
else:
win_rate = 0.0
avg_win = Decimal("0")
avg_loss = Decimal("0")
# Get current account balance from latest transaction
latest_txn = (
self.db.query(Transaction)
.filter(Transaction.account_id == account_id)
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
.first()
)
current_balance = (
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
)
result = {
"total_positions": total_positions,
"open_positions": open_positions_count,
"closed_positions": closed_positions_count,
"total_realized_pnl": float(total_realized_pnl),
"total_unrealized_pnl": float(total_unrealized_pnl),
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
"win_rate": float(win_rate),
"avg_win": float(avg_win),
"avg_loss": float(avg_loss),
"current_balance": float(current_balance),
}
# Add update stats if prices were fetched
if update_stats:
result["price_update_stats"] = update_stats
return result
def get_balance_history(
self, account_id: int, days: int = 30
) -> list[Dict]:
"""
Get account balance history for charting.
This doesn't need market data, just transaction history.
Args:
account_id: Account ID
days: Number of days to retrieve
Returns:
List of {date, balance} dictionaries
"""
cutoff_date = datetime.now().date() - timedelta(days=days)
transactions = (
self.db.query(Transaction.run_date, Transaction.cash_balance)
.filter(
and_(
Transaction.account_id == account_id,
Transaction.run_date >= cutoff_date,
Transaction.cash_balance.isnot(None),
)
)
.order_by(Transaction.run_date)
.all()
)
# Get one balance per day (use last transaction of the day)
daily_balances = {}
for txn in transactions:
daily_balances[txn.run_date] = float(txn.cash_balance)
return [
{"date": date.isoformat(), "balance": balance}
for date, balance in sorted(daily_balances.items())
]
def get_top_trades(
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
) -> list[Dict]:
"""
Get top performing trades (by realized P&L).
This doesn't need market data, just closed positions.
Args:
account_id: Account ID
limit: Maximum number of trades to return
start_date: Filter positions closed on or after this date
end_date: Filter positions closed on or before this date
Returns:
List of trade dictionaries
"""
query = self.db.query(Position).filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
# Apply date filters if provided
if start_date:
query = query.filter(Position.close_date >= start_date)
if end_date:
query = query.filter(Position.close_date <= end_date)
positions = query.order_by(Position.realized_pnl.desc()).limit(limit).all()
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]
def get_worst_trades(
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
) -> list[Dict]:
"""
Get worst performing trades (by realized P&L).
This doesn't need market data, just closed positions.
Args:
account_id: Account ID
limit: Maximum number of trades to return
start_date: Filter positions closed on or after this date
end_date: Filter positions closed on or before this date
Returns:
List of trade dictionaries
"""
query = self.db.query(Position).filter(
and_(
Position.account_id == account_id,
Position.status == PositionStatus.CLOSED,
Position.realized_pnl.isnot(None),
)
)
# Apply date filters if provided
if start_date:
query = query.filter(Position.close_date >= start_date)
if end_date:
query = query.filter(Position.close_date <= end_date)
positions = query.order_by(Position.realized_pnl.asc()).limit(limit).all()
return [
{
"symbol": p.symbol,
"option_symbol": p.option_symbol,
"position_type": p.position_type.value,
"open_date": p.open_date.isoformat(),
"close_date": p.close_date.isoformat() if p.close_date else None,
"quantity": float(p.total_quantity),
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
"realized_pnl": float(p.realized_pnl),
}
for p in positions
]

View File

@@ -0,0 +1,465 @@
"""Service for tracking and calculating trading positions."""
from sqlalchemy.orm import Session
from sqlalchemy import and_
from typing import List, Optional, Dict
from decimal import Decimal
from collections import defaultdict
from datetime import datetime
import re
from app.models import Transaction, Position, PositionTransaction
from app.models.position import PositionType, PositionStatus
from app.utils import parse_option_symbol
class PositionTracker:
"""
Service for tracking trading positions from transactions.
Matches opening and closing transactions using FIFO (First-In-First-Out) method.
Handles stocks, calls, and puts including complex scenarios like assignments and expirations.
"""
def __init__(self, db: Session):
"""
Initialize position tracker.
Args:
db: Database session
"""
self.db = db
def rebuild_positions(self, account_id: int) -> int:
"""
Rebuild all positions for an account from transactions.
Deletes existing positions and recalculates from scratch.
Args:
account_id: Account ID to rebuild positions for
Returns:
Number of positions created
"""
# Delete existing positions
self.db.query(Position).filter(Position.account_id == account_id).delete()
self.db.commit()
# Get all transactions ordered by date
transactions = (
self.db.query(Transaction)
.filter(Transaction.account_id == account_id)
.order_by(Transaction.run_date, Transaction.id)
.all()
)
# Group transactions by symbol and option details
# For options, we need to group by the full option contract (symbol + strike + expiration)
# For stocks, we group by symbol only
symbol_txns = defaultdict(list)
for txn in transactions:
if txn.symbol:
# Create a unique grouping key
grouping_key = self._get_grouping_key(txn)
symbol_txns[grouping_key].append(txn)
# Process each symbol/contract group
position_count = 0
for grouping_key, txns in symbol_txns.items():
positions = self._process_symbol_transactions(account_id, grouping_key, txns)
position_count += len(positions)
self.db.commit()
return position_count
def _process_symbol_transactions(
self, account_id: int, symbol: str, transactions: List[Transaction]
) -> List[Position]:
"""
Process all transactions for a single symbol to create positions.
Args:
account_id: Account ID
symbol: Trading symbol
transactions: List of transactions for this symbol
Returns:
List of created Position objects
"""
positions = []
# Determine position type from first transaction
position_type = self._determine_position_type_from_txn(transactions[0]) if transactions else PositionType.STOCK
# Track open positions using FIFO
open_positions: List[Dict] = []
for txn in transactions:
action = txn.action.upper()
# Determine if this is an opening or closing transaction
if self._is_opening_transaction(action):
# Create new open position
open_pos = {
"transactions": [txn],
"quantity": abs(txn.quantity) if txn.quantity else Decimal("0"),
"entry_price": txn.price,
"open_date": txn.run_date,
"is_short": "SELL" in action or "SOLD" in action,
}
open_positions.append(open_pos)
elif self._is_closing_transaction(action):
# Close positions using FIFO
close_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
remaining_to_close = close_quantity
while remaining_to_close > 0 and open_positions:
open_pos = open_positions[0]
open_qty = open_pos["quantity"]
if open_qty <= remaining_to_close:
# Close entire position
open_pos["transactions"].append(txn)
position = self._create_position(
account_id,
symbol,
position_type,
open_pos,
close_date=txn.run_date,
exit_price=txn.price,
close_quantity=open_qty,
)
positions.append(position)
open_positions.pop(0)
remaining_to_close -= open_qty
else:
# Partially close position
# Split into closed portion
closed_portion = {
"transactions": open_pos["transactions"] + [txn],
"quantity": remaining_to_close,
"entry_price": open_pos["entry_price"],
"open_date": open_pos["open_date"],
"is_short": open_pos["is_short"],
}
position = self._create_position(
account_id,
symbol,
position_type,
closed_portion,
close_date=txn.run_date,
exit_price=txn.price,
close_quantity=remaining_to_close,
)
positions.append(position)
# Update open position with remaining quantity
open_pos["quantity"] -= remaining_to_close
remaining_to_close = Decimal("0")
elif self._is_expiration(action):
# Handle option expirations
expire_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
remaining_to_expire = expire_quantity
while remaining_to_expire > 0 and open_positions:
open_pos = open_positions[0]
open_qty = open_pos["quantity"]
if open_qty <= remaining_to_expire:
# Expire entire position
open_pos["transactions"].append(txn)
position = self._create_position(
account_id,
symbol,
position_type,
open_pos,
close_date=txn.run_date,
exit_price=Decimal("0"), # Expired worthless
close_quantity=open_qty,
)
positions.append(position)
open_positions.pop(0)
remaining_to_expire -= open_qty
else:
# Partially expire
closed_portion = {
"transactions": open_pos["transactions"] + [txn],
"quantity": remaining_to_expire,
"entry_price": open_pos["entry_price"],
"open_date": open_pos["open_date"],
"is_short": open_pos["is_short"],
}
position = self._create_position(
account_id,
symbol,
position_type,
closed_portion,
close_date=txn.run_date,
exit_price=Decimal("0"),
close_quantity=remaining_to_expire,
)
positions.append(position)
open_pos["quantity"] -= remaining_to_expire
remaining_to_expire = Decimal("0")
# Create positions for any remaining open positions
for open_pos in open_positions:
position = self._create_position(
account_id, symbol, position_type, open_pos
)
positions.append(position)
return positions
def _create_position(
self,
account_id: int,
symbol: str,
position_type: PositionType,
position_data: Dict,
close_date: Optional[datetime] = None,
exit_price: Optional[Decimal] = None,
close_quantity: Optional[Decimal] = None,
) -> Position:
"""
Create a Position database object.
Args:
account_id: Account ID
symbol: Trading symbol
position_type: Type of position
position_data: Dictionary with position information
close_date: Close date (if closed)
exit_price: Exit price (if closed)
close_quantity: Quantity closed (if closed)
Returns:
Created Position object
"""
is_closed = close_date is not None
quantity = close_quantity if close_quantity else position_data["quantity"]
# Calculate P&L if closed
realized_pnl = None
if is_closed and position_data["entry_price"] and exit_price is not None:
if position_data["is_short"]:
# Short position: profit when price decreases
realized_pnl = (
position_data["entry_price"] - exit_price
) * quantity * 100
else:
# Long position: profit when price increases
realized_pnl = (
exit_price - position_data["entry_price"]
) * quantity * 100
# Subtract fees and commissions
for txn in position_data["transactions"]:
if txn.commission:
realized_pnl -= txn.commission
if txn.fees:
realized_pnl -= txn.fees
# Extract option symbol from first transaction if this is an option
option_symbol = None
if position_type != PositionType.STOCK and position_data["transactions"]:
first_txn = position_data["transactions"][0]
# Try to extract option details from description
option_symbol = self._extract_option_symbol_from_description(
first_txn.description, first_txn.action, symbol
)
# Create position
position = Position(
account_id=account_id,
symbol=symbol,
option_symbol=option_symbol,
position_type=position_type,
status=PositionStatus.CLOSED if is_closed else PositionStatus.OPEN,
open_date=position_data["open_date"],
close_date=close_date,
total_quantity=quantity if not position_data["is_short"] else -quantity,
avg_entry_price=position_data["entry_price"],
avg_exit_price=exit_price,
realized_pnl=realized_pnl,
)
self.db.add(position)
self.db.flush() # Get position ID
# Link transactions to position
for txn in position_data["transactions"]:
link = PositionTransaction(
position_id=position.id, transaction_id=txn.id
)
self.db.add(link)
return position
def _extract_option_symbol_from_description(
self, description: str, action: str, base_symbol: str
) -> Optional[str]:
"""
Extract option symbol from transaction description.
Example: "CALL (TGT) TARGET CORP JAN 16 26 $95 (100 SHS)"
Returns: "-TGT260116C95"
Args:
description: Transaction description
action: Transaction action
base_symbol: Underlying symbol
Returns:
Option symbol in standard format, or None if can't parse
"""
if not description:
return None
# Determine if CALL or PUT
call_or_put = None
if "CALL" in description.upper():
call_or_put = "C"
elif "PUT" in description.upper():
call_or_put = "P"
else:
return None
# Extract date and strike: "JAN 16 26 $95"
# Pattern: MONTH DAY YY $STRIKE
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
match = re.search(date_strike_pattern, description)
if not match:
return None
month_abbr, day, year, strike = match.groups()
# Convert month abbreviation to number
month_map = {
'JAN': '01', 'FEB': '02', 'MAR': '03', 'APR': '04',
'MAY': '05', 'JUN': '06', 'JUL': '07', 'AUG': '08',
'SEP': '09', 'OCT': '10', 'NOV': '11', 'DEC': '12'
}
month = month_map.get(month_abbr.upper())
if not month:
return None
# Format: -SYMBOL + YYMMDD + C/P + STRIKE
# Remove decimal point from strike if it's a whole number
strike_num = float(strike)
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike.replace('.', '')
option_symbol = f"-{base_symbol}{year}{month}{day.zfill(2)}{call_or_put}{strike_str}"
return option_symbol
def _determine_position_type_from_txn(self, txn: Transaction) -> PositionType:
"""
Determine position type from transaction action/description.
Args:
txn: Transaction to analyze
Returns:
PositionType (STOCK, CALL, or PUT)
"""
# Check action and description for option indicators
action_upper = txn.action.upper() if txn.action else ""
desc_upper = txn.description.upper() if txn.description else ""
# Look for CALL or PUT keywords
if "CALL" in action_upper or "CALL" in desc_upper:
return PositionType.CALL
elif "PUT" in action_upper or "PUT" in desc_upper:
return PositionType.PUT
# Fall back to checking symbol format (for backwards compatibility)
if txn.symbol and txn.symbol.startswith("-"):
option_info = parse_option_symbol(txn.symbol)
if option_info:
return (
PositionType.CALL
if option_info.option_type == "CALL"
else PositionType.PUT
)
return PositionType.STOCK
def _get_base_symbol(self, symbol: str) -> str:
"""Extract base symbol from option symbol."""
if symbol.startswith("-"):
option_info = parse_option_symbol(symbol)
if option_info:
return option_info.underlying_symbol
return symbol
def _is_opening_transaction(self, action: str) -> bool:
"""Check if action represents opening a position."""
opening_keywords = [
"OPENING TRANSACTION",
"YOU BOUGHT OPENING",
"YOU SOLD OPENING",
]
return any(keyword in action for keyword in opening_keywords)
def _is_closing_transaction(self, action: str) -> bool:
"""Check if action represents closing a position."""
closing_keywords = [
"CLOSING TRANSACTION",
"YOU BOUGHT CLOSING",
"YOU SOLD CLOSING",
"ASSIGNED",
]
return any(keyword in action for keyword in closing_keywords)
def _is_expiration(self, action: str) -> bool:
"""Check if action represents an expiration."""
return "EXPIRED" in action
def _get_grouping_key(self, txn: Transaction) -> str:
"""
Create a unique grouping key for transactions.
For options, returns: symbol + option details (e.g., "TGT-JAN16-100C")
For stocks, returns: just the symbol (e.g., "TGT")
Args:
txn: Transaction to create key for
Returns:
Grouping key string
"""
# Determine if this is an option transaction
action_upper = txn.action.upper() if txn.action else ""
desc_upper = txn.description.upper() if txn.description else ""
is_option = "CALL" in action_upper or "CALL" in desc_upper or "PUT" in action_upper or "PUT" in desc_upper
if not is_option or not txn.description:
# Stock transaction - group by symbol only
return txn.symbol
# Option transaction - extract strike and expiration to create unique key
# Pattern: "CALL (TGT) TARGET CORP JAN 16 26 $100 (100 SHS)"
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
match = re.search(date_strike_pattern, txn.description)
if not match:
# Can't parse option details, fall back to symbol only
return txn.symbol
month_abbr, day, year, strike = match.groups()
# Determine call or put
call_or_put = "C" if "CALL" in desc_upper else "P"
# Create key: SYMBOL-MONTHDAY-STRIKEC/P
# e.g., "TGT-JAN16-100C"
strike_num = float(strike)
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike
grouping_key = f"{txn.symbol}-{month_abbr}{day}-{strike_str}{call_or_put}"
return grouping_key

View File

@@ -0,0 +1,5 @@
"""Utility functions and helpers."""
from app.utils.deduplication import generate_transaction_hash
from app.utils.option_parser import parse_option_symbol, OptionInfo
__all__ = ["generate_transaction_hash", "parse_option_symbol", "OptionInfo"]

View File

@@ -0,0 +1,65 @@
"""Transaction deduplication utilities."""
import hashlib
from datetime import date
from decimal import Decimal
from typing import Optional
def generate_transaction_hash(
account_id: int,
run_date: date,
symbol: Optional[str],
action: str,
amount: Optional[Decimal],
quantity: Optional[Decimal],
price: Optional[Decimal],
) -> str:
"""
Generate a unique SHA-256 hash for a transaction to prevent duplicates.
The hash is generated from key transaction attributes that uniquely identify
a transaction: account, date, symbol, action, amount, quantity, and price.
Args:
account_id: Account identifier
run_date: Transaction date
symbol: Trading symbol
action: Transaction action description
amount: Transaction amount
quantity: Number of shares/contracts
price: Price per unit
Returns:
str: 64-character hexadecimal SHA-256 hash
Example:
>>> generate_transaction_hash(
... account_id=1,
... run_date=date(2025, 12, 26),
... symbol="AAPL",
... action="YOU BOUGHT",
... amount=Decimal("-1500.00"),
... quantity=Decimal("10"),
... price=Decimal("150.00")
... )
'a1b2c3d4...'
"""
# Convert values to strings, handling None values
symbol_str = symbol or ""
amount_str = str(amount) if amount is not None else ""
quantity_str = str(quantity) if quantity is not None else ""
price_str = str(price) if price is not None else ""
# Create hash string with pipe delimiter
hash_string = (
f"{account_id}|"
f"{run_date.isoformat()}|"
f"{symbol_str}|"
f"{action}|"
f"{amount_str}|"
f"{quantity_str}|"
f"{price_str}"
)
# Generate SHA-256 hash
return hashlib.sha256(hash_string.encode("utf-8")).hexdigest()

View File

@@ -0,0 +1,91 @@
"""Option symbol parsing utilities."""
import re
from datetime import datetime
from typing import Optional, NamedTuple
from decimal import Decimal
class OptionInfo(NamedTuple):
"""
Parsed option information.
Attributes:
underlying_symbol: Base ticker symbol (e.g., "AAPL")
expiration_date: Option expiration date
option_type: "CALL" or "PUT"
strike_price: Strike price
"""
underlying_symbol: str
expiration_date: datetime
option_type: str
strike_price: Decimal
def parse_option_symbol(option_symbol: str) -> Optional[OptionInfo]:
"""
Parse Fidelity option symbol format into components.
Fidelity format: -SYMBOL + YYMMDD + C/P + STRIKE
Example: -AAPL260116C150 = AAPL Call expiring Jan 16, 2026 at $150 strike
Args:
option_symbol: Fidelity option symbol string
Returns:
OptionInfo object if parsing successful, None otherwise
Examples:
>>> parse_option_symbol("-AAPL260116C150")
OptionInfo(
underlying_symbol='AAPL',
expiration_date=datetime(2026, 1, 16),
option_type='CALL',
strike_price=Decimal('150')
)
>>> parse_option_symbol("-TSLA251219P500")
OptionInfo(
underlying_symbol='TSLA',
expiration_date=datetime(2025, 12, 19),
option_type='PUT',
strike_price=Decimal('500')
)
"""
# Regex pattern: -SYMBOL + YYMMDD + C/P + STRIKE
# Symbol: one or more uppercase letters
# Date: 6 digits (YYMMDD)
# Type: C (call) or P (put)
# Strike: digits with optional decimal point
pattern = r"^-([A-Z]+)(\d{6})([CP])(\d+\.?\d*)$"
match = re.match(pattern, option_symbol)
if not match:
return None
symbol, date_str, option_type, strike_str = match.groups()
# Parse date (YYMMDD format)
try:
# Assume 20XX for years (works until 2100)
year = 2000 + int(date_str[:2])
month = int(date_str[2:4])
day = int(date_str[4:6])
expiration_date = datetime(year, month, day)
except (ValueError, IndexError):
return None
# Parse option type
option_type_full = "CALL" if option_type == "C" else "PUT"
# Parse strike price
try:
strike_price = Decimal(strike_str)
except (ValueError, ArithmeticError):
return None
return OptionInfo(
underlying_symbol=symbol,
expiration_date=expiration_date,
option_type=option_type_full,
strike_price=strike_price,
)

12
backend/requirements.txt Normal file
View File

@@ -0,0 +1,12 @@
fastapi==0.109.0
uvicorn[standard]==0.27.0
sqlalchemy==2.0.25
alembic==1.13.1
psycopg2-binary==2.9.9
pydantic==2.5.3
pydantic-settings==2.1.0
python-multipart==0.0.6
pandas==2.1.4
yfinance==0.2.35
python-dateutil==2.8.2
pytz==2024.1

94
backend/seed_demo_data.py Normal file
View File

@@ -0,0 +1,94 @@
"""
Demo data seeder script.
Creates a sample account and imports the provided CSV file.
"""
import sys
from pathlib import Path
# Add parent directory to path
sys.path.insert(0, str(Path(__file__).parent))
from sqlalchemy.orm import Session
from app.database import SessionLocal, engine, Base
from app.models import Account
from app.services import ImportService
from app.services.position_tracker import PositionTracker
def seed_demo_data():
"""Seed demo account and transactions."""
print("🌱 Seeding demo data...")
# Create tables
Base.metadata.create_all(bind=engine)
# Create database session
db = SessionLocal()
try:
# Check if demo account already exists
existing = (
db.query(Account)
.filter(Account.account_number == "DEMO123456")
.first()
)
if existing:
print("✅ Demo account already exists")
demo_account = existing
else:
# Create demo account
demo_account = Account(
account_number="DEMO123456",
account_name="Demo Trading Account",
account_type="margin",
)
db.add(demo_account)
db.commit()
db.refresh(demo_account)
print(f"✅ Created demo account (ID: {demo_account.id})")
# Check for CSV file
csv_path = Path("/app/imports/History_for_Account_X38661988.csv")
if not csv_path.exists():
# Try alternative path (development)
csv_path = Path(__file__).parent.parent / "History_for_Account_X38661988.csv"
if not csv_path.exists():
print("⚠️ Sample CSV file not found. Skipping import.")
print(" Place the CSV file in /app/imports/ to seed demo data.")
return
# Import transactions
print(f"📊 Importing transactions from {csv_path.name}...")
import_service = ImportService(db)
result = import_service.import_from_file(csv_path, demo_account.id)
print(f"✅ Imported {result.imported} transactions")
print(f" Skipped {result.skipped} duplicates")
if result.errors:
print(f" ⚠️ {len(result.errors)} errors occurred")
# Build positions
if result.imported > 0:
print("📈 Building positions...")
position_tracker = PositionTracker(db)
positions_created = position_tracker.rebuild_positions(demo_account.id)
print(f"✅ Created {positions_created} positions")
print("\n🎉 Demo data seeded successfully!")
print(f"\n📝 Demo Account Details:")
print(f" Account Number: {demo_account.account_number}")
print(f" Account Name: {demo_account.account_name}")
print(f" Account ID: {demo_account.id}")
except Exception as e:
print(f"❌ Error seeding demo data: {e}")
db.rollback()
raise
finally:
db.close()
if __name__ == "__main__":
seed_demo_data()