From bde3b9726660d81740824ef1e0f691f8092fd15d Mon Sep 17 00:00:00 2001 From: KeshavAnandCode Date: Wed, 18 Mar 2026 18:10:08 -0500 Subject: [PATCH] finished CNN, good accruacy overall but not good for minority and harder classes --- notebooks/tutorial-cnn.ipynb | 79 ++++++++++++++++++++++++++++++++++++ 1 file changed, 79 insertions(+) diff --git a/notebooks/tutorial-cnn.ipynb b/notebooks/tutorial-cnn.ipynb index 945c9f9..2388ce0 100644 --- a/notebooks/tutorial-cnn.ipynb +++ b/notebooks/tutorial-cnn.ipynb @@ -433,6 +433,85 @@ "PATH = '../models/tutorial-cnn.pth'\n", "torch.save(model.state_dict(), PATH)" ] + }, + { + "cell_type": "markdown", + "id": "a24dd4f0", + "metadata": {}, + "source": [ + "Test on test set using Pytorch tutorial method for both total and per class (maybe spot overfitting as well)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "bc158602", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Test Accuracy: 74.53%\n" + ] + } + ], + "source": [ + "correct = 0\n", + "total = 0\n", + "\n", + "with torch.no_grad():\n", + " for data in test_loader:\n", + " images, labels = data\n", + " images, labels = images.to(device), labels.to(device)\n", + " outputs = model(images)\n", + " _, predicted = torch.max(outputs, 1)\n", + " total += labels.size(0)\n", + " correct += (predicted == labels).sum().item()\n", + "\n", + "print(f'Test Accuracy: {100 * correct / total:.2f}%')" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "8cc7ed40", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Accuracy for class: Bicycle is 43.5%\n", + "Accuracy for class: Bus is 53.8%\n", + "Accuracy for class: Car is 87.6%\n", + "Accuracy for class: Motorcycle is 75.2%\n", + "Accuracy for class: NonVehicles is 99.0%\n", + "Accuracy for class: Taxi is 20.0%\n", + "Accuracy for class: Truck is 13.9%\n", + "Accuracy for class: Van is 24.7%\n" + ] + } + ], + "source": [ + "correct_pred = {classname: 0 for classname in dataset.classes}\n", + "total_pred = {classname: 0 for classname in dataset.classes}\n", + "\n", + "with torch.no_grad():\n", + " for data in test_loader:\n", + " images, labels = data\n", + " images, labels = images.to(device), labels.to(device)\n", + " outputs = model(images)\n", + " _, predictions = torch.max(outputs, 1)\n", + " for label, prediction in zip(labels, predictions):\n", + " if label == prediction:\n", + " correct_pred[dataset.classes[label]] += 1\n", + " total_pred[dataset.classes[label]] += 1\n", + "\n", + "for classname, correct_count in correct_pred.items():\n", + " accuracy = 100 * float(correct_count) / total_pred[classname]\n", + " print(f'Accuracy for class: {classname:10s} is {accuracy:.1f}%')" + ] } ], "metadata": {