finished CNN, good accruacy overall but not good for minority and harder classes
This commit is contained in:
@@ -433,6 +433,85 @@
|
||||
"PATH = '../models/tutorial-cnn.pth'\n",
|
||||
"torch.save(model.state_dict(), PATH)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "a24dd4f0",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Test on test set using Pytorch tutorial method for both total and per class (maybe spot overfitting as well)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 15,
|
||||
"id": "bc158602",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Test Accuracy: 74.53%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"correct = 0\n",
|
||||
"total = 0\n",
|
||||
"\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in test_loader:\n",
|
||||
" images, labels = data\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" outputs = model(images)\n",
|
||||
" _, predicted = torch.max(outputs, 1)\n",
|
||||
" total += labels.size(0)\n",
|
||||
" correct += (predicted == labels).sum().item()\n",
|
||||
"\n",
|
||||
"print(f'Test Accuracy: {100 * correct / total:.2f}%')"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 16,
|
||||
"id": "8cc7ed40",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Accuracy for class: Bicycle is 43.5%\n",
|
||||
"Accuracy for class: Bus is 53.8%\n",
|
||||
"Accuracy for class: Car is 87.6%\n",
|
||||
"Accuracy for class: Motorcycle is 75.2%\n",
|
||||
"Accuracy for class: NonVehicles is 99.0%\n",
|
||||
"Accuracy for class: Taxi is 20.0%\n",
|
||||
"Accuracy for class: Truck is 13.9%\n",
|
||||
"Accuracy for class: Van is 24.7%\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"correct_pred = {classname: 0 for classname in dataset.classes}\n",
|
||||
"total_pred = {classname: 0 for classname in dataset.classes}\n",
|
||||
"\n",
|
||||
"with torch.no_grad():\n",
|
||||
" for data in test_loader:\n",
|
||||
" images, labels = data\n",
|
||||
" images, labels = images.to(device), labels.to(device)\n",
|
||||
" outputs = model(images)\n",
|
||||
" _, predictions = torch.max(outputs, 1)\n",
|
||||
" for label, prediction in zip(labels, predictions):\n",
|
||||
" if label == prediction:\n",
|
||||
" correct_pred[dataset.classes[label]] += 1\n",
|
||||
" total_pred[dataset.classes[label]] += 1\n",
|
||||
"\n",
|
||||
"for classname, correct_count in correct_pred.items():\n",
|
||||
" accuracy = 100 * float(correct_count) / total_pred[classname]\n",
|
||||
" print(f'Accuracy for class: {classname:10s} is {accuracy:.1f}%')"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user