Initial release v1.1.0
- Complete MVP for tracking Fidelity brokerage account performance - Transaction import from CSV with deduplication - Automatic FIFO position tracking with options support - Real-time P&L calculations with market data caching - Dashboard with timeframe filtering (30/90/180 days, 1 year, YTD, all time) - Docker-based deployment with PostgreSQL backend - React/TypeScript frontend with TailwindCSS - FastAPI backend with SQLAlchemy ORM Features: - Multi-account support - Import via CSV upload or filesystem - Open and closed position tracking - Balance history charting - Performance analytics and metrics - Top trades analysis - Responsive UI design Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
42
backend/Dockerfile
Normal file
42
backend/Dockerfile
Normal file
@@ -0,0 +1,42 @@
|
||||
# Multi-stage build for Python FastAPI backend
|
||||
FROM python:3.11-slim as builder
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install system dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
gcc \
|
||||
postgresql-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy requirements and install Python dependencies
|
||||
COPY requirements.txt .
|
||||
RUN pip install --no-cache-dir --user -r requirements.txt
|
||||
|
||||
# Final stage
|
||||
FROM python:3.11-slim
|
||||
|
||||
WORKDIR /app
|
||||
|
||||
# Install runtime dependencies
|
||||
RUN apt-get update && apt-get install -y \
|
||||
postgresql-client \
|
||||
&& rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# Copy Python dependencies from builder
|
||||
COPY --from=builder /root/.local /root/.local
|
||||
|
||||
# Copy application code
|
||||
COPY . .
|
||||
|
||||
# Make sure scripts in .local are usable
|
||||
ENV PATH=/root/.local/bin:$PATH
|
||||
|
||||
# Create imports directory
|
||||
RUN mkdir -p /app/imports
|
||||
|
||||
# Expose port
|
||||
EXPOSE 8000
|
||||
|
||||
# Run migrations and start server
|
||||
CMD alembic upgrade head && uvicorn app.main:app --host 0.0.0.0 --port 8000
|
||||
52
backend/alembic.ini
Normal file
52
backend/alembic.ini
Normal file
@@ -0,0 +1,52 @@
|
||||
# Alembic configuration file
|
||||
|
||||
[alembic]
|
||||
# Path to migration scripts
|
||||
script_location = alembic
|
||||
|
||||
# Template used to generate migration files
|
||||
file_template = %%(year)d%%(month).2d%%(day).2d_%%(hour).2d%%(minute).2d_%%(rev)s_%%(slug)s
|
||||
|
||||
# Timezone for migration timestamps
|
||||
timezone = UTC
|
||||
|
||||
# Prepend migration scripts with proper encoding
|
||||
prepend_sys_path = .
|
||||
|
||||
# Version location specification
|
||||
version_path_separator = os
|
||||
|
||||
# Logging configuration
|
||||
[loggers]
|
||||
keys = root,sqlalchemy,alembic
|
||||
|
||||
[handlers]
|
||||
keys = console
|
||||
|
||||
[formatters]
|
||||
keys = generic
|
||||
|
||||
[logger_root]
|
||||
level = WARN
|
||||
handlers = console
|
||||
qualname =
|
||||
|
||||
[logger_sqlalchemy]
|
||||
level = WARN
|
||||
handlers =
|
||||
qualname = sqlalchemy.engine
|
||||
|
||||
[logger_alembic]
|
||||
level = INFO
|
||||
handlers =
|
||||
qualname = alembic
|
||||
|
||||
[handler_console]
|
||||
class = StreamHandler
|
||||
args = (sys.stderr,)
|
||||
level = NOTSET
|
||||
formatter = generic
|
||||
|
||||
[formatter_generic]
|
||||
format = %(levelname)-5.5s [%(name)s] %(message)s
|
||||
datefmt = %H:%M:%S
|
||||
72
backend/alembic/env.py
Normal file
72
backend/alembic/env.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""Alembic environment configuration for database migrations."""
|
||||
from logging.config import fileConfig
|
||||
from sqlalchemy import engine_from_config, pool
|
||||
from alembic import context
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path to import app modules
|
||||
sys.path.insert(0, str(Path(__file__).resolve().parent.parent))
|
||||
|
||||
from app.config import settings
|
||||
from app.database import Base
|
||||
from app.models import Account, Transaction, Position, PositionTransaction
|
||||
|
||||
# Alembic Config object
|
||||
config = context.config
|
||||
|
||||
# Override sqlalchemy.url with our settings
|
||||
config.set_main_option("sqlalchemy.url", settings.database_url)
|
||||
|
||||
# Interpret the config file for Python logging
|
||||
if config.config_file_name is not None:
|
||||
fileConfig(config.config_file_name)
|
||||
|
||||
# Target metadata for autogenerate support
|
||||
target_metadata = Base.metadata
|
||||
|
||||
|
||||
def run_migrations_offline() -> None:
|
||||
"""
|
||||
Run migrations in 'offline' mode.
|
||||
|
||||
This configures the context with just a URL and not an Engine,
|
||||
though an Engine is acceptable here as well. By skipping the Engine
|
||||
creation we don't even need a DBAPI to be available.
|
||||
"""
|
||||
url = config.get_main_option("sqlalchemy.url")
|
||||
context.configure(
|
||||
url=url,
|
||||
target_metadata=target_metadata,
|
||||
literal_binds=True,
|
||||
dialect_opts={"paramstyle": "named"},
|
||||
)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
def run_migrations_online() -> None:
|
||||
"""
|
||||
Run migrations in 'online' mode.
|
||||
|
||||
In this scenario we need to create an Engine and associate a
|
||||
connection with the context.
|
||||
"""
|
||||
connectable = engine_from_config(
|
||||
config.get_section(config.config_ini_section),
|
||||
prefix="sqlalchemy.",
|
||||
poolclass=pool.NullPool,
|
||||
)
|
||||
|
||||
with connectable.connect() as connection:
|
||||
context.configure(connection=connection, target_metadata=target_metadata)
|
||||
|
||||
with context.begin_transaction():
|
||||
context.run_migrations()
|
||||
|
||||
|
||||
if context.is_offline_mode():
|
||||
run_migrations_offline()
|
||||
else:
|
||||
run_migrations_online()
|
||||
25
backend/alembic/script.py.mako
Normal file
25
backend/alembic/script.py.mako
Normal file
@@ -0,0 +1,25 @@
|
||||
"""${message}
|
||||
|
||||
Revision ID: ${up_revision}
|
||||
Revises: ${down_revision | comma,n}
|
||||
Create Date: ${create_date}
|
||||
|
||||
"""
|
||||
from typing import Sequence, Union
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
${imports if imports else ""}
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = ${repr(up_revision)}
|
||||
down_revision: Union[str, None] = ${repr(down_revision)}
|
||||
branch_labels: Union[str, Sequence[str], None] = ${repr(branch_labels)}
|
||||
depends_on: Union[str, Sequence[str], None] = ${repr(depends_on)}
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
${upgrades if upgrades else "pass"}
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
${downgrades if downgrades else "pass"}
|
||||
83
backend/alembic/versions/001_initial_schema.py
Normal file
83
backend/alembic/versions/001_initial_schema.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""Initial schema
|
||||
|
||||
Revision ID: 001_initial_schema
|
||||
Revises:
|
||||
Create Date: 2026-01-20 10:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '001_initial_schema'
|
||||
down_revision = None
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create accounts table
|
||||
op.create_table(
|
||||
'accounts',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('account_number', sa.String(length=50), nullable=False),
|
||||
sa.Column('account_name', sa.String(length=200), nullable=False),
|
||||
sa.Column('account_type', sa.Enum('CASH', 'MARGIN', name='accounttype'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_accounts_id'), 'accounts', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_accounts_account_number'), 'accounts', ['account_number'], unique=True)
|
||||
|
||||
# Create transactions table
|
||||
op.create_table(
|
||||
'transactions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('account_id', sa.Integer(), nullable=False),
|
||||
sa.Column('run_date', sa.Date(), nullable=False),
|
||||
sa.Column('action', sa.String(length=500), nullable=False),
|
||||
sa.Column('symbol', sa.String(length=50), nullable=True),
|
||||
sa.Column('description', sa.String(length=500), nullable=True),
|
||||
sa.Column('transaction_type', sa.String(length=20), nullable=True),
|
||||
sa.Column('exchange_quantity', sa.Numeric(precision=20, scale=8), nullable=True),
|
||||
sa.Column('exchange_currency', sa.String(length=10), nullable=True),
|
||||
sa.Column('currency', sa.String(length=10), nullable=True),
|
||||
sa.Column('price', sa.Numeric(precision=20, scale=8), nullable=True),
|
||||
sa.Column('quantity', sa.Numeric(precision=20, scale=8), nullable=True),
|
||||
sa.Column('exchange_rate', sa.Numeric(precision=20, scale=8), nullable=True),
|
||||
sa.Column('commission', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('fees', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('accrued_interest', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('amount', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('cash_balance', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('settlement_date', sa.Date(), nullable=True),
|
||||
sa.Column('unique_hash', sa.String(length=64), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_transactions_id'), 'transactions', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_transactions_account_id'), 'transactions', ['account_id'], unique=False)
|
||||
op.create_index(op.f('ix_transactions_run_date'), 'transactions', ['run_date'], unique=False)
|
||||
op.create_index(op.f('ix_transactions_symbol'), 'transactions', ['symbol'], unique=False)
|
||||
op.create_index(op.f('ix_transactions_unique_hash'), 'transactions', ['unique_hash'], unique=True)
|
||||
op.create_index('idx_account_date', 'transactions', ['account_id', 'run_date'], unique=False)
|
||||
op.create_index('idx_account_symbol', 'transactions', ['account_id', 'symbol'], unique=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_account_symbol', table_name='transactions')
|
||||
op.drop_index('idx_account_date', table_name='transactions')
|
||||
op.drop_index(op.f('ix_transactions_unique_hash'), table_name='transactions')
|
||||
op.drop_index(op.f('ix_transactions_symbol'), table_name='transactions')
|
||||
op.drop_index(op.f('ix_transactions_run_date'), table_name='transactions')
|
||||
op.drop_index(op.f('ix_transactions_account_id'), table_name='transactions')
|
||||
op.drop_index(op.f('ix_transactions_id'), table_name='transactions')
|
||||
op.drop_table('transactions')
|
||||
op.drop_index(op.f('ix_accounts_account_number'), table_name='accounts')
|
||||
op.drop_index(op.f('ix_accounts_id'), table_name='accounts')
|
||||
op.drop_table('accounts')
|
||||
op.execute('DROP TYPE accounttype')
|
||||
70
backend/alembic/versions/002_add_positions.py
Normal file
70
backend/alembic/versions/002_add_positions.py
Normal file
@@ -0,0 +1,70 @@
|
||||
"""Add positions tables
|
||||
|
||||
Revision ID: 002_add_positions
|
||||
Revises: 001_initial_schema
|
||||
Create Date: 2026-01-20 15:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '002_add_positions'
|
||||
down_revision = '001_initial_schema'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create positions table
|
||||
op.create_table(
|
||||
'positions',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('account_id', sa.Integer(), nullable=False),
|
||||
sa.Column('symbol', sa.String(length=50), nullable=False),
|
||||
sa.Column('option_symbol', sa.String(length=100), nullable=True),
|
||||
sa.Column('position_type', sa.Enum('STOCK', 'CALL', 'PUT', name='positiontype'), nullable=False),
|
||||
sa.Column('status', sa.Enum('OPEN', 'CLOSED', name='positionstatus'), nullable=False),
|
||||
sa.Column('open_date', sa.Date(), nullable=False),
|
||||
sa.Column('close_date', sa.Date(), nullable=True),
|
||||
sa.Column('total_quantity', sa.Numeric(precision=20, scale=8), nullable=False),
|
||||
sa.Column('avg_entry_price', sa.Numeric(precision=20, scale=8), nullable=True),
|
||||
sa.Column('avg_exit_price', sa.Numeric(precision=20, scale=8), nullable=True),
|
||||
sa.Column('realized_pnl', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('unrealized_pnl', sa.Numeric(precision=20, scale=2), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(timezone=True), server_default=sa.text('now()'), nullable=False),
|
||||
sa.ForeignKeyConstraint(['account_id'], ['accounts.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_positions_id'), 'positions', ['id'], unique=False)
|
||||
op.create_index(op.f('ix_positions_account_id'), 'positions', ['account_id'], unique=False)
|
||||
op.create_index(op.f('ix_positions_symbol'), 'positions', ['symbol'], unique=False)
|
||||
op.create_index(op.f('ix_positions_option_symbol'), 'positions', ['option_symbol'], unique=False)
|
||||
op.create_index(op.f('ix_positions_status'), 'positions', ['status'], unique=False)
|
||||
op.create_index('idx_account_status', 'positions', ['account_id', 'status'], unique=False)
|
||||
op.create_index('idx_account_symbol_status', 'positions', ['account_id', 'symbol', 'status'], unique=False)
|
||||
|
||||
# Create position_transactions junction table
|
||||
op.create_table(
|
||||
'position_transactions',
|
||||
sa.Column('position_id', sa.Integer(), nullable=False),
|
||||
sa.Column('transaction_id', sa.Integer(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['position_id'], ['positions.id'], ondelete='CASCADE'),
|
||||
sa.ForeignKeyConstraint(['transaction_id'], ['transactions.id'], ondelete='CASCADE'),
|
||||
sa.PrimaryKeyConstraint('position_id', 'transaction_id')
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_table('position_transactions')
|
||||
op.drop_index('idx_account_symbol_status', table_name='positions')
|
||||
op.drop_index('idx_account_status', table_name='positions')
|
||||
op.drop_index(op.f('ix_positions_status'), table_name='positions')
|
||||
op.drop_index(op.f('ix_positions_option_symbol'), table_name='positions')
|
||||
op.drop_index(op.f('ix_positions_symbol'), table_name='positions')
|
||||
op.drop_index(op.f('ix_positions_account_id'), table_name='positions')
|
||||
op.drop_index(op.f('ix_positions_id'), table_name='positions')
|
||||
op.drop_table('positions')
|
||||
op.execute('DROP TYPE positionstatus')
|
||||
op.execute('DROP TYPE positiontype')
|
||||
40
backend/alembic/versions/add_market_prices_table.py
Normal file
40
backend/alembic/versions/add_market_prices_table.py
Normal file
@@ -0,0 +1,40 @@
|
||||
"""Add market_prices table for price caching
|
||||
|
||||
Revision ID: 003_market_prices
|
||||
Revises: 002_add_positions
|
||||
Create Date: 2026-01-20 16:00:00.000000
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '003_market_prices'
|
||||
down_revision = '002_add_positions'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Create market_prices table
|
||||
op.create_table(
|
||||
'market_prices',
|
||||
sa.Column('id', sa.Integer(), nullable=False),
|
||||
sa.Column('symbol', sa.String(length=20), nullable=False),
|
||||
sa.Column('price', sa.Numeric(precision=20, scale=6), nullable=False),
|
||||
sa.Column('fetched_at', sa.DateTime(), nullable=False, default=datetime.utcnow),
|
||||
sa.Column('source', sa.String(length=50), default='yahoo_finance'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
|
||||
# Create indexes
|
||||
op.create_index('idx_market_prices_symbol', 'market_prices', ['symbol'], unique=True)
|
||||
op.create_index('idx_symbol_fetched', 'market_prices', ['symbol', 'fetched_at'])
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_index('idx_symbol_fetched', table_name='market_prices')
|
||||
op.drop_index('idx_market_prices_symbol', table_name='market_prices')
|
||||
op.drop_table('market_prices')
|
||||
2
backend/app/__init__.py
Normal file
2
backend/app/__init__.py
Normal file
@@ -0,0 +1,2 @@
|
||||
"""myFidelityTracker backend application."""
|
||||
__version__ = "1.0.0"
|
||||
1
backend/app/api/__init__.py
Normal file
1
backend/app/api/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API routes and endpoints."""
|
||||
19
backend/app/api/deps.py
Normal file
19
backend/app/api/deps.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""API dependencies."""
|
||||
from typing import Generator
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from app.database import SessionLocal
|
||||
|
||||
|
||||
def get_db() -> Generator[Session, None, None]:
|
||||
"""
|
||||
Dependency that provides a database session.
|
||||
|
||||
Yields:
|
||||
Database session
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
1
backend/app/api/endpoints/__init__.py
Normal file
1
backend/app/api/endpoints/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API endpoint modules."""
|
||||
151
backend/app/api/endpoints/accounts.py
Normal file
151
backend/app/api/endpoints/accounts.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Account management API endpoints."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.models import Account
|
||||
from app.schemas import AccountCreate, AccountUpdate, AccountResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("", response_model=AccountResponse, status_code=status.HTTP_201_CREATED)
|
||||
def create_account(account: AccountCreate, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Create a new brokerage account.
|
||||
|
||||
Args:
|
||||
account: Account creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created account
|
||||
|
||||
Raises:
|
||||
HTTPException: If account number already exists
|
||||
"""
|
||||
# Check if account number already exists
|
||||
existing = (
|
||||
db.query(Account)
|
||||
.filter(Account.account_number == account.account_number)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Account with number {account.account_number} already exists",
|
||||
)
|
||||
|
||||
# Create new account
|
||||
db_account = Account(**account.model_dump())
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
|
||||
return db_account
|
||||
|
||||
|
||||
@router.get("", response_model=List[AccountResponse])
|
||||
def list_accounts(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
"""
|
||||
List all accounts.
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of accounts
|
||||
"""
|
||||
accounts = db.query(Account).offset(skip).limit(limit).all()
|
||||
return accounts
|
||||
|
||||
|
||||
@router.get("/{account_id}", response_model=AccountResponse)
|
||||
def get_account(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get account by ID.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Account details
|
||||
|
||||
Raises:
|
||||
HTTPException: If account not found
|
||||
"""
|
||||
account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account {account_id} not found",
|
||||
)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@router.put("/{account_id}", response_model=AccountResponse)
|
||||
def update_account(
|
||||
account_id: int, account_update: AccountUpdate, db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update account details.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
account_update: Updated account data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated account
|
||||
|
||||
Raises:
|
||||
HTTPException: If account not found
|
||||
"""
|
||||
db_account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not db_account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account {account_id} not found",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
update_data = account_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_account, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
|
||||
return db_account
|
||||
|
||||
|
||||
@router.delete("/{account_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_account(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Delete an account and all associated data.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Raises:
|
||||
HTTPException: If account not found
|
||||
"""
|
||||
db_account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not db_account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account {account_id} not found",
|
||||
)
|
||||
|
||||
db.delete(db_account)
|
||||
db.commit()
|
||||
111
backend/app/api/endpoints/analytics.py
Normal file
111
backend/app/api/endpoints/analytics.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Analytics API endpoints."""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.services.performance_calculator import PerformanceCalculator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/overview/{account_id}")
|
||||
def get_overview(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get overview statistics for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
stats = calculator.calculate_account_stats(account_id)
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/balance-history/{account_id}")
|
||||
def get_balance_history(
|
||||
account_id: int,
|
||||
days: int = Query(default=30, ge=1, le=3650),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve (default: 30)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
history = calculator.get_balance_history(account_id, days)
|
||||
return {"data": history}
|
||||
|
||||
|
||||
@router.get("/top-trades/{account_id}")
|
||||
def get_top_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get top performing trades.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 20)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
trades = calculator.get_top_trades(account_id, limit)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.get("/worst-trades/{account_id}")
|
||||
def get_worst_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get worst performing trades (biggest losses).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 20)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
trades = calculator.get_worst_trades(account_id, limit)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.post("/update-pnl/{account_id}")
|
||||
def update_unrealized_pnl(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Update unrealized P&L for all open positions in an account.
|
||||
|
||||
Fetches current market prices and recalculates P&L.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of positions updated
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
updated = calculator.update_open_positions_pnl(account_id)
|
||||
return {"positions_updated": updated}
|
||||
273
backend/app/api/endpoints/analytics_v2.py
Normal file
273
backend/app/api/endpoints/analytics_v2.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Enhanced analytics API endpoints with efficient market data handling.
|
||||
|
||||
This version uses PerformanceCalculatorV2 with:
|
||||
- Database-backed price caching
|
||||
- Rate-limited API calls
|
||||
- Stale-while-revalidate pattern for better UX
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, Query, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.services.performance_calculator_v2 import PerformanceCalculatorV2
|
||||
from app.services.market_data_service import MarketDataService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/overview/{account_id}")
|
||||
def get_overview(
|
||||
account_id: int,
|
||||
refresh_prices: bool = Query(default=False, description="Force fresh price fetch"),
|
||||
max_api_calls: int = Query(default=5, ge=0, le=50, description="Max Yahoo Finance API calls"),
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get overview statistics for an account.
|
||||
|
||||
By default, uses cached prices (stale-while-revalidate pattern).
|
||||
Set refresh_prices=true to force fresh data (may be slow).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
refresh_prices: Whether to fetch fresh prices from Yahoo Finance
|
||||
max_api_calls: Maximum number of API calls to make
|
||||
start_date: Filter positions opened on or after this date
|
||||
end_date: Filter positions opened on or before this date
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics and cache stats
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
|
||||
|
||||
# If not refreshing, use cached only (fast)
|
||||
if not refresh_prices:
|
||||
max_api_calls = 0
|
||||
|
||||
stats = calculator.calculate_account_stats(
|
||||
account_id,
|
||||
update_prices=True,
|
||||
max_api_calls=max_api_calls,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/balance-history/{account_id}")
|
||||
def get_balance_history(
|
||||
account_id: int,
|
||||
days: int = Query(default=30, ge=1, le=3650),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
This endpoint doesn't need market data, so it's always fast.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve (default: 30)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db)
|
||||
history = calculator.get_balance_history(account_id, days)
|
||||
return {"data": history}
|
||||
|
||||
|
||||
@router.get("/top-trades/{account_id}")
|
||||
def get_top_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get top performing trades.
|
||||
|
||||
This endpoint only uses closed positions, so no market data needed.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 10)
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db)
|
||||
trades = calculator.get_top_trades(account_id, limit, start_date, end_date)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.get("/worst-trades/{account_id}")
|
||||
def get_worst_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get worst performing trades.
|
||||
|
||||
This endpoint only uses closed positions, so no market data needed.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 10)
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db)
|
||||
trades = calculator.get_worst_trades(account_id, limit, start_date, end_date)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.post("/refresh-prices/{account_id}")
|
||||
def refresh_prices(
|
||||
account_id: int,
|
||||
max_api_calls: int = Query(default=10, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Manually trigger a price refresh for open positions.
|
||||
|
||||
This is useful when you want fresh data but don't want to wait
|
||||
on the dashboard load.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Update statistics
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
|
||||
|
||||
stats = calculator.update_open_positions_pnl(
|
||||
account_id,
|
||||
max_api_calls=max_api_calls,
|
||||
allow_stale=False # Force fresh fetches
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Price refresh completed",
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh-prices-background/{account_id}")
|
||||
def refresh_prices_background(
|
||||
account_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
max_api_calls: int = Query(default=20, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Trigger a background price refresh.
|
||||
|
||||
This returns immediately while prices are fetched in the background.
|
||||
Client can poll /overview endpoint to see updated data.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
background_tasks: FastAPI background tasks
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Acknowledgment that background task was started
|
||||
"""
|
||||
def refresh_task():
|
||||
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
|
||||
calculator.update_open_positions_pnl(
|
||||
account_id,
|
||||
max_api_calls=max_api_calls,
|
||||
allow_stale=False
|
||||
)
|
||||
|
||||
background_tasks.add_task(refresh_task)
|
||||
|
||||
return {
|
||||
"message": "Price refresh started in background",
|
||||
"account_id": account_id,
|
||||
"max_api_calls": max_api_calls
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh-stale-cache")
|
||||
def refresh_stale_cache(
|
||||
min_age_minutes: int = Query(default=10, ge=1, le=1440),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Background maintenance endpoint to refresh stale cached prices.
|
||||
|
||||
This can be called periodically (e.g., via cron) to keep cache fresh.
|
||||
|
||||
Args:
|
||||
min_age_minutes: Only refresh prices older than this many minutes
|
||||
limit: Maximum number of prices to refresh
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of prices refreshed
|
||||
"""
|
||||
market_data = MarketDataService(db, cache_ttl_seconds=300)
|
||||
|
||||
refreshed = market_data.refresh_stale_prices(
|
||||
min_age_seconds=min_age_minutes * 60,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Stale price refresh completed",
|
||||
"refreshed": refreshed,
|
||||
"min_age_minutes": min_age_minutes
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/clear-old-cache")
|
||||
def clear_old_cache(
|
||||
older_than_days: int = Query(default=30, ge=1, le=365),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear old cached prices from database.
|
||||
|
||||
Args:
|
||||
older_than_days: Delete prices older than this many days
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of records deleted
|
||||
"""
|
||||
market_data = MarketDataService(db)
|
||||
|
||||
deleted = market_data.clear_cache(older_than_days=older_than_days)
|
||||
|
||||
return {
|
||||
"message": "Old cache cleared",
|
||||
"deleted": deleted,
|
||||
"older_than_days": older_than_days
|
||||
}
|
||||
128
backend/app/api/endpoints/import_endpoint.py
Normal file
128
backend/app/api/endpoints/import_endpoint.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Import API endpoints for CSV file uploads."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.services import ImportService
|
||||
from app.services.position_tracker import PositionTracker
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/upload/{account_id}")
|
||||
def upload_csv(
|
||||
account_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Upload and import a CSV file for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to import transactions for
|
||||
file: CSV file to upload
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Import statistics
|
||||
|
||||
Raises:
|
||||
HTTPException: If import fails
|
||||
"""
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="File must be a CSV"
|
||||
)
|
||||
|
||||
# Save uploaded file to temporary location
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
|
||||
shutil.copyfileobj(file.file, tmp_file)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
|
||||
# Import transactions
|
||||
import_service = ImportService(db)
|
||||
result = import_service.import_from_file(tmp_path, account_id)
|
||||
|
||||
# Rebuild positions after import
|
||||
if result.imported > 0:
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(account_id)
|
||||
else:
|
||||
positions_created = 0
|
||||
|
||||
# Clean up temporary file
|
||||
tmp_path.unlink()
|
||||
|
||||
return {
|
||||
"filename": file.filename,
|
||||
"imported": result.imported,
|
||||
"skipped": result.skipped,
|
||||
"errors": result.errors,
|
||||
"total_rows": result.total_rows,
|
||||
"positions_created": positions_created,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Import failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/filesystem/{account_id}")
|
||||
def import_from_filesystem(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Import all CSV files from the filesystem import directory.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to import transactions for
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Import statistics for all files
|
||||
|
||||
Raises:
|
||||
HTTPException: If import directory doesn't exist
|
||||
"""
|
||||
import_dir = Path(settings.IMPORT_DIR)
|
||||
|
||||
if not import_dir.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Import directory not found: {import_dir}",
|
||||
)
|
||||
|
||||
try:
|
||||
import_service = ImportService(db)
|
||||
results = import_service.import_from_directory(import_dir, account_id)
|
||||
|
||||
# Rebuild positions if any transactions were imported
|
||||
total_imported = sum(r.imported for r in results.values())
|
||||
if total_imported > 0:
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(account_id)
|
||||
else:
|
||||
positions_created = 0
|
||||
|
||||
return {
|
||||
"files": {
|
||||
filename: {
|
||||
"imported": result.imported,
|
||||
"skipped": result.skipped,
|
||||
"errors": result.errors,
|
||||
"total_rows": result.total_rows,
|
||||
}
|
||||
for filename, result in results.items()
|
||||
},
|
||||
"total_imported": total_imported,
|
||||
"positions_created": positions_created,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Import failed: {str(e)}",
|
||||
)
|
||||
104
backend/app/api/endpoints/positions.py
Normal file
104
backend/app/api/endpoints/positions.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Position API endpoints."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import List, Optional
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.models import Position
|
||||
from app.models.position import PositionStatus
|
||||
from app.schemas import PositionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[PositionResponse])
|
||||
def list_positions(
|
||||
account_id: Optional[int] = None,
|
||||
status_filter: Optional[PositionStatus] = Query(
|
||||
default=None, alias="status", description="Filter by position status"
|
||||
),
|
||||
symbol: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = Query(default=100, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List positions with optional filtering.
|
||||
|
||||
Args:
|
||||
account_id: Filter by account ID
|
||||
status_filter: Filter by status (open/closed)
|
||||
symbol: Filter by symbol
|
||||
skip: Number of records to skip (pagination)
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of positions
|
||||
"""
|
||||
query = db.query(Position)
|
||||
|
||||
# Apply filters
|
||||
if account_id:
|
||||
query = query.filter(Position.account_id == account_id)
|
||||
|
||||
if status_filter:
|
||||
query = query.filter(Position.status == status_filter)
|
||||
|
||||
if symbol:
|
||||
query = query.filter(Position.symbol == symbol)
|
||||
|
||||
# Order by most recent first
|
||||
query = query.order_by(Position.open_date.desc(), Position.id.desc())
|
||||
|
||||
# Pagination
|
||||
positions = query.offset(skip).limit(limit).all()
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
@router.get("/{position_id}", response_model=PositionResponse)
|
||||
def get_position(position_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get position by ID.
|
||||
|
||||
Args:
|
||||
position_id: Position ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Position details
|
||||
|
||||
Raises:
|
||||
HTTPException: If position not found
|
||||
"""
|
||||
position = db.query(Position).filter(Position.id == position_id).first()
|
||||
|
||||
if not position:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Position {position_id} not found",
|
||||
)
|
||||
|
||||
return position
|
||||
|
||||
|
||||
@router.post("/{account_id}/rebuild")
|
||||
def rebuild_positions(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Rebuild all positions for an account from transactions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of positions created
|
||||
"""
|
||||
from app.services.position_tracker import PositionTracker
|
||||
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(account_id)
|
||||
|
||||
return {"positions_created": positions_created}
|
||||
227
backend/app/api/endpoints/transactions.py
Normal file
227
backend/app/api/endpoints/transactions.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Transaction API endpoints."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import date
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.models import Transaction, Position, PositionTransaction
|
||||
from app.schemas import TransactionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[TransactionResponse])
|
||||
def list_transactions(
|
||||
account_id: Optional[int] = None,
|
||||
symbol: Optional[str] = None,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
skip: int = 0,
|
||||
limit: int = Query(default=50, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List transactions with optional filtering.
|
||||
|
||||
Args:
|
||||
account_id: Filter by account ID
|
||||
symbol: Filter by symbol
|
||||
start_date: Filter by start date (inclusive)
|
||||
end_date: Filter by end date (inclusive)
|
||||
skip: Number of records to skip (pagination)
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of transactions
|
||||
"""
|
||||
query = db.query(Transaction)
|
||||
|
||||
# Apply filters
|
||||
if account_id:
|
||||
query = query.filter(Transaction.account_id == account_id)
|
||||
|
||||
if symbol:
|
||||
query = query.filter(Transaction.symbol == symbol)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(Transaction.run_date >= start_date)
|
||||
|
||||
if end_date:
|
||||
query = query.filter(Transaction.run_date <= end_date)
|
||||
|
||||
# Order by date descending
|
||||
query = query.order_by(Transaction.run_date.desc(), Transaction.id.desc())
|
||||
|
||||
# Pagination
|
||||
transactions = query.offset(skip).limit(limit).all()
|
||||
|
||||
return transactions
|
||||
|
||||
|
||||
@router.get("/{transaction_id}", response_model=TransactionResponse)
|
||||
def get_transaction(transaction_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get transaction by ID.
|
||||
|
||||
Args:
|
||||
transaction_id: Transaction ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Transaction details
|
||||
|
||||
Raises:
|
||||
HTTPException: If transaction not found
|
||||
"""
|
||||
transaction = (
|
||||
db.query(Transaction).filter(Transaction.id == transaction_id).first()
|
||||
)
|
||||
|
||||
if not transaction:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Transaction {transaction_id} not found",
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
|
||||
@router.get("/{transaction_id}/position-details")
|
||||
def get_transaction_position_details(
|
||||
transaction_id: int, db: Session = Depends(get_db)
|
||||
) -> Dict:
|
||||
"""
|
||||
Get full position details for a transaction, including all related transactions.
|
||||
|
||||
This endpoint finds the position associated with a transaction and returns:
|
||||
- All transactions that are part of the same position
|
||||
- Position metadata (type, status, P&L, etc.)
|
||||
- Strategy classification for options (covered call, cash-secured put, etc.)
|
||||
|
||||
Args:
|
||||
transaction_id: Transaction ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with position details and all related transactions
|
||||
|
||||
Raises:
|
||||
HTTPException: If transaction not found or not part of a position
|
||||
"""
|
||||
# Find the transaction
|
||||
transaction = (
|
||||
db.query(Transaction).filter(Transaction.id == transaction_id).first()
|
||||
)
|
||||
|
||||
if not transaction:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Transaction {transaction_id} not found",
|
||||
)
|
||||
|
||||
# Find the position this transaction belongs to
|
||||
position_link = (
|
||||
db.query(PositionTransaction)
|
||||
.filter(PositionTransaction.transaction_id == transaction_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not position_link:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Transaction {transaction_id} is not part of any position",
|
||||
)
|
||||
|
||||
# Get the position with all its transactions
|
||||
position = (
|
||||
db.query(Position)
|
||||
.filter(Position.id == position_link.position_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not position:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Position not found",
|
||||
)
|
||||
|
||||
# Get all transactions for this position
|
||||
all_transactions = []
|
||||
for link in position.transaction_links:
|
||||
txn = link.transaction
|
||||
all_transactions.append({
|
||||
"id": txn.id,
|
||||
"run_date": txn.run_date.isoformat(),
|
||||
"action": txn.action,
|
||||
"symbol": txn.symbol,
|
||||
"description": txn.description,
|
||||
"quantity": float(txn.quantity) if txn.quantity else None,
|
||||
"price": float(txn.price) if txn.price else None,
|
||||
"amount": float(txn.amount) if txn.amount else None,
|
||||
"commission": float(txn.commission) if txn.commission else None,
|
||||
"fees": float(txn.fees) if txn.fees else None,
|
||||
})
|
||||
|
||||
# Sort transactions by date
|
||||
all_transactions.sort(key=lambda t: t["run_date"])
|
||||
|
||||
# Determine strategy type for options
|
||||
strategy = _classify_option_strategy(position, all_transactions)
|
||||
|
||||
return {
|
||||
"position": {
|
||||
"id": position.id,
|
||||
"symbol": position.symbol,
|
||||
"option_symbol": position.option_symbol,
|
||||
"position_type": position.position_type.value,
|
||||
"status": position.status.value,
|
||||
"open_date": position.open_date.isoformat(),
|
||||
"close_date": position.close_date.isoformat() if position.close_date else None,
|
||||
"total_quantity": float(position.total_quantity),
|
||||
"avg_entry_price": float(position.avg_entry_price) if position.avg_entry_price is not None else None,
|
||||
"avg_exit_price": float(position.avg_exit_price) if position.avg_exit_price is not None else None,
|
||||
"realized_pnl": float(position.realized_pnl) if position.realized_pnl is not None else None,
|
||||
"unrealized_pnl": float(position.unrealized_pnl) if position.unrealized_pnl is not None else None,
|
||||
"strategy": strategy,
|
||||
},
|
||||
"transactions": all_transactions,
|
||||
}
|
||||
|
||||
|
||||
def _classify_option_strategy(position: Position, transactions: List[Dict]) -> str:
|
||||
"""
|
||||
Classify the option strategy based on position type and transactions.
|
||||
|
||||
Args:
|
||||
position: Position object
|
||||
transactions: List of transaction dictionaries
|
||||
|
||||
Returns:
|
||||
Strategy name (e.g., "Long Call", "Covered Call", "Cash-Secured Put")
|
||||
"""
|
||||
if position.position_type.value == "stock":
|
||||
return "Stock"
|
||||
|
||||
# Check if this is a short or long position
|
||||
is_short = position.total_quantity < 0
|
||||
|
||||
# For options
|
||||
if position.position_type.value == "call":
|
||||
if is_short:
|
||||
# Short call - could be covered or naked
|
||||
# We'd need to check if there's a corresponding stock position to determine
|
||||
# For now, just return "Short Call" (could enhance later)
|
||||
return "Short Call (Covered Call)"
|
||||
else:
|
||||
return "Long Call"
|
||||
elif position.position_type.value == "put":
|
||||
if is_short:
|
||||
# Short put - could be cash-secured or naked
|
||||
return "Short Put (Cash-Secured Put)"
|
||||
else:
|
||||
return "Long Put"
|
||||
|
||||
return "Unknown"
|
||||
53
backend/app/config.py
Normal file
53
backend/app/config.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""
|
||||
Application configuration settings.
|
||||
Loads configuration from environment variables with sensible defaults.
|
||||
"""
|
||||
from pydantic_settings import BaseSettings
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
"""Application settings loaded from environment variables."""
|
||||
|
||||
# Database configuration
|
||||
POSTGRES_HOST: str = "postgres"
|
||||
POSTGRES_PORT: int = 5432
|
||||
POSTGRES_DB: str = "fidelitytracker"
|
||||
POSTGRES_USER: str = "fidelity"
|
||||
POSTGRES_PASSWORD: str = "fidelity123"
|
||||
|
||||
# API configuration
|
||||
API_V1_PREFIX: str = "/api"
|
||||
PROJECT_NAME: str = "myFidelityTracker"
|
||||
|
||||
# CORS configuration - allow all origins for local development
|
||||
CORS_ORIGINS: str = "*"
|
||||
|
||||
@property
|
||||
def cors_origins_list(self) -> list[str]:
|
||||
"""Parse CORS origins from comma-separated string."""
|
||||
if self.CORS_ORIGINS == "*":
|
||||
return ["*"]
|
||||
return [origin.strip() for origin in self.CORS_ORIGINS.split(",")]
|
||||
|
||||
# File import configuration
|
||||
IMPORT_DIR: str = "/app/imports"
|
||||
|
||||
# Market data cache TTL (seconds)
|
||||
MARKET_DATA_CACHE_TTL: int = 60
|
||||
|
||||
@property
|
||||
def database_url(self) -> str:
|
||||
"""Construct PostgreSQL database URL."""
|
||||
return (
|
||||
f"postgresql://{self.POSTGRES_USER}:{self.POSTGRES_PASSWORD}"
|
||||
f"@{self.POSTGRES_HOST}:{self.POSTGRES_PORT}/{self.POSTGRES_DB}"
|
||||
)
|
||||
|
||||
class Config:
|
||||
env_file = ".env"
|
||||
case_sensitive = True
|
||||
|
||||
|
||||
# Global settings instance
|
||||
settings = Settings()
|
||||
38
backend/app/database.py
Normal file
38
backend/app/database.py
Normal file
@@ -0,0 +1,38 @@
|
||||
"""
|
||||
Database configuration and session management.
|
||||
Provides SQLAlchemy engine and session factory.
|
||||
"""
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.ext.declarative import declarative_base
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from app.config import settings
|
||||
|
||||
# Create SQLAlchemy engine
|
||||
engine = create_engine(
|
||||
settings.database_url,
|
||||
pool_pre_ping=True, # Enable connection health checks
|
||||
pool_size=10,
|
||||
max_overflow=20
|
||||
)
|
||||
|
||||
# Create session factory
|
||||
SessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=engine)
|
||||
|
||||
# Base class for SQLAlchemy models
|
||||
Base = declarative_base()
|
||||
|
||||
|
||||
def get_db():
|
||||
"""
|
||||
Dependency function that provides a database session.
|
||||
Automatically closes the session after the request is completed.
|
||||
|
||||
Yields:
|
||||
Session: SQLAlchemy database session
|
||||
"""
|
||||
db = SessionLocal()
|
||||
try:
|
||||
yield db
|
||||
finally:
|
||||
db.close()
|
||||
66
backend/app/main.py
Normal file
66
backend/app/main.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""
|
||||
FastAPI application entry point for myFidelityTracker.
|
||||
|
||||
This module initializes the FastAPI application, configures CORS,
|
||||
and registers all API routers.
|
||||
"""
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
||||
from app.config import settings
|
||||
from app.api.endpoints import accounts, transactions, positions, import_endpoint
|
||||
from app.api.endpoints import analytics_v2 as analytics
|
||||
|
||||
# Create FastAPI application
|
||||
app = FastAPI(
|
||||
title=settings.PROJECT_NAME,
|
||||
description="Track and analyze your Fidelity brokerage account performance",
|
||||
version="1.0.0",
|
||||
)
|
||||
|
||||
# Configure CORS middleware - allow all origins for local network access
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"], # Allow all origins for local development
|
||||
allow_credentials=False, # Must be False when using allow_origins=["*"]
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
|
||||
|
||||
# Register API routers
|
||||
app.include_router(
|
||||
accounts.router, prefix=f"{settings.API_V1_PREFIX}/accounts", tags=["accounts"]
|
||||
)
|
||||
app.include_router(
|
||||
transactions.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/transactions",
|
||||
tags=["transactions"],
|
||||
)
|
||||
app.include_router(
|
||||
positions.router, prefix=f"{settings.API_V1_PREFIX}/positions", tags=["positions"]
|
||||
)
|
||||
app.include_router(
|
||||
analytics.router, prefix=f"{settings.API_V1_PREFIX}/analytics", tags=["analytics"]
|
||||
)
|
||||
app.include_router(
|
||||
import_endpoint.router,
|
||||
prefix=f"{settings.API_V1_PREFIX}/import",
|
||||
tags=["import"],
|
||||
)
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def root():
|
||||
"""Root endpoint returning API information."""
|
||||
return {
|
||||
"name": settings.PROJECT_NAME,
|
||||
"version": "1.0.0",
|
||||
"message": "Welcome to myFidelityTracker API",
|
||||
}
|
||||
|
||||
|
||||
@app.get("/health")
|
||||
def health_check():
|
||||
"""Health check endpoint."""
|
||||
return {"status": "healthy"}
|
||||
7
backend/app/models/__init__.py
Normal file
7
backend/app/models/__init__.py
Normal file
@@ -0,0 +1,7 @@
|
||||
"""SQLAlchemy models for the application."""
|
||||
from app.models.account import Account
|
||||
from app.models.transaction import Transaction
|
||||
from app.models.position import Position, PositionTransaction
|
||||
from app.models.market_price import MarketPrice
|
||||
|
||||
__all__ = ["Account", "Transaction", "Position", "PositionTransaction", "MarketPrice"]
|
||||
41
backend/app/models/account.py
Normal file
41
backend/app/models/account.py
Normal file
@@ -0,0 +1,41 @@
|
||||
"""Account model representing a brokerage account."""
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Enum
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
import enum
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class AccountType(str, enum.Enum):
|
||||
"""Enumeration of account types."""
|
||||
CASH = "cash"
|
||||
MARGIN = "margin"
|
||||
|
||||
|
||||
class Account(Base):
|
||||
"""
|
||||
Represents a brokerage account.
|
||||
|
||||
Attributes:
|
||||
id: Primary key
|
||||
account_number: Unique account identifier
|
||||
account_name: Human-readable account name
|
||||
account_type: Type of account (cash or margin)
|
||||
created_at: Timestamp of account creation
|
||||
updated_at: Timestamp of last update
|
||||
transactions: Related transactions
|
||||
positions: Related positions
|
||||
"""
|
||||
__tablename__ = "accounts"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
account_number = Column(String(50), unique=True, nullable=False, index=True)
|
||||
account_name = Column(String(200), nullable=False)
|
||||
account_type = Column(Enum(AccountType), nullable=False, default=AccountType.CASH)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
transactions = relationship("Transaction", back_populates="account", cascade="all, delete-orphan")
|
||||
positions = relationship("Position", back_populates="account", cascade="all, delete-orphan")
|
||||
29
backend/app/models/market_price.py
Normal file
29
backend/app/models/market_price.py
Normal file
@@ -0,0 +1,29 @@
|
||||
"""Market price cache model for storing Yahoo Finance data."""
|
||||
from sqlalchemy import Column, Integer, String, Numeric, DateTime, Index
|
||||
from datetime import datetime
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class MarketPrice(Base):
|
||||
"""
|
||||
Cache table for market prices from Yahoo Finance.
|
||||
|
||||
Stores the last fetched price for each symbol to reduce API calls.
|
||||
"""
|
||||
|
||||
__tablename__ = "market_prices"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
symbol = Column(String(20), unique=True, nullable=False, index=True)
|
||||
price = Column(Numeric(precision=20, scale=6), nullable=False)
|
||||
fetched_at = Column(DateTime, nullable=False, default=datetime.utcnow)
|
||||
source = Column(String(50), default="yahoo_finance")
|
||||
|
||||
# Index for quick lookups by symbol and freshness checks
|
||||
__table_args__ = (
|
||||
Index('idx_symbol_fetched', 'symbol', 'fetched_at'),
|
||||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<MarketPrice(symbol={self.symbol}, price={self.price}, fetched_at={self.fetched_at})>"
|
||||
104
backend/app/models/position.py
Normal file
104
backend/app/models/position.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Position model representing a trading position."""
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Numeric, ForeignKey, Date, Enum, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
import enum
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class PositionType(str, enum.Enum):
|
||||
"""Enumeration of position types."""
|
||||
STOCK = "stock"
|
||||
CALL = "call"
|
||||
PUT = "put"
|
||||
|
||||
|
||||
class PositionStatus(str, enum.Enum):
|
||||
"""Enumeration of position statuses."""
|
||||
OPEN = "open"
|
||||
CLOSED = "closed"
|
||||
|
||||
|
||||
class Position(Base):
|
||||
"""
|
||||
Represents a trading position (open or closed).
|
||||
|
||||
A position aggregates related transactions (entries and exits) for a specific security.
|
||||
For options, tracks strikes, expirations, and option-specific details.
|
||||
|
||||
Attributes:
|
||||
id: Primary key
|
||||
account_id: Foreign key to account
|
||||
symbol: Base trading symbol (e.g., AAPL)
|
||||
option_symbol: Full option symbol if applicable (e.g., -AAPL260116C150)
|
||||
position_type: Type (stock, call, put)
|
||||
status: Status (open, closed)
|
||||
open_date: Date position was opened
|
||||
close_date: Date position was closed (if closed)
|
||||
total_quantity: Net quantity (can be negative for short positions)
|
||||
avg_entry_price: Average entry price
|
||||
avg_exit_price: Average exit price (if closed)
|
||||
realized_pnl: Realized profit/loss for closed positions
|
||||
unrealized_pnl: Unrealized profit/loss for open positions
|
||||
created_at: Timestamp of record creation
|
||||
updated_at: Timestamp of last update
|
||||
"""
|
||||
__tablename__ = "positions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
account_id = Column(Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Symbol information
|
||||
symbol = Column(String(50), nullable=False, index=True)
|
||||
option_symbol = Column(String(100), index=True) # Full option symbol for options
|
||||
position_type = Column(Enum(PositionType), nullable=False, default=PositionType.STOCK)
|
||||
|
||||
# Status and dates
|
||||
status = Column(Enum(PositionStatus), nullable=False, default=PositionStatus.OPEN, index=True)
|
||||
open_date = Column(Date, nullable=False)
|
||||
close_date = Column(Date)
|
||||
|
||||
# Position metrics
|
||||
total_quantity = Column(Numeric(20, 8), nullable=False) # Can be negative for short
|
||||
avg_entry_price = Column(Numeric(20, 8))
|
||||
avg_exit_price = Column(Numeric(20, 8))
|
||||
|
||||
# P&L tracking
|
||||
realized_pnl = Column(Numeric(20, 2)) # For closed positions
|
||||
unrealized_pnl = Column(Numeric(20, 2)) # For open positions
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
account = relationship("Account", back_populates="positions")
|
||||
transaction_links = relationship("PositionTransaction", back_populates="position", cascade="all, delete-orphan")
|
||||
|
||||
# Composite indexes for common queries
|
||||
__table_args__ = (
|
||||
Index('idx_account_status', 'account_id', 'status'),
|
||||
Index('idx_account_symbol_status', 'account_id', 'symbol', 'status'),
|
||||
)
|
||||
|
||||
|
||||
class PositionTransaction(Base):
|
||||
"""
|
||||
Junction table linking positions to transactions.
|
||||
|
||||
A position can have multiple transactions (entries, exits, adjustments).
|
||||
A transaction can be part of multiple positions (e.g., closing multiple lots).
|
||||
|
||||
Attributes:
|
||||
position_id: Foreign key to position
|
||||
transaction_id: Foreign key to transaction
|
||||
"""
|
||||
__tablename__ = "position_transactions"
|
||||
|
||||
position_id = Column(Integer, ForeignKey("positions.id", ondelete="CASCADE"), primary_key=True)
|
||||
transaction_id = Column(Integer, ForeignKey("transactions.id", ondelete="CASCADE"), primary_key=True)
|
||||
|
||||
# Relationships
|
||||
position = relationship("Position", back_populates="transaction_links")
|
||||
transaction = relationship("Transaction", back_populates="position_links")
|
||||
81
backend/app/models/transaction.py
Normal file
81
backend/app/models/transaction.py
Normal file
@@ -0,0 +1,81 @@
|
||||
"""Transaction model representing a brokerage transaction."""
|
||||
from sqlalchemy import Column, Integer, String, DateTime, Numeric, ForeignKey, Date, Index
|
||||
from sqlalchemy.orm import relationship
|
||||
from sqlalchemy.sql import func
|
||||
|
||||
from app.database import Base
|
||||
|
||||
|
||||
class Transaction(Base):
|
||||
"""
|
||||
Represents a single brokerage transaction.
|
||||
|
||||
Attributes:
|
||||
id: Primary key
|
||||
account_id: Foreign key to account
|
||||
run_date: Date the transaction was recorded
|
||||
action: Description of the transaction action
|
||||
symbol: Trading symbol
|
||||
description: Full transaction description
|
||||
transaction_type: Type (Cash/Margin)
|
||||
exchange_quantity: Quantity in exchange currency
|
||||
exchange_currency: Exchange currency code
|
||||
currency: Transaction currency
|
||||
price: Transaction price per unit
|
||||
quantity: Number of shares/contracts
|
||||
exchange_rate: Currency exchange rate
|
||||
commission: Commission fees
|
||||
fees: Additional fees
|
||||
accrued_interest: Interest accrued
|
||||
amount: Total transaction amount
|
||||
cash_balance: Account balance after transaction
|
||||
settlement_date: Date transaction settles
|
||||
unique_hash: SHA-256 hash for deduplication
|
||||
created_at: Timestamp of record creation
|
||||
updated_at: Timestamp of last update
|
||||
"""
|
||||
__tablename__ = "transactions"
|
||||
|
||||
id = Column(Integer, primary_key=True, index=True)
|
||||
account_id = Column(Integer, ForeignKey("accounts.id", ondelete="CASCADE"), nullable=False, index=True)
|
||||
|
||||
# Transaction details from CSV
|
||||
run_date = Column(Date, nullable=False, index=True)
|
||||
action = Column(String(500), nullable=False)
|
||||
symbol = Column(String(50), index=True)
|
||||
description = Column(String(500))
|
||||
transaction_type = Column(String(20)) # Cash, Margin
|
||||
|
||||
# Quantities and currencies
|
||||
exchange_quantity = Column(Numeric(20, 8))
|
||||
exchange_currency = Column(String(10))
|
||||
currency = Column(String(10))
|
||||
|
||||
# Financial details
|
||||
price = Column(Numeric(20, 8))
|
||||
quantity = Column(Numeric(20, 8))
|
||||
exchange_rate = Column(Numeric(20, 8))
|
||||
commission = Column(Numeric(20, 2))
|
||||
fees = Column(Numeric(20, 2))
|
||||
accrued_interest = Column(Numeric(20, 2))
|
||||
amount = Column(Numeric(20, 2))
|
||||
cash_balance = Column(Numeric(20, 2))
|
||||
|
||||
settlement_date = Column(Date)
|
||||
|
||||
# Deduplication hash
|
||||
unique_hash = Column(String(64), unique=True, nullable=False, index=True)
|
||||
|
||||
# Timestamps
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now(), nullable=False)
|
||||
updated_at = Column(DateTime(timezone=True), onupdate=func.now(), server_default=func.now(), nullable=False)
|
||||
|
||||
# Relationships
|
||||
account = relationship("Account", back_populates="transactions")
|
||||
position_links = relationship("PositionTransaction", back_populates="transaction", cascade="all, delete-orphan")
|
||||
|
||||
# Composite index for common queries
|
||||
__table_args__ = (
|
||||
Index('idx_account_date', 'account_id', 'run_date'),
|
||||
Index('idx_account_symbol', 'account_id', 'symbol'),
|
||||
)
|
||||
5
backend/app/parsers/__init__.py
Normal file
5
backend/app/parsers/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""CSV parser modules for various brokerage formats."""
|
||||
from app.parsers.base_parser import BaseParser, ParseResult
|
||||
from app.parsers.fidelity_parser import FidelityParser
|
||||
|
||||
__all__ = ["BaseParser", "ParseResult", "FidelityParser"]
|
||||
99
backend/app/parsers/base_parser.py
Normal file
99
backend/app/parsers/base_parser.py
Normal file
@@ -0,0 +1,99 @@
|
||||
"""Base parser interface for brokerage CSV files."""
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Dict, Any, NamedTuple
|
||||
from pathlib import Path
|
||||
import pandas as pd
|
||||
|
||||
|
||||
class ParseResult(NamedTuple):
|
||||
"""
|
||||
Result of parsing a brokerage CSV file.
|
||||
|
||||
Attributes:
|
||||
transactions: List of parsed transaction dictionaries
|
||||
errors: List of error messages encountered during parsing
|
||||
row_count: Total number of rows processed
|
||||
"""
|
||||
transactions: List[Dict[str, Any]]
|
||||
errors: List[str]
|
||||
row_count: int
|
||||
|
||||
|
||||
class BaseParser(ABC):
|
||||
"""
|
||||
Abstract base class for brokerage CSV parsers.
|
||||
|
||||
Provides a standard interface for parsing CSV files from different brokerages.
|
||||
Subclasses must implement the parse() method for their specific format.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def parse(self, file_path: Path) -> ParseResult:
|
||||
"""
|
||||
Parse a brokerage CSV file into standardized transaction dictionaries.
|
||||
|
||||
Args:
|
||||
file_path: Path to the CSV file to parse
|
||||
|
||||
Returns:
|
||||
ParseResult containing transactions, errors, and row count
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist
|
||||
ValueError: If the file format is invalid
|
||||
"""
|
||||
pass
|
||||
|
||||
def _read_csv(self, file_path: Path, **kwargs) -> pd.DataFrame:
|
||||
"""
|
||||
Read CSV file into a pandas DataFrame with error handling.
|
||||
|
||||
Args:
|
||||
file_path: Path to CSV file
|
||||
**kwargs: Additional arguments passed to pd.read_csv()
|
||||
|
||||
Returns:
|
||||
DataFrame containing CSV data
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file does not exist
|
||||
pd.errors.EmptyDataError: If file is empty
|
||||
"""
|
||||
if not file_path.exists():
|
||||
raise FileNotFoundError(f"CSV file not found: {file_path}")
|
||||
|
||||
return pd.read_csv(file_path, **kwargs)
|
||||
|
||||
@staticmethod
|
||||
def _safe_decimal(value: Any) -> Any:
|
||||
"""
|
||||
Safely convert value to decimal-compatible format, handling NaN and None.
|
||||
|
||||
Args:
|
||||
value: Value to convert
|
||||
|
||||
Returns:
|
||||
Converted value or None if invalid
|
||||
"""
|
||||
if pd.isna(value):
|
||||
return None
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
|
||||
@staticmethod
|
||||
def _safe_date(value: Any) -> Any:
|
||||
"""
|
||||
Safely convert value to date, handling NaN and None.
|
||||
|
||||
Args:
|
||||
value: Value to convert
|
||||
|
||||
Returns:
|
||||
Converted date or None if invalid
|
||||
"""
|
||||
if pd.isna(value):
|
||||
return None
|
||||
if value == "":
|
||||
return None
|
||||
return value
|
||||
257
backend/app/parsers/fidelity_parser.py
Normal file
257
backend/app/parsers/fidelity_parser.py
Normal file
@@ -0,0 +1,257 @@
|
||||
"""Fidelity brokerage CSV parser."""
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any
|
||||
import pandas as pd
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from app.parsers.base_parser import BaseParser, ParseResult
|
||||
|
||||
|
||||
class FidelityParser(BaseParser):
|
||||
"""
|
||||
Parser for Fidelity brokerage account history CSV files.
|
||||
|
||||
Expected CSV columns:
|
||||
- Run Date
|
||||
- Action
|
||||
- Symbol
|
||||
- Description
|
||||
- Type
|
||||
- Exchange Quantity
|
||||
- Exchange Currency
|
||||
- Currency
|
||||
- Price
|
||||
- Quantity
|
||||
- Exchange Rate
|
||||
- Commission
|
||||
- Fees
|
||||
- Accrued Interest
|
||||
- Amount
|
||||
- Cash Balance
|
||||
- Settlement Date
|
||||
"""
|
||||
|
||||
# Expected column names in Fidelity CSV
|
||||
EXPECTED_COLUMNS = [
|
||||
"Run Date",
|
||||
"Action",
|
||||
"Symbol",
|
||||
"Description",
|
||||
"Type",
|
||||
"Exchange Quantity",
|
||||
"Exchange Currency",
|
||||
"Currency",
|
||||
"Price",
|
||||
"Quantity",
|
||||
"Exchange Rate",
|
||||
"Commission",
|
||||
"Fees",
|
||||
"Accrued Interest",
|
||||
"Amount",
|
||||
"Cash Balance",
|
||||
"Settlement Date",
|
||||
]
|
||||
|
||||
def parse(self, file_path: Path) -> ParseResult:
|
||||
"""
|
||||
Parse a Fidelity CSV file into standardized transaction dictionaries.
|
||||
|
||||
Args:
|
||||
file_path: Path to the Fidelity CSV file
|
||||
|
||||
Returns:
|
||||
ParseResult containing parsed transactions, errors, and row count
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist
|
||||
ValueError: If the CSV format is invalid
|
||||
"""
|
||||
errors = []
|
||||
transactions = []
|
||||
|
||||
try:
|
||||
# Read CSV, skipping empty rows at the beginning
|
||||
df = self._read_csv(file_path, skiprows=self._find_header_row(file_path))
|
||||
|
||||
# Validate columns
|
||||
missing_cols = set(self.EXPECTED_COLUMNS) - set(df.columns)
|
||||
if missing_cols:
|
||||
raise ValueError(f"Missing required columns: {missing_cols}")
|
||||
|
||||
# Parse each row
|
||||
for idx, row in df.iterrows():
|
||||
try:
|
||||
transaction = self._parse_row(row)
|
||||
if transaction:
|
||||
transactions.append(transaction)
|
||||
except Exception as e:
|
||||
errors.append(f"Row {idx + 1}: {str(e)}")
|
||||
|
||||
return ParseResult(
|
||||
transactions=transactions, errors=errors, row_count=len(df)
|
||||
)
|
||||
|
||||
except FileNotFoundError as e:
|
||||
raise e
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to parse Fidelity CSV: {str(e)}")
|
||||
|
||||
def _find_header_row(self, file_path: Path) -> int:
|
||||
"""
|
||||
Find the row number where the header starts in Fidelity CSV.
|
||||
|
||||
Fidelity CSVs may have empty rows or metadata at the beginning.
|
||||
|
||||
Args:
|
||||
file_path: Path to CSV file
|
||||
|
||||
Returns:
|
||||
Row number (0-indexed) where the header is located
|
||||
"""
|
||||
with open(file_path, "r", encoding="utf-8-sig") as f:
|
||||
for i, line in enumerate(f):
|
||||
if "Run Date" in line:
|
||||
return i
|
||||
return 0 # Default to first row if not found
|
||||
|
||||
def _extract_real_ticker(self, symbol: str, description: str, action: str) -> str:
|
||||
"""
|
||||
Extract the real underlying ticker from option descriptions.
|
||||
|
||||
Fidelity uses internal reference numbers (like 6736999MM) in the Symbol column
|
||||
for options, but the real ticker is in the Description/Action in parentheses.
|
||||
|
||||
Examples:
|
||||
- Description: "CALL (OPEN) OPENDOOR JAN 16 26 (100 SHS)"
|
||||
- Action: "YOU SOLD CLOSING TRANSACTION CALL (OPEN) OPENDOOR..."
|
||||
|
||||
Args:
|
||||
symbol: Symbol from CSV (might be Fidelity internal reference)
|
||||
description: Description field
|
||||
action: Action field
|
||||
|
||||
Returns:
|
||||
Real ticker symbol, or original symbol if not found
|
||||
"""
|
||||
# If symbol looks normal (letters only, not Fidelity's numeric codes), return it
|
||||
if symbol and re.match(r'^[A-Z]{1,5}$', symbol):
|
||||
return symbol
|
||||
|
||||
# Try to extract from description first (more reliable)
|
||||
# Pattern: (TICKER) or CALL (TICKER) or PUT (TICKER)
|
||||
if description:
|
||||
# Look for pattern like "CALL (OPEN)" or "PUT (AAPL)"
|
||||
match = re.search(r'(?:CALL|PUT)\s*\(([A-Z]+)\)', description, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Look for standalone (TICKER) pattern
|
||||
match = re.search(r'\(([A-Z]{1,5})\)', description)
|
||||
if match:
|
||||
ticker = match.group(1)
|
||||
# Make sure it's not something like (100 or (Margin)
|
||||
if not ticker.isdigit() and ticker not in ['MARGIN', 'CASH', 'SHS']:
|
||||
return ticker
|
||||
|
||||
# Fall back to action field
|
||||
if action:
|
||||
match = re.search(r'(?:CALL|PUT)\s*\(([A-Z]+)\)', action, re.IGNORECASE)
|
||||
if match:
|
||||
return match.group(1)
|
||||
|
||||
# Return original symbol if we couldn't extract anything better
|
||||
return symbol if symbol else None
|
||||
|
||||
def _parse_row(self, row: pd.Series) -> Dict[str, Any]:
|
||||
"""
|
||||
Parse a single row from Fidelity CSV into a transaction dictionary.
|
||||
|
||||
Args:
|
||||
row: Pandas Series representing one CSV row
|
||||
|
||||
Returns:
|
||||
Dictionary with transaction data, or None if row should be skipped
|
||||
|
||||
Raises:
|
||||
ValueError: If required fields are missing or invalid
|
||||
"""
|
||||
# Parse dates
|
||||
run_date = self._parse_date(row["Run Date"])
|
||||
settlement_date = self._parse_date(row["Settlement Date"])
|
||||
|
||||
# Extract raw values
|
||||
raw_symbol = self._safe_string(row["Symbol"])
|
||||
description = self._safe_string(row["Description"])
|
||||
action = str(row["Action"]).strip() if pd.notna(row["Action"]) else ""
|
||||
|
||||
# Extract the real ticker (especially important for options)
|
||||
actual_symbol = self._extract_real_ticker(raw_symbol, description, action)
|
||||
|
||||
# Extract and clean values
|
||||
transaction = {
|
||||
"run_date": run_date,
|
||||
"action": action,
|
||||
"symbol": actual_symbol,
|
||||
"description": description,
|
||||
"transaction_type": self._safe_string(row["Type"]),
|
||||
"exchange_quantity": self._safe_decimal(row["Exchange Quantity"]),
|
||||
"exchange_currency": self._safe_string(row["Exchange Currency"]),
|
||||
"currency": self._safe_string(row["Currency"]),
|
||||
"price": self._safe_decimal(row["Price"]),
|
||||
"quantity": self._safe_decimal(row["Quantity"]),
|
||||
"exchange_rate": self._safe_decimal(row["Exchange Rate"]),
|
||||
"commission": self._safe_decimal(row["Commission"]),
|
||||
"fees": self._safe_decimal(row["Fees"]),
|
||||
"accrued_interest": self._safe_decimal(row["Accrued Interest"]),
|
||||
"amount": self._safe_decimal(row["Amount"]),
|
||||
"cash_balance": self._safe_decimal(row["Cash Balance"]),
|
||||
"settlement_date": settlement_date,
|
||||
}
|
||||
|
||||
return transaction
|
||||
|
||||
def _parse_date(self, date_value: Any) -> Any:
|
||||
"""
|
||||
Parse date value from CSV, handling various formats.
|
||||
|
||||
Args:
|
||||
date_value: Date value from CSV (string or datetime)
|
||||
|
||||
Returns:
|
||||
datetime.date object or None if empty/invalid
|
||||
"""
|
||||
if pd.isna(date_value) or date_value == "":
|
||||
return None
|
||||
|
||||
# If already a datetime object
|
||||
if isinstance(date_value, datetime):
|
||||
return date_value.date()
|
||||
|
||||
# Try parsing common date formats
|
||||
date_str = str(date_value).strip()
|
||||
if not date_str:
|
||||
return None
|
||||
|
||||
# Try common formats
|
||||
for fmt in ["%m/%d/%Y", "%Y-%m-%d", "%m-%d-%Y"]:
|
||||
try:
|
||||
return datetime.strptime(date_str, fmt).date()
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
return None
|
||||
|
||||
def _safe_string(self, value: Any) -> str:
|
||||
"""
|
||||
Safely convert value to string, handling NaN and empty values.
|
||||
|
||||
Args:
|
||||
value: Value to convert
|
||||
|
||||
Returns:
|
||||
String value or None if empty
|
||||
"""
|
||||
if pd.isna(value) or value == "":
|
||||
return None
|
||||
return str(value).strip()
|
||||
14
backend/app/schemas/__init__.py
Normal file
14
backend/app/schemas/__init__.py
Normal file
@@ -0,0 +1,14 @@
|
||||
"""Pydantic schemas for API request/response validation."""
|
||||
from app.schemas.account import AccountCreate, AccountUpdate, AccountResponse
|
||||
from app.schemas.transaction import TransactionCreate, TransactionResponse
|
||||
from app.schemas.position import PositionResponse, PositionStats
|
||||
|
||||
__all__ = [
|
||||
"AccountCreate",
|
||||
"AccountUpdate",
|
||||
"AccountResponse",
|
||||
"TransactionCreate",
|
||||
"TransactionResponse",
|
||||
"PositionResponse",
|
||||
"PositionStats",
|
||||
]
|
||||
34
backend/app/schemas/account.py
Normal file
34
backend/app/schemas/account.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Pydantic schemas for account-related API operations."""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from app.models.account import AccountType
|
||||
|
||||
|
||||
class AccountBase(BaseModel):
|
||||
"""Base schema for account data."""
|
||||
account_number: str = Field(..., description="Unique account identifier")
|
||||
account_name: str = Field(..., description="Human-readable account name")
|
||||
account_type: AccountType = Field(default=AccountType.CASH, description="Account type")
|
||||
|
||||
|
||||
class AccountCreate(AccountBase):
|
||||
"""Schema for creating a new account."""
|
||||
pass
|
||||
|
||||
|
||||
class AccountUpdate(BaseModel):
|
||||
"""Schema for updating an existing account."""
|
||||
account_name: Optional[str] = Field(None, description="Updated account name")
|
||||
account_type: Optional[AccountType] = Field(None, description="Updated account type")
|
||||
|
||||
|
||||
class AccountResponse(AccountBase):
|
||||
"""Schema for account API responses."""
|
||||
id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
45
backend/app/schemas/position.py
Normal file
45
backend/app/schemas/position.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Pydantic schemas for position-related API operations."""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
from decimal import Decimal
|
||||
|
||||
from app.models.position import PositionType, PositionStatus
|
||||
|
||||
|
||||
class PositionBase(BaseModel):
|
||||
"""Base schema for position data."""
|
||||
symbol: str
|
||||
option_symbol: Optional[str] = None
|
||||
position_type: PositionType
|
||||
status: PositionStatus
|
||||
open_date: date
|
||||
close_date: Optional[date] = None
|
||||
total_quantity: Decimal
|
||||
avg_entry_price: Optional[Decimal] = None
|
||||
avg_exit_price: Optional[Decimal] = None
|
||||
realized_pnl: Optional[Decimal] = None
|
||||
unrealized_pnl: Optional[Decimal] = None
|
||||
|
||||
|
||||
class PositionResponse(PositionBase):
|
||||
"""Schema for position API responses."""
|
||||
id: int
|
||||
account_id: int
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
|
||||
|
||||
class PositionStats(BaseModel):
|
||||
"""Schema for aggregate position statistics."""
|
||||
total_positions: int = Field(..., description="Total number of positions")
|
||||
open_positions: int = Field(..., description="Number of open positions")
|
||||
closed_positions: int = Field(..., description="Number of closed positions")
|
||||
total_realized_pnl: Decimal = Field(..., description="Total realized P&L")
|
||||
total_unrealized_pnl: Decimal = Field(..., description="Total unrealized P&L")
|
||||
win_rate: float = Field(..., description="Percentage of profitable trades")
|
||||
avg_win: Decimal = Field(..., description="Average profit on winning trades")
|
||||
avg_loss: Decimal = Field(..., description="Average loss on losing trades")
|
||||
44
backend/app/schemas/transaction.py
Normal file
44
backend/app/schemas/transaction.py
Normal file
@@ -0,0 +1,44 @@
|
||||
"""Pydantic schemas for transaction-related API operations."""
|
||||
from pydantic import BaseModel, Field
|
||||
from datetime import date, datetime
|
||||
from typing import Optional
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class TransactionBase(BaseModel):
|
||||
"""Base schema for transaction data."""
|
||||
run_date: date
|
||||
action: str
|
||||
symbol: Optional[str] = None
|
||||
description: Optional[str] = None
|
||||
transaction_type: Optional[str] = None
|
||||
exchange_quantity: Optional[Decimal] = None
|
||||
exchange_currency: Optional[str] = None
|
||||
currency: Optional[str] = None
|
||||
price: Optional[Decimal] = None
|
||||
quantity: Optional[Decimal] = None
|
||||
exchange_rate: Optional[Decimal] = None
|
||||
commission: Optional[Decimal] = None
|
||||
fees: Optional[Decimal] = None
|
||||
accrued_interest: Optional[Decimal] = None
|
||||
amount: Optional[Decimal] = None
|
||||
cash_balance: Optional[Decimal] = None
|
||||
settlement_date: Optional[date] = None
|
||||
|
||||
|
||||
class TransactionCreate(TransactionBase):
|
||||
"""Schema for creating a new transaction."""
|
||||
account_id: int
|
||||
unique_hash: str
|
||||
|
||||
|
||||
class TransactionResponse(TransactionBase):
|
||||
"""Schema for transaction API responses."""
|
||||
id: int
|
||||
account_id: int
|
||||
unique_hash: str
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
class Config:
|
||||
from_attributes = True
|
||||
6
backend/app/services/__init__.py
Normal file
6
backend/app/services/__init__.py
Normal file
@@ -0,0 +1,6 @@
|
||||
"""Business logic services."""
|
||||
from app.services.import_service import ImportService, ImportResult
|
||||
from app.services.position_tracker import PositionTracker
|
||||
from app.services.performance_calculator import PerformanceCalculator
|
||||
|
||||
__all__ = ["ImportService", "ImportResult", "PositionTracker", "PerformanceCalculator"]
|
||||
149
backend/app/services/import_service.py
Normal file
149
backend/app/services/import_service.py
Normal file
@@ -0,0 +1,149 @@
|
||||
"""Service for importing transactions from CSV files."""
|
||||
from pathlib import Path
|
||||
from typing import List, Dict, Any, NamedTuple
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from app.parsers import FidelityParser
|
||||
from app.models import Transaction
|
||||
from app.utils import generate_transaction_hash
|
||||
|
||||
|
||||
class ImportResult(NamedTuple):
|
||||
"""
|
||||
Result of an import operation.
|
||||
|
||||
Attributes:
|
||||
imported: Number of successfully imported transactions
|
||||
skipped: Number of skipped duplicate transactions
|
||||
errors: List of error messages
|
||||
total_rows: Total number of rows processed
|
||||
"""
|
||||
imported: int
|
||||
skipped: int
|
||||
errors: List[str]
|
||||
total_rows: int
|
||||
|
||||
|
||||
class ImportService:
|
||||
"""
|
||||
Service for importing transactions from brokerage CSV files.
|
||||
|
||||
Handles parsing, deduplication, and database insertion.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize import service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
self.db = db
|
||||
self.parser = FidelityParser() # Can be extended to support multiple parsers
|
||||
|
||||
def import_from_file(self, file_path: Path, account_id: int) -> ImportResult:
|
||||
"""
|
||||
Import transactions from a CSV file.
|
||||
|
||||
Args:
|
||||
file_path: Path to CSV file
|
||||
account_id: ID of the account to import transactions for
|
||||
|
||||
Returns:
|
||||
ImportResult with statistics
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If file doesn't exist
|
||||
ValueError: If file format is invalid
|
||||
"""
|
||||
# Parse CSV file
|
||||
parse_result = self.parser.parse(file_path)
|
||||
|
||||
imported = 0
|
||||
skipped = 0
|
||||
errors = list(parse_result.errors)
|
||||
|
||||
# Process each transaction
|
||||
for txn_data in parse_result.transactions:
|
||||
try:
|
||||
# Generate deduplication hash
|
||||
unique_hash = generate_transaction_hash(
|
||||
account_id=account_id,
|
||||
run_date=txn_data["run_date"],
|
||||
symbol=txn_data.get("symbol"),
|
||||
action=txn_data["action"],
|
||||
amount=txn_data.get("amount"),
|
||||
quantity=txn_data.get("quantity"),
|
||||
price=txn_data.get("price"),
|
||||
)
|
||||
|
||||
# Check if transaction already exists
|
||||
existing = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.unique_hash == unique_hash)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
skipped += 1
|
||||
continue
|
||||
|
||||
# Create new transaction
|
||||
transaction = Transaction(
|
||||
account_id=account_id,
|
||||
unique_hash=unique_hash,
|
||||
**txn_data
|
||||
)
|
||||
|
||||
self.db.add(transaction)
|
||||
self.db.commit()
|
||||
imported += 1
|
||||
|
||||
except IntegrityError:
|
||||
# Duplicate hash (edge case if concurrent imports)
|
||||
self.db.rollback()
|
||||
skipped += 1
|
||||
except Exception as e:
|
||||
self.db.rollback()
|
||||
errors.append(f"Failed to import transaction: {str(e)}")
|
||||
|
||||
return ImportResult(
|
||||
imported=imported,
|
||||
skipped=skipped,
|
||||
errors=errors,
|
||||
total_rows=parse_result.row_count,
|
||||
)
|
||||
|
||||
def import_from_directory(
|
||||
self, directory: Path, account_id: int, pattern: str = "*.csv"
|
||||
) -> Dict[str, ImportResult]:
|
||||
"""
|
||||
Import transactions from all CSV files in a directory.
|
||||
|
||||
Args:
|
||||
directory: Path to directory containing CSV files
|
||||
account_id: ID of the account to import transactions for
|
||||
pattern: Glob pattern for matching files (default: *.csv)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping filename to ImportResult
|
||||
"""
|
||||
if not directory.exists() or not directory.is_dir():
|
||||
raise ValueError(f"Invalid directory: {directory}")
|
||||
|
||||
results = {}
|
||||
|
||||
for file_path in directory.glob(pattern):
|
||||
try:
|
||||
result = self.import_from_file(file_path, account_id)
|
||||
results[file_path.name] = result
|
||||
except Exception as e:
|
||||
results[file_path.name] = ImportResult(
|
||||
imported=0,
|
||||
skipped=0,
|
||||
errors=[str(e)],
|
||||
total_rows=0,
|
||||
)
|
||||
|
||||
return results
|
||||
330
backend/app/services/market_data_service.py
Normal file
330
backend/app/services/market_data_service.py
Normal file
@@ -0,0 +1,330 @@
|
||||
"""
|
||||
Market data service with rate limiting, caching, and batch processing.
|
||||
|
||||
This service handles fetching market prices from Yahoo Finance with:
|
||||
- Database-backed caching to survive restarts
|
||||
- Rate limiting with exponential backoff
|
||||
- Batch processing to reduce API calls
|
||||
- Stale-while-revalidate pattern for better UX
|
||||
"""
|
||||
import time
|
||||
import yfinance as yf
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import Dict, List, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from app.models.market_price import MarketPrice
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class MarketDataService:
|
||||
"""Service for fetching and caching market prices with rate limiting."""
|
||||
|
||||
def __init__(self, db: Session, cache_ttl_seconds: int = 300):
|
||||
"""
|
||||
Initialize market data service.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
cache_ttl_seconds: How long cached prices are considered fresh (default: 5 minutes)
|
||||
"""
|
||||
self.db = db
|
||||
self.cache_ttl = cache_ttl_seconds
|
||||
self._rate_limit_delay = 0.5 # Start with 500ms between requests
|
||||
self._last_request_time = 0.0
|
||||
self._consecutive_errors = 0
|
||||
self._max_retries = 3
|
||||
|
||||
@staticmethod
|
||||
def _is_valid_stock_symbol(symbol: str) -> bool:
|
||||
"""
|
||||
Check if a symbol is a valid stock ticker (not an option symbol or CUSIP).
|
||||
|
||||
Args:
|
||||
symbol: Symbol to check
|
||||
|
||||
Returns:
|
||||
True if it looks like a valid stock ticker
|
||||
"""
|
||||
if not symbol or len(symbol) > 5:
|
||||
return False
|
||||
|
||||
# Stock symbols should start with a letter, not a number
|
||||
# Numbers indicate CUSIP codes or option symbols
|
||||
if symbol[0].isdigit():
|
||||
return False
|
||||
|
||||
# Should be mostly uppercase letters
|
||||
# Allow $ for preferred shares (e.g., BRK.B becomes BRK-B)
|
||||
return symbol.replace('-', '').replace('.', '').isalpha()
|
||||
|
||||
def get_price(self, symbol: str, allow_stale: bool = True) -> Optional[Decimal]:
|
||||
"""
|
||||
Get current price for a symbol with caching.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
allow_stale: If True, return stale cache data instead of None
|
||||
|
||||
Returns:
|
||||
Price or None if unavailable
|
||||
"""
|
||||
# Skip invalid symbols (option symbols, CUSIPs, etc.)
|
||||
if not self._is_valid_stock_symbol(symbol):
|
||||
logger.debug(f"Skipping invalid symbol: {symbol} (not a stock ticker)")
|
||||
return None
|
||||
|
||||
# Check database cache first
|
||||
cached = self._get_cached_price(symbol)
|
||||
|
||||
if cached:
|
||||
price, age_seconds = cached
|
||||
if age_seconds < self.cache_ttl:
|
||||
# Fresh cache hit
|
||||
logger.debug(f"Cache HIT (fresh): {symbol} = ${price} (age: {age_seconds}s)")
|
||||
return price
|
||||
elif allow_stale:
|
||||
# Stale cache hit, but we'll return it
|
||||
logger.debug(f"Cache HIT (stale): {symbol} = ${price} (age: {age_seconds}s)")
|
||||
return price
|
||||
|
||||
# Cache miss or expired - fetch from Yahoo Finance
|
||||
logger.info(f"Cache MISS: {symbol}, fetching from Yahoo Finance...")
|
||||
fresh_price = self._fetch_from_yahoo(symbol)
|
||||
|
||||
if fresh_price is not None:
|
||||
self._update_cache(symbol, fresh_price)
|
||||
return fresh_price
|
||||
|
||||
# If fetch failed and we have stale data, return it
|
||||
if cached and allow_stale:
|
||||
price, age_seconds = cached
|
||||
logger.warning(f"Yahoo fetch failed, using stale cache: {symbol} = ${price} (age: {age_seconds}s)")
|
||||
return price
|
||||
|
||||
return None
|
||||
|
||||
def get_prices_batch(
|
||||
self,
|
||||
symbols: List[str],
|
||||
allow_stale: bool = True,
|
||||
max_fetches: int = 10
|
||||
) -> Dict[str, Optional[Decimal]]:
|
||||
"""
|
||||
Get prices for multiple symbols with rate limiting.
|
||||
|
||||
Args:
|
||||
symbols: List of ticker symbols
|
||||
allow_stale: Return stale cache data if available
|
||||
max_fetches: Maximum number of API calls to make (remaining use cache)
|
||||
|
||||
Returns:
|
||||
Dictionary mapping symbol to price (or None if unavailable)
|
||||
"""
|
||||
results = {}
|
||||
symbols_to_fetch = []
|
||||
|
||||
# First pass: Check cache for all symbols
|
||||
for symbol in symbols:
|
||||
# Skip invalid symbols
|
||||
if not self._is_valid_stock_symbol(symbol):
|
||||
logger.debug(f"Skipping invalid symbol in batch: {symbol}")
|
||||
results[symbol] = None
|
||||
continue
|
||||
cached = self._get_cached_price(symbol)
|
||||
|
||||
if cached:
|
||||
price, age_seconds = cached
|
||||
if age_seconds < self.cache_ttl:
|
||||
# Fresh cache - use it
|
||||
results[symbol] = price
|
||||
elif allow_stale:
|
||||
# Stale but usable
|
||||
results[symbol] = price
|
||||
if age_seconds < self.cache_ttl * 2: # Not TOO stale
|
||||
symbols_to_fetch.append(symbol)
|
||||
else:
|
||||
# Stale and not allowing stale - need to fetch
|
||||
symbols_to_fetch.append(symbol)
|
||||
else:
|
||||
# No cache at all
|
||||
symbols_to_fetch.append(symbol)
|
||||
|
||||
# Second pass: Fetch missing/stale symbols (with limit)
|
||||
if symbols_to_fetch:
|
||||
logger.info(f"Batch fetching {len(symbols_to_fetch)} symbols (max: {max_fetches})")
|
||||
|
||||
for i, symbol in enumerate(symbols_to_fetch[:max_fetches]):
|
||||
if i > 0:
|
||||
# Rate limiting delay
|
||||
time.sleep(self._rate_limit_delay)
|
||||
|
||||
price = self._fetch_from_yahoo(symbol)
|
||||
if price is not None:
|
||||
results[symbol] = price
|
||||
self._update_cache(symbol, price)
|
||||
elif symbol not in results:
|
||||
# No cached value and fetch failed
|
||||
results[symbol] = None
|
||||
|
||||
return results
|
||||
|
||||
def refresh_stale_prices(self, min_age_seconds: int = 300, limit: int = 20) -> int:
|
||||
"""
|
||||
Background task to refresh stale prices.
|
||||
|
||||
Args:
|
||||
min_age_seconds: Only refresh prices older than this
|
||||
limit: Maximum number of prices to refresh
|
||||
|
||||
Returns:
|
||||
Number of prices refreshed
|
||||
"""
|
||||
cutoff_time = datetime.utcnow() - timedelta(seconds=min_age_seconds)
|
||||
|
||||
# Get stale prices ordered by oldest first
|
||||
stale_prices = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.fetched_at < cutoff_time)
|
||||
.order_by(MarketPrice.fetched_at.asc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
refreshed = 0
|
||||
for cached_price in stale_prices:
|
||||
time.sleep(self._rate_limit_delay)
|
||||
|
||||
fresh_price = self._fetch_from_yahoo(cached_price.symbol)
|
||||
if fresh_price is not None:
|
||||
self._update_cache(cached_price.symbol, fresh_price)
|
||||
refreshed += 1
|
||||
|
||||
logger.info(f"Refreshed {refreshed}/{len(stale_prices)} stale prices")
|
||||
return refreshed
|
||||
|
||||
def _get_cached_price(self, symbol: str) -> Optional[tuple[Decimal, float]]:
|
||||
"""
|
||||
Get cached price from database.
|
||||
|
||||
Returns:
|
||||
Tuple of (price, age_in_seconds) or None if not cached
|
||||
"""
|
||||
cached = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.symbol == symbol)
|
||||
.first()
|
||||
)
|
||||
|
||||
if cached:
|
||||
age = (datetime.utcnow() - cached.fetched_at).total_seconds()
|
||||
return (cached.price, age)
|
||||
|
||||
return None
|
||||
|
||||
def _update_cache(self, symbol: str, price: Decimal) -> None:
|
||||
"""Update or insert price in database cache."""
|
||||
cached = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.symbol == symbol)
|
||||
.first()
|
||||
)
|
||||
|
||||
if cached:
|
||||
cached.price = price
|
||||
cached.fetched_at = datetime.utcnow()
|
||||
else:
|
||||
new_price = MarketPrice(
|
||||
symbol=symbol,
|
||||
price=price,
|
||||
fetched_at=datetime.utcnow()
|
||||
)
|
||||
self.db.add(new_price)
|
||||
|
||||
self.db.commit()
|
||||
|
||||
def _fetch_from_yahoo(self, symbol: str) -> Optional[Decimal]:
|
||||
"""
|
||||
Fetch price from Yahoo Finance with rate limiting and retries.
|
||||
|
||||
Returns:
|
||||
Price or None if fetch failed
|
||||
"""
|
||||
for attempt in range(self._max_retries):
|
||||
try:
|
||||
# Rate limiting
|
||||
elapsed = time.time() - self._last_request_time
|
||||
if elapsed < self._rate_limit_delay:
|
||||
time.sleep(self._rate_limit_delay - elapsed)
|
||||
|
||||
self._last_request_time = time.time()
|
||||
|
||||
# Fetch from Yahoo
|
||||
ticker = yf.Ticker(symbol)
|
||||
info = ticker.info
|
||||
|
||||
# Try different price fields
|
||||
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
|
||||
if field in info and info[field]:
|
||||
price = Decimal(str(info[field]))
|
||||
|
||||
# Success - reset error tracking
|
||||
self._consecutive_errors = 0
|
||||
self._rate_limit_delay = max(0.5, self._rate_limit_delay * 0.9) # Gradually decrease delay
|
||||
|
||||
logger.debug(f"Fetched {symbol} = ${price}")
|
||||
return price
|
||||
|
||||
# No price found in response
|
||||
logger.warning(f"No price data in Yahoo response for {symbol}")
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
if "429" in error_str or "too many requests" in error_str:
|
||||
# Rate limit hit - back off exponentially
|
||||
self._consecutive_errors += 1
|
||||
self._rate_limit_delay = min(10.0, self._rate_limit_delay * 2) # Double delay, max 10s
|
||||
|
||||
logger.warning(
|
||||
f"Rate limit hit for {symbol} (attempt {attempt + 1}/{self._max_retries}), "
|
||||
f"backing off to {self._rate_limit_delay}s delay"
|
||||
)
|
||||
|
||||
if attempt < self._max_retries - 1:
|
||||
time.sleep(self._rate_limit_delay * (attempt + 1)) # Longer wait for retries
|
||||
continue
|
||||
else:
|
||||
# Other error
|
||||
logger.error(f"Error fetching {symbol}: {e}")
|
||||
return None
|
||||
|
||||
logger.error(f"Failed to fetch {symbol} after {self._max_retries} attempts")
|
||||
return None
|
||||
|
||||
def clear_cache(self, older_than_days: int = 30) -> int:
|
||||
"""
|
||||
Clear old cached prices.
|
||||
|
||||
Args:
|
||||
older_than_days: Delete prices older than this many days
|
||||
|
||||
Returns:
|
||||
Number of records deleted
|
||||
"""
|
||||
cutoff = datetime.utcnow() - timedelta(days=older_than_days)
|
||||
|
||||
deleted = (
|
||||
self.db.query(MarketPrice)
|
||||
.filter(MarketPrice.fetched_at < cutoff)
|
||||
.delete()
|
||||
)
|
||||
|
||||
self.db.commit()
|
||||
logger.info(f"Cleared {deleted} cached prices older than {older_than_days} days")
|
||||
return deleted
|
||||
364
backend/app/services/performance_calculator.py
Normal file
364
backend/app/services/performance_calculator.py
Normal file
@@ -0,0 +1,364 @@
|
||||
"""Service for calculating performance metrics and unrealized P&L."""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, func
|
||||
from typing import Dict, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
import yfinance as yf
|
||||
from functools import lru_cache
|
||||
|
||||
from app.models import Position, Transaction
|
||||
from app.models.position import PositionStatus
|
||||
|
||||
|
||||
class PerformanceCalculator:
|
||||
"""
|
||||
Service for calculating performance metrics and market data.
|
||||
|
||||
Integrates with Yahoo Finance API for real-time pricing of open positions.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, cache_ttl: int = 60):
|
||||
"""
|
||||
Initialize performance calculator.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (default: 60)
|
||||
"""
|
||||
self.db = db
|
||||
self.cache_ttl = cache_ttl
|
||||
self._price_cache: Dict[str, tuple[Decimal, datetime]] = {}
|
||||
|
||||
def calculate_unrealized_pnl(self, position: Position) -> Optional[Decimal]:
|
||||
"""
|
||||
Calculate unrealized P&L for an open position.
|
||||
|
||||
Args:
|
||||
position: Open position to calculate P&L for
|
||||
|
||||
Returns:
|
||||
Unrealized P&L or None if market data unavailable
|
||||
"""
|
||||
if position.status != PositionStatus.OPEN:
|
||||
return None
|
||||
|
||||
# Get current market price
|
||||
current_price = self.get_current_price(position.symbol)
|
||||
if current_price is None:
|
||||
return None
|
||||
|
||||
if position.avg_entry_price is None:
|
||||
return None
|
||||
|
||||
# Calculate P&L based on position direction
|
||||
quantity = abs(position.total_quantity)
|
||||
is_short = position.total_quantity < 0
|
||||
|
||||
if is_short:
|
||||
# Short position: profit when price decreases
|
||||
pnl = (position.avg_entry_price - current_price) * quantity * 100
|
||||
else:
|
||||
# Long position: profit when price increases
|
||||
pnl = (current_price - position.avg_entry_price) * quantity * 100
|
||||
|
||||
# Subtract fees and commissions from opening transactions
|
||||
total_fees = Decimal("0")
|
||||
for link in position.transaction_links:
|
||||
txn = link.transaction
|
||||
if txn.commission:
|
||||
total_fees += txn.commission
|
||||
if txn.fees:
|
||||
total_fees += txn.fees
|
||||
|
||||
pnl -= total_fees
|
||||
|
||||
return pnl
|
||||
|
||||
def update_open_positions_pnl(self, account_id: int) -> int:
|
||||
"""
|
||||
Update unrealized P&L for all open positions in an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to update
|
||||
|
||||
Returns:
|
||||
Number of positions updated
|
||||
"""
|
||||
open_positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.OPEN,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
updated = 0
|
||||
for position in open_positions:
|
||||
unrealized_pnl = self.calculate_unrealized_pnl(position)
|
||||
if unrealized_pnl is not None:
|
||||
position.unrealized_pnl = unrealized_pnl
|
||||
updated += 1
|
||||
|
||||
self.db.commit()
|
||||
return updated
|
||||
|
||||
def get_current_price(self, symbol: str) -> Optional[Decimal]:
|
||||
"""
|
||||
Get current market price for a symbol.
|
||||
|
||||
Uses Yahoo Finance API with caching to reduce API calls.
|
||||
|
||||
Args:
|
||||
symbol: Stock ticker symbol
|
||||
|
||||
Returns:
|
||||
Current price or None if unavailable
|
||||
"""
|
||||
# Check cache
|
||||
if symbol in self._price_cache:
|
||||
price, timestamp = self._price_cache[symbol]
|
||||
if datetime.now() - timestamp < timedelta(seconds=self.cache_ttl):
|
||||
return price
|
||||
|
||||
# Fetch from Yahoo Finance
|
||||
try:
|
||||
ticker = yf.Ticker(symbol)
|
||||
info = ticker.info
|
||||
|
||||
# Try different price fields
|
||||
current_price = None
|
||||
for field in ["currentPrice", "regularMarketPrice", "previousClose"]:
|
||||
if field in info and info[field]:
|
||||
current_price = Decimal(str(info[field]))
|
||||
break
|
||||
|
||||
if current_price is not None:
|
||||
# Cache the price
|
||||
self._price_cache[symbol] = (current_price, datetime.now())
|
||||
return current_price
|
||||
|
||||
except Exception:
|
||||
# Failed to fetch price
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
def calculate_account_stats(self, account_id: int) -> Dict:
|
||||
"""
|
||||
Calculate aggregate statistics for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
# Get all positions
|
||||
positions = (
|
||||
self.db.query(Position)
|
||||
.filter(Position.account_id == account_id)
|
||||
.all()
|
||||
)
|
||||
|
||||
total_positions = len(positions)
|
||||
open_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.OPEN
|
||||
)
|
||||
closed_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Calculate P&L
|
||||
total_realized_pnl = sum(
|
||||
(p.realized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Update unrealized P&L for open positions
|
||||
self.update_open_positions_pnl(account_id)
|
||||
|
||||
total_unrealized_pnl = sum(
|
||||
(p.unrealized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.OPEN
|
||||
)
|
||||
|
||||
# Calculate win rate and average win/loss
|
||||
closed_with_pnl = [
|
||||
p for p in positions
|
||||
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
|
||||
]
|
||||
|
||||
if closed_with_pnl:
|
||||
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
|
||||
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
|
||||
|
||||
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
|
||||
|
||||
avg_win = (
|
||||
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
|
||||
if winning_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
|
||||
avg_loss = (
|
||||
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
|
||||
if losing_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
else:
|
||||
win_rate = 0.0
|
||||
avg_win = Decimal("0")
|
||||
avg_loss = Decimal("0")
|
||||
|
||||
# Get current account balance from latest transaction
|
||||
latest_txn = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.account_id == account_id)
|
||||
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
current_balance = (
|
||||
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
|
||||
)
|
||||
|
||||
return {
|
||||
"total_positions": total_positions,
|
||||
"open_positions": open_positions_count,
|
||||
"closed_positions": closed_positions_count,
|
||||
"total_realized_pnl": float(total_realized_pnl),
|
||||
"total_unrealized_pnl": float(total_unrealized_pnl),
|
||||
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
|
||||
"win_rate": float(win_rate),
|
||||
"avg_win": float(avg_win),
|
||||
"avg_loss": float(avg_loss),
|
||||
"current_balance": float(current_balance),
|
||||
}
|
||||
|
||||
def get_balance_history(
|
||||
self, account_id: int, days: int = 30
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
cutoff_date = datetime.now().date() - timedelta(days=days)
|
||||
|
||||
transactions = (
|
||||
self.db.query(Transaction.run_date, Transaction.cash_balance)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.account_id == account_id,
|
||||
Transaction.run_date >= cutoff_date,
|
||||
Transaction.cash_balance.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Transaction.run_date)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get one balance per day (use last transaction of the day)
|
||||
daily_balances = {}
|
||||
for txn in transactions:
|
||||
daily_balances[txn.run_date] = float(txn.cash_balance)
|
||||
|
||||
return [
|
||||
{"date": date.isoformat(), "balance": balance}
|
||||
for date, balance in sorted(daily_balances.items())
|
||||
]
|
||||
|
||||
def get_top_trades(
|
||||
self, account_id: int, limit: int = 10
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get top performing trades (by realized P&L).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Position.realized_pnl.desc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
|
||||
def get_worst_trades(
|
||||
self, account_id: int, limit: int = 20
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get worst performing trades (biggest losses by realized P&L).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Position.realized_pnl.asc())
|
||||
.limit(limit)
|
||||
.all()
|
||||
)
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
433
backend/app/services/performance_calculator_v2.py
Normal file
433
backend/app/services/performance_calculator_v2.py
Normal file
@@ -0,0 +1,433 @@
|
||||
"""
|
||||
Improved performance calculator with rate-limited market data fetching.
|
||||
|
||||
This version uses the MarketDataService for efficient, cached price lookups.
|
||||
"""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import Dict, Optional
|
||||
from decimal import Decimal
|
||||
from datetime import datetime, timedelta
|
||||
import logging
|
||||
|
||||
from app.models import Position, Transaction
|
||||
from app.models.position import PositionStatus
|
||||
from app.services.market_data_service import MarketDataService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class PerformanceCalculatorV2:
|
||||
"""
|
||||
Enhanced performance calculator with efficient market data handling.
|
||||
|
||||
Features:
|
||||
- Database-backed price caching
|
||||
- Rate-limited API calls
|
||||
- Batch price fetching
|
||||
- Stale-while-revalidate pattern
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session, cache_ttl: int = 300):
|
||||
"""
|
||||
Initialize performance calculator.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
cache_ttl: Cache time-to-live in seconds (default: 5 minutes)
|
||||
"""
|
||||
self.db = db
|
||||
self.market_data = MarketDataService(db, cache_ttl_seconds=cache_ttl)
|
||||
|
||||
def calculate_unrealized_pnl(self, position: Position, current_price: Optional[Decimal] = None) -> Optional[Decimal]:
|
||||
"""
|
||||
Calculate unrealized P&L for an open position.
|
||||
|
||||
Args:
|
||||
position: Open position to calculate P&L for
|
||||
current_price: Optional pre-fetched current price (avoids API call)
|
||||
|
||||
Returns:
|
||||
Unrealized P&L or None if market data unavailable
|
||||
"""
|
||||
if position.status != PositionStatus.OPEN:
|
||||
return None
|
||||
|
||||
# Use provided price or fetch it
|
||||
if current_price is None:
|
||||
current_price = self.market_data.get_price(position.symbol, allow_stale=True)
|
||||
|
||||
if current_price is None or position.avg_entry_price is None:
|
||||
return None
|
||||
|
||||
# Calculate P&L based on position direction
|
||||
quantity = abs(position.total_quantity)
|
||||
is_short = position.total_quantity < 0
|
||||
|
||||
if is_short:
|
||||
# Short position: profit when price decreases
|
||||
pnl = (position.avg_entry_price - current_price) * quantity * 100
|
||||
else:
|
||||
# Long position: profit when price increases
|
||||
pnl = (current_price - position.avg_entry_price) * quantity * 100
|
||||
|
||||
# Subtract fees and commissions from opening transactions
|
||||
total_fees = Decimal("0")
|
||||
for link in position.transaction_links:
|
||||
txn = link.transaction
|
||||
if txn.commission:
|
||||
total_fees += txn.commission
|
||||
if txn.fees:
|
||||
total_fees += txn.fees
|
||||
|
||||
pnl -= total_fees
|
||||
|
||||
return pnl
|
||||
|
||||
def update_open_positions_pnl(
|
||||
self,
|
||||
account_id: int,
|
||||
max_api_calls: int = 10,
|
||||
allow_stale: bool = True
|
||||
) -> Dict[str, int]:
|
||||
"""
|
||||
Update unrealized P&L for all open positions in an account.
|
||||
|
||||
Uses batch fetching with rate limiting to avoid overwhelming Yahoo Finance API.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to update
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls to make
|
||||
allow_stale: Allow using stale cached prices
|
||||
|
||||
Returns:
|
||||
Dictionary with update statistics
|
||||
"""
|
||||
open_positions = (
|
||||
self.db.query(Position)
|
||||
.filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.OPEN,
|
||||
)
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if not open_positions:
|
||||
return {
|
||||
"total": 0,
|
||||
"updated": 0,
|
||||
"cached": 0,
|
||||
"failed": 0
|
||||
}
|
||||
|
||||
# Get unique symbols
|
||||
symbols = list(set(p.symbol for p in open_positions))
|
||||
|
||||
logger.info(f"Updating P&L for {len(open_positions)} positions across {len(symbols)} symbols")
|
||||
|
||||
# Fetch prices in batch
|
||||
prices = self.market_data.get_prices_batch(
|
||||
symbols,
|
||||
allow_stale=allow_stale,
|
||||
max_fetches=max_api_calls
|
||||
)
|
||||
|
||||
# Update P&L for each position
|
||||
updated = 0
|
||||
cached = 0
|
||||
failed = 0
|
||||
|
||||
for position in open_positions:
|
||||
price = prices.get(position.symbol)
|
||||
|
||||
if price is not None:
|
||||
unrealized_pnl = self.calculate_unrealized_pnl(position, current_price=price)
|
||||
|
||||
if unrealized_pnl is not None:
|
||||
position.unrealized_pnl = unrealized_pnl
|
||||
updated += 1
|
||||
|
||||
# Check if price was from cache (age > 0) or fresh fetch
|
||||
cached_info = self.market_data._get_cached_price(position.symbol)
|
||||
if cached_info:
|
||||
_, age = cached_info
|
||||
if age < self.market_data.cache_ttl:
|
||||
cached += 1
|
||||
else:
|
||||
failed += 1
|
||||
else:
|
||||
failed += 1
|
||||
logger.warning(f"Could not get price for {position.symbol}")
|
||||
|
||||
self.db.commit()
|
||||
|
||||
logger.info(
|
||||
f"Updated {updated}/{len(open_positions)} positions "
|
||||
f"(cached: {cached}, failed: {failed})"
|
||||
)
|
||||
|
||||
return {
|
||||
"total": len(open_positions),
|
||||
"updated": updated,
|
||||
"cached": cached,
|
||||
"failed": failed
|
||||
}
|
||||
|
||||
def calculate_account_stats(
|
||||
self,
|
||||
account_id: int,
|
||||
update_prices: bool = True,
|
||||
max_api_calls: int = 10,
|
||||
start_date = None,
|
||||
end_date = None
|
||||
) -> Dict:
|
||||
"""
|
||||
Calculate aggregate statistics for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
update_prices: Whether to fetch fresh prices (if False, uses cached only)
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls
|
||||
start_date: Filter positions opened on or after this date
|
||||
end_date: Filter positions opened on or before this date
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
# Get all positions with optional date filtering
|
||||
query = self.db.query(Position).filter(Position.account_id == account_id)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(Position.open_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Position.open_date <= end_date)
|
||||
|
||||
positions = query.all()
|
||||
|
||||
total_positions = len(positions)
|
||||
open_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.OPEN
|
||||
)
|
||||
closed_positions_count = sum(
|
||||
1 for p in positions if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Calculate realized P&L (doesn't need market data)
|
||||
total_realized_pnl = sum(
|
||||
(p.realized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.CLOSED
|
||||
)
|
||||
|
||||
# Update unrealized P&L for open positions
|
||||
update_stats = None
|
||||
if update_prices and open_positions_count > 0:
|
||||
update_stats = self.update_open_positions_pnl(
|
||||
account_id,
|
||||
max_api_calls=max_api_calls,
|
||||
allow_stale=True
|
||||
)
|
||||
|
||||
# Calculate total unrealized P&L
|
||||
total_unrealized_pnl = sum(
|
||||
(p.unrealized_pnl or Decimal("0"))
|
||||
for p in positions
|
||||
if p.status == PositionStatus.OPEN
|
||||
)
|
||||
|
||||
# Calculate win rate and average win/loss
|
||||
closed_with_pnl = [
|
||||
p for p in positions
|
||||
if p.status == PositionStatus.CLOSED and p.realized_pnl is not None
|
||||
]
|
||||
|
||||
if closed_with_pnl:
|
||||
winning_trades = [p for p in closed_with_pnl if p.realized_pnl > 0]
|
||||
losing_trades = [p for p in closed_with_pnl if p.realized_pnl < 0]
|
||||
|
||||
win_rate = (len(winning_trades) / len(closed_with_pnl)) * 100
|
||||
|
||||
avg_win = (
|
||||
sum(p.realized_pnl for p in winning_trades) / len(winning_trades)
|
||||
if winning_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
|
||||
avg_loss = (
|
||||
sum(p.realized_pnl for p in losing_trades) / len(losing_trades)
|
||||
if losing_trades
|
||||
else Decimal("0")
|
||||
)
|
||||
else:
|
||||
win_rate = 0.0
|
||||
avg_win = Decimal("0")
|
||||
avg_loss = Decimal("0")
|
||||
|
||||
# Get current account balance from latest transaction
|
||||
latest_txn = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.account_id == account_id)
|
||||
.order_by(Transaction.run_date.desc(), Transaction.id.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
current_balance = (
|
||||
latest_txn.cash_balance if latest_txn and latest_txn.cash_balance else Decimal("0")
|
||||
)
|
||||
|
||||
result = {
|
||||
"total_positions": total_positions,
|
||||
"open_positions": open_positions_count,
|
||||
"closed_positions": closed_positions_count,
|
||||
"total_realized_pnl": float(total_realized_pnl),
|
||||
"total_unrealized_pnl": float(total_unrealized_pnl),
|
||||
"total_pnl": float(total_realized_pnl + total_unrealized_pnl),
|
||||
"win_rate": float(win_rate),
|
||||
"avg_win": float(avg_win),
|
||||
"avg_loss": float(avg_loss),
|
||||
"current_balance": float(current_balance),
|
||||
}
|
||||
|
||||
# Add update stats if prices were fetched
|
||||
if update_stats:
|
||||
result["price_update_stats"] = update_stats
|
||||
|
||||
return result
|
||||
|
||||
def get_balance_history(
|
||||
self, account_id: int, days: int = 30
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
This doesn't need market data, just transaction history.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
cutoff_date = datetime.now().date() - timedelta(days=days)
|
||||
|
||||
transactions = (
|
||||
self.db.query(Transaction.run_date, Transaction.cash_balance)
|
||||
.filter(
|
||||
and_(
|
||||
Transaction.account_id == account_id,
|
||||
Transaction.run_date >= cutoff_date,
|
||||
Transaction.cash_balance.isnot(None),
|
||||
)
|
||||
)
|
||||
.order_by(Transaction.run_date)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Get one balance per day (use last transaction of the day)
|
||||
daily_balances = {}
|
||||
for txn in transactions:
|
||||
daily_balances[txn.run_date] = float(txn.cash_balance)
|
||||
|
||||
return [
|
||||
{"date": date.isoformat(), "balance": balance}
|
||||
for date, balance in sorted(daily_balances.items())
|
||||
]
|
||||
|
||||
def get_top_trades(
|
||||
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get top performing trades (by realized P&L).
|
||||
|
||||
This doesn't need market data, just closed positions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
query = self.db.query(Position).filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply date filters if provided
|
||||
if start_date:
|
||||
query = query.filter(Position.close_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Position.close_date <= end_date)
|
||||
|
||||
positions = query.order_by(Position.realized_pnl.desc()).limit(limit).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
|
||||
def get_worst_trades(
|
||||
self, account_id: int, limit: int = 10, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None
|
||||
) -> list[Dict]:
|
||||
"""
|
||||
Get worst performing trades (by realized P&L).
|
||||
|
||||
This doesn't need market data, just closed positions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
query = self.db.query(Position).filter(
|
||||
and_(
|
||||
Position.account_id == account_id,
|
||||
Position.status == PositionStatus.CLOSED,
|
||||
Position.realized_pnl.isnot(None),
|
||||
)
|
||||
)
|
||||
|
||||
# Apply date filters if provided
|
||||
if start_date:
|
||||
query = query.filter(Position.close_date >= start_date)
|
||||
if end_date:
|
||||
query = query.filter(Position.close_date <= end_date)
|
||||
|
||||
positions = query.order_by(Position.realized_pnl.asc()).limit(limit).all()
|
||||
|
||||
return [
|
||||
{
|
||||
"symbol": p.symbol,
|
||||
"option_symbol": p.option_symbol,
|
||||
"position_type": p.position_type.value,
|
||||
"open_date": p.open_date.isoformat(),
|
||||
"close_date": p.close_date.isoformat() if p.close_date else None,
|
||||
"quantity": float(p.total_quantity),
|
||||
"entry_price": float(p.avg_entry_price) if p.avg_entry_price else None,
|
||||
"exit_price": float(p.avg_exit_price) if p.avg_exit_price else None,
|
||||
"realized_pnl": float(p.realized_pnl),
|
||||
}
|
||||
for p in positions
|
||||
]
|
||||
465
backend/app/services/position_tracker.py
Normal file
465
backend/app/services/position_tracker.py
Normal file
@@ -0,0 +1,465 @@
|
||||
"""Service for tracking and calculating trading positions."""
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import List, Optional, Dict
|
||||
from decimal import Decimal
|
||||
from collections import defaultdict
|
||||
from datetime import datetime
|
||||
import re
|
||||
|
||||
from app.models import Transaction, Position, PositionTransaction
|
||||
from app.models.position import PositionType, PositionStatus
|
||||
from app.utils import parse_option_symbol
|
||||
|
||||
|
||||
class PositionTracker:
|
||||
"""
|
||||
Service for tracking trading positions from transactions.
|
||||
|
||||
Matches opening and closing transactions using FIFO (First-In-First-Out) method.
|
||||
Handles stocks, calls, and puts including complex scenarios like assignments and expirations.
|
||||
"""
|
||||
|
||||
def __init__(self, db: Session):
|
||||
"""
|
||||
Initialize position tracker.
|
||||
|
||||
Args:
|
||||
db: Database session
|
||||
"""
|
||||
self.db = db
|
||||
|
||||
def rebuild_positions(self, account_id: int) -> int:
|
||||
"""
|
||||
Rebuild all positions for an account from transactions.
|
||||
|
||||
Deletes existing positions and recalculates from scratch.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to rebuild positions for
|
||||
|
||||
Returns:
|
||||
Number of positions created
|
||||
"""
|
||||
# Delete existing positions
|
||||
self.db.query(Position).filter(Position.account_id == account_id).delete()
|
||||
self.db.commit()
|
||||
|
||||
# Get all transactions ordered by date
|
||||
transactions = (
|
||||
self.db.query(Transaction)
|
||||
.filter(Transaction.account_id == account_id)
|
||||
.order_by(Transaction.run_date, Transaction.id)
|
||||
.all()
|
||||
)
|
||||
|
||||
# Group transactions by symbol and option details
|
||||
# For options, we need to group by the full option contract (symbol + strike + expiration)
|
||||
# For stocks, we group by symbol only
|
||||
symbol_txns = defaultdict(list)
|
||||
for txn in transactions:
|
||||
if txn.symbol:
|
||||
# Create a unique grouping key
|
||||
grouping_key = self._get_grouping_key(txn)
|
||||
symbol_txns[grouping_key].append(txn)
|
||||
|
||||
# Process each symbol/contract group
|
||||
position_count = 0
|
||||
for grouping_key, txns in symbol_txns.items():
|
||||
positions = self._process_symbol_transactions(account_id, grouping_key, txns)
|
||||
position_count += len(positions)
|
||||
|
||||
self.db.commit()
|
||||
return position_count
|
||||
|
||||
def _process_symbol_transactions(
|
||||
self, account_id: int, symbol: str, transactions: List[Transaction]
|
||||
) -> List[Position]:
|
||||
"""
|
||||
Process all transactions for a single symbol to create positions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
symbol: Trading symbol
|
||||
transactions: List of transactions for this symbol
|
||||
|
||||
Returns:
|
||||
List of created Position objects
|
||||
"""
|
||||
positions = []
|
||||
|
||||
# Determine position type from first transaction
|
||||
position_type = self._determine_position_type_from_txn(transactions[0]) if transactions else PositionType.STOCK
|
||||
|
||||
# Track open positions using FIFO
|
||||
open_positions: List[Dict] = []
|
||||
|
||||
for txn in transactions:
|
||||
action = txn.action.upper()
|
||||
|
||||
# Determine if this is an opening or closing transaction
|
||||
if self._is_opening_transaction(action):
|
||||
# Create new open position
|
||||
open_pos = {
|
||||
"transactions": [txn],
|
||||
"quantity": abs(txn.quantity) if txn.quantity else Decimal("0"),
|
||||
"entry_price": txn.price,
|
||||
"open_date": txn.run_date,
|
||||
"is_short": "SELL" in action or "SOLD" in action,
|
||||
}
|
||||
open_positions.append(open_pos)
|
||||
|
||||
elif self._is_closing_transaction(action):
|
||||
# Close positions using FIFO
|
||||
close_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
|
||||
remaining_to_close = close_quantity
|
||||
|
||||
while remaining_to_close > 0 and open_positions:
|
||||
open_pos = open_positions[0]
|
||||
open_qty = open_pos["quantity"]
|
||||
|
||||
if open_qty <= remaining_to_close:
|
||||
# Close entire position
|
||||
open_pos["transactions"].append(txn)
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
open_pos,
|
||||
close_date=txn.run_date,
|
||||
exit_price=txn.price,
|
||||
close_quantity=open_qty,
|
||||
)
|
||||
positions.append(position)
|
||||
open_positions.pop(0)
|
||||
remaining_to_close -= open_qty
|
||||
else:
|
||||
# Partially close position
|
||||
# Split into closed portion
|
||||
closed_portion = {
|
||||
"transactions": open_pos["transactions"] + [txn],
|
||||
"quantity": remaining_to_close,
|
||||
"entry_price": open_pos["entry_price"],
|
||||
"open_date": open_pos["open_date"],
|
||||
"is_short": open_pos["is_short"],
|
||||
}
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
closed_portion,
|
||||
close_date=txn.run_date,
|
||||
exit_price=txn.price,
|
||||
close_quantity=remaining_to_close,
|
||||
)
|
||||
positions.append(position)
|
||||
|
||||
# Update open position with remaining quantity
|
||||
open_pos["quantity"] -= remaining_to_close
|
||||
remaining_to_close = Decimal("0")
|
||||
|
||||
elif self._is_expiration(action):
|
||||
# Handle option expirations
|
||||
expire_quantity = abs(txn.quantity) if txn.quantity else Decimal("0")
|
||||
remaining_to_expire = expire_quantity
|
||||
|
||||
while remaining_to_expire > 0 and open_positions:
|
||||
open_pos = open_positions[0]
|
||||
open_qty = open_pos["quantity"]
|
||||
|
||||
if open_qty <= remaining_to_expire:
|
||||
# Expire entire position
|
||||
open_pos["transactions"].append(txn)
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
open_pos,
|
||||
close_date=txn.run_date,
|
||||
exit_price=Decimal("0"), # Expired worthless
|
||||
close_quantity=open_qty,
|
||||
)
|
||||
positions.append(position)
|
||||
open_positions.pop(0)
|
||||
remaining_to_expire -= open_qty
|
||||
else:
|
||||
# Partially expire
|
||||
closed_portion = {
|
||||
"transactions": open_pos["transactions"] + [txn],
|
||||
"quantity": remaining_to_expire,
|
||||
"entry_price": open_pos["entry_price"],
|
||||
"open_date": open_pos["open_date"],
|
||||
"is_short": open_pos["is_short"],
|
||||
}
|
||||
position = self._create_position(
|
||||
account_id,
|
||||
symbol,
|
||||
position_type,
|
||||
closed_portion,
|
||||
close_date=txn.run_date,
|
||||
exit_price=Decimal("0"),
|
||||
close_quantity=remaining_to_expire,
|
||||
)
|
||||
positions.append(position)
|
||||
open_pos["quantity"] -= remaining_to_expire
|
||||
remaining_to_expire = Decimal("0")
|
||||
|
||||
# Create positions for any remaining open positions
|
||||
for open_pos in open_positions:
|
||||
position = self._create_position(
|
||||
account_id, symbol, position_type, open_pos
|
||||
)
|
||||
positions.append(position)
|
||||
|
||||
return positions
|
||||
|
||||
def _create_position(
|
||||
self,
|
||||
account_id: int,
|
||||
symbol: str,
|
||||
position_type: PositionType,
|
||||
position_data: Dict,
|
||||
close_date: Optional[datetime] = None,
|
||||
exit_price: Optional[Decimal] = None,
|
||||
close_quantity: Optional[Decimal] = None,
|
||||
) -> Position:
|
||||
"""
|
||||
Create a Position database object.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
symbol: Trading symbol
|
||||
position_type: Type of position
|
||||
position_data: Dictionary with position information
|
||||
close_date: Close date (if closed)
|
||||
exit_price: Exit price (if closed)
|
||||
close_quantity: Quantity closed (if closed)
|
||||
|
||||
Returns:
|
||||
Created Position object
|
||||
"""
|
||||
is_closed = close_date is not None
|
||||
quantity = close_quantity if close_quantity else position_data["quantity"]
|
||||
|
||||
# Calculate P&L if closed
|
||||
realized_pnl = None
|
||||
if is_closed and position_data["entry_price"] and exit_price is not None:
|
||||
if position_data["is_short"]:
|
||||
# Short position: profit when price decreases
|
||||
realized_pnl = (
|
||||
position_data["entry_price"] - exit_price
|
||||
) * quantity * 100
|
||||
else:
|
||||
# Long position: profit when price increases
|
||||
realized_pnl = (
|
||||
exit_price - position_data["entry_price"]
|
||||
) * quantity * 100
|
||||
|
||||
# Subtract fees and commissions
|
||||
for txn in position_data["transactions"]:
|
||||
if txn.commission:
|
||||
realized_pnl -= txn.commission
|
||||
if txn.fees:
|
||||
realized_pnl -= txn.fees
|
||||
|
||||
# Extract option symbol from first transaction if this is an option
|
||||
option_symbol = None
|
||||
if position_type != PositionType.STOCK and position_data["transactions"]:
|
||||
first_txn = position_data["transactions"][0]
|
||||
# Try to extract option details from description
|
||||
option_symbol = self._extract_option_symbol_from_description(
|
||||
first_txn.description, first_txn.action, symbol
|
||||
)
|
||||
|
||||
# Create position
|
||||
position = Position(
|
||||
account_id=account_id,
|
||||
symbol=symbol,
|
||||
option_symbol=option_symbol,
|
||||
position_type=position_type,
|
||||
status=PositionStatus.CLOSED if is_closed else PositionStatus.OPEN,
|
||||
open_date=position_data["open_date"],
|
||||
close_date=close_date,
|
||||
total_quantity=quantity if not position_data["is_short"] else -quantity,
|
||||
avg_entry_price=position_data["entry_price"],
|
||||
avg_exit_price=exit_price,
|
||||
realized_pnl=realized_pnl,
|
||||
)
|
||||
|
||||
self.db.add(position)
|
||||
self.db.flush() # Get position ID
|
||||
|
||||
# Link transactions to position
|
||||
for txn in position_data["transactions"]:
|
||||
link = PositionTransaction(
|
||||
position_id=position.id, transaction_id=txn.id
|
||||
)
|
||||
self.db.add(link)
|
||||
|
||||
return position
|
||||
|
||||
def _extract_option_symbol_from_description(
|
||||
self, description: str, action: str, base_symbol: str
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Extract option symbol from transaction description.
|
||||
|
||||
Example: "CALL (TGT) TARGET CORP JAN 16 26 $95 (100 SHS)"
|
||||
Returns: "-TGT260116C95"
|
||||
|
||||
Args:
|
||||
description: Transaction description
|
||||
action: Transaction action
|
||||
base_symbol: Underlying symbol
|
||||
|
||||
Returns:
|
||||
Option symbol in standard format, or None if can't parse
|
||||
"""
|
||||
if not description:
|
||||
return None
|
||||
|
||||
# Determine if CALL or PUT
|
||||
call_or_put = None
|
||||
if "CALL" in description.upper():
|
||||
call_or_put = "C"
|
||||
elif "PUT" in description.upper():
|
||||
call_or_put = "P"
|
||||
else:
|
||||
return None
|
||||
|
||||
# Extract date and strike: "JAN 16 26 $95"
|
||||
# Pattern: MONTH DAY YY $STRIKE
|
||||
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
|
||||
match = re.search(date_strike_pattern, description)
|
||||
|
||||
if not match:
|
||||
return None
|
||||
|
||||
month_abbr, day, year, strike = match.groups()
|
||||
|
||||
# Convert month abbreviation to number
|
||||
month_map = {
|
||||
'JAN': '01', 'FEB': '02', 'MAR': '03', 'APR': '04',
|
||||
'MAY': '05', 'JUN': '06', 'JUL': '07', 'AUG': '08',
|
||||
'SEP': '09', 'OCT': '10', 'NOV': '11', 'DEC': '12'
|
||||
}
|
||||
|
||||
month = month_map.get(month_abbr.upper())
|
||||
if not month:
|
||||
return None
|
||||
|
||||
# Format: -SYMBOL + YYMMDD + C/P + STRIKE
|
||||
# Remove decimal point from strike if it's a whole number
|
||||
strike_num = float(strike)
|
||||
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike.replace('.', '')
|
||||
|
||||
option_symbol = f"-{base_symbol}{year}{month}{day.zfill(2)}{call_or_put}{strike_str}"
|
||||
return option_symbol
|
||||
|
||||
def _determine_position_type_from_txn(self, txn: Transaction) -> PositionType:
|
||||
"""
|
||||
Determine position type from transaction action/description.
|
||||
|
||||
Args:
|
||||
txn: Transaction to analyze
|
||||
|
||||
Returns:
|
||||
PositionType (STOCK, CALL, or PUT)
|
||||
"""
|
||||
# Check action and description for option indicators
|
||||
action_upper = txn.action.upper() if txn.action else ""
|
||||
desc_upper = txn.description.upper() if txn.description else ""
|
||||
|
||||
# Look for CALL or PUT keywords
|
||||
if "CALL" in action_upper or "CALL" in desc_upper:
|
||||
return PositionType.CALL
|
||||
elif "PUT" in action_upper or "PUT" in desc_upper:
|
||||
return PositionType.PUT
|
||||
|
||||
# Fall back to checking symbol format (for backwards compatibility)
|
||||
if txn.symbol and txn.symbol.startswith("-"):
|
||||
option_info = parse_option_symbol(txn.symbol)
|
||||
if option_info:
|
||||
return (
|
||||
PositionType.CALL
|
||||
if option_info.option_type == "CALL"
|
||||
else PositionType.PUT
|
||||
)
|
||||
|
||||
return PositionType.STOCK
|
||||
|
||||
def _get_base_symbol(self, symbol: str) -> str:
|
||||
"""Extract base symbol from option symbol."""
|
||||
if symbol.startswith("-"):
|
||||
option_info = parse_option_symbol(symbol)
|
||||
if option_info:
|
||||
return option_info.underlying_symbol
|
||||
return symbol
|
||||
|
||||
def _is_opening_transaction(self, action: str) -> bool:
|
||||
"""Check if action represents opening a position."""
|
||||
opening_keywords = [
|
||||
"OPENING TRANSACTION",
|
||||
"YOU BOUGHT OPENING",
|
||||
"YOU SOLD OPENING",
|
||||
]
|
||||
return any(keyword in action for keyword in opening_keywords)
|
||||
|
||||
def _is_closing_transaction(self, action: str) -> bool:
|
||||
"""Check if action represents closing a position."""
|
||||
closing_keywords = [
|
||||
"CLOSING TRANSACTION",
|
||||
"YOU BOUGHT CLOSING",
|
||||
"YOU SOLD CLOSING",
|
||||
"ASSIGNED",
|
||||
]
|
||||
return any(keyword in action for keyword in closing_keywords)
|
||||
|
||||
def _is_expiration(self, action: str) -> bool:
|
||||
"""Check if action represents an expiration."""
|
||||
return "EXPIRED" in action
|
||||
|
||||
def _get_grouping_key(self, txn: Transaction) -> str:
|
||||
"""
|
||||
Create a unique grouping key for transactions.
|
||||
|
||||
For options, returns: symbol + option details (e.g., "TGT-JAN16-100C")
|
||||
For stocks, returns: just the symbol (e.g., "TGT")
|
||||
|
||||
Args:
|
||||
txn: Transaction to create key for
|
||||
|
||||
Returns:
|
||||
Grouping key string
|
||||
"""
|
||||
# Determine if this is an option transaction
|
||||
action_upper = txn.action.upper() if txn.action else ""
|
||||
desc_upper = txn.description.upper() if txn.description else ""
|
||||
|
||||
is_option = "CALL" in action_upper or "CALL" in desc_upper or "PUT" in action_upper or "PUT" in desc_upper
|
||||
|
||||
if not is_option or not txn.description:
|
||||
# Stock transaction - group by symbol only
|
||||
return txn.symbol
|
||||
|
||||
# Option transaction - extract strike and expiration to create unique key
|
||||
# Pattern: "CALL (TGT) TARGET CORP JAN 16 26 $100 (100 SHS)"
|
||||
date_strike_pattern = r'([A-Z]{3})\s+(\d{1,2})\s+(\d{2})\s+\$([\d.]+)'
|
||||
match = re.search(date_strike_pattern, txn.description)
|
||||
|
||||
if not match:
|
||||
# Can't parse option details, fall back to symbol only
|
||||
return txn.symbol
|
||||
|
||||
month_abbr, day, year, strike = match.groups()
|
||||
|
||||
# Determine call or put
|
||||
call_or_put = "C" if "CALL" in desc_upper else "P"
|
||||
|
||||
# Create key: SYMBOL-MONTHDAY-STRIKEC/P
|
||||
# e.g., "TGT-JAN16-100C"
|
||||
strike_num = float(strike)
|
||||
strike_str = str(int(strike_num)) if strike_num.is_integer() else strike
|
||||
|
||||
grouping_key = f"{txn.symbol}-{month_abbr}{day}-{strike_str}{call_or_put}"
|
||||
return grouping_key
|
||||
5
backend/app/utils/__init__.py
Normal file
5
backend/app/utils/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Utility functions and helpers."""
|
||||
from app.utils.deduplication import generate_transaction_hash
|
||||
from app.utils.option_parser import parse_option_symbol, OptionInfo
|
||||
|
||||
__all__ = ["generate_transaction_hash", "parse_option_symbol", "OptionInfo"]
|
||||
65
backend/app/utils/deduplication.py
Normal file
65
backend/app/utils/deduplication.py
Normal file
@@ -0,0 +1,65 @@
|
||||
"""Transaction deduplication utilities."""
|
||||
import hashlib
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
|
||||
|
||||
def generate_transaction_hash(
|
||||
account_id: int,
|
||||
run_date: date,
|
||||
symbol: Optional[str],
|
||||
action: str,
|
||||
amount: Optional[Decimal],
|
||||
quantity: Optional[Decimal],
|
||||
price: Optional[Decimal],
|
||||
) -> str:
|
||||
"""
|
||||
Generate a unique SHA-256 hash for a transaction to prevent duplicates.
|
||||
|
||||
The hash is generated from key transaction attributes that uniquely identify
|
||||
a transaction: account, date, symbol, action, amount, quantity, and price.
|
||||
|
||||
Args:
|
||||
account_id: Account identifier
|
||||
run_date: Transaction date
|
||||
symbol: Trading symbol
|
||||
action: Transaction action description
|
||||
amount: Transaction amount
|
||||
quantity: Number of shares/contracts
|
||||
price: Price per unit
|
||||
|
||||
Returns:
|
||||
str: 64-character hexadecimal SHA-256 hash
|
||||
|
||||
Example:
|
||||
>>> generate_transaction_hash(
|
||||
... account_id=1,
|
||||
... run_date=date(2025, 12, 26),
|
||||
... symbol="AAPL",
|
||||
... action="YOU BOUGHT",
|
||||
... amount=Decimal("-1500.00"),
|
||||
... quantity=Decimal("10"),
|
||||
... price=Decimal("150.00")
|
||||
... )
|
||||
'a1b2c3d4...'
|
||||
"""
|
||||
# Convert values to strings, handling None values
|
||||
symbol_str = symbol or ""
|
||||
amount_str = str(amount) if amount is not None else ""
|
||||
quantity_str = str(quantity) if quantity is not None else ""
|
||||
price_str = str(price) if price is not None else ""
|
||||
|
||||
# Create hash string with pipe delimiter
|
||||
hash_string = (
|
||||
f"{account_id}|"
|
||||
f"{run_date.isoformat()}|"
|
||||
f"{symbol_str}|"
|
||||
f"{action}|"
|
||||
f"{amount_str}|"
|
||||
f"{quantity_str}|"
|
||||
f"{price_str}"
|
||||
)
|
||||
|
||||
# Generate SHA-256 hash
|
||||
return hashlib.sha256(hash_string.encode("utf-8")).hexdigest()
|
||||
91
backend/app/utils/option_parser.py
Normal file
91
backend/app/utils/option_parser.py
Normal file
@@ -0,0 +1,91 @@
|
||||
"""Option symbol parsing utilities."""
|
||||
import re
|
||||
from datetime import datetime
|
||||
from typing import Optional, NamedTuple
|
||||
from decimal import Decimal
|
||||
|
||||
|
||||
class OptionInfo(NamedTuple):
|
||||
"""
|
||||
Parsed option information.
|
||||
|
||||
Attributes:
|
||||
underlying_symbol: Base ticker symbol (e.g., "AAPL")
|
||||
expiration_date: Option expiration date
|
||||
option_type: "CALL" or "PUT"
|
||||
strike_price: Strike price
|
||||
"""
|
||||
underlying_symbol: str
|
||||
expiration_date: datetime
|
||||
option_type: str
|
||||
strike_price: Decimal
|
||||
|
||||
|
||||
def parse_option_symbol(option_symbol: str) -> Optional[OptionInfo]:
|
||||
"""
|
||||
Parse Fidelity option symbol format into components.
|
||||
|
||||
Fidelity format: -SYMBOL + YYMMDD + C/P + STRIKE
|
||||
Example: -AAPL260116C150 = AAPL Call expiring Jan 16, 2026 at $150 strike
|
||||
|
||||
Args:
|
||||
option_symbol: Fidelity option symbol string
|
||||
|
||||
Returns:
|
||||
OptionInfo object if parsing successful, None otherwise
|
||||
|
||||
Examples:
|
||||
>>> parse_option_symbol("-AAPL260116C150")
|
||||
OptionInfo(
|
||||
underlying_symbol='AAPL',
|
||||
expiration_date=datetime(2026, 1, 16),
|
||||
option_type='CALL',
|
||||
strike_price=Decimal('150')
|
||||
)
|
||||
|
||||
>>> parse_option_symbol("-TSLA251219P500")
|
||||
OptionInfo(
|
||||
underlying_symbol='TSLA',
|
||||
expiration_date=datetime(2025, 12, 19),
|
||||
option_type='PUT',
|
||||
strike_price=Decimal('500')
|
||||
)
|
||||
"""
|
||||
# Regex pattern: -SYMBOL + YYMMDD + C/P + STRIKE
|
||||
# Symbol: one or more uppercase letters
|
||||
# Date: 6 digits (YYMMDD)
|
||||
# Type: C (call) or P (put)
|
||||
# Strike: digits with optional decimal point
|
||||
pattern = r"^-([A-Z]+)(\d{6})([CP])(\d+\.?\d*)$"
|
||||
|
||||
match = re.match(pattern, option_symbol)
|
||||
if not match:
|
||||
return None
|
||||
|
||||
symbol, date_str, option_type, strike_str = match.groups()
|
||||
|
||||
# Parse date (YYMMDD format)
|
||||
try:
|
||||
# Assume 20XX for years (works until 2100)
|
||||
year = 2000 + int(date_str[:2])
|
||||
month = int(date_str[2:4])
|
||||
day = int(date_str[4:6])
|
||||
expiration_date = datetime(year, month, day)
|
||||
except (ValueError, IndexError):
|
||||
return None
|
||||
|
||||
# Parse option type
|
||||
option_type_full = "CALL" if option_type == "C" else "PUT"
|
||||
|
||||
# Parse strike price
|
||||
try:
|
||||
strike_price = Decimal(strike_str)
|
||||
except (ValueError, ArithmeticError):
|
||||
return None
|
||||
|
||||
return OptionInfo(
|
||||
underlying_symbol=symbol,
|
||||
expiration_date=expiration_date,
|
||||
option_type=option_type_full,
|
||||
strike_price=strike_price,
|
||||
)
|
||||
12
backend/requirements.txt
Normal file
12
backend/requirements.txt
Normal file
@@ -0,0 +1,12 @@
|
||||
fastapi==0.109.0
|
||||
uvicorn[standard]==0.27.0
|
||||
sqlalchemy==2.0.25
|
||||
alembic==1.13.1
|
||||
psycopg2-binary==2.9.9
|
||||
pydantic==2.5.3
|
||||
pydantic-settings==2.1.0
|
||||
python-multipart==0.0.6
|
||||
pandas==2.1.4
|
||||
yfinance==0.2.35
|
||||
python-dateutil==2.8.2
|
||||
pytz==2024.1
|
||||
94
backend/seed_demo_data.py
Normal file
94
backend/seed_demo_data.py
Normal file
@@ -0,0 +1,94 @@
|
||||
"""
|
||||
Demo data seeder script.
|
||||
Creates a sample account and imports the provided CSV file.
|
||||
"""
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
# Add parent directory to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
from app.database import SessionLocal, engine, Base
|
||||
from app.models import Account
|
||||
from app.services import ImportService
|
||||
from app.services.position_tracker import PositionTracker
|
||||
|
||||
def seed_demo_data():
|
||||
"""Seed demo account and transactions."""
|
||||
print("🌱 Seeding demo data...")
|
||||
|
||||
# Create tables
|
||||
Base.metadata.create_all(bind=engine)
|
||||
|
||||
# Create database session
|
||||
db = SessionLocal()
|
||||
|
||||
try:
|
||||
# Check if demo account already exists
|
||||
existing = (
|
||||
db.query(Account)
|
||||
.filter(Account.account_number == "DEMO123456")
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
print("✅ Demo account already exists")
|
||||
demo_account = existing
|
||||
else:
|
||||
# Create demo account
|
||||
demo_account = Account(
|
||||
account_number="DEMO123456",
|
||||
account_name="Demo Trading Account",
|
||||
account_type="margin",
|
||||
)
|
||||
db.add(demo_account)
|
||||
db.commit()
|
||||
db.refresh(demo_account)
|
||||
print(f"✅ Created demo account (ID: {demo_account.id})")
|
||||
|
||||
# Check for CSV file
|
||||
csv_path = Path("/app/imports/History_for_Account_X38661988.csv")
|
||||
if not csv_path.exists():
|
||||
# Try alternative path (development)
|
||||
csv_path = Path(__file__).parent.parent / "History_for_Account_X38661988.csv"
|
||||
|
||||
if not csv_path.exists():
|
||||
print("⚠️ Sample CSV file not found. Skipping import.")
|
||||
print(" Place the CSV file in /app/imports/ to seed demo data.")
|
||||
return
|
||||
|
||||
# Import transactions
|
||||
print(f"📊 Importing transactions from {csv_path.name}...")
|
||||
import_service = ImportService(db)
|
||||
result = import_service.import_from_file(csv_path, demo_account.id)
|
||||
|
||||
print(f"✅ Imported {result.imported} transactions")
|
||||
print(f" Skipped {result.skipped} duplicates")
|
||||
if result.errors:
|
||||
print(f" ⚠️ {len(result.errors)} errors occurred")
|
||||
|
||||
# Build positions
|
||||
if result.imported > 0:
|
||||
print("📈 Building positions...")
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(demo_account.id)
|
||||
print(f"✅ Created {positions_created} positions")
|
||||
|
||||
print("\n🎉 Demo data seeded successfully!")
|
||||
print(f"\n📝 Demo Account Details:")
|
||||
print(f" Account Number: {demo_account.account_number}")
|
||||
print(f" Account Name: {demo_account.account_name}")
|
||||
print(f" Account ID: {demo_account.id}")
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error seeding demo data: {e}")
|
||||
db.rollback()
|
||||
raise
|
||||
|
||||
finally:
|
||||
db.close()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
seed_demo_data()
|
||||
Reference in New Issue
Block a user