diff --git a/pytorch/TIPSv2_Segmentation.ipynb b/pytorch/TIPSv2_Segmentation.ipynb new file mode 100644 index 0000000..9b9d3cd --- /dev/null +++ b/pytorch/TIPSv2_Segmentation.ipynb @@ -0,0 +1,722 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "T4" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "id": "DjlILc7jN0MO" + }, + "source": [ + "Copyright 2026 Google LLC.\n", + "\n", + "SPDX-License-Identifier: Apache-2.0" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "qs1CcsHIxtfx" + }, + "outputs": [], + "source": [ + "# @title TIPS Zero-Shot Segmentation notebook\n", + "\n", + "# Licensed under the Apache License, Version 2.0 (the \"License\");\n", + "# you may not use this file except in compliance with the License.\n", + "# You may obtain a copy of the License at\n", + "\n", + "# https://www.apache.org/licenses/LICENSE-2.0\n", + "\n", + "# Unless required by applicable law or agreed to in writing, software\n", + "# distributed under the License is distributed on an \"AS IS\" BASIS,\n", + "# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.\n", + "# See the License for the specific language governing permissions and\n", + "# limitations under the License." + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": { + "id": "k3kNA5pUn2zR" + } + }, + { + "cell_type": "markdown", + "source": [ + "## Install Dependencies" + ], + "metadata": { + "id": "HbiYTlwtqE7j" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Install dependencies and clone TIPS repo\n", + "import dataclasses\n", + "import json\n", + "import math\n", + "import os\n", + "import shutil\n", + "import subprocess\n", + "import sys\n", + "import warnings\n", + "import zipfile\n", + "import gdown\n", + "from typing import Callable\n", + "\n", + "import numpy as np\n", + "import PIL.Image\n", + "import torch\n", + "import torch.nn.functional as F\n", + "import torchvision.transforms as TVT\n", + "import torchvision.transforms.functional as TVTF\n", + "import tqdm\n", + "from torch import Tensor, nn\n", + "\n", + "# Root directory for all files (Colab default is /content).\n", + "ROOT_DIR = os.getcwd()\n", + "TIPS_DIR = os.path.join(ROOT_DIR, 'tips')\n", + "\n", + "# Install required packages.\n", + "!pip install -q torch torchvision torchaudio\n", + "!pip install -q tensorflow_text mediapy jax jaxlib scikit-learn\n", + "\n", + "# Clone the TIPS repository.\n", + "if not os.path.exists(TIPS_DIR):\n", + " !git clone https://github.com/google-deepmind/tips.git {TIPS_DIR}\n", + "\n", + "# Add the root directory to PYTHONPATH so that `tips.*` imports work.\n", + "if ROOT_DIR not in sys.path:\n", + " sys.path.insert(0, ROOT_DIR)\n", + "\n", + "print(f'ROOT_DIR: {ROOT_DIR}')\n", + "print(f'TIPS_DIR: {TIPS_DIR}')\n", + "print('Installation complete!')" + ], + "metadata": { + "id": "hSfr_yMxmTKq" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Download and Extract Data" + ], + "metadata": { + "id": "YNhV_fZzpVAu" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Download the checkpoints, sample images, and ADE20k Dataset\n", + "import urllib.request\n", + "\n", + "variant = 'L' # @param [\"B\", \"L\", \"So\", \"g\"]\n", + "\n", + "CHECKPOINT_BASE_URL = 'https://storage.googleapis.com/tips_data/v2_0/checkpoints/pytorch'\n", + "TOKENIZER_URL = 'https://storage.googleapis.com/tips_data/v1_0/checkpoints/tokenizer.model'\n", + "IMAGE_BASE_URL = 'https://raw.githubusercontent.com/google-deepmind/tips/main/scenic/images'\n", + "ADE20K_URL = 'http://data.csail.mit.edu/places/ADEchallenge/ADEChallengeData2016.zip'\n", + "ADE20K_TMP_PATH = '/content/ADEChallengeData2016.zip'\n", + "\n", + "# Directories for checkpoints and images (under ROOT_DIR).\n", + "CKPT_DIR = os.path.join(ROOT_DIR, 'checkpoints')\n", + "IMG_DIR = os.path.join(ROOT_DIR, 'images')\n", + "os.makedirs(CKPT_DIR, exist_ok=True)\n", + "os.makedirs(IMG_DIR, exist_ok=True)\n", + "\n", + "# Download checkpoints for selected variant.\n", + "V2_CHECKPOINT_MAP = {\n", + " 'B': ('tips_v2_oss_b14_vision.npz', 'tips_v2_oss_b14_text.npz'),\n", + " 'L': ('tips_v2_oss_l14_vision.npz', 'tips_v2_oss_l14_text.npz'),\n", + " 'So': ('tips_v2_oss_so14_vision.npz', 'tips_v2_oss_so14_text.npz'),\n", + " 'g': ('tips_v2_oss_g14_vision.npz', 'tips_v2_oss_g14_text.npz'),\n", + "}\n", + "vision_ckpt_name, text_ckpt_name = V2_CHECKPOINT_MAP[variant]\n", + "\n", + "for ckpt_name in [vision_ckpt_name, text_ckpt_name]:\n", + " ckpt_path = os.path.join(CKPT_DIR, ckpt_name)\n", + " if not os.path.exists(ckpt_path):\n", + " print(f'\\nDownloading {ckpt_name}...')\n", + " urllib.request.urlretrieve(f'{CHECKPOINT_BASE_URL}/{ckpt_name}', ckpt_path)\n", + " print(f' Saved to {ckpt_path}')\n", + " else:\n", + " print(f' {ckpt_name} already exists.')\n", + "\n", + "# Download tokenizer.\n", + "tokenizer_file = os.path.join(CKPT_DIR, 'tokenizer.model')\n", + "if not os.path.exists(tokenizer_file):\n", + " print('\\nDownloading tokenizer...')\n", + " urllib.request.urlretrieve(TOKENIZER_URL, tokenizer_file)\n", + " print(f' Saved to {tokenizer_file}')\n", + "else:\n", + " print(' tokenizer.model already exists.')\n", + "\n", + "# Download sample images.\n", + "sample_images = ['example_image.jpg', 'example_image_2.jpg']\n", + "for img_name in sample_images:\n", + " img_path = os.path.join(IMG_DIR, img_name)\n", + " if not os.path.exists(img_path):\n", + " print(f'\\nDownloading {img_name}...')\n", + " urllib.request.urlretrieve(f'{IMAGE_BASE_URL}/{img_name}', img_path)\n", + " print(f' Saved to {img_path}')\n", + " else:\n", + " print(f' {img_name} already exists.')\n", + "\n", + "# Download ADE20K Dataset\n", + "if not os.path.exists(ADE20K_TMP_PATH):\n", + " print(f'\\nDownloading ADEChallengeData2016.zip...')\n", + " urllib.request.urlretrieve(ADE20K_URL, ADE20K_TMP_PATH)\n", + " print(f' Saved to {ADE20K_TMP_PATH}')\n", + "else:\n", + " print(' ADEChallengeData2016.zip already exists.')\n", + "\n", + "print('\\nAll downloads complete!')" + ], + "metadata": { + "id": "8WC0Ubv3nmJ6" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#@title Extract the ADE20K Dataset\n", + "\n", + "ADE20K_DIR = \"/tmp/ADEChallengeData2016\"\n", + "\n", + "if not os.path.isdir(ADE20K_DIR):\n", + " zip_path = \"/content/ADEChallengeData2016.zip\"\n", + " print(\"Extracting...\")\n", + " with zipfile.ZipFile(zip_path, \"r\") as zf:\n", + " zf.extractall(\"/tmp/\")\n", + " print(f\"Extracted to {ADE20K_DIR}\")\n", + "else:\n", + " print(f\"ADE20K already at {ADE20K_DIR}\")\n", + "\n", + "ADE20K_CLASS_NAMES = (\n", + " \"wall\", \"building\", \"sky\", \"floor\", \"tree\", \"ceiling\", \"road\",\n", + " \"bed\", \"windowpane\", \"grass\", \"cabinet\", \"sidewalk\", \"person\", \"earth\",\n", + " \"door\", \"table\", \"mountain\", \"plant\", \"curtain\", \"chair\", \"car\", \"water\",\n", + " \"painting\", \"sofa\", \"shelf\", \"house\", \"sea\", \"mirror\", \"rug\", \"field\",\n", + " \"armchair\", \"seat\", \"fence\", \"desk\", \"rock\", \"wardrobe\", \"lamp\", \"bathtub\",\n", + " \"railing\", \"cushion\", \"base\", \"box\", \"column\", \"signboard\",\n", + " \"chest of drawers\", \"counter\", \"sand\", \"sink\", \"skyscraper\", \"fireplace\",\n", + " \"refrigerator\", \"grandstand\", \"path\", \"stairs\", \"runway\", \"case\",\n", + " \"pool table\", \"pillow\", \"screen door\", \"stairway\", \"river\", \"bridge\",\n", + " \"bookcase\", \"blind\", \"coffee table\", \"toilet\", \"flower\", \"book\", \"hill\",\n", + " \"bench\", \"countertop\", \"stove\", \"palm\", \"kitchen island\", \"computer\",\n", + " \"swivel chair\", \"boat\", \"bar\", \"arcade machine\", \"hovel\", \"bus\", \"towel\",\n", + " \"light\", \"truck\", \"tower\", \"chandelier\", \"awning\", \"streetlight\", \"booth\",\n", + " \"television receiver\", \"airplane\", \"dirt track\", \"apparel\", \"pole\", \"land\",\n", + " \"bannister\", \"escalator\", \"ottoman\", \"bottle\", \"buffet\", \"poster\", \"stage\",\n", + " \"van\", \"ship\", \"fountain\", \"conveyer belt\", \"canopy\", \"washer\", \"plaything\",\n", + " \"swimming pool\", \"stool\", \"barrel\", \"basket\", \"waterfall\", \"tent\", \"bag\",\n", + " \"minibike\", \"cradle\", \"oven\", \"ball\", \"food\", \"step\", \"tank\", \"trade name\",\n", + " \"microwave\", \"pot\", \"animal\", \"bicycle\", \"lake\", \"dishwasher\", \"screen\",\n", + " \"blanket\", \"sculpture\", \"hood\", \"sconce\", \"vase\", \"traffic light\", \"tray\",\n", + " \"ashcan\", \"fan\", \"pier\", \"crt screen\", \"plate\", \"monitor\", \"bulletin board\",\n", + " \"shower\", \"radiator\", \"glass\", \"clock\", \"flag\",\n", + ")\n", + "\n", + "NORMALIZE_TIPS = TVT.Normalize(mean=[0.0, 0.0, 0.0], std=[1.0, 1.0, 1.0])\n", + "\n", + "\n", + "class Ade20kDataset(torch.utils.data.Dataset):\n", + " CLASS_NAMES = ADE20K_CLASS_NAMES\n", + "\n", + " def __init__(self, root: str, split: str, transform: Callable) -> None:\n", + " self.transform = transform\n", + " img_dir = os.path.join(root, \"images\", split)\n", + " ann_dir = os.path.join(root, \"annotations\", split)\n", + " self.images = sorted([\n", + " os.path.join(img_dir, f) for f in os.listdir(img_dir) if f.endswith(\".jpg\")\n", + " ])\n", + " self.annotations = sorted([\n", + " os.path.join(ann_dir, f) for f in os.listdir(ann_dir) if f.endswith(\".png\")\n", + " ])\n", + " assert len(self.images) == len(self.annotations), (\n", + " f\"Mismatch: {len(self.images)} images vs {len(self.annotations)} annotations\"\n", + " )\n", + " print(f\"Loaded {len(self.images)} ADE20K {split} images.\")\n", + "\n", + " def __getitem__(self, idx: int) -> tuple:\n", + " img = PIL.Image.open(self.images[idx]).convert(\"RGB\")\n", + " mask = np.array(PIL.Image.open(self.annotations[idx]))\n", + " img = self.transform(img)\n", + " mask = torch.from_numpy(mask).long()\n", + " mask = torch.where((mask == 0) | (mask == 255), 255, mask - 1)\n", + " return img, mask\n", + "\n", + " def __len__(self) -> int:\n", + " return len(self.images)" + ], + "metadata": { + "id": "6487acc4" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Define Functions and Configure Data" + ], + "metadata": { + "id": "wdv43aDBpibN" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Define Segmentation Helper Functions — Value Attention\n", + "\n", + "#######\n", + "# In zero-shot segmentation eval, we use the \"values trick\", where we use the\n", + "# value embeddings from the last transformer block's attention layer as the\n", + "# patch representations, rather then the final patch embeddings of the encoder.\n", + "# This approach was introduced in MaskCLIP.\n", + "\n", + "def _get_all_blocks(model_image):\n", + " if model_image.chunked_blocks:\n", + " blocks = []\n", + " for chunk in model_image.blocks:\n", + " for blk in chunk:\n", + " if not isinstance(blk, nn.Identity):\n", + " blocks.append(blk)\n", + " return blocks\n", + " return list(model_image.blocks)\n", + "\n", + "\n", + "def encode_image_value_attention(model_image, img: Tensor) -> Tensor:\n", + " B, _, H, W = img.shape\n", + " P = model_image.patch_size\n", + " new_H = math.ceil(H / P) * P\n", + " new_W = math.ceil(W / P) * P\n", + "\n", + " if (H, W) != (new_H, new_W):\n", + " img = F.interpolate(img, size=(new_H, new_W), mode=\"bicubic\", align_corners=False)\n", + "\n", + " B, _, h_i, w_i = img.shape\n", + "\n", + " x = model_image.prepare_tokens_with_masks(img)\n", + "\n", + " num_register = model_image.num_register_tokens\n", + " all_blocks = _get_all_blocks(model_image)\n", + " for i, blk in enumerate(all_blocks):\n", + " if i < len(all_blocks) - 1:\n", + " x = blk(x)\n", + " else:\n", + " x_normed = blk.norm1(x)\n", + " b_dim, n_dim, c_dim = x_normed.shape\n", + " qkv = (\n", + " blk.attn.qkv(x_normed)\n", + " .reshape(b_dim, n_dim, 3, blk.attn.num_heads, c_dim // blk.attn.num_heads)\n", + " .permute(2, 0, 3, 1, 4)\n", + " )\n", + " v = qkv[2]\n", + " v_out = v.transpose(1, 2).reshape(b_dim, n_dim, c_dim)\n", + " v_out = blk.attn.proj(v_out)\n", + " v_out = blk.ls1(v_out)\n", + " x_val = v_out + x\n", + "\n", + " y_val = blk.norm2(x_val)\n", + " y_val = blk.ls2(blk.mlp(y_val))\n", + " x_val = x_val + y_val\n", + "\n", + " x_val = model_image.norm(x_val)\n", + " patch_tokens = x_val[:, 1 + num_register:, :]\n", + " blocks_patches = patch_tokens.reshape(B, h_i // P, w_i // P, -1).contiguous()\n", + " return blocks_patches\n", + "\n", + "\n", + "class ShortSideResize(nn.Module):\n", + " def __init__(self, size: int, interpolation: TVT.InterpolationMode) -> None:\n", + " super().__init__()\n", + " self.size = size\n", + " self.interpolation = interpolation\n", + "\n", + " def forward(self, img: Tensor) -> Tensor:\n", + " _, h, w = TVTF.get_dimensions(img)\n", + " if (w <= h and w == self.size) or (h <= w and h == self.size):\n", + " return img\n", + " if w < h:\n", + " new_w = self.size\n", + " new_h = int(self.size * h / w)\n", + " return TVTF.resize(img, [new_h, new_w], self.interpolation)\n", + " else:\n", + " new_h = self.size\n", + " new_w = int(self.size * w / h)\n", + " return TVTF.resize(img, [new_h, new_w], self.interpolation)\n", + "\n", + "\n", + "def predict_whole(model_image, img: Tensor, text_features: Tensor) -> Tensor:\n", + " _, H, W = img.shape\n", + " blocks_feats = encode_image_value_attention(model_image, img.unsqueeze(0))\n", + " _, h, w, _ = blocks_feats.shape\n", + " blocks_feats = blocks_feats.squeeze(0)\n", + "\n", + " blocks_feats = F.normalize(blocks_feats, p=2, dim=-1)\n", + " cos = torch.einsum(\"cd,hwd->chw\", text_features, blocks_feats)\n", + "\n", + " return cos\n", + "\n", + "def predict_slide(model_image, img: Tensor, text_features: Tensor, side: int, stride: int) -> Tensor:\n", + " _, H, W = img.shape\n", + " num_classes, _ = text_features.shape\n", + " probs = torch.zeros([num_classes, H, W], device=\"cuda\")\n", + " counts = torch.zeros([H, W], device=\"cuda\")\n", + " h_grids = max(H - side + stride - 1, 0) // stride + 1\n", + " w_grids = max(W - side + stride - 1, 0) // stride + 1\n", + " for i in range(h_grids):\n", + " for j in range(w_grids):\n", + " y1 = i * stride\n", + " x1 = j * stride\n", + " y2 = min(y1 + side, H)\n", + " x2 = min(x1 + side, W)\n", + " y1 = max(y2 - side, 0)\n", + " x1 = max(x2 - side, 0)\n", + "\n", + " img_window = img[:, y1:y2, x1:x2]\n", + " cos = predict_whole(model_image, img_window, text_features)\n", + "\n", + " cos = F.interpolate(\n", + " cos.unsqueeze(0),\n", + " size=img_window.shape[1:],\n", + " mode=\"bilinear\",\n", + " align_corners=False,\n", + " ).squeeze(0)\n", + " probs[:, y1:y2, x1:x2] += cos.softmax(dim=0)\n", + " counts[y1:y2, x1:x2] += 1\n", + " probs /= counts\n", + "\n", + " return probs" + ], + "metadata": { + "id": "04dd66ac" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# @title Configure Prompt Templates\n", + "# Subset of ImageNet prompts used in the TCL paper.\n", + "_SUB_IMAGENET_PROMPTS = [\n", + " 'itap of a {}.',\n", + " 'a bad photo of a {}.',\n", + " 'a origami {}.',\n", + " 'a photo of the large {}.',\n", + " 'a {} in a video game.',\n", + " 'art of the {}.',\n", + " 'a photo of the small {}.',\n", + "]\n", + "\n", + "# Templates used in the TCL paper.\n", + "# https://github.com/khanrc/tcl/blob/main/datasets/templates.py#L145\n", + "_TCL_PROMPTS = _SUB_IMAGENET_PROMPTS + [\n", + " 'a photo of many {}.',\n", + " 'a photo of {}s.',\n", + "]\n", + "\n", + "PROMPT_TEMPLATES = _TCL_PROMPTS\n", + "print(f\"Using {len(PROMPT_TEMPLATES)} TCL prompt templates\")" + ], + "metadata": { + "id": "1-9bXqemRQu-" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "## Prepare Model" + ], + "metadata": { + "id": "3VLB6jiQptCw" + } + }, + { + "cell_type": "code", + "source": [ + "#@title Import and Patch TIPS Model\n", + "\n", + "import io\n", + "import os\n", + "import numpy as np\n", + "import tensorflow as tf\n", + "import tensorflow_text\n", + "from tips.pytorch import image_encoder\n", + "from tips.pytorch import text_encoder\n", + "import torch\n", + "from torch import nn\n", + "import torch.nn.functional as F\n", + "\n", + "PATCH_SIZE = 14\n", + "VOCAB_SIZE = 32000\n", + "MAX_SEQ_LEN = 64\n", + "\n", + "device = \"cuda\"\n", + "\n", + "# Load the tokenizer directly from the downloaded checkpoint\n", + "tokenizer_path = '/content/checkpoints/tokenizer.model'\n", + "with tf.io.gfile.GFile(tokenizer_path, 'rb') as f:\n", + " tokenizer = tensorflow_text.SentencepieceTokenizer(f.read())\n", + "\n", + "def get_pretokenized_batch(class_idx: int) -> tuple[torch.Tensor, torch.Tensor]:\n", + " # Generate texts on the fly using ADE20K_CLASS_NAMES and PROMPT_TEMPLATES\n", + " class_name = ADE20K_CLASS_NAMES[class_idx]\n", + " texts = [template.format(class_name) for template in PROMPT_TEMPLATES]\n", + "\n", + " # Tokenize the texts\n", + " tokens = tokenizer.tokenize(texts).to_list()\n", + "\n", + " max_l = min(max(len(ids) for ids in tokens), MAX_SEQ_LEN)\n", + " num = len(tokens)\n", + " token_ids = np.zeros((num, max_l), dtype=np.int64)\n", + " paddings = np.ones((num, max_l), dtype=np.float32)\n", + "\n", + " for i, ids in enumerate(tokens):\n", + " length = min(len(ids), max_l)\n", + " token_ids[i, :length] = ids[:length]\n", + " paddings[i, :length] = 0.0\n", + "\n", + " return (\n", + " torch.from_numpy(token_ids).to(device, non_blocking=True),\n", + " torch.from_numpy(paddings).to(device, non_blocking=True),\n", + " )\n", + "\n", + "# Monkey-patch to address the device mismatch in TextEncoder\n", + "def new_text_encoder_call(self, ids, paddings):\n", + " \"\"\"Applies TextEncoder module with device-aware positional embeddings.\"\"\"\n", + " _, seq_length = ids.shape\n", + " mask = (paddings == 0).type(torch.float32)\n", + " mask = mask.permute(1, 0) # NL -> LN\n", + " x = self.token_embedding(ids)\n", + " if self.scale_sqrt_depth:\n", + " x = x * (self.embedding_dim**0.5)\n", + "\n", + " pos_embeddings = self.pos_embedder(seq_length=seq_length)\n", + " x = x + pos_embeddings.to(x.device)\n", + "\n", + " x = x.permute(1, 0, 2) # NLD -> LND\n", + " x = self.transformer(x, mask)\n", + " x = x.permute(1, 0, 2) # LND -> NLD\n", + " x = self.ln_final(x)\n", + " x = self.pooling(x, compatible_paddings=paddings[:, :, None])\n", + " return x\n", + "\n", + "text_encoder.TextEncoder.__call__ = new_text_encoder_call\n", + "print(\"TextEncoder has been patched and tokenizer is loaded for on-the-fly tokenization!\")" + ], + "metadata": { + "id": "A9CeNsT9y550" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "id": "3eMXX0d3Idr7" + }, + "outputs": [], + "source": [ + "#@title Configure the TIPS model.\n", + "\n", + "image_size = 448 # @param {type: \"number\"}\n", + "stride = 336 # @param {type: \"number\"}\n", + "side = 448 # @param {type: \"number\"}\n", + "resize = 448 # @param {type: \"number\"}\n", + "mode = \"slide\" # @param {type: \"string\"}\n", + "\n", + "# Checkpoint and tokenizer paths (absolute paths).\n", + "image_encoder_checkpoint = os.path.join(CKPT_DIR, vision_ckpt_name)\n", + "text_encoder_checkpoint = os.path.join(CKPT_DIR, text_ckpt_name)\n", + "tokenizer_path = os.path.join(CKPT_DIR, 'tokenizer.model')\n", + "\n", + "print(f'Image encoder checkpoint: {image_encoder_checkpoint}')\n", + "print(f'Text encoder checkpoint: {text_encoder_checkpoint}')\n", + "print(f'Tokenizer path: {tokenizer_path}')" + ] + }, + { + "cell_type": "code", + "source": [ + "# @title Load TIPS Model\n", + "\n", + "## Load Vision Model\n", + "weights_image = dict(np.load(image_encoder_checkpoint, allow_pickle=False))\n", + "for key in weights_image:\n", + " weights_image[key] = torch.tensor(weights_image[key])\n", + "\n", + "# ViT-g-based models use swiglu instead of mlp\n", + "ffn_layer = 'swiglu' if variant == 'g' else 'mlp'\n", + "\n", + "embeddings_image, spatial_features = [], []\n", + "\n", + "with torch.no_grad():\n", + " # Load the vision encoder.\n", + " model_image = image_encoder.vit_large(\n", + " img_size=image_size,\n", + " patch_size=PATCH_SIZE,\n", + " ffn_layer=ffn_layer,\n", + " block_chunks=0,\n", + " init_values=1.0,\n", + " interpolate_antialias=True,\n", + " interpolate_offset=0.0,\n", + " )\n", + " model_image.load_state_dict(weights_image)\n", + " model_image = model_image.to(device)\n", + "\n", + "## Load Text Model\n", + "text_config = {\n", + " 'hidden_size': 1024,\n", + " 'mlp_dim': 4096,\n", + " 'num_heads': 16,\n", + " 'num_layers': 12,\n", + "}\n", + "\n", + "with open(text_encoder_checkpoint, 'rb') as fin:\n", + " inbuffer = io.BytesIO(fin.read())\n", + "np_weights_text = dict(np.load(inbuffer, allow_pickle=False))\n", + "\n", + "weights_text = {}\n", + "for key, value in np_weights_text.items():\n", + " weights_text[key] = torch.from_numpy(value)\n", + "\n", + "temperature = weights_text.pop('temperature')\n", + "with torch.no_grad():\n", + " # Load the text encoder.\n", + " model_text = text_encoder.TextEncoder(\n", + " text_config,\n", + " vocab_size=VOCAB_SIZE,\n", + " )\n", + "\n", + " model_text.load_state_dict(weights_text)\n", + " model_text = model_text.to(device)" + ], + "metadata": { + "id": "85QA6ywCx2Yj" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Run Inference" + ], + "metadata": { + "id": "ReEhOSbMpFRR" + } + }, + { + "cell_type": "code", + "source": [ + "# @title Compute mIoU\n", + "def compute_miou(confusion_matrix: torch.Tensor) -> float:\n", + " intersection = confusion_matrix.diag()\n", + " union = confusion_matrix.sum(dim=1) + confusion_matrix.sum(dim=0) - intersection\n", + " iou = intersection / (union + 1e-10)\n", + " valid = union > 0\n", + " return iou[valid].mean().item()\n", + "\n", + "transform = TVT.Compose(\n", + " [\n", + " ShortSideResize(resize, TVT.InterpolationMode.BICUBIC),\n", + " TVT.ToTensor(),\n", + " NORMALIZE_TIPS,\n", + " ]\n", + ")\n", + "dataset = Ade20kDataset(ADE20K_DIR, \"validation\", transform)\n", + "class_names = dataset.CLASS_NAMES\n", + "num_classes = len(class_names)\n", + "print(f\"Dataset: {len(dataset)} images, {num_classes} classes\")\n", + "dataloader = torch.utils.data.DataLoader(\n", + " dataset,\n", + " batch_size=None,\n", + " num_workers=0,\n", + " shuffle=False,\n", + " pin_memory=True,\n", + ")\n", + "\n", + "text_feats = []\n", + "with torch.no_grad():\n", + " for ci, class_name in enumerate(tqdm.tqdm(class_names, desc=\"Class names\", unit=\"name\", ncols=0)):\n", + " ids, paddings = get_pretokenized_batch(ci)\n", + " ids, paddings = ids.to(device), paddings.to(device)\n", + " feats = model_text(ids, paddings)\n", + " feats = F.normalize(feats, p=2, dim=-1)\n", + " feats = feats.mean(dim=0)\n", + " feats = F.normalize(feats, p=2, dim=-1)\n", + " text_feats.append(feats)\n", + "text_feats = torch.stack(text_feats).float()\n", + "print(f\"Text features shape: {text_feats.shape}, dtype: {text_feats.dtype}\")\n", + "\n", + "confusion_matrix = torch.zeros(num_classes, num_classes, dtype=torch.long, device=device)\n", + "for idx, (img, target) in enumerate(tqdm.tqdm(dataloader, desc=\"Segmentation\", unit=\"img\", ncols=0)):\n", + " _, H, W = img.shape\n", + " H_target, W_target = target.shape\n", + " img = img.to(device, dtype=torch.float, non_blocking=True)\n", + " target = target.to(device, non_blocking=True)\n", + " if idx == 0:\n", + " tqdm.tqdm.write(f\"Image shape: {img.shape}\")\n", + " tqdm.tqdm.write(f\"Target shape: {target.shape}\")\n", + "\n", + " with torch.inference_mode():\n", + " if mode == \"whole\":\n", + " pred = predict_whole(model_image, img, text_feats)\n", + " elif mode == \"slide\":\n", + " pred = predict_slide(model_image, img, text_feats, side, stride)\n", + " else:\n", + " raise ValueError(f\"Unknown mode {mode}\")\n", + "\n", + " pred = F.interpolate(pred.unsqueeze(0), size=(H_target, W_target), mode=\"bilinear\", align_corners=False)\n", + " pred = pred.squeeze(0).argmax(dim=0)\n", + "\n", + " mask = target != 255\n", + " pred_valid = pred[mask]\n", + " target_valid = target[mask]\n", + " if pred_valid.numel() > 0:\n", + " indices = target_valid * num_classes + pred_valid\n", + " confusion_matrix += torch.bincount(indices, minlength=num_classes * num_classes).reshape(num_classes, num_classes)\n", + "\n", + "miou = compute_miou(confusion_matrix)\n", + "print(f\"Segmentation mIoU: {100 * miou:.2f}\")" + ], + "metadata": { + "id": "34dfe700" + }, + "execution_count": null, + "outputs": [] + } + ] +}