""" Integration tests for the FastAPI API layer. Uses TestClient with an in-memory SQLite DB — no live market data. """ from datetime import date, timedelta from unittest.mock import MagicMock, patch import pytest from fastapi.testclient import TestClient from sqlalchemy import create_engine from sqlalchemy.orm import sessionmaker from app.database import Base, get_db from app.main import app # ─── Test DB setup ───────────────────────────────────────────────────────────── TEST_DB_URL = "sqlite://" # in-memory test_engine = create_engine(TEST_DB_URL, connect_args={"check_same_thread": False}) TestSessionLocal = sessionmaker(autocommit=False, autoflush=False, bind=test_engine) def override_get_db(): db = TestSessionLocal() try: yield db finally: db.close() @pytest.fixture(autouse=True) def setup_db(): Base.metadata.create_all(bind=test_engine) yield Base.metadata.drop_all(bind=test_engine) app.dependency_overrides[get_db] = override_get_db # Disable scheduler during tests with patch("app.main.start_scheduler"), patch("app.main.stop_scheduler"): client = TestClient(app, raise_server_exceptions=True) FAKE_TOKEN = "abc123device0000000000000000000000000000000000000000000000000000" # ─── Device registration ─────────────────────────────────────────────────────── def test_register_device(): resp = client.post("/api/v1/devices/register", json={"apns_token": FAKE_TOKEN, "device_name": "Test iPhone"}) assert resp.status_code == 200 data = resp.json() assert data["apns_token"] == FAKE_TOKEN assert "id" in data def test_register_device_idempotent(): client.post("/api/v1/devices/register", json={"apns_token": FAKE_TOKEN}) resp = client.post("/api/v1/devices/register", json={"apns_token": FAKE_TOKEN}) assert resp.status_code == 200 # ─── Portfolio ───────────────────────────────────────────────────────────────── @pytest.fixture def registered_device(): client.post("/api/v1/devices/register", json={"apns_token": FAKE_TOKEN}) return FAKE_TOKEN def test_add_portfolio(registered_device): resp = client.post( "/api/v1/portfolio", json=[{"ticker": "AAPL", "shares": 100}, {"ticker": "MSFT", "shares": 200}], headers={"X-Device-Token": registered_device}, ) assert resp.status_code == 200 tickers = [p["ticker"] for p in resp.json()] assert "AAPL" in tickers assert "MSFT" in tickers def test_get_portfolio_empty(registered_device): resp = client.get("/api/v1/portfolio", headers={"X-Device-Token": registered_device}) assert resp.status_code == 200 assert resp.json() == [] def test_get_portfolio_after_add(registered_device): client.post( "/api/v1/portfolio", json=[{"ticker": "NVDA", "shares": 50}], headers={"X-Device-Token": registered_device}, ) resp = client.get("/api/v1/portfolio", headers={"X-Device-Token": registered_device}) assert resp.status_code == 200 assert resp.json()[0]["ticker"] == "NVDA" def test_delete_ticker(registered_device): client.post( "/api/v1/portfolio", json=[{"ticker": "AMD", "shares": 100}], headers={"X-Device-Token": registered_device}, ) resp = client.delete("/api/v1/portfolio/AMD", headers={"X-Device-Token": registered_device}) assert resp.status_code == 204 remaining = client.get("/api/v1/portfolio", headers={"X-Device-Token": registered_device}).json() assert all(p["ticker"] != "AMD" for p in remaining) def test_portfolio_unregistered_device(): resp = client.get("/api/v1/portfolio", headers={"X-Device-Token": "nonexistent_token"}) assert resp.status_code == 404 # ─── Option Positions ────────────────────────────────────────────────────────── def test_log_position(registered_device): expiry = str(date.today() + timedelta(days=14)) resp = client.post( "/api/v1/positions", json={ "ticker": "AAPL", "strategy": "covered_call", "strike": 195.0, "expiration": expiry, "premium_received": 2.50, "contracts": 1, }, headers={"X-Device-Token": registered_device}, ) assert resp.status_code == 201 data = resp.json() assert data["ticker"] == "AAPL" assert data["status"] == "open" assert data["id"] is not None def test_close_position(registered_device): expiry = str(date.today() + timedelta(days=14)) pos = client.post( "/api/v1/positions", json={ "ticker": "AAPL", "strategy": "covered_call", "strike": 195.0, "expiration": expiry, "premium_received": 2.50, }, headers={"X-Device-Token": registered_device}, ).json() resp = client.patch( f"/api/v1/positions/{pos['id']}", json={"status": "closed", "close_reason": "bought_back"}, headers={"X-Device-Token": registered_device}, ) assert resp.status_code == 200 assert resp.json()["status"] == "closed" def test_get_open_positions_filter(registered_device): expiry = str(date.today() + timedelta(days=14)) client.post( "/api/v1/positions", json={"ticker": "TSLA", "strategy": "cash_secured_put", "strike": 200.0, "expiration": expiry, "premium_received": 3.0}, headers={"X-Device-Token": registered_device}, ) resp = client.get("/api/v1/positions?status=open", headers={"X-Device-Token": registered_device}) assert resp.status_code == 200 assert len(resp.json()) >= 1 assert all(p["status"] == "open" for p in resp.json()) # ─── Alerts ──────────────────────────────────────────────────────────────────── def test_get_alerts_empty(registered_device): resp = client.get("/api/v1/alerts", headers={"X-Device-Token": registered_device}) assert resp.status_code == 200 assert resp.json() == [] def test_acknowledge_alert(registered_device): from app.models.db_models import Alert from datetime import datetime db = TestSessionLocal() device_id = client.post("/api/v1/devices/register", json={"apns_token": FAKE_TOKEN}).json()["id"] alert = Alert( device_id=device_id, ticker="AAPL", alert_type="close_early", message="Test alert", sent_at=datetime.utcnow(), acknowledged=False, ) db.add(alert) db.commit() db.refresh(alert) alert_id = alert.id db.close() resp = client.patch( f"/api/v1/alerts/{alert_id}/acknowledge", headers={"X-Device-Token": FAKE_TOKEN}, ) assert resp.status_code == 200 assert resp.json()["acknowledged"] is True # ─── Health ──────────────────────────────────────────────────────────────────── def test_health(): with patch("app.routers"), patch("app.services.position_monitor.last_run", None): resp = client.get("/api/v1/health") assert resp.status_code == 200 assert resp.json()["status"] == "ok"