split
This commit is contained in:
@@ -173,6 +173,77 @@
|
|||||||
"y = x @ x\n",
|
"y = x @ x\n",
|
||||||
"print(\"GPU works!\", y.shape)"
|
"print(\"GPU works!\", y.shape)"
|
||||||
]
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "870fadbe",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Transform image dataset into tensors. Normalized with 0.5 mean and 0.5 STD (can be differente to use pretrained weights)."
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 6,
|
||||||
|
"id": "37793c77",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"['Bicycle', 'Bus', 'Car', 'Motorcycle', 'NonVehicles', 'Taxi', 'Truck', 'Van']\n",
|
||||||
|
"26378\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"from torchvision import datasets, transforms\n",
|
||||||
|
"from torch.utils.data import random_split, DataLoader\n",
|
||||||
|
"\n",
|
||||||
|
"transform = transforms.Compose([\n",
|
||||||
|
" transforms.ToTensor(),\n",
|
||||||
|
" transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n",
|
||||||
|
"])\n",
|
||||||
|
"\n",
|
||||||
|
"dataset = datasets.ImageFolder(root=data_dir, transform=transform)\n",
|
||||||
|
"\n",
|
||||||
|
"print(dataset.classes) # should print 8 classes\n",
|
||||||
|
"print(len(dataset)) # total image count"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "markdown",
|
||||||
|
"id": "ac33bc21",
|
||||||
|
"metadata": {},
|
||||||
|
"source": [
|
||||||
|
"Split 80-20 and save into 2 variables"
|
||||||
|
]
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"cell_type": "code",
|
||||||
|
"execution_count": 7,
|
||||||
|
"id": "f68c1a25",
|
||||||
|
"metadata": {},
|
||||||
|
"outputs": [
|
||||||
|
{
|
||||||
|
"name": "stdout",
|
||||||
|
"output_type": "stream",
|
||||||
|
"text": [
|
||||||
|
"Train: 21102, Test: 5276\n"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"source": [
|
||||||
|
"import math \n",
|
||||||
|
"\n",
|
||||||
|
"train_size = math.floor(len(dataset) * 0.8)\n",
|
||||||
|
"test_size = len(dataset) - train_size\n",
|
||||||
|
"\n",
|
||||||
|
"train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n",
|
||||||
|
"\n",
|
||||||
|
"print(f\"Train: {len(train_dataset)}, Test: {len(test_dataset)}\")"
|
||||||
|
]
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
"metadata": {
|
"metadata": {
|
||||||
|
|||||||
Reference in New Issue
Block a user