Compare commits
7 Commits
ebe592a6b1
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
| 8383520a3a | |||
| 7935b01771 | |||
| 248317c959 | |||
| cc1539743e | |||
| 1511423057 | |||
| 50b4c1c905 | |||
| 3378ad7328 |
179
mcp_server.py
Normal file
179
mcp_server.py
Normal file
@@ -0,0 +1,179 @@
|
|||||||
|
import os
|
||||||
|
os.environ["MCP_ALLOW_ALL_ORIGINS"] = "1"
|
||||||
|
|
||||||
|
import re
|
||||||
|
import json
|
||||||
|
import numpy as np
|
||||||
|
from pathlib import Path
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
from qdrant_client import QdrantClient
|
||||||
|
from qdrant_client.models import Filter, FieldCondition, MatchValue
|
||||||
|
from mcp.server.fastmcp import FastMCP
|
||||||
|
from mcp.server import streamable_http
|
||||||
|
streamable_http.ALLOWED_ORIGINS = None
|
||||||
|
|
||||||
|
import mcp.server.streamable_http as _sh
|
||||||
|
_sh.is_valid_origin = lambda origin, allowed: True
|
||||||
|
import uvicorn
|
||||||
|
from starlette.middleware.cors import CORSMiddleware
|
||||||
|
from starlette.middleware.base import BaseHTTPMiddleware
|
||||||
|
from starlette.requests import Request
|
||||||
|
import httpx
|
||||||
|
from mcp.server.transport_security import TransportSecuritySettings, TransportSecurityMiddleware
|
||||||
|
|
||||||
|
TransportSecurityMiddleware.__init__ = lambda self, settings=None: setattr(
|
||||||
|
self, "settings", TransportSecuritySettings(enable_dns_rebinding_protection=False)
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Paths ──────────────────────────────────────────────────────────────────
|
||||||
|
project_root = Path(__file__).resolve().parent
|
||||||
|
|
||||||
|
# ── Models / Clients ───────────────────────────────────────────────────────
|
||||||
|
model = SentenceTransformer("nomic-ai/nomic-embed-text-v1.5", trust_remote_code=True)
|
||||||
|
|
||||||
|
qdrant = QdrantClient(path=str(project_root / "data" / "qdrant_local"))
|
||||||
|
COLLECTION = "apush_chunks"
|
||||||
|
|
||||||
|
with open(project_root / "data" / "processed" / "parent_lookup.json") as f:
|
||||||
|
parent_lookup = json.load(f)
|
||||||
|
|
||||||
|
# ── Config ─────────────────────────────────────────────────────────────────
|
||||||
|
TOP_K = 10
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
# ── Embed ──────────────────────────────────────────────────────────────────
|
||||||
|
def embed_query(query: str) -> np.ndarray:
|
||||||
|
return model.encode(
|
||||||
|
f"search_query: {query}",
|
||||||
|
normalize_embeddings=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# ── Highlight ──────────────────────────────────────────────────────────────
|
||||||
|
def highlight_passage(query_emb: np.ndarray, passage: str) -> str:
|
||||||
|
sentences = [s.strip() for s in re.split(r'(?<=[.!?])\s+', passage) if len(s.strip()) > 20]
|
||||||
|
if not sentences:
|
||||||
|
return passage
|
||||||
|
|
||||||
|
sent_embs = model.encode(
|
||||||
|
[f"search_document: {s}" for s in sentences],
|
||||||
|
normalize_embeddings=True,
|
||||||
|
batch_size=32,
|
||||||
|
show_progress_bar=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
scores = sent_embs @ query_emb
|
||||||
|
top_n = min(3, len(scores))
|
||||||
|
threshold = float(sorted(scores)[-top_n])
|
||||||
|
|
||||||
|
highlighted = passage
|
||||||
|
for sent, score in zip(sentences, scores):
|
||||||
|
if float(score) >= threshold:
|
||||||
|
if f"**{sent}**" not in highlighted:
|
||||||
|
highlighted = highlighted.replace(sent, f"**{sent}**")
|
||||||
|
|
||||||
|
return highlighted
|
||||||
|
|
||||||
|
# ── Retrieve ───────────────────────────────────────────────────────────────
|
||||||
|
def retrieve(query: str) -> dict:
|
||||||
|
query_emb = embed_query(query)
|
||||||
|
|
||||||
|
hits = qdrant.query_points(
|
||||||
|
collection_name=COLLECTION,
|
||||||
|
query=query_emb.tolist(),
|
||||||
|
limit=TOP_K,
|
||||||
|
query_filter=Filter(
|
||||||
|
must_not=[
|
||||||
|
FieldCondition(key="is_chapter_review", match=MatchValue(value=True))
|
||||||
|
]
|
||||||
|
),
|
||||||
|
).points
|
||||||
|
|
||||||
|
top_score = hits[0].score if hits else 0
|
||||||
|
confidence = "HIGH" if top_score >= 0.70 else "MEDIUM" if top_score >= 0.50 else "LOW"
|
||||||
|
|
||||||
|
seen_parents = set()
|
||||||
|
unique_hits = []
|
||||||
|
for h in hits:
|
||||||
|
pid = h.payload["parent_id"]
|
||||||
|
if pid not in seen_parents:
|
||||||
|
seen_parents.add(pid)
|
||||||
|
unique_hits.append(h)
|
||||||
|
|
||||||
|
unique_hits = unique_hits[:5]
|
||||||
|
|
||||||
|
sources = []
|
||||||
|
for h in unique_hits:
|
||||||
|
pid = h.payload["parent_id"]
|
||||||
|
parts = parent_lookup.get(pid, [])
|
||||||
|
full_text = "\n\n".join(p["text"] for p in parts)
|
||||||
|
highlighted = highlight_passage(query_emb, full_text)
|
||||||
|
|
||||||
|
sources.append({
|
||||||
|
"score": h.score,
|
||||||
|
"chapter_num": h.payload["chapter_num"],
|
||||||
|
"chapter_title": h.payload["chapter_title"],
|
||||||
|
"section_title": h.payload["section_title"],
|
||||||
|
"textbook_page": h.payload["textbook_page"],
|
||||||
|
"text": highlighted,
|
||||||
|
})
|
||||||
|
|
||||||
|
return {
|
||||||
|
"query": query,
|
||||||
|
"confidence": confidence,
|
||||||
|
"top_score": top_score,
|
||||||
|
"sources": sources,
|
||||||
|
}
|
||||||
|
|
||||||
|
# ── Origin bypass middleware ───────────────────────────────────────────────
|
||||||
|
class AllowAllOriginsMiddleware(BaseHTTPMiddleware):
|
||||||
|
async def dispatch(self, request: Request, call_next):
|
||||||
|
request._headers = request.headers.mutablecopy()
|
||||||
|
request._headers["origin"] = "http://127.0.0.1:11434"
|
||||||
|
return await call_next(request)
|
||||||
|
|
||||||
|
# ── MCP Server ─────────────────────────────────────────────────────────────
|
||||||
|
mcp = FastMCP("APUSH Tutor")
|
||||||
|
|
||||||
|
@mcp.tool()
|
||||||
|
def search_textbook(query: str) -> str:
|
||||||
|
"""
|
||||||
|
Search the AP US History textbook for relevant passages.
|
||||||
|
Call this before answering ANY US history question.
|
||||||
|
For broad topics call it multiple times with different search angles.
|
||||||
|
Returns passages with the most relevant sentences bolded.
|
||||||
|
Always cite inline (Ch#, p.###) and list sources at the end.
|
||||||
|
"""
|
||||||
|
retrieved = retrieve(query)
|
||||||
|
|
||||||
|
if not retrieved["sources"]:
|
||||||
|
return "No relevant passages found in the textbook."
|
||||||
|
|
||||||
|
header = f"[Confidence: {retrieved['confidence']} | Top score: {retrieved['top_score']:.3f}]\n\n"
|
||||||
|
|
||||||
|
passages = "\n\n---\n\n".join(
|
||||||
|
f"[SOURCE {i+1} | Ch{s['chapter_num']} › {s['section_title']} › p.{s['textbook_page']} | score: {s['score']:.3f}]\n{s['text']}"
|
||||||
|
for i, s in enumerate(retrieved["sources"])
|
||||||
|
)
|
||||||
|
|
||||||
|
footer = "\n\n===SOURCES===\n" + "\n".join(
|
||||||
|
f"[{i+1}] Ch{s['chapter_num']} › {s['section_title']} › p.{s['textbook_page']} (score: {s['score']:.3f})"
|
||||||
|
for i, s in enumerate(retrieved["sources"])
|
||||||
|
)
|
||||||
|
|
||||||
|
return header + passages + footer
|
||||||
|
|
||||||
|
|
||||||
|
# ── Run ────────────────────────────────────────────────────────────────────
|
||||||
|
if __name__ == "__main__":
|
||||||
|
app = mcp.streamable_http_app()
|
||||||
|
app.add_middleware(AllowAllOriginsMiddleware)
|
||||||
|
app.add_middleware(
|
||||||
|
CORSMiddleware,
|
||||||
|
allow_origins=["*"],
|
||||||
|
allow_methods=["*"],
|
||||||
|
allow_headers=["*"],
|
||||||
|
)
|
||||||
|
|
||||||
|
print("Starting APUSH MCP server on http://127.0.0.1:52437/mcp")
|
||||||
|
uvicorn.run(app, host="127.0.0.1", port=52437)
|
||||||
@@ -393,7 +393,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.14.3"
|
"version": "3.14.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
@@ -1511,7 +1511,7 @@
|
|||||||
"name": "python",
|
"name": "python",
|
||||||
"nbconvert_exporter": "python",
|
"nbconvert_exporter": "python",
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.14.3"
|
"version": "3.14.4"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"nbformat": 4,
|
"nbformat": 4,
|
||||||
|
|||||||
File diff suppressed because one or more lines are too long
Reference in New Issue
Block a user