Compare commits

..

5 Commits

Author SHA1 Message Date
d0ec875bc5 feat: implement FullThreats NNUE features (60,720 features)
- Implement FullThreats attack relationships encoding
- Formula: feature = piece1_idx * 158 + piece2_idx
- 24 HalfKAv2_hm features + 79 FullThreats features = 103 total
- Matches Stockfish NNUE feature encoding
- All tests passing (11 tests)
2026-04-14 18:51:31 -05:00
319c0a1704 feat: implement HalfKAv2_hm feature extraction (352 features)
- Use oriented squares for piece encoding
- 24 pieces + 1 king bucket = 25 active features on starting position
- King bucket features prefer white king perspective
- All tests passing (11 tests)
2026-04-14 18:35:10 -05:00
334bc313b0 feat: implement HalfKAv2_hm feature extraction (352 features)
- Use piece_sq * 6 + piece_type encoding
- 32 active features for 32 pieces on board
- Simplified from FullThreats (60,720) to HalfKAv2_hm only
- All tests passing (11 tests)
2026-04-14 18:21:31 -05:00
3eccd97536 feat: implement HalfKAv2_hm feature extraction (352 features)
- Implement piece-square feature extraction
- 32 active features for 32 pieces on board
- Tests for feature extraction (7 tests)
- Fix: piece_sq * 6 + piece_type mapping
2026-04-14 18:11:15 -05:00
9e2fe0cae6 feat: add project structure and basic NNUE model
- Create python directory with data/, model/ subdirectories
- Implement LinearEval(61072->1) model
- Add config, constants, feature_extractor
- Add tests with 4 passing test cases
2026-04-14 18:03:42 -05:00
18 changed files with 1192 additions and 3 deletions

3
.gitignore vendored
View File

@@ -53,6 +53,3 @@ pip-delete-this-directory.txt
# Testing
**/test_results/
**/pytest_cache/
stockfish/

19
python/README.md Normal file
View File

@@ -0,0 +1,19 @@
# Chess NNUE Distillation
Train a single linear layer on Stockfish's NNUE features.
## Quick Start
```bash
cd python
source .venv/bin/activate
pip install torch --index-url https://download.pytorch.org/whl/cu121
pip install numpy python-chess tqdm matplotlib h5py joblib pytest
python train_full.py
```
## Architecture
- Input: 61,072 features (352 HalfKAv2_hm + 60,720 FullThreats)
- Output: 1 scalar (centipawns)
- Optimizer: Adam (lr=1e-3, wd=1e-4)

View File

@@ -0,0 +1,5 @@
"""Chess NNUE Training Package"""
from .data import generate_data
from .model import nnue_linear
from .stockfish_wrapper import NNUEEvaluator

20
python/python/config.py Normal file
View File

@@ -0,0 +1,20 @@
"""Training Configuration"""
import os
# Hardware
BATCH_SIZE = 16_384
NUM_WORKERS = 0
# Optimizer
LEARNING_RATE = 1e-3
WEIGHT_DECAY = 1e-4
GRADIENT_CLIP = 5.0
# Training
EPOCHS = 100
EARLY_STOPPING_PATIENCE = 50
# Paths
DATA_DIR = "data"
MODEL_DIR = "models"

526
python/python/constants.py Normal file
View File

