Initial release v1.1.0
- Complete MVP for tracking Fidelity brokerage account performance - Transaction import from CSV with deduplication - Automatic FIFO position tracking with options support - Real-time P&L calculations with market data caching - Dashboard with timeframe filtering (30/90/180 days, 1 year, YTD, all time) - Docker-based deployment with PostgreSQL backend - React/TypeScript frontend with TailwindCSS - FastAPI backend with SQLAlchemy ORM Features: - Multi-account support - Import via CSV upload or filesystem - Open and closed position tracking - Balance history charting - Performance analytics and metrics - Top trades analysis - Responsive UI design Co-Authored-By: Claude Sonnet 4.5 <noreply@anthropic.com>
This commit is contained in:
1
backend/app/api/endpoints/__init__.py
Normal file
1
backend/app/api/endpoints/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""API endpoint modules."""
|
||||
151
backend/app/api/endpoints/accounts.py
Normal file
151
backend/app/api/endpoints/accounts.py
Normal file
@@ -0,0 +1,151 @@
|
||||
"""Account management API endpoints."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, status
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import List
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.models import Account
|
||||
from app.schemas import AccountCreate, AccountUpdate, AccountResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("", response_model=AccountResponse, status_code=status.HTTP_201_CREATED)
|
||||
def create_account(account: AccountCreate, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Create a new brokerage account.
|
||||
|
||||
Args:
|
||||
account: Account creation data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Created account
|
||||
|
||||
Raises:
|
||||
HTTPException: If account number already exists
|
||||
"""
|
||||
# Check if account number already exists
|
||||
existing = (
|
||||
db.query(Account)
|
||||
.filter(Account.account_number == account.account_number)
|
||||
.first()
|
||||
)
|
||||
|
||||
if existing:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail=f"Account with number {account.account_number} already exists",
|
||||
)
|
||||
|
||||
# Create new account
|
||||
db_account = Account(**account.model_dump())
|
||||
db.add(db_account)
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
|
||||
return db_account
|
||||
|
||||
|
||||
@router.get("", response_model=List[AccountResponse])
|
||||
def list_accounts(skip: int = 0, limit: int = 100, db: Session = Depends(get_db)):
|
||||
"""
|
||||
List all accounts.
|
||||
|
||||
Args:
|
||||
skip: Number of records to skip
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of accounts
|
||||
"""
|
||||
accounts = db.query(Account).offset(skip).limit(limit).all()
|
||||
return accounts
|
||||
|
||||
|
||||
@router.get("/{account_id}", response_model=AccountResponse)
|
||||
def get_account(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get account by ID.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Account details
|
||||
|
||||
Raises:
|
||||
HTTPException: If account not found
|
||||
"""
|
||||
account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account {account_id} not found",
|
||||
)
|
||||
|
||||
return account
|
||||
|
||||
|
||||
@router.put("/{account_id}", response_model=AccountResponse)
|
||||
def update_account(
|
||||
account_id: int, account_update: AccountUpdate, db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Update account details.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
account_update: Updated account data
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Updated account
|
||||
|
||||
Raises:
|
||||
HTTPException: If account not found
|
||||
"""
|
||||
db_account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not db_account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account {account_id} not found",
|
||||
)
|
||||
|
||||
# Update fields
|
||||
update_data = account_update.model_dump(exclude_unset=True)
|
||||
for field, value in update_data.items():
|
||||
setattr(db_account, field, value)
|
||||
|
||||
db.commit()
|
||||
db.refresh(db_account)
|
||||
|
||||
return db_account
|
||||
|
||||
|
||||
@router.delete("/{account_id}", status_code=status.HTTP_204_NO_CONTENT)
|
||||
def delete_account(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Delete an account and all associated data.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Raises:
|
||||
HTTPException: If account not found
|
||||
"""
|
||||
db_account = db.query(Account).filter(Account.id == account_id).first()
|
||||
|
||||
if not db_account:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Account {account_id} not found",
|
||||
)
|
||||
|
||||
db.delete(db_account)
|
||||
db.commit()
|
||||
111
backend/app/api/endpoints/analytics.py
Normal file
111
backend/app/api/endpoints/analytics.py
Normal file
@@ -0,0 +1,111 @@
|
||||
"""Analytics API endpoints."""
|
||||
from fastapi import APIRouter, Depends, Query
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.services.performance_calculator import PerformanceCalculator
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/overview/{account_id}")
|
||||
def get_overview(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get overview statistics for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
stats = calculator.calculate_account_stats(account_id)
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/balance-history/{account_id}")
|
||||
def get_balance_history(
|
||||
account_id: int,
|
||||
days: int = Query(default=30, ge=1, le=3650),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve (default: 30)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
history = calculator.get_balance_history(account_id, days)
|
||||
return {"data": history}
|
||||
|
||||
|
||||
@router.get("/top-trades/{account_id}")
|
||||
def get_top_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get top performing trades.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 20)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
trades = calculator.get_top_trades(account_id, limit)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.get("/worst-trades/{account_id}")
|
||||
def get_worst_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get worst performing trades (biggest losses).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 20)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
trades = calculator.get_worst_trades(account_id, limit)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.post("/update-pnl/{account_id}")
|
||||
def update_unrealized_pnl(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Update unrealized P&L for all open positions in an account.
|
||||
|
||||
Fetches current market prices and recalculates P&L.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of positions updated
|
||||
"""
|
||||
calculator = PerformanceCalculator(db)
|
||||
updated = calculator.update_open_positions_pnl(account_id)
|
||||
return {"positions_updated": updated}
|
||||
273
backend/app/api/endpoints/analytics_v2.py
Normal file
273
backend/app/api/endpoints/analytics_v2.py
Normal file
@@ -0,0 +1,273 @@
|
||||
"""
|
||||
Enhanced analytics API endpoints with efficient market data handling.
|
||||
|
||||
This version uses PerformanceCalculatorV2 with:
|
||||
- Database-backed price caching
|
||||
- Rate-limited API calls
|
||||
- Stale-while-revalidate pattern for better UX
|
||||
"""
|
||||
from fastapi import APIRouter, Depends, Query, BackgroundTasks
|
||||
from sqlalchemy.orm import Session
|
||||
from typing import Optional
|
||||
from datetime import date
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.services.performance_calculator_v2 import PerformanceCalculatorV2
|
||||
from app.services.market_data_service import MarketDataService
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("/overview/{account_id}")
|
||||
def get_overview(
|
||||
account_id: int,
|
||||
refresh_prices: bool = Query(default=False, description="Force fresh price fetch"),
|
||||
max_api_calls: int = Query(default=5, ge=0, le=50, description="Max Yahoo Finance API calls"),
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Get overview statistics for an account.
|
||||
|
||||
By default, uses cached prices (stale-while-revalidate pattern).
|
||||
Set refresh_prices=true to force fresh data (may be slow).
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
refresh_prices: Whether to fetch fresh prices from Yahoo Finance
|
||||
max_api_calls: Maximum number of API calls to make
|
||||
start_date: Filter positions opened on or after this date
|
||||
end_date: Filter positions opened on or before this date
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with performance metrics and cache stats
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
|
||||
|
||||
# If not refreshing, use cached only (fast)
|
||||
if not refresh_prices:
|
||||
max_api_calls = 0
|
||||
|
||||
stats = calculator.calculate_account_stats(
|
||||
account_id,
|
||||
update_prices=True,
|
||||
max_api_calls=max_api_calls,
|
||||
start_date=start_date,
|
||||
end_date=end_date
|
||||
)
|
||||
|
||||
return stats
|
||||
|
||||
|
||||
@router.get("/balance-history/{account_id}")
|
||||
def get_balance_history(
|
||||
account_id: int,
|
||||
days: int = Query(default=30, ge=1, le=3650),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get account balance history for charting.
|
||||
|
||||
This endpoint doesn't need market data, so it's always fast.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
days: Number of days to retrieve (default: 30)
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of {date, balance} dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db)
|
||||
history = calculator.get_balance_history(account_id, days)
|
||||
return {"data": history}
|
||||
|
||||
|
||||
@router.get("/top-trades/{account_id}")
|
||||
def get_top_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get top performing trades.
|
||||
|
||||
This endpoint only uses closed positions, so no market data needed.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 10)
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db)
|
||||
trades = calculator.get_top_trades(account_id, limit, start_date, end_date)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.get("/worst-trades/{account_id}")
|
||||
def get_worst_trades(
|
||||
account_id: int,
|
||||
limit: int = Query(default=10, ge=1, le=100),
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Get worst performing trades.
|
||||
|
||||
This endpoint only uses closed positions, so no market data needed.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
limit: Maximum number of trades to return (default: 10)
|
||||
start_date: Filter positions closed on or after this date
|
||||
end_date: Filter positions closed on or before this date
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of trade dictionaries
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db)
|
||||
trades = calculator.get_worst_trades(account_id, limit, start_date, end_date)
|
||||
return {"data": trades}
|
||||
|
||||
|
||||
@router.post("/refresh-prices/{account_id}")
|
||||
def refresh_prices(
|
||||
account_id: int,
|
||||
max_api_calls: int = Query(default=10, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Manually trigger a price refresh for open positions.
|
||||
|
||||
This is useful when you want fresh data but don't want to wait
|
||||
on the dashboard load.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Update statistics
|
||||
"""
|
||||
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
|
||||
|
||||
stats = calculator.update_open_positions_pnl(
|
||||
account_id,
|
||||
max_api_calls=max_api_calls,
|
||||
allow_stale=False # Force fresh fetches
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Price refresh completed",
|
||||
"stats": stats
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh-prices-background/{account_id}")
|
||||
def refresh_prices_background(
|
||||
account_id: int,
|
||||
background_tasks: BackgroundTasks,
|
||||
max_api_calls: int = Query(default=20, ge=1, le=50),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Trigger a background price refresh.
|
||||
|
||||
This returns immediately while prices are fetched in the background.
|
||||
Client can poll /overview endpoint to see updated data.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
background_tasks: FastAPI background tasks
|
||||
max_api_calls: Maximum number of Yahoo Finance API calls
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Acknowledgment that background task was started
|
||||
"""
|
||||
def refresh_task():
|
||||
calculator = PerformanceCalculatorV2(db, cache_ttl=300)
|
||||
calculator.update_open_positions_pnl(
|
||||
account_id,
|
||||
max_api_calls=max_api_calls,
|
||||
allow_stale=False
|
||||
)
|
||||
|
||||
background_tasks.add_task(refresh_task)
|
||||
|
||||
return {
|
||||
"message": "Price refresh started in background",
|
||||
"account_id": account_id,
|
||||
"max_api_calls": max_api_calls
|
||||
}
|
||||
|
||||
|
||||
@router.post("/refresh-stale-cache")
|
||||
def refresh_stale_cache(
|
||||
min_age_minutes: int = Query(default=10, ge=1, le=1440),
|
||||
limit: int = Query(default=20, ge=1, le=100),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Background maintenance endpoint to refresh stale cached prices.
|
||||
|
||||
This can be called periodically (e.g., via cron) to keep cache fresh.
|
||||
|
||||
Args:
|
||||
min_age_minutes: Only refresh prices older than this many minutes
|
||||
limit: Maximum number of prices to refresh
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of prices refreshed
|
||||
"""
|
||||
market_data = MarketDataService(db, cache_ttl_seconds=300)
|
||||
|
||||
refreshed = market_data.refresh_stale_prices(
|
||||
min_age_seconds=min_age_minutes * 60,
|
||||
limit=limit
|
||||
)
|
||||
|
||||
return {
|
||||
"message": "Stale price refresh completed",
|
||||
"refreshed": refreshed,
|
||||
"min_age_minutes": min_age_minutes
|
||||
}
|
||||
|
||||
|
||||
@router.delete("/clear-old-cache")
|
||||
def clear_old_cache(
|
||||
older_than_days: int = Query(default=30, ge=1, le=365),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
Clear old cached prices from database.
|
||||
|
||||
Args:
|
||||
older_than_days: Delete prices older than this many days
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of records deleted
|
||||
"""
|
||||
market_data = MarketDataService(db)
|
||||
|
||||
deleted = market_data.clear_cache(older_than_days=older_than_days)
|
||||
|
||||
return {
|
||||
"message": "Old cache cleared",
|
||||
"deleted": deleted,
|
||||
"older_than_days": older_than_days
|
||||
}
|
||||
128
backend/app/api/endpoints/import_endpoint.py
Normal file
128
backend/app/api/endpoints/import_endpoint.py
Normal file
@@ -0,0 +1,128 @@
|
||||
"""Import API endpoints for CSV file uploads."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, File, status
|
||||
from sqlalchemy.orm import Session
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import shutil
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.services import ImportService
|
||||
from app.services.position_tracker import PositionTracker
|
||||
from app.config import settings
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post("/upload/{account_id}")
|
||||
def upload_csv(
|
||||
account_id: int, file: UploadFile = File(...), db: Session = Depends(get_db)
|
||||
):
|
||||
"""
|
||||
Upload and import a CSV file for an account.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to import transactions for
|
||||
file: CSV file to upload
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Import statistics
|
||||
|
||||
Raises:
|
||||
HTTPException: If import fails
|
||||
"""
|
||||
if not file.filename.endswith(".csv"):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST, detail="File must be a CSV"
|
||||
)
|
||||
|
||||
# Save uploaded file to temporary location
|
||||
try:
|
||||
with tempfile.NamedTemporaryFile(delete=False, suffix=".csv") as tmp_file:
|
||||
shutil.copyfileobj(file.file, tmp_file)
|
||||
tmp_path = Path(tmp_file.name)
|
||||
|
||||
# Import transactions
|
||||
import_service = ImportService(db)
|
||||
result = import_service.import_from_file(tmp_path, account_id)
|
||||
|
||||
# Rebuild positions after import
|
||||
if result.imported > 0:
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(account_id)
|
||||
else:
|
||||
positions_created = 0
|
||||
|
||||
# Clean up temporary file
|
||||
tmp_path.unlink()
|
||||
|
||||
return {
|
||||
"filename": file.filename,
|
||||
"imported": result.imported,
|
||||
"skipped": result.skipped,
|
||||
"errors": result.errors,
|
||||
"total_rows": result.total_rows,
|
||||
"positions_created": positions_created,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Import failed: {str(e)}",
|
||||
)
|
||||
|
||||
|
||||
@router.post("/filesystem/{account_id}")
|
||||
def import_from_filesystem(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Import all CSV files from the filesystem import directory.
|
||||
|
||||
Args:
|
||||
account_id: Account ID to import transactions for
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Import statistics for all files
|
||||
|
||||
Raises:
|
||||
HTTPException: If import directory doesn't exist
|
||||
"""
|
||||
import_dir = Path(settings.IMPORT_DIR)
|
||||
|
||||
if not import_dir.exists():
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Import directory not found: {import_dir}",
|
||||
)
|
||||
|
||||
try:
|
||||
import_service = ImportService(db)
|
||||
results = import_service.import_from_directory(import_dir, account_id)
|
||||
|
||||
# Rebuild positions if any transactions were imported
|
||||
total_imported = sum(r.imported for r in results.values())
|
||||
if total_imported > 0:
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(account_id)
|
||||
else:
|
||||
positions_created = 0
|
||||
|
||||
return {
|
||||
"files": {
|
||||
filename: {
|
||||
"imported": result.imported,
|
||||
"skipped": result.skipped,
|
||||
"errors": result.errors,
|
||||
"total_rows": result.total_rows,
|
||||
}
|
||||
for filename, result in results.items()
|
||||
},
|
||||
"total_imported": total_imported,
|
||||
"positions_created": positions_created,
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"Import failed: {str(e)}",
|
||||
)
|
||||
104
backend/app/api/endpoints/positions.py
Normal file
104
backend/app/api/endpoints/positions.py
Normal file
@@ -0,0 +1,104 @@
|
||||
"""Position API endpoints."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_
|
||||
from typing import List, Optional
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.models import Position
|
||||
from app.models.position import PositionStatus
|
||||
from app.schemas import PositionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[PositionResponse])
|
||||
def list_positions(
|
||||
account_id: Optional[int] = None,
|
||||
status_filter: Optional[PositionStatus] = Query(
|
||||
default=None, alias="status", description="Filter by position status"
|
||||
),
|
||||
symbol: Optional[str] = None,
|
||||
skip: int = 0,
|
||||
limit: int = Query(default=100, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List positions with optional filtering.
|
||||
|
||||
Args:
|
||||
account_id: Filter by account ID
|
||||
status_filter: Filter by status (open/closed)
|
||||
symbol: Filter by symbol
|
||||
skip: Number of records to skip (pagination)
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of positions
|
||||
"""
|
||||
query = db.query(Position)
|
||||
|
||||
# Apply filters
|
||||
if account_id:
|
||||
query = query.filter(Position.account_id == account_id)
|
||||
|
||||
if status_filter:
|
||||
query = query.filter(Position.status == status_filter)
|
||||
|
||||
if symbol:
|
||||
query = query.filter(Position.symbol == symbol)
|
||||
|
||||
# Order by most recent first
|
||||
query = query.order_by(Position.open_date.desc(), Position.id.desc())
|
||||
|
||||
# Pagination
|
||||
positions = query.offset(skip).limit(limit).all()
|
||||
|
||||
return positions
|
||||
|
||||
|
||||
@router.get("/{position_id}", response_model=PositionResponse)
|
||||
def get_position(position_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get position by ID.
|
||||
|
||||
Args:
|
||||
position_id: Position ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Position details
|
||||
|
||||
Raises:
|
||||
HTTPException: If position not found
|
||||
"""
|
||||
position = db.query(Position).filter(Position.id == position_id).first()
|
||||
|
||||
if not position:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Position {position_id} not found",
|
||||
)
|
||||
|
||||
return position
|
||||
|
||||
|
||||
@router.post("/{account_id}/rebuild")
|
||||
def rebuild_positions(account_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Rebuild all positions for an account from transactions.
|
||||
|
||||
Args:
|
||||
account_id: Account ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Number of positions created
|
||||
"""
|
||||
from app.services.position_tracker import PositionTracker
|
||||
|
||||
position_tracker = PositionTracker(db)
|
||||
positions_created = position_tracker.rebuild_positions(account_id)
|
||||
|
||||
return {"positions_created": positions_created}
|
||||
227
backend/app/api/endpoints/transactions.py
Normal file
227
backend/app/api/endpoints/transactions.py
Normal file
@@ -0,0 +1,227 @@
|
||||
"""Transaction API endpoints."""
|
||||
from fastapi import APIRouter, Depends, HTTPException, Query, status
|
||||
from sqlalchemy.orm import Session
|
||||
from sqlalchemy import and_, or_
|
||||
from typing import List, Optional, Dict
|
||||
from datetime import date
|
||||
|
||||
from app.api.deps import get_db
|
||||
from app.models import Transaction, Position, PositionTransaction
|
||||
from app.schemas import TransactionResponse
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get("", response_model=List[TransactionResponse])
|
||||
def list_transactions(
|
||||
account_id: Optional[int] = None,
|
||||
symbol: Optional[str] = None,
|
||||
start_date: Optional[date] = None,
|
||||
end_date: Optional[date] = None,
|
||||
skip: int = 0,
|
||||
limit: int = Query(default=50, le=500),
|
||||
db: Session = Depends(get_db),
|
||||
):
|
||||
"""
|
||||
List transactions with optional filtering.
|
||||
|
||||
Args:
|
||||
account_id: Filter by account ID
|
||||
symbol: Filter by symbol
|
||||
start_date: Filter by start date (inclusive)
|
||||
end_date: Filter by end date (inclusive)
|
||||
skip: Number of records to skip (pagination)
|
||||
limit: Maximum number of records to return
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
List of transactions
|
||||
"""
|
||||
query = db.query(Transaction)
|
||||
|
||||
# Apply filters
|
||||
if account_id:
|
||||
query = query.filter(Transaction.account_id == account_id)
|
||||
|
||||
if symbol:
|
||||
query = query.filter(Transaction.symbol == symbol)
|
||||
|
||||
if start_date:
|
||||
query = query.filter(Transaction.run_date >= start_date)
|
||||
|
||||
if end_date:
|
||||
query = query.filter(Transaction.run_date <= end_date)
|
||||
|
||||
# Order by date descending
|
||||
query = query.order_by(Transaction.run_date.desc(), Transaction.id.desc())
|
||||
|
||||
# Pagination
|
||||
transactions = query.offset(skip).limit(limit).all()
|
||||
|
||||
return transactions
|
||||
|
||||
|
||||
@router.get("/{transaction_id}", response_model=TransactionResponse)
|
||||
def get_transaction(transaction_id: int, db: Session = Depends(get_db)):
|
||||
"""
|
||||
Get transaction by ID.
|
||||
|
||||
Args:
|
||||
transaction_id: Transaction ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Transaction details
|
||||
|
||||
Raises:
|
||||
HTTPException: If transaction not found
|
||||
"""
|
||||
transaction = (
|
||||
db.query(Transaction).filter(Transaction.id == transaction_id).first()
|
||||
)
|
||||
|
||||
if not transaction:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Transaction {transaction_id} not found",
|
||||
)
|
||||
|
||||
return transaction
|
||||
|
||||
|
||||
@router.get("/{transaction_id}/position-details")
|
||||
def get_transaction_position_details(
|
||||
transaction_id: int, db: Session = Depends(get_db)
|
||||
) -> Dict:
|
||||
"""
|
||||
Get full position details for a transaction, including all related transactions.
|
||||
|
||||
This endpoint finds the position associated with a transaction and returns:
|
||||
- All transactions that are part of the same position
|
||||
- Position metadata (type, status, P&L, etc.)
|
||||
- Strategy classification for options (covered call, cash-secured put, etc.)
|
||||
|
||||
Args:
|
||||
transaction_id: Transaction ID
|
||||
db: Database session
|
||||
|
||||
Returns:
|
||||
Dictionary with position details and all related transactions
|
||||
|
||||
Raises:
|
||||
HTTPException: If transaction not found or not part of a position
|
||||
"""
|
||||
# Find the transaction
|
||||
transaction = (
|
||||
db.query(Transaction).filter(Transaction.id == transaction_id).first()
|
||||
)
|
||||
|
||||
if not transaction:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Transaction {transaction_id} not found",
|
||||
)
|
||||
|
||||
# Find the position this transaction belongs to
|
||||
position_link = (
|
||||
db.query(PositionTransaction)
|
||||
.filter(PositionTransaction.transaction_id == transaction_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not position_link:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Transaction {transaction_id} is not part of any position",
|
||||
)
|
||||
|
||||
# Get the position with all its transactions
|
||||
position = (
|
||||
db.query(Position)
|
||||
.filter(Position.id == position_link.position_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not position:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail=f"Position not found",
|
||||
)
|
||||
|
||||
# Get all transactions for this position
|
||||
all_transactions = []
|
||||
for link in position.transaction_links:
|
||||
txn = link.transaction
|
||||
all_transactions.append({
|
||||
"id": txn.id,
|
||||
"run_date": txn.run_date.isoformat(),
|
||||
"action": txn.action,
|
||||
"symbol": txn.symbol,
|
||||
"description": txn.description,
|
||||
"quantity": float(txn.quantity) if txn.quantity else None,
|
||||
"price": float(txn.price) if txn.price else None,
|
||||
"amount": float(txn.amount) if txn.amount else None,
|
||||
"commission": float(txn.commission) if txn.commission else None,
|
||||
"fees": float(txn.fees) if txn.fees else None,
|
||||
})
|
||||
|
||||
# Sort transactions by date
|
||||
all_transactions.sort(key=lambda t: t["run_date"])
|
||||
|
||||
# Determine strategy type for options
|
||||
strategy = _classify_option_strategy(position, all_transactions)
|
||||
|
||||
return {
|
||||
"position": {
|
||||
"id": position.id,
|
||||
"symbol": position.symbol,
|
||||
"option_symbol": position.option_symbol,
|
||||
"position_type": position.position_type.value,
|
||||
"status": position.status.value,
|
||||
"open_date": position.open_date.isoformat(),
|
||||
"close_date": position.close_date.isoformat() if position.close_date else None,
|
||||
"total_quantity": float(position.total_quantity),
|
||||
"avg_entry_price": float(position.avg_entry_price) if position.avg_entry_price is not None else None,
|
||||
"avg_exit_price": float(position.avg_exit_price) if position.avg_exit_price is not None else None,
|
||||
"realized_pnl": float(position.realized_pnl) if position.realized_pnl is not None else None,
|
||||
"unrealized_pnl": float(position.unrealized_pnl) if position.unrealized_pnl is not None else None,
|
||||
"strategy": strategy,
|
||||
},
|
||||
"transactions": all_transactions,
|
||||
}
|
||||
|
||||
|
||||
def _classify_option_strategy(position: Position, transactions: List[Dict]) -> str:
|
||||
"""
|
||||
Classify the option strategy based on position type and transactions.
|
||||
|
||||
Args:
|
||||
position: Position object
|
||||
transactions: List of transaction dictionaries
|
||||
|
||||
Returns:
|
||||
Strategy name (e.g., "Long Call", "Covered Call", "Cash-Secured Put")
|
||||
"""
|
||||
if position.position_type.value == "stock":
|
||||
return "Stock"
|
||||
|
||||
# Check if this is a short or long position
|
||||
is_short = position.total_quantity < 0
|
||||
|
||||
# For options
|
||||
if position.position_type.value == "call":
|
||||
if is_short:
|
||||
# Short call - could be covered or naked
|
||||
# We'd need to check if there's a corresponding stock position to determine
|
||||
# For now, just return "Short Call" (could enhance later)
|
||||
return "Short Call (Covered Call)"
|
||||
else:
|
||||
return "Long Call"
|
||||
elif position.position_type.value == "put":
|
||||
if is_short:
|
||||
# Short put - could be cash-secured or naked
|
||||
return "Short Put (Cash-Secured Put)"
|
||||
else:
|
||||
return "Long Put"
|
||||
|
||||
return "Unknown"
|
||||
Reference in New Issue
Block a user