finished CNN, good accruacy overall but not good for minority and harder classes

This commit is contained in:
KeshavAnandCode
2026-03-18 18:10:08 -05:00
parent 4556d8c780
commit bde3b97266

View File

@@ -433,6 +433,85 @@
"PATH = '../models/tutorial-cnn.pth'\n",
"torch.save(model.state_dict(), PATH)"
]
},
{
"cell_type": "markdown",
"id": "a24dd4f0",
"metadata": {},
"source": [
"Test on test set using Pytorch tutorial method for both total and per class (maybe spot overfitting as well)"
]
},
{
"cell_type": "code",
"execution_count": 15,
"id": "bc158602",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Test Accuracy: 74.53%\n"
]
}
],
"source": [
"correct = 0\n",
"total = 0\n",
"\n",
"with torch.no_grad():\n",
" for data in test_loader:\n",
" images, labels = data\n",
" images, labels = images.to(device), labels.to(device)\n",
" outputs = model(images)\n",
" _, predicted = torch.max(outputs, 1)\n",
" total += labels.size(0)\n",
" correct += (predicted == labels).sum().item()\n",
"\n",
"print(f'Test Accuracy: {100 * correct / total:.2f}%')"
]
},
{
"cell_type": "code",
"execution_count": 16,
"id": "8cc7ed40",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Accuracy for class: Bicycle is 43.5%\n",
"Accuracy for class: Bus is 53.8%\n",
"Accuracy for class: Car is 87.6%\n",
"Accuracy for class: Motorcycle is 75.2%\n",
"Accuracy for class: NonVehicles is 99.0%\n",
"Accuracy for class: Taxi is 20.0%\n",
"Accuracy for class: Truck is 13.9%\n",
"Accuracy for class: Van is 24.7%\n"
]
}
],
"source": [
"correct_pred = {classname: 0 for classname in dataset.classes}\n",
"total_pred = {classname: 0 for classname in dataset.classes}\n",
"\n",
"with torch.no_grad():\n",
" for data in test_loader:\n",
" images, labels = data\n",
" images, labels = images.to(device), labels.to(device)\n",
" outputs = model(images)\n",
" _, predictions = torch.max(outputs, 1)\n",
" for label, prediction in zip(labels, predictions):\n",
" if label == prediction:\n",
" correct_pred[dataset.classes[label]] += 1\n",
" total_pred[dataset.classes[label]] += 1\n",
"\n",
"for classname, correct_count in correct_pred.items():\n",
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
" print(f'Accuracy for class: {classname:10s} is {accuracy:.1f}%')"
]
}
],
"metadata": {