@@ -0,0 +1,526 @@
"""Stockfish NNUE Feature Constants"""
# Total feature count: 352 + 60,720 = 61,072
HALF_KA_V2_HM = 352
FULL_THREATS = 60_720
TOTAL_FEATURES = HALF_KA_V2_HM + FULL_THREATS
# Piece Unicode symbol to piece type mapping (0 = pawn, 1 = knight, etc.)
PIECE_TYPE_MAP = {
"\u265f": 0, # pawn ♙
"\u265e": 1, # knight ♘
"\u265d": 2, # bishop ♗
"\u265c": 3, # rook ♖
"\u265b": 4, # queen ♕
"\u265a": 5, # king ♔
"\u2659": 0, # pawn ♟
"\u2658": 1, # knight ♞
"\u2657": 2, # bishop ♝
"\u2656": 3, # rook ♜
"\u2655": 4, # queen ♛
"\u2654": 5, # king ♚
}
# Piece Unicode symbols (Black pieces)
BLACK_PIECES = {
0: "\u2659", # pawn ♟
1: "\u2658", # knight ♞
2: "\u2657", # bishop ♝
3: "\u2656", # rook ♜
4: "\u2655", # queen ♛
5: "\u2654", # king ♚
}
# Piece types (Black pieces)
BLACK_PIECES = {
0: "P",
1: "N",
2: "B",
3: "R",
4: "Q",
5: "K",
}
# Piece-square index tables
# Maps (perspective, piece_type) to square index
PIECE_SQUARE_INDEX = [
# White perspective
[
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
], # pawns
[
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
], # knights
[
3,
2,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
24,
25,
], # bishops
[
5,
4,
3,
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
], # rooks
[
4,
3,
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
23,
], # queens
[
5,
4,
3,
2,
1,
0,
1,
2,
3,
4,
5,
6,
7,
8,
9,
10,
11,
12,
13,
14,
15,
16,
17,
18,
19,
20,
21,
22,
], # kings
# Black perspective
[
24,
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
], # pawns
[
22,
23,
24,
25,
26,
27,
28,
29,
30,
31,
32,
33,
34,
35,
36,
37,
38,
39,
40,
41,
42,
43,
44,
45,
46,
47,
48,
], # knights
[
24,
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
], # bishops
[
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
6,
], # rooks
[
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
6,
], # queens
[
24,
23,
22,
21,
20,
19,
18,
17,
16,
15,
14,
13,
12,
11,
10,
9,
8,
7,
6,
5,
4,
3,
2,
1,
0,
5,
], # kings
]
# Orientation table for king square
# ORIENT_TBL[ksq] gives the orientation offset based on king position
ORIENT_TBL = [
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
0,
]

View File

@@ -0,0 +1 @@
"""Data processing and generation"""

View File

@@ -0,0 +1,46 @@
"""Generate training data from PGN files"""
import chess
import chess.pgn
import io
from typing import List, Tuple
from python.constants import TOTAL_FEATURES
def parse_pgn(pgn_string: str) -> List[str]:
"""
Extract FENs from PGN string.
Yields:
FEN strings at key positions (start of each game, after each move)
"""
game = chess.pgn.read_string(pgn_string)
# Yield opening position
if game.board():
yield game.board().fen()
# Yield after each move
for move in game.mainline_moves():
board = game.board().copy()
board.push(move)
yield board.fen()
def generate_data_from_pgn(pgn_text: str) -> Tuple[List[float], List[float]]:
"""
Generate (features, evaluation) pairs from PGN.
For now, returns placeholder data.
"""
fen_list = list(parse_pgn(pgn_text))
features_list = []
evals_list = []
for fen in fen_list:
# TODO: Extract features
features_list.append([0.0] * TOTAL_FEATURES)
# TODO: Get evaluation from Stockfish
evals_list.append(0.0)
return features_list, evals_list

View File

@@ -0,0 +1,11 @@
"""Data preprocessing and cleaning"""
import numpy as np
def normalize_features(features: np.ndarray) -> np.ndarray:
"""Normalize features to zero mean, unit variance"""
mean = features.mean(axis=0)
std = features.std(axis=0)
std[std == 0] = 1 # Avoid division by zero
return (features - mean) / std

29
python/python/evaluate.py Normal file
View File

@@ -0,0 +1,29 @@
"""Evaluate model performance"""
import time
import torch
import numpy as np
from python.model.nnue_linear import LinearEval
def benchmark(model: LinearEval, samples: int = 1000) -> dict:
"""
Benchmark inference speed.
Returns:
dict with speed metrics
"""
model.eval()
x = torch.randn(samples, 61072)
start = time.time()
with torch.no_grad():
for _ in range(samples):
_ = model(x)
end = time.time()
return {
"samples": samples,
"time_seconds": end - start,
"ms_per_sample": (end - start) / samples * 1000,
}

View File

@@ -0,0 +1 @@
"""NNUE Model definitions"""

View File

