From 64ed10e95c065e1b95e46f3de43aad2c468123e2 Mon Sep 17 00:00:00 2001 From: KeshavAnandCode Date: Wed, 18 Mar 2026 17:26:52 -0500 Subject: [PATCH] Torch working --- notebooks/model-training-v1.ipynb | 82 +++++++++++++++++++++++++++++++ requirements.txt | 29 +++++++++++ 2 files changed, 111 insertions(+) diff --git a/notebooks/model-training-v1.ipynb b/notebooks/model-training-v1.ipynb index bb46749..a0234b4 100644 --- a/notebooks/model-training-v1.ipynb +++ b/notebooks/model-training-v1.ipynb @@ -51,6 +51,14 @@ " print(f\"{class_name}: {count} images\")" ] }, + { + "cell_type": "markdown", + "id": "64122ad4", + "metadata": {}, + "source": [ + "Check out sample image from dataset" + ] + }, { "cell_type": "code", "execution_count": 3, @@ -91,6 +99,80 @@ "plt.title(first_class)\n", "plt.show()" ] + }, + { + "cell_type": "markdown", + "id": "c19ec00a", + "metadata": {}, + "source": [ + "Ensure that all images are RGB, all of same resolution" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "3cedd586", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Unique sizes: {(64, 64)}\n", + "Unique modes: {'RGB'}\n" + ] + } + ], + "source": [ + "sizes = set()\n", + "modes = set()\n", + "\n", + "for class_name in os.listdir(data_dir):\n", + " class_path = os.path.join(data_dir, class_name)\n", + " if not os.path.isdir(class_path):\n", + " continue\n", + " for img_name in os.listdir(class_path):\n", + " img = Image.open(os.path.join(class_path, img_name))\n", + " sizes.add(img.size)\n", + " modes.add(img.mode)\n", + "\n", + "print(f\"Unique sizes: {sizes}\")\n", + "print(f\"Unique modes: {modes}\")" + ] + }, + { + "cell_type": "markdown", + "id": "88ac961b", + "metadata": {}, + "source": [ + "Ensure that torch works with GPU (5080) [Credit: Claude]" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "8f556b22", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "True\n", + "NVIDIA GeForce RTX 5080\n", + "GPU works! torch.Size([1000, 1000])\n" + ] + } + ], + "source": [ + "import torch\n", + "print(torch.cuda.is_available()) \n", + "print(torch.cuda.get_device_name(0)) \n", + "\n", + "x = torch.randn(1000, 1000).cuda()\n", + "y = x @ x\n", + "print(\"GPU works!\", y.shape)" + ] } ], "metadata": { diff --git a/requirements.txt b/requirements.txt index dff06e2..18e4972 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,22 +1,46 @@ asttokens==3.0.1 comm==0.2.3 contourpy==1.3.3 +cuda-bindings==12.9.4 +cuda-pathfinder==1.2.2 +cuda-toolkit==12.8.1 cycler==0.12.1 debugpy==1.8.20 decorator==5.2.1 executing==2.2.1 +filelock==3.25.2 fonttools==4.62.1 +fsspec==2026.2.0 ipykernel==7.2.0 ipython==9.10.0 ipython_pygments_lexers==1.1.1 jedi==0.19.2 +Jinja2==3.1.6 jupyter_client==8.8.0 jupyter_core==5.9.1 kiwisolver==1.5.0 +MarkupSafe==3.0.2 matplotlib==3.10.8 matplotlib-inline==0.2.1 +mpmath==1.3.0 nest-asyncio==1.6.0 +networkx==3.6.1 numpy==2.4.3 +nvidia-cublas-cu12==12.8.4.1 +nvidia-cuda-cupti-cu12==12.8.90 +nvidia-cuda-nvrtc-cu12==12.8.93 +nvidia-cuda-runtime-cu12==12.8.90 +nvidia-cudnn-cu12==9.20.0.48 +nvidia-cufft-cu12==11.3.3.83 +nvidia-cufile-cu12==1.13.1.3 +nvidia-curand-cu12==10.3.9.90 +nvidia-cusolver-cu12==11.7.3.90 +nvidia-cusparse-cu12==12.5.8.93 +nvidia-cusparselt-cu12==0.7.1 +nvidia-nccl-cu12==2.29.7 +nvidia-nvjitlink-cu12==12.8.93 +nvidia-nvshmem-cu12==3.4.5 +nvidia-nvtx-cu12==12.8.90 packaging==26.0 parso==0.8.6 pexpect==4.9.0 @@ -32,7 +56,12 @@ python-dateutil==2.9.0.post0 pyzmq==27.1.0 six==1.17.0 stack-data==0.6.3 +sympy==1.14.0 +torch==2.12.0.dev20260318+cu128 +torchaudio==2.11.0.dev20260318+cu128 +torchvision==0.26.0.dev20260318+cu128 tornado==6.5.5 traitlets==5.14.3 +triton==3.6.0+git9844da95 typing_extensions==4.15.0 wcwidth==0.6.0