Hint: Loss Function

Your starting code looks like this:

class MaskedL1Loss(nn.Module):
    def __init__(self):
        super(MaskedL1Loss, self).__init__()

    def forward(self, pred, gt):
        # TODO: write this function
        return None

There’s one task for you here: writing the forward() function.

Background

In PyTorch, loss functions are network layers (nn.Module) like all of the other components of a neural network. The job of a loss function is to quantitatively compare two inputs (pred and gt) and produce a single non-negative number. A result of zero would mean that pred and gt are indentical, and larger numbers mean a great difference.

PyTorch Modules all have a function called forward. This is where the programmer specifies the computation that should be performed when data flows “forwards” through this network layer. (There is also a notion of backwards flow, to compute gradients, but that is handled by autodifferentiation. You don’t have to worry about that.)

There are three inputs:

In our use, we’re comparing depth images, which only have one channel. In the starting code, the batch size is 8 images at a time, and the images are 320 pixels wide by 240 pixels tall. So you should be seeing tensors with dimensions (8, 1, 240, 320)

L1 Averaging

L1 loss functions are popular for depth prediction. The code looks like this:

def forward(self, pred, gt):
    loss = torch.abs(pred - gt)
    return loss.sum() / loss.numel()

Subtract the prediction from the ground truth. Take an absolute value. Sum up and divide by N.

Masked L1 Averaging

Our ground truth data has holes in it. Some pixels in the gt depth images have depth 0. This doesn’t really mean a depth of zero. It’s a signal that the sensor used to build the dataset wasn’t able to figure out the depth of that pixel. The data is missing.

If we naively use L1 Averaging for our loss function it will try to force our model to predict depths of zero at those points, which isn’t what we want. Instead, we need to tweak our loss function to only sum up over places where gt > 0.

One last hint

If you’re struggling with this, here are two pieces of coding advice: