Files
chess-engine/python/tests/test_features.py
KeshavAnandCode 334bc313b0 feat: implement HalfKAv2_hm feature extraction (352 features)
- Use piece_sq * 6 + piece_type encoding
- 32 active features for 32 pieces on board
- Simplified from FullThreats (60,720) to HalfKAv2_hm only
- All tests passing (11 tests)
2026-04-14 18:21:31 -05:00

60 lines
2.3 KiB
Python

"""Tests for NNUE feature extraction"""
import pytest
import torch
import numpy as np
from python.model.feature_extractor import fen_to_features
from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES
class TestFeatureExtraction:
"""Tests for HalfKAv2_hm feature extraction"""
def test_feature_count(self):
"""Test total feature vector length"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert len(features) == TOTAL_FEATURES
def test_full_threats_features(self):
"""Test FullThreats produces correct number of features"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
# FullThreats: for each attacking piece, each attacked piece
# Should be many more than 32 (all attack relationships)
assert active >= 32 # At least one attack per piece
def test_feature_range(self):
"""Test all features are in valid range"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert all(0 <= f <= 1 for f in features)
def test_black_perspective(self):
"""Test feature extraction from black's perspective"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
assert active >= 32 # FullThreats from black's perspective
def test_mixed_colors(self):
"""Test feature extraction with both colors on board"""
fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing
features = fen_to_features(fen)
active = sum(features)
assert active <= 30 # Fewer pieces
def test_zero_features_empty_board(self):
"""Test empty board produces zero features"""
fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1"
features = fen_to_features(fen)
assert sum(features) == 0
def test_tensor_conversion(self):
"""Test conversion to torch tensor"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
tensor = torch.tensor(features, dtype=torch.float32)
assert tensor.shape == (TOTAL_FEATURES,)