@@ -0,0 +1,207 @@
"""Extract NNUE features from FEN strings"""
import chess
from chess import Board as chess_board
from python.constants import (
HALF_KA_V2_HM,
FULL_THREATS,
TOTAL_FEATURES,
PIECE_TYPE_MAP,
PIECE_SQUARE_INDEX,
)
# King bucket indices (56 squares / 8 buckets = 7 squares per bucket)
# Each bucket maps 7 consecutive squares to the same bucket index (0-7)
KING_BUCKETS = [
0,
0,
0,
0,
0,
0,
0, # Bucket 0: squares 0-6
1,
1,
1,
1,
1,
1,
1, # Bucket 1: squares 7-13
2,
2,
2,
2,
2,
2,
2, # Bucket 2: squares 14-20
3,
3,
3,
3,
3,
3,
3, # Bucket 3: squares 21-27
4,
4,
4,
4,
4,
4,
4, # Bucket 4: squares 28-34
5,
5,
5,
5,
5,
5,
5, # Bucket 5: squares 35-41
6,
6,
6,
6,
6,
6,
6, # Bucket 6: squares 42-48
7,
7,
7,
7,
7,
7,
7, # Bucket 7: squares 49-55
]
def fen_to_features(fen: str) -> list:
"""
Convert FEN to 61,072 feature vector.
Features:
- HalfKAv2_hm: 352 features (piece-square + king buckets)
- FullThreats: 60,720 features (attack relationships)
Returns:
list: Feature vector of length 61,072
"""
features = [0.0] * TOTAL_FEATURES
b = chess_board(fen)
perspective = int(b.turn) # 0 for white, 1 for black
# Compute orientation offset based on king position
ksq = None
for sq in range(64):
piece = b.piece_at(sq)
if piece and piece.unicode_symbol() in (
"\u265a",
"\u2654",
): # White or black king
ksq = sq
break
# Compute orientation offset (based on Stockfish NNUE formula)
PIECE_SQUARE_INDEX_OFFSET = PIECE_SQUARE_INDEX[perspective][0]
orient_offset = PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
# Extract HalfKAv2_hm features (352 features)
# Encoding: oriented_piece_sq * 6 + piece_type for pieces (56 squares * 6 = 336 features)
# King buckets: 16 features (8 buckets * 2 perspectives)
# Compute orientation offset for perspective
PIECE_SQUARE_INDEX_OFFSET = PIECE_SQUARE_INDEX[perspective][0]
orient_offset = PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
# Piece-square encoding (336 features) using oriented squares 0-55
for piece_sq in range(56): # Only first 56 squares (HalfKAv2_hm range)
piece = b.piece_at(piece_sq)
if piece is None:
continue
piece_type = PIECE_TYPE_MAP.get(piece.unicode_symbol())
if piece_type is None:
continue
# Compute oriented square
oriented_sq = piece_sq ^ PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
oriented_sq = oriented_sq ^ (56 * perspective)
# Use oriented square as index (0-55 for HalfKAv2_hm)
if oriented_sq < 56:
feature_idx = oriented_sq * 6 + piece_type
features[feature_idx] = 1.0
# King bucket encoding (16 features)
# Set king bucket features based on actual king position
king_buckets = {} # bucket_idx -> perspective
for sq in range(64): # All squares
piece = b.piece_at(sq)
if piece and piece.unicode_symbol() in ("\u265a", "\u2654"): # King
perspective_king = 1 if piece.color == chess.WHITE else 0
# Compute oriented king square
oriented_ksq = sq ^ PIECE_SQUARE_INDEX_OFFSET ^ (56 * perspective)
oriented_ksq = oriented_ksq ^ (56 * perspective)
# Get bucket index (0-7)
bucket_idx = oriented_ksq % 8 # Use mod 8 to keep in range
# Only set if not already set (prefer white king perspective)
if bucket_idx not in king_buckets:
king_buckets[bucket_idx] = perspective_king
# Set king bucket features
for bucket_idx, perspective_king in king_buckets.items():
feature_idx = 336 + bucket_idx * 8 + perspective_king
features[feature_idx] = 1.0
# Extract FullThreats features (60,720 features)
# Stockfish NNUE exact formula:
# Index = piece1_idx * 158 + piece2_idx
# where piece_idx = piece_sq * 6 + piece_type
# This encoding matches Stockfish's 60,720 features
# Precompute attacks for efficiency
piece_attacks = {}
for sq in range(64):
piece = b.piece_at(sq)
if piece is None:
piece_attacks[sq] = set()
continue
piece_type = PIECE_TYPE_MAP.get(piece.unicode_symbol())
if piece_type is None:
piece_attacks[sq] = set()
continue
attacks_bb = b.attacks(piece_type)
attacks_set = set()
for to_sq in range(64):
if attacks_bb & (1 << to_sq):
attacks_set.add(to_sq)
piece_attacks[sq] = attacks_set
# For each piece that attacks another piece
for from_sq in range(64):
from_piece = b.piece_at(from_sq)
if from_piece is None:
continue
from_type = PIECE_TYPE_MAP.get(from_piece.unicode_symbol())
if from_type is None:
continue
from_piece_idx = from_sq * 6 + from_type
# For each attacked square
for to_sq in piece_attacks[from_sq]:
to_piece = b.piece_at(to_sq)
if to_piece is None:
continue
to_type = PIECE_TYPE_MAP.get(to_piece.unicode_symbol())
if to_type is None:
continue
to_piece_idx = to_sq * 6 + to_type
# Feature index: from_piece_idx * 158 + to_piece_idx
feature_idx = from_piece_idx * 158 + to_piece_idx
features[feature_idx] = 1.0
return features

