finished CNN, good accruacy overall but not good for minority and harder classes
This commit is contained in:
@@ -433,6 +433,85 @@
|
|||||||
"PATH = '../models/tutorial-cnn.pth'\n",
|
"PATH = '../models/tutorial-cnn.pth'\n",
|
||||||
"torch.save(model.state_dict(), PATH)"
|
"torch.save(model.state_dict(), PATH)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "a24dd4f0",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Test on test set using Pytorch tutorial method for both total and per class (maybe spot overfitting as well)"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 15,
|
||||||
|
"id": "bc158602",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Test Accuracy: 74.53%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"correct = 0\n",
|
||||||
|
"total = 0\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for data in test_loader:\n",
|
||||||
|
" images, labels = data\n",
|
||||||
|
" images, labels = images.to(device), labels.to(device)\n",
|
||||||
|
" outputs = model(images)\n",
|
||||||
|
" _, predicted = torch.max(outputs, 1)\n",
|
||||||
|
" total += labels.size(0)\n",
|
||||||
|
" correct += (predicted == labels).sum().item()\n",
|
||||||
|
"\n",
|
||||||
|
"print(f'Test Accuracy: {100 * correct / total:.2f}%')"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 16,
|
||||||
|
"id": "8cc7ed40",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Accuracy for class: Bicycle is 43.5%\n",
|
||||||
|
"Accuracy for class: Bus is 53.8%\n",
|
||||||
|
"Accuracy for class: Car is 87.6%\n",
|
||||||
|
"Accuracy for class: Motorcycle is 75.2%\n",
|
||||||
|
"Accuracy for class: NonVehicles is 99.0%\n",
|
||||||
|
"Accuracy for class: Taxi is 20.0%\n",
|
||||||
|
"Accuracy for class: Truck is 13.9%\n",
|
||||||
|
"Accuracy for class: Van is 24.7%\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"correct_pred = {classname: 0 for classname in dataset.classes}\n",
|
||||||
|
"total_pred = {classname: 0 for classname in dataset.classes}\n",
|
||||||
|
"\n",
|
||||||
|
"with torch.no_grad():\n",
|
||||||
|
" for data in test_loader:\n",
|
||||||
|
" images, labels = data\n",
|
||||||
|
" images, labels = images.to(device), labels.to(device)\n",
|
||||||
|
" outputs = model(images)\n",
|
||||||
|
" _, predictions = torch.max(outputs, 1)\n",
|
||||||
|
" for label, prediction in zip(labels, predictions):\n",
|
||||||
|
" if label == prediction:\n",
|
||||||
|
" correct_pred[dataset.classes[label]] += 1\n",
|
||||||
|
" total_pred[dataset.classes[label]] += 1\n",
|
||||||
|
"\n",
|
||||||
|
"for classname, correct_count in correct_pred.items():\n",
|
||||||
|
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
|
||||||
|
" print(f'Accuracy for class: {classname:10s} is {accuracy:.1f}%')"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Reference in New Issue
Block a user