changed nk, system prompt, logic, and added highlightign

This commit is contained in:
2026-04-19 14:34:52 -05:00
parent cc1539743e
commit 248317c959

View File

@@ -1,17 +1,17 @@
import os import os
os.environ["MCP_ALLOW_ALL_ORIGINS"] = "1" os.environ["MCP_ALLOW_ALL_ORIGINS"] = "1"
import re
import json import json
import numpy as np
from pathlib import Path from pathlib import Path
from sentence_transformers import SentenceTransformer from sentence_transformers import SentenceTransformer
from qdrant_client import QdrantClient from qdrant_client import QdrantClient
from qdrant_client.models import Filter, FieldCondition, MatchValue from qdrant_client.models import Filter, FieldCondition, MatchValue
from mcp.server.fastmcp import FastMCP from mcp.server.fastmcp import FastMCP
# Add this right after the FastMCP import, before anything else
from mcp.server import streamable_http from mcp.server import streamable_http
streamable_http.ALLOWED_ORIGINS = None # try this first streamable_http.ALLOWED_ORIGINS = None
# If that doesn't work, patch the actual check function:
import mcp.server.streamable_http as _sh import mcp.server.streamable_http as _sh
_sh.is_valid_origin = lambda origin, allowed: True _sh.is_valid_origin = lambda origin, allowed: True
import uvicorn import uvicorn
@@ -20,7 +20,6 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request from starlette.requests import Request
from mcp.server.transport_security import TransportSecuritySettings, TransportSecurityMiddleware from mcp.server.transport_security import TransportSecuritySettings, TransportSecurityMiddleware
# Monkey-patch to disable DNS rebinding protection entirely
TransportSecurityMiddleware.__init__ = lambda self, settings=None: setattr( TransportSecurityMiddleware.__init__ = lambda self, settings=None: setattr(
self, "settings", TransportSecuritySettings(enable_dns_rebinding_protection=False) self, "settings", TransportSecuritySettings(enable_dns_rebinding_protection=False)
) )
@@ -38,42 +37,94 @@ with open(project_root / "data" / "processed" / "parent_lookup.json") as f:
parent_lookup = json.load(f) parent_lookup = json.load(f)
# ── Config ───────────────────────────────────────────────────────────────── # ── Config ─────────────────────────────────────────────────────────────────
TOP_K = 10 TOP_K = 6
SYSTEM_PROMPT = """You are an expert AP US History tutor helping a student ace their APUSH exam. SYSTEM_PROMPT = """You are an elite AP US History tutor. Your only goal is to help the student master APUSH and score a 5.
You have access to the search_textbook tool. Call it before answering ANY history question. ━━━ TOOL USE ━━━
ALWAYS call search_textbook before answering any history question — no exceptions.
For complex questions (LEQ/DBQ/thematic), call it 2-3 times with different search angles to get full coverage.
ANSWERING: ━━━ CITATIONS ━━━
- Cite inline like (Ch5, p.153) after every specific claim - Cite inline after every specific claim: (Ch5, p.153)
- **Bold** key terms, dates, names, and critical facts - The **bolded sentences** in each source are the most relevant — prioritize citing and building on those
- Correct false premises directly — don't reinforce wrong assumptions - Never invent or guess a citation — if unsure, say "Outside textbook:"
- If the textbook doesn't cover it, answer from general knowledge and prefix with "Outside textbook:" - If the textbook is silent on something relevant, supplement with general knowledge, clearly labeled
FORMAT — match the question type: ━━━ ACCURACY ━━━
- One word/fact → one word - Correct false premises immediately and directly — never reinforce a wrong assumption
- SAQ → 1 focused paragraph, dense with evidence - Distinguish causation from correlation, primary from secondary causes
- LEQ/DBQ → full essay: context, thesis, body paragraphs with evidence, nuance - Note historiographical debates where relevant (e.g. revisionist vs traditional interpretations)
- General question → clear prose, as long as needed - Be precise with dates, names, legislation, and turning points — vagueness loses points on the exam
END EVERY RESPONSE WITH: ━━━ FORMAT — match the question type exactly ━━━
- Identification / one fact → one concise answer, one citation
- SAQ (Short Answer) → 3 tight paragraphs: claim → evidence → analysis. No intro/conclusion fluff
- LEQ (Long Essay) → Full essay: contextualization → thesis → 3 body paragraphs (each with specific evidence + analysis) → conclusion with complexity
- DBQ → Same as LEQ plus: sourcing, audience/purpose/context for docs, corroboration across docs
- Compare/contrast → Use parallel structure, explicit similarities AND differences
- General question → Clear prose, as long as needed, no padding
━━━ APUSH EXAM SKILLS ━━━
When writing essays, explicitly hit the College Board rubric:
- Contextualization: zoom out to broader historical context BEFORE the thesis
- Thesis: historically defensible, specific, addresses complexity (not just "there were many causes")
- Evidence: at least 2 specific pieces of evidence per body paragraph
- Analysis: explain HOW and WHY, not just what happened
- Complexity: demonstrate nuance — turning points, continuity vs change, multiple causation, or cross-period connections
━━━ END EVERY RESPONSE WITH ━━━
--- ---
**Sources Used:** **Sources Used:**
[list every source from the tool output with chapter, section, page, and score] [list each source: Ch# Section p.### — score: X.XXX]
**Retrieval Confidence:** HIGH/MEDIUM/LOW""" **Retrieval Confidence:** HIGH / MEDIUM / LOW
**Exam Tip:** [one sentence of targeted advice for how this topic typically appears on the APUSH exam]"""
# ── Embed ────────────────────────────────────────────────────────────────── # ── Embed ──────────────────────────────────────────────────────────────────
def embed_query(query: str) -> list[float]: def embed_query(query: str) -> np.ndarray:
return model.encode( return model.encode(
f"search_query: {query}", f"search_query: {query}",
normalize_embeddings=True, normalize_embeddings=True,
).tolist() )
# ── Highlight ──────────────────────────────────────────────────────────────
def highlight_passage(query_emb: np.ndarray, passage: str) -> str:
"""
Bold the top 3 most query-relevant sentences using the already-loaded
embedder. Reuses the query embedding computed during retrieval — zero
extra model calls.
"""
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', passage) if len(s.strip()) > 20]
if not sentences:
return passage
sent_embs = model.encode(
[f"search_document: {s}" for s in sentences],
normalize_embeddings=True,
batch_size=32,
show_progress_bar=False,
)
scores = sent_embs @ query_emb # cosine sim (both normalized)
top_n = min(3, len(scores))
threshold = float(sorted(scores)[-top_n])
highlighted = passage
for sent, score in zip(sentences, scores):
if float(score) >= threshold:
# avoid double-bolding if somehow already bolded
if f"**{sent}**" not in highlighted:
highlighted = highlighted.replace(sent, f"**{sent}**")
return highlighted
# ── Retrieve ─────────────────────────────────────────────────────────────── # ── Retrieve ───────────────────────────────────────────────────────────────
def retrieve(query: str) -> dict: def retrieve(query: str) -> dict:
query_emb = embed_query(query) # compute once, reuse for highlighting
hits = qdrant.query_points( hits = qdrant.query_points(
collection_name=COLLECTION, collection_name=COLLECTION,
query=embed_query(query), query=query_emb.tolist(),
limit=TOP_K, limit=TOP_K,
query_filter=Filter( query_filter=Filter(
must_not=[ must_not=[
@@ -98,20 +149,22 @@ def retrieve(query: str) -> dict:
seen_parents.add(pid) seen_parents.add(pid)
unique_hits.append(h) unique_hits.append(h)
unique_hits = unique_hits[:5] unique_hits = unique_hits[:4]
sources = [] sources = []
for h in unique_hits: for h in unique_hits:
pid = h.payload["parent_id"] pid = h.payload["parent_id"]
parts = parent_lookup.get(pid, []) parts = parent_lookup.get(pid, [])
full_text = "\n\n".join(p["text"] for p in parts) full_text = "\n\n".join(p["text"] for p in parts)
highlighted = highlight_passage(query_emb, full_text) # reuse query_emb
sources.append({ sources.append({
"score": h.score, "score": h.score,
"chapter_num": h.payload["chapter_num"], "chapter_num": h.payload["chapter_num"],
"chapter_title": h.payload["chapter_title"], "chapter_title": h.payload["chapter_title"],
"section_title": h.payload["section_title"], "section_title": h.payload["section_title"],
"textbook_page": h.payload["textbook_page"], "textbook_page": h.payload["textbook_page"],
"text": full_text, "text": highlighted,
}) })
return { return {
@@ -124,13 +177,12 @@ def retrieve(query: str) -> dict:
# ── Origin bypass middleware ──────────────────────────────────────────────── # ── Origin bypass middleware ────────────────────────────────────────────────
class AllowAllOriginsMiddleware(BaseHTTPMiddleware): class AllowAllOriginsMiddleware(BaseHTTPMiddleware):
async def dispatch(self, request: Request, call_next): async def dispatch(self, request: Request, call_next):
# Spoof origin so FastMCP's internal check passes
request._headers = request.headers.mutablecopy() request._headers = request.headers.mutablecopy()
request._headers["origin"] = "http://127.0.0.1:11434" request._headers["origin"] = "http://127.0.0.1:11434"
return await call_next(request) return await call_next(request)
# ── MCP Server ───────────────────────────────────────────────────────────── # ── MCP Server ─────────────────────────────────────────────────────────────
mcp = FastMCP("APUSH Tutor") mcp = FastMCP("APUSH Tutor", instructions=SYSTEM_PROMPT)
@mcp.tool() @mcp.tool()
def search_textbook(query: str) -> str: def search_textbook(query: str) -> str:
@@ -159,11 +211,6 @@ def search_textbook(query: str) -> str:
return header + passages + footer return header + passages + footer
@mcp.prompt()
def system_prompt() -> str:
"""The APUSH tutor system prompt."""
return SYSTEM_PROMPT
# ── Run ──────────────────────────────────────────────────────────────────── # ── Run ────────────────────────────────────────────────────────────────────
if __name__ == "__main__": if __name__ == "__main__":
app = mcp.streamable_http_app() app = mcp.streamable_http_app()