This hint is organized with a smaller hint first. Then below the divider is a larger hint with more code snippets. Choose the amount of support that you’d like.
Your starting code looks like this:
class StereoLoss(nn.Module):
def __init__(self, height, width):
super().__init__()
self.height = height
self.width = width
y, x = torch.meshgrid(
torch.linspace(-1, 1, height),
torch.linspace(-1, 1, width),
indexing="ij"
)
def forward(self, imgL, imgR, disp):
# TODO!
return -1.0 # placeholderIn this project, we don’t have ground truth depth. Instead, we have stereo image pairs and use photometric consistency as our training signal.
Our loss function now needs three inputs:
The disparity predictions are per-pixel values that specify a shift: a disparity of \(d\) at pixel \((x, y)\) in the left image means that pixel \((x - d, y)\) from the right image sees the same 3D point.
If we know the correct disparity at each pixel, we can warp the right image to match the left image. The loss measures how well the warped image matches the actual left image.
The heavy lifting will happen in pytorch’s grid_sample
function. Your loss function needs to do three things:
Build a sampling grid: Use the predicted disparity to compute where each pixel in the left image should sample from in the right image. The starter code gives you a meshgrid in normalized [-1, 1] coordinates—you’ll modify the x coordinates based on disparity.
Warp the right image: Use
F.grid_sample to resample the right image according to your
grid. Read the PyTorch docs for this function carefully.
Compute photometric loss: Measure how different the actual left image is from the warped right image. L1 loss (mean absolute difference) is a good choice.
torch.meshgrid: Already in your starter code. Creates
coordinate grids.self.register_buffer(name, tensor): Stores a tensor as
part of the module so it moves to GPU automatically.torch.stack: Combines tensors along a new
dimension.F.grid_sample: The key function for warping. Read its
documentation.forward.grid_sample expects coordinates in [-1, 1] range, not
pixel coordinates. You’ll need to convert your disparity values
appropriately.grid_sample needs a specific shape.register_bufferThe starter code creates a meshgrid in __init__, but the
local variables x and y disappear when
__init__ finishes. You need to store them so
forward can use them.
You might think to just write self.x = x, and that would
work—except for one problem. When you call .to(device) on a
PyTorch module to move it to GPU, only parameters and
buffers get moved. Regular instance variables like
self.x = x would stay on CPU while your images are on GPU,
causing a device mismatch error.
The solution is register_buffer:
self.register_buffer('x', x)
self.register_buffer('y', y)This tells PyTorch: “these tensors belong to this module.” Now when
you call criterion.to(device) or
criterion.cuda(), the buffers move along with everything
else. You access them as self.x and self.y in
your forward method.
Simple version: If this is a lot, here’s a simpler
approach: move the entire meshgrid call into the forward
function. Yes, you’re recomputing the meshgrid indices over and over.
But you don’t have to deal with pytorch’s memory model. It should still
run, just slower.
If you’d like more specific guidance, read on.
The starter code creates a meshgrid but doesn’t save it. You need to
store these as instance variables so you can use them in
forward:
self.register_buffer('x', x)
self.register_buffer('y', y)Using register_buffer ensures these tensors move to GPU
with the model.
Disparity tells us the horizontal shift. In normalized coordinates, we need to subtract the disparity from x coordinates:
# disp has shape (N, 1, H, W)
# Normalize disparity to the [-1, 1] coordinate system
disp_normalized = disp / (self.width / 2) # Convert pixel disparity to normalized coords
# Create sampling coordinates
sample_x = self.x - disp_normalized.squeeze(1) # Shape: (N, H, W)
sample_y = self.y.expand(N, -1, -1) # Shape: (N, H, W)Why divide by width/2? The normalized coordinate system
spans 2 units (-1 to 1) over width pixels. So 1 pixel =
2/width normalized units.
PyTorch’s F.grid_sample does the actual warping. It
expects: - Input image: (N, C, H, W) - Grid:
(N, H, W, 2) where the last dimension is
(x, y) coordinates
# Stack coordinates into the grid format grid_sample expects
grid = torch.stack([sample_x, sample_y], dim=-1) # Shape: (N, H, W, 2)
# Warp the right image to the left view
warped_R = F.grid_sample(imgR, grid, mode='bilinear', padding_mode='border',
align_corners=True)Important settings: -
align_corners=True: Matches our coordinate system where -1
and +1 are at pixel centers - padding_mode='border':
Repeats edge pixels for out-of-bounds coordinates
Compare the warped right image to the actual left image with an L1 loss:
loss = torch.abs(imgL - warped_R).mean()register_buffer or manually call
.to(device)(N, H, W, 2), not (N, 2, H, W)(N, 1, H, W), but you need (N, H, W) for the
grid