split
This commit is contained in:
@@ -173,6 +173,77 @@
|
||||
"y = x @ x\n",
|
||||
"print(\"GPU works!\", y.shape)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "870fadbe",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Transform image dataset into tensors. Normalized with 0.5 mean and 0.5 STD (can be differente to use pretrained weights)."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "37793c77",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"['Bicycle', 'Bus', 'Car', 'Motorcycle', 'NonVehicles', 'Taxi', 'Truck', 'Van']\n",
|
||||
"26378\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from torchvision import datasets, transforms\n",
|
||||
"from torch.utils.data import random_split, DataLoader\n",
|
||||
"\n",
|
||||
"transform = transforms.Compose([\n",
|
||||
" transforms.ToTensor(),\n",
|
||||
" transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
||||
"])\n",
|
||||
"\n",
|
||||
"dataset = datasets.ImageFolder(root=data_dir, transform=transform)\n",
|
||||
"\n",
|
||||
"print(dataset.classes) # should print 8 classes\n",
|
||||
"print(len(dataset)) # total image count"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"id": "ac33bc21",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"Split 80-20 and save into 2 variables"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"id": "f68c1a25",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Train: 21102, Test: 5276\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"import math \n",
|
||||
"\n",
|
||||
"train_size = math.floor(len(dataset) * 0.8)\n",
|
||||
"test_size = len(dataset) - train_size\n",
|
||||
"\n",
|
||||
"train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
|
||||
"\n",
|
||||
"print(f\"Train: {len(train_dataset)}, Test: {len(test_dataset)}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
||||
Reference in New Issue
Block a user