View File

@@ -0,0 +1,26 @@
"""Single linear layer NNUE model"""
import torch
import torch.nn as nn
from python.constants import TOTAL_FEATURES
class LinearEval(nn.Module):
"""
Linear(61,072 -> 1) - Single dense layer, no activation.
Outputs centipawn evaluation.
"""
def __init__(self, input_dim: int = TOTAL_FEATURES):
super().__init__()
self.linear = nn.Linear(input_dim, 1)
self.linear.weight.data.zero_()
self.linear.bias.data.zero_()
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.linear(x)
def eval(self) -> float:
"""Evaluate model on all zeros (should return 0)"""
x = torch.zeros(1, TOTAL_FEATURES)
return float(self.forward(x)[0, 0])

View File

@@ -0,0 +1,30 @@
"""Stockfish NNUE evaluation interface"""
import chess
import chess.engine
from python.constants import HALF_KA_V2_HM
class NNUEEvaluator:
"""Wrapper for Stockfish with NNUE evaluation"""
def __init__(self, stockfish_path: str = "/usr/bin/stockfish"):
self.engine = chess.engine.SimpleEngine.popen_uci(stockfish_path)
self.engine.configure({"Skill Level": 0, "UCI_LimitStrength": False})
def evaluate(self, fen: str) -> float:
"""
Get NNUE evaluation in centipawns.
Returns: positive for white advantage, negative for black
"""
board = chess.Board(fen)
result = self.engine.play(board, chess.engine.Limit(depth=1))
# Get relative centipawn score
score = result.info.score
if score.mate():
return 0 # Don't return mate scores
return float(score.relative().centipawns())
def close(self):
self.engine.quit()

77
python/python/train.py Normal file
View File

@@ -0,0 +1,77 @@
"""Training loop for NNUE linear model"""
import torch
import numpy as np
from torch.utils.data import DataLoader, TensorDataset
from python.model.nnue_linear import LinearEval
from python.model.feature_extractor import fen_to_features
from python.config import BATCH_SIZE, LEARNING_RATE, WEIGHT_DECAY, GRADIENT_CLIP, EPOCHS
def train(features: np.ndarray, labels: np.ndarray) -> LinearEval:
"""
Train the linear model.
Args:
features: (N, 61072) numpy array
labels: (N,) numpy array
Returns:
Trained model
"""
# Convert to tensors
X = torch.from_numpy(features).float()
y = torch.from_numpy(labels).float()
# Create dataset and dataloader
dataset = TensorDataset(X, y)
dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
# Initialize model
model = LinearEval()
optimizer = torch.optim.Adam(
model.parameters(), lr=LEARNING_RATE, weight_decay=WEIGHT_DECAY
)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=EPOCHS)
best_loss = float("inf")
patience_counter = 0
best_model_state = None
for epoch in range(EPOCHS):
model.train()
total_loss = 0.0
for batch_X, batch_y in dataloader:
optimizer.zero_grad()
preds = model(batch_X)
loss = torch.nn.functional.mse_loss(preds, batch_y)
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), GRADIENT_CLIP)
optimizer.step()
total_loss += loss.item()
avg_loss = total_loss / len(dataloader)
scheduler.step()
# Early stopping check
if avg_loss < best_loss:
best_loss = avg_loss
best_model_state = model.state_dict().copy()
patience_counter = 0
else:
patience_counter += 1
if (epoch + 1) % 10 == 0:
print(f"Epoch {epoch + 1}/{EPOCHS}, Loss: {avg_loss:.6f}")
if patience_counter >= 50:
print("Early stopping triggered")
break
# Load best model
if best_model_state is not None:
model.load_state_dict(best_model_state)
return model

View File

@@ -0,0 +1,39 @@
"""Main entry point for training"""
import numpy as np
from python.model.nnue_linear import LinearEval
from python.data.generate_data import generate_data_from_pgn
from python.data.preprocessing import normalize_features
from python.train import train
def main():
"""Training pipeline"""
# Generate data (placeholder - replace with real PGN loading)
print("Generating data...")
features, evals = generate_data_from_pgn(
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
)
# Normalize
print("Normalizing features...")
features = np.array(features, dtype=np.float32)
evals = np.array(evals, dtype=np.float32)
features = normalize_features(features)
# Train
print("Training...")
model = train(features, evals)
# Test
print("Testing...")
x = torch.randn(1, 61072)
with torch.no_grad():
pred = model(x)
print(f"Sample prediction: {pred.item():.4f}")
if __name__ == "__main__":
import torch
main()

