Home Syllabus Schedule AI Prompt Resources

Hint: Stereo Loss Function

This hint is organized with a smaller hint first. Then below the divider is a larger hint with more code snippets. Choose the amount of support that you’d like.


Your starting code looks like this:

class StereoLoss(nn.Module):
    def __init__(self, height, width):
        super().__init__()
        self.height = height
        self.width = width

        y, x = torch.meshgrid(
            torch.linspace(-1, 1, height),
            torch.linspace(-1, 1, width),
            indexing="ij"
        )

    def forward(self, imgL, imgR, disp):
        # TODO!
        return -1.0  # placeholder

Background: Self-Supervised Stereo Depth

In this project, we don’t have ground truth depth. Instead, we have stereo image pairs and use photometric consistency as our training signal.

Our loss function now needs three inputs:

The disparity predictions are per-pixel values that specify a shift: a disparity of \(d\) at pixel \((x, y)\) in the left image means that pixel \((x - d, y)\) from the right image sees the same 3D point.

If we know the correct disparity at each pixel, we can warp the right image to match the left image. The loss measures how well the warped image matches the actual left image.

The Steps

The heavy lifting will happen in pytorch’s grid_sample function. Your loss function needs to do three things:

  1. Build a sampling grid: Use the predicted disparity to compute where each pixel in the left image should sample from in the right image. The starter code gives you a meshgrid in normalized [-1, 1] coordinates—you’ll modify the x coordinates based on disparity.

  2. Warp the right image: Use F.grid_sample to resample the right image according to your grid. Read the PyTorch docs for this function carefully.

  3. Compute photometric loss: Measure how different the actual left image is from the warped right image. L1 loss (mean absolute difference) is a good choice.

Important Functions

Things to Watch Out For

About register_buffer

The starter code creates a meshgrid in __init__, but the local variables x and y disappear when __init__ finishes. You need to store them so forward can use them.

You might think to just write self.x = x, and that would work—except for one problem. When you call .to(device) on a PyTorch module to move it to GPU, only parameters and buffers get moved. Regular instance variables like self.x = x would stay on CPU while your images are on GPU, causing a device mismatch error.

The solution is register_buffer:

self.register_buffer('x', x)
self.register_buffer('y', y)

This tells PyTorch: “these tensors belong to this module.” Now when you call criterion.to(device) or criterion.cuda(), the buffers move along with everything else. You access them as self.x and self.y in your forward method.

Simple version: If this is a lot, here’s a simpler approach: move the entire meshgrid call into the forward function. Yes, you’re recomputing the meshgrid indices over and over. But you don’t have to deal with pytorch’s memory model. It should still run, just slower.


Bigger Hint

If you’d like more specific guidance, read on.

Step 1: Store the Meshgrid

The starter code creates a meshgrid but doesn’t save it. You need to store these as instance variables so you can use them in forward:

self.register_buffer('x', x)
self.register_buffer('y', y)

Using register_buffer ensures these tensors move to GPU with the model.

Step 2: Build the Sampling Grid

Disparity tells us the horizontal shift. In normalized coordinates, we need to subtract the disparity from x coordinates:

# disp has shape (N, 1, H, W)
# Normalize disparity to the [-1, 1] coordinate system
disp_normalized = disp / (self.width / 2)  # Convert pixel disparity to normalized coords

# Create sampling coordinates
sample_x = self.x - disp_normalized.squeeze(1)  # Shape: (N, H, W)
sample_y = self.y.expand(N, -1, -1)              # Shape: (N, H, W)

Why divide by width/2? The normalized coordinate system spans 2 units (-1 to 1) over width pixels. So 1 pixel = 2/width normalized units.

Step 3: Use grid_sample

PyTorch’s F.grid_sample does the actual warping. It expects: - Input image: (N, C, H, W) - Grid: (N, H, W, 2) where the last dimension is (x, y) coordinates

# Stack coordinates into the grid format grid_sample expects
grid = torch.stack([sample_x, sample_y], dim=-1)  # Shape: (N, H, W, 2)

# Warp the right image to the left view
warped_R = F.grid_sample(imgR, grid, mode='bilinear', padding_mode='border',
                         align_corners=True)

Important settings: - align_corners=True: Matches our coordinate system where -1 and +1 are at pixel centers - padding_mode='border': Repeats edge pixels for out-of-bounds coordinates

Step 4: Compute Photometric Loss

Compare the warped right image to the actual left image with an L1 loss:

loss = torch.abs(imgL - warped_R).mean()

Common Bugs

  1. Forgetting to move meshgrid to GPU: Use register_buffer or manually call .to(device)
  2. Wrong disparity sign: If your disparity is positive, you should subtract it from x
  3. Wrong grid shape: Must be (N, H, W, 2), not (N, 2, H, W)
  4. Forgetting to squeeze disparity: disp is (N, 1, H, W), but you need (N, H, W) for the grid