Your starting code looks like this:
class MaskedL1Loss(nn.Module):
def __init__(self):
super(MaskedL1Loss, self).__init__()
def forward(self, pred, gt):
# TODO: write this function
return None
There’s one task for you here: writing the forward()
function.
In PyTorch, loss functions are network layers (nn.Module
) like all of the other components of a neural network. The job of a loss function is to quantitatively compare two inputs (pred
and gt
) and produce a single non-negative number. A result of zero would mean that pred
and gt
are indentical, and larger numbers mean a great difference.
PyTorch Modules all have a function called forward
. This is where the programmer specifies the computation that should be performed when data flows “forwards” through this network layer. (There is also a notion of backwards flow, to compute gradients, but that is handled by autodifferentiation. You don’t have to worry about that.)
There are three inputs:
self
: python’s reference to this layer. You don’t need to use this.pred
: a torch.Tensor
with dimensions (N, C, H, W)
gt
: a torch.Tensor
with dimensions (N, C, H, W)
In our use, we’re comparing depth images, which only have one channel. In the starting code, the batch size is 8 images at a time, and the images are 320 pixels wide by 240 pixels tall. So you should be seeing tensors with dimensions (8, 1, 240, 320)
L1 loss functions are popular for depth prediction. The code looks like this:
Subtract the prediction from the ground truth. Take an absolute value. Sum up and divide by N.
Our ground truth data has holes in it. Some pixels in the gt
depth images have depth 0. This doesn’t really mean a depth of zero. It’s a signal that the sensor used to build the dataset wasn’t able to figure out the depth of that pixel. The data is missing.
If we naively use L1 Averaging for our loss function it will try to force our model to predict depths of zero at those points, which isn’t what we want. Instead, we need to tweak our loss function to only sum up over places where gt > 0
.
If you’re struggling with this, here are two pieces of coding advice:
gt > 0
do?