This repository has been archived on 2026-04-23. You can view files and clone it. You cannot open issues or pull requests or push a commit.
Files
apush-rag/mcp_server.py
2026-04-19 15:54:47 -05:00

179 lines
7.2 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
import os
os.environ["MCP_ALLOW_ALL_ORIGINS"] = "1"
import re
import json
import numpy as np
from pathlib import Path
from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue
from mcp.server.fastmcp import FastMCP
from mcp.server import streamable_http
streamable_http.ALLOWED_ORIGINS = None
import mcp.server.streamable_http as _sh
_sh.is_valid_origin = lambda origin, allowed: True
import uvicorn
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
import httpx
from mcp.server.transport_security import TransportSecuritySettings, TransportSecurityMiddleware
TransportSecurityMiddleware.__init__ = lambda self, settings=None: setattr(
self, "settings", TransportSecuritySettings(enable_dns_rebinding_protection=False)
)
# ── Paths ──────────────────────────────────────────────────────────────────
project_root = Path(__file__).resolve().parent
# ── Models / Clients ───────────────────────────────────────────────────────
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
qdrant = QdrantClient(path=str(project_root / "data" / "qdrant_local"))
COLLECTION = "apush_chunks"
with open(project_root / "data" / "processed" / "parent_lookup.json") as f:
parent_lookup = json.load(f)
# ── Config ─────────────────────────────────────────────────────────────────
TOP_K = 10
# ── Embed ──────────────────────────────────────────────────────────────────
def embed_query(query: str) -> np.ndarray:
return model.encode(
f"search_query: {query}",
normalize_embeddings=True,
)
# ── Highlight ──────────────────────────────────────────────────────────────
def highlight_passage(query_emb: np.ndarray, passage: str) -> str:
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', passage) if len(s.strip()) > 20]
if not sentences:
return passage
sent_embs = model.encode(
[f"search_document: {s}" for s in sentences],
normalize_embeddings=True,
batch_size=32,
show_progress_bar=False,
)
scores = sent_embs @ query_emb
top_n = min(3, len(scores))
threshold = float(sorted(scores)[-top_n])
highlighted = passage
for sent, score in zip(sentences, scores):
if float(score) >= threshold:
if f"**{sent}**" not in highlighted:
highlighted = highlighted.replace(sent, f"**{sent}**")
return highlighted
# ── Retrieve ───────────────────────────────────────────────────────────────
def retrieve(query: str) -> dict:
query_emb = embed_query(query)
hits = qdrant.query_points(
collection_name=COLLECTION,
query=query_emb.tolist(),
limit=TOP_K,
query_filter=Filter(
must_not=[
FieldCondition(key="is_chapter_review", match=MatchValue(value=True))
]
),
).points
top_score = hits[0].score if hits else 0
confidence = "HIGH" if top_score >= 0.70 else "MEDIUM" if top_score >= 0.50 else "LOW"
seen_parents = set()
unique_hits = []
for h in hits:
pid = h.payload["parent_id"]
if pid not in seen_parents:
seen_parents.add(pid)
unique_hits.append(h)
unique_hits = unique_hits[:5]
sources = []
for h in unique_hits:
pid = h.payload["parent_id"]
parts = parent_lookup.get(pid, [])
full_text = "\n\n".join(p["text"] for p in parts)
highlighted = highlight_passage(query_emb, full_text)
sources.append({
"score": h.score,
"chapter_num": h.payload["chapter_num"],
"chapter_title": h.payload["chapter_title"],
"section_title": h.payload["section_title"],
"textbook_page": h.payload["textbook_page"],
"text": highlighted,
})
return {
"query": query,
"confidence": confidence,
"top_score": top_score,
"sources": sources,
}
# ── Origin bypass middleware ───────────────────────────────────────────────
class AllowAllOriginsMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next):
request._headers = request.headers.mutablecopy()
request._headers["origin"] = "http://127.0.0.1:11434"
return await call_next(request)
# ── MCP Server ─────────────────────────────────────────────────────────────
mcp = FastMCP("APUSH Tutor")
@mcp.tool()
def search_textbook(query: str) -> str:
"""
Search the AP US History textbook for relevant passages.
Call this before answering ANY US history question.
For broad topics call it multiple times with different search angles.
Returns passages with the most relevant sentences bolded.
Always cite inline (Ch#, p.###) and list sources at the end.
"""
retrieved = retrieve(query)
if not retrieved["sources"]:
return "No relevant passages found in the textbook."
header = f"[Confidence: {retrieved['confidence']} | Top score: {retrieved['top_score']:.3f}]\n\n"
passages = "\n\n---\n\n".join(
f"[SOURCE {i+1} | Ch{s['chapter_num']} {s['section_title']} p.{s['textbook_page']} | score: {s['score']:.3f}]\n{s['text']}"
for i, s in enumerate(retrieved["sources"])
)
footer = "\n\n===SOURCES===\n" + "\n".join(
f"[{i+1}] Ch{s['chapter_num']} {s['section_title']} p.{s['textbook_page']} (score: {s['score']:.3f})"
for i, s in enumerate(retrieved["sources"])
)
return header + passages + footer
# ── Run ────────────────────────────────────────────────────────────────────
if __name__ == "__main__":
app = mcp.streamable_http_app()
app.add_middleware(AllowAllOriginsMiddleware)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["*"],
allow_headers=["*"],
)
print("Starting APUSH MCP server on http://127.0.0.1:52437/mcp")
uvicorn.run(app, host="127.0.0.1", port=52437)