From d756b6bd1a580960dc1d8fdbbe08165837da5d3a Mon Sep 17 00:00:00 2001 From: KeshavAnandCode Date: Wed, 18 Mar 2026 17:33:10 -0500 Subject: [PATCH] split --- notebooks/model-training-v1.ipynb | 71 +++++++++++++++++++++++++++++++ 1 file changed, 71 insertions(+) diff --git a/notebooks/model-training-v1.ipynb b/notebooks/model-training-v1.ipynb index a0234b4..c171d74 100644 --- a/notebooks/model-training-v1.ipynb +++ b/notebooks/model-training-v1.ipynb @@ -173,6 +173,77 @@ "y = x @ x\n", "print(\"GPU works!\", y.shape)" ] + }, + { + "cell_type": "markdown", + "id": "870fadbe", + "metadata": {}, + "source": [ + "Transform image dataset into tensors. Normalized with 0.5 mean and 0.5 STD (can be differente to use pretrained weights)." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "37793c77", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "['Bicycle', 'Bus', 'Car', 'Motorcycle', 'NonVehicles', 'Taxi', 'Truck', 'Van']\n", + "26378\n" + ] + } + ], + "source": [ + "from torchvision import datasets, transforms\n", + "from torch.utils.data import random_split, DataLoader\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(),\n", + " transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5])\n", + "])\n", + "\n", + "dataset = datasets.ImageFolder(root=data_dir, transform=transform)\n", + "\n", + "print(dataset.classes) # should print 8 classes\n", + "print(len(dataset)) # total image count" + ] + }, + { + "cell_type": "markdown", + "id": "ac33bc21", + "metadata": {}, + "source": [ + "Split 80-20 and save into 2 variables" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "f68c1a25", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train: 21102, Test: 5276\n" + ] + } + ], + "source": [ + "import math \n", + "\n", + "train_size = math.floor(len(dataset) * 0.8)\n", + "test_size = len(dataset) - train_size\n", + "\n", + "train_dataset, test_dataset = random_split(dataset, [train_size, test_size])\n", + "\n", + "print(f\"Train: {len(train_dataset)}, Test: {len(test_dataset)}\")" + ] } ], "metadata": {