|
|
|
|
@@ -1,26 +1,26 @@
|
|
|
|
|
import os
|
|
|
|
|
os.environ["MCP_ALLOW_ALL_ORIGINS"] = "1"
|
|
|
|
|
|
|
|
|
|
import re
|
|
|
|
|
import json
|
|
|
|
|
import numpy as np
|
|
|
|
|
from pathlib import Path
|
|
|
|
|
from sentence_transformers import SentenceTransformer
|
|
|
|
|
from qdrant_client import QdrantClient
|
|
|
|
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
|
|
|
|
from mcp.server.fastmcp import FastMCP
|
|
|
|
|
# Add this right after the FastMCP import, before anything else
|
|
|
|
|
from mcp.server import streamable_http
|
|
|
|
|
streamable_http.ALLOWED_ORIGINS = None # try this first
|
|
|
|
|
streamable_http.ALLOWED_ORIGINS = None
|
|
|
|
|
|
|
|
|
|
# If that doesn't work, patch the actual check function:
|
|
|
|
|
import mcp.server.streamable_http as _sh
|
|
|
|
|
_sh.is_valid_origin = lambda origin, allowed: True
|
|
|
|
|
import uvicorn
|
|
|
|
|
from starlette.middleware.cors import CORSMiddleware
|
|
|
|
|
from starlette.middleware.base import BaseHTTPMiddleware
|
|
|
|
|
from starlette.requests import Request
|
|
|
|
|
import httpx
|
|
|
|
|
from mcp.server.transport_security import TransportSecuritySettings, TransportSecurityMiddleware
|
|
|
|
|
|
|
|
|
|
# Monkey-patch to disable DNS rebinding protection entirely
|
|
|
|
|
TransportSecurityMiddleware.__init__ = lambda self, settings=None: setattr(
|
|
|
|
|
self, "settings", TransportSecuritySettings(enable_dns_rebinding_protection=False)
|
|
|
|
|
)
|
|
|
|
|
@@ -40,40 +40,47 @@ with open(project_root / "data" / "processed" / "parent_lookup.json") as f:
|
|
|
|
|
# ── Config ─────────────────────────────────────────────────────────────────
|
|
|
|
|
TOP_K = 10
|
|
|
|
|
|
|
|
|
|
SYSTEM_PROMPT = """You are an expert AP US History tutor helping a student ace their APUSH exam.
|
|
|
|
|
|
|
|
|
|
You have access to the search_textbook tool. Call it before answering ANY history question.
|
|
|
|
|
|
|
|
|
|
ANSWERING:
|
|
|
|
|
- Cite inline like (Ch5, p.153) after every specific claim
|
|
|
|
|
- **Bold** key terms, dates, names, and critical facts
|
|
|
|
|
- Correct false premises directly — don't reinforce wrong assumptions
|
|
|
|
|
- If the textbook doesn't cover it, answer from general knowledge and prefix with "Outside textbook:"
|
|
|
|
|
|
|
|
|
|
FORMAT — match the question type:
|
|
|
|
|
- One word/fact → one word
|
|
|
|
|
- SAQ → 1 focused paragraph, dense with evidence
|
|
|
|
|
- LEQ/DBQ → full essay: context, thesis, body paragraphs with evidence, nuance
|
|
|
|
|
- General question → clear prose, as long as needed
|
|
|
|
|
|
|
|
|
|
END EVERY RESPONSE WITH:
|
|
|
|
|
---
|
|
|
|
|
**Sources Used:**
|
|
|
|
|
[list every source from the tool output with chapter, section, page, and score]
|
|
|
|
|
**Retrieval Confidence:** HIGH/MEDIUM/LOW"""
|
|
|
|
|
|
|
|
|
|
# ── Embed ──────────────────────────────────────────────────────────────────
|
|
|
|
|
def embed_query(query: str) -> list[float]:
|
|
|
|
|
def embed_query(query: str) -> np.ndarray:
|
|
|
|
|
return model.encode(
|
|
|
|
|
f"search_query: {query}",
|
|
|
|
|
normalize_embeddings=True,
|
|
|
|
|
).tolist()
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
# ── Highlight ──────────────────────────────────────────────────────────────
|
|
|
|
|
def highlight_passage(query_emb: np.ndarray, passage: str) -> str:
|
|
|
|
|
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', passage) if len(s.strip()) > 20]
|
|
|
|
|
if not sentences:
|
|
|
|
|
return passage
|
|
|
|
|
|
|
|
|
|
sent_embs = model.encode(
|
|
|
|
|
[f"search_document: {s}" for s in sentences],
|
|
|
|
|
normalize_embeddings=True,
|
|
|
|
|
batch_size=32,
|
|
|
|
|
show_progress_bar=False,
|
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
scores = sent_embs @ query_emb
|
|
|
|
|
top_n = min(3, len(scores))
|
|
|
|
|
threshold = float(sorted(scores)[-top_n])
|
|
|
|
|
|
|
|
|
|
highlighted = passage
|
|
|
|
|
for sent, score in zip(sentences, scores):
|
|
|
|
|
if float(score) >= threshold:
|
|
|
|
|
if f"**{sent}**" not in highlighted:
|
|
|
|
|
highlighted = highlighted.replace(sent, f"**{sent}**")
|
|
|
|
|
|
|
|
|
|
return highlighted
|
|
|
|
|
|
|
|
|
|
# ── Retrieve ───────────────────────────────────────────────────────────────
|
|
|
|
|
def retrieve(query: str) -> dict:
|
|
|
|
|
query_emb = embed_query(query)
|
|
|
|
|
|
|
|
|
|
hits = qdrant.query_points(
|
|
|
|
|
collection_name=COLLECTION,
|
|
|
|
|
query=embed_query(query),
|
|
|
|
|
query=query_emb.tolist(),
|
|
|
|
|
limit=TOP_K,
|
|
|
|
|
query_filter=Filter(
|
|
|
|
|
must_not=[
|
|
|
|
|
@@ -83,12 +90,7 @@ def retrieve(query: str) -> dict:
|
|
|
|
|
).points
|
|
|
|
|
|
|
|
|
|
top_score = hits[0].score if hits else 0
|
|
|
|
|
if top_score >= 0.70:
|
|
|
|
|
confidence = "HIGH"
|
|
|
|
|
elif top_score >= 0.50:
|
|
|
|
|
confidence = "MEDIUM"
|
|
|
|
|
else:
|
|
|
|
|
confidence = "LOW"
|
|
|
|
|
confidence = "HIGH" if top_score >= 0.70 else "MEDIUM" if top_score >= 0.50 else "LOW"
|
|
|
|
|
|
|
|
|
|
seen_parents = set()
|
|
|
|
|
unique_hits = []
|
|
|
|
|
@@ -105,13 +107,15 @@ def retrieve(query: str) -> dict:
|
|
|
|
|
pid = h.payload["parent_id"]
|
|
|
|
|
parts = parent_lookup.get(pid, [])
|
|
|
|
|
full_text = "\n\n".join(p["text"] for p in parts)
|
|
|
|
|
highlighted = highlight_passage(query_emb, full_text)
|
|
|
|
|
|
|
|
|
|
sources.append({
|
|
|
|
|
"score": h.score,
|
|
|
|
|
"chapter_num": h.payload["chapter_num"],
|
|
|
|
|
"chapter_title": h.payload["chapter_title"],
|
|
|
|
|
"section_title": h.payload["section_title"],
|
|
|
|
|
"textbook_page": h.payload["textbook_page"],
|
|
|
|
|
"text": full_text,
|
|
|
|
|
"text": highlighted,
|
|
|
|
|
})
|
|
|
|
|
|
|
|
|
|
return {
|
|
|
|
|
@@ -121,10 +125,9 @@ def retrieve(query: str) -> dict:
|
|
|
|
|
"sources": sources,
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
# ── Origin bypass middleware ────────────────────────────────────────────────
|
|
|
|
|
# ── Origin bypass middleware ───────────────────────────────────────────────
|
|
|
|
|
class AllowAllOriginsMiddleware(BaseHTTPMiddleware):
|
|
|
|
|
async def dispatch(self, request: Request, call_next):
|
|
|
|
|
# Spoof origin so FastMCP's internal check passes
|
|
|
|
|
request._headers = request.headers.mutablecopy()
|
|
|
|
|
request._headers["origin"] = "http://127.0.0.1:11434"
|
|
|
|
|
return await call_next(request)
|
|
|
|
|
@@ -136,9 +139,10 @@ mcp = FastMCP("APUSH Tutor")
|
|
|
|
|
def search_textbook(query: str) -> str:
|
|
|
|
|
"""
|
|
|
|
|
Search the AP US History textbook for relevant passages.
|
|
|
|
|
Use this for any question about US history before answering.
|
|
|
|
|
Always cite sources inline and list all sources at the end.
|
|
|
|
|
Bold or emphasize the most important phrases in your answer.
|
|
|
|
|
Call this before answering ANY US history question.
|
|
|
|
|
For broad topics call it multiple times with different search angles.
|
|
|
|
|
Returns passages with the most relevant sentences bolded.
|
|
|
|
|
Always cite inline (Ch#, p.###) and list sources at the end.
|
|
|
|
|
"""
|
|
|
|
|
retrieved = retrieve(query)
|
|
|
|
|
|
|
|
|
|
@@ -159,10 +163,6 @@ def search_textbook(query: str) -> str:
|
|
|
|
|
|
|
|
|
|
return header + passages + footer
|
|
|
|
|
|
|
|
|
|
@mcp.prompt()
|
|
|
|
|
def system_prompt() -> str:
|
|
|
|
|
"""The APUSH tutor system prompt."""
|
|
|
|
|
return SYSTEM_PROMPT
|
|
|
|
|
|
|
|
|
|
# ── Run ────────────────────────────────────────────────────────────────────
|
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|