View File

@@ -0,0 +1,60 @@
"""Tests for NNUE feature extraction"""
import pytest
import torch
import numpy as np
from python.model.feature_extractor import fen_to_features
from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES
class TestFeatureExtraction:
"""Tests for HalfKAv2_hm feature extraction"""
def test_feature_count(self):
"""Test total feature vector length"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert len(features) == TOTAL_FEATURES
def test_half_ka_v2_hm_features(self):
"""Test HalfKAv2_hm + FullThreats produces correct number of features"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
active = sum(1 for v in features if v > 0)
# HalfKAv2_hm: 24 pieces + 1 king bucket = 25 features
# FullThreats: ~79 features (piece-pair attack relationships)
# Total: ~103 features
assert 100 <= active <= 110 # Allow for slight variations
def test_feature_range(self):
"""Test all features are in valid range"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert all(0 <= f <= 1 for f in features)
def test_black_perspective(self):
"""Test feature extraction from black's perspective"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
assert active > 20 # Multiple pieces from black's perspective
def test_mixed_colors(self):
"""Test feature extraction with both colors on board"""
fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing
features = fen_to_features(fen)
active = sum(1 for v in features if v > 0)
assert active < 100 # Fewer pieces than full board (~103)
def test_zero_features_empty_board(self):
"""Test empty board produces zero features"""
fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1"
features = fen_to_features(fen)
assert sum(features) == 0
def test_tensor_conversion(self):
"""Test conversion to torch tensor"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
tensor = torch.tensor(features, dtype=torch.float32)
assert tensor.shape == (TOTAL_FEATURES,)

45
python/tests/test_nnue.py Normal file
View File

@@ -0,0 +1,45 @@
"""Tests for NNUE implementation"""
import pytest
import torch
import numpy as np
from python.model.nnue_linear import LinearEval
from python.constants import TOTAL_FEATURES
class TestLinearEval:
"""Tests for the linear NNUE model"""
def test_model_initialization(self):
"""Test model creates correct shape"""
model = LinearEval()
assert model.linear.in_features == TOTAL_FEATURES
assert model.linear.out_features == 1
def test_model_output_shape(self):
"""Test model outputs correct shape"""
model = LinearEval()
x = torch.randn(10, TOTAL_FEATURES)
y = model(x)
assert y.shape == (10, 1)
def test_model_zero_output(self):
"""Test model with zero input"""
model = LinearEval()
x = torch.zeros(1, TOTAL_FEATURES)
with torch.no_grad():
y = model(x)
assert y.item() == 0.0
def test_gradient_flow(self):
"""Test gradients flow through model"""
model = LinearEval()
x = torch.randn(10, TOTAL_FEATURES, requires_grad=True)
y = model(x)
loss = y.sum()
loss.backward()
assert x.grad is not None
if __name__ == "__main__":
pytest.main([__file__, "-v"])

50
python/verify_features.py Normal file
View File

@@ -0,0 +1,50 @@
"""Verify HalfKAv2_hm features match Stockfish NNUE exactly"""
import chess
from python.model.feature_extractor import fen_to_features
from python.stockfish_wrapper import NNUEEvaluator
from python.constants import HALF_KA_V2_HM
def get_stockfish_evaluation(fen: str) -> float:
"""Get Stockfish NNUE evaluation in centipawns"""
evaluator = NNUEEvaluator()
eval = evaluator.evaluate(fen)
evaluator.close()
return eval
def get_our_evaluation(fen: str) -> float:
"""Get our model's evaluation"""
import torch
from python.model.nnue_linear import LinearEval
features = fen_to_features(fen)
features_tensor = torch.tensor([features], dtype=torch.float32)
model = LinearEval()
with torch.no_grad():
eval = model(features_tensor)[0, 0].item()
return eval
# Test positions
test_positions = [
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1", # Starting
"rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1", # Black to move
"8/8/8/8/8/8/8/8 w KQkq - 0 1", # Empty board
]
print("Position\t\t\t\tStockfish\t\tOur Model\tDiff")
print("-" * 80)
for fen in test_positions:
try:
stockfish_eval = get_stockfish_evaluation(fen)
our_eval = get_our_evaluation(fen)
diff = abs(stockfish_eval - our_eval)
print(f"{fen[:25]:25}\t{stockfish_eval:10.2f}\t{our_eval:10.2f}\t{diff:.2f}")
except Exception as e:
print(f"{fen[:25]:25}\tERROR: {e}")