Files
chess-engine/python/tests/test_features.py
KeshavAnandCode 3eccd97536 feat: implement HalfKAv2_hm feature extraction (352 features)
- Implement piece-square feature extraction
- 32 active features for 32 pieces on board
- Tests for feature extraction (7 tests)
- Fix: piece_sq * 6 + piece_type mapping
2026-04-14 18:11:15 -05:00

58 lines
2.2 KiB
Python

"""Tests for NNUE feature extraction"""
import pytest
import torch
import numpy as np
from python.model.feature_extractor import fen_to_features
from python.constants import HALF_KA_V2_HM, TOTAL_FEATURES
class TestFeatureExtraction:
"""Tests for HalfKAv2_hm feature extraction"""
def test_feature_count(self):
"""Test total feature vector length"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert len(features) == TOTAL_FEATURES
def test_half_ka_hm_features(self):
"""Test HalfKAv2_hm produces correct number of features (32 pieces on full board)"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
assert active == 32 # 32 pieces on full board
def test_feature_range(self):
"""Test all features are in valid range"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
assert all(0 <= f <= 1 for f in features)
def test_black_perspective(self):
"""Test feature extraction from black's perspective"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR b KQkq - 0 1"
features = fen_to_features(fen)
active = sum(features)
assert active == 32 # 32 pieces
def test_mixed_colors(self):
"""Test feature extraction with both colors on board"""
fen = "r3k2r/pppppppp/8/8/8/8/PPPPPPPP/R3K2R w KQkq - 0 1" # King and queen missing
features = fen_to_features(fen)
active = sum(features)
assert active <= 30 # Fewer pieces
def test_zero_features_empty_board(self):
"""Test empty board produces zero features"""
fen = "8/8/8/8/8/8/8/8 w KQkq - 0 1"
features = fen_to_features(fen)
assert sum(features) == 0
def test_tensor_conversion(self):
"""Test conversion to torch tensor"""
fen = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
features = fen_to_features(fen)
tensor = torch.tensor(features, dtype=torch.float32)
assert tensor.shape == (TOTAL_FEATURES,)