A repo for my thoughts

A place for my projects.

Follow me on GitHub

Introduction to segmentation

Image segmentation is the task of dividing, a given image, in multiple “zones”. Each zone corresponds to a different class, for example: sky, tree, road and building can all be different class. The segmented image is often called mask or ground truth. In other terms we can thought a segmentation problem as a multiclass classification problem applied to each pixel in an image. image info

In this post we will explore how to build an image segmentation algorithm in PyTorch. The dataset that we will use is the Scene Understanding Datasets. In this dataset there are 715 different images with relative mask regions. Each pixel can belong to eight different classes: sky, tree, road, grass, water, building, mountain, or foreground object plus a special class for unknow classification. image info

Some data preparation

We need to correctly load the data this is done in the function below:

def get_images_from_id(idx:int, df=df):
    '''this function takes a single image, mask pair from a given dataset'''    
    row = df.iloc[idx]

    image_path = DATA_DIR+'images/'+row.file_name+'.jpg'
    mask_path = DATA_DIR+'labels/'+row.file_name+'.regions.txt'

    image = cv2.imread(image_path)

    image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
    mask = np.loadtxt(mask_path)

    return image, mask

Right now the masks are encoded with a single channel, and each class is indicated by a different integer. To construct an efficient segmentation algorithm we preferer the one hot encoding where each class has his own channel, and the belonging to that class is indicated by a 1 or a 0 in the specific channel. Therefore we need two function, one to transform mask in the one hot encoding and one to get back.

def one_hot_encode(mask):
    '''this function transform the mask in the one hot encoded version'''
    out = np.zeros(shape=(*mask.shape, 9))
    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            if mask[i,j]>=0:
                out[i, j, int(mask[i,j])] = 1
            else:
                out[i,j, 8] = 1

    return out
def one_hot_decode(mask):
    '''this function transform the one hot encoded version 
    on the mask in a plottable one'''
    out = np.zeros(shape=mask.shape[:-1])

    for i in range(mask.shape[0]):
        for j in range(mask.shape[1]):
            out[i][j] = mask[i, j, :].argmax()
            if out[i][j] == 8:
                out[i][j] = -1

    return out

After the test, train and validation split we have there dataset:

  • train_df size: 457
  • valid_df size: 115
  • test_df size: 143

The train dataset will be augmented with the Albumentation library, the test and validation dataset will just be resized.

def train_aug():
    return A.Compose([        
        A.HorizontalFlip(p=0.5),
        A.augmentations.geometric.rotate.Rotate(limit=[-30,30]),
        A.augmentations.crops.transforms.RandomCrop(
            min_height,
            min_witdh,
            always_apply=False,
            p=0.1
            ),
        A.Resize(HEIGHT, WITDH),        
    ])

def valid_aug():
        return A.Compose([
        A.Resize(HEIGHT, WITDH)
    ])

Some example of augmentations: image info

PyTorch dataset

We will create a PyTorch dataset that will take care of image preprocessing and augmentation. We will also define the dataloaders to load the data at training time.

class SegmentationDataset(Dataset):

    def __init__(self, df, augmentations):
        self.df = df
        self.augmentations = augmentations
        
    def __len__(self):
        return len(self.df)
    
    def __getitem__(self, idx):

        image, mask = get_images_from_id(idx, self.df) #h,w,c
        mask = one_hot_encode(mask)
        
        if self.augmentations:
            data = self.augmentations(image = image, mask = mask)
            image = data['image'] #(h,w,c)
            mask = data['mask'] #(h,w,c)

        # allining with pytorch convention
        image = np.transpose(image, (2,0,1)).astype(np.float32) # (c, h ,w)
        mask = np.transpose(mask, (2,0,1)).astype(np.float32) # (c, h ,w)

        image = torch.Tensor(image) / 255.0
        mask = torch.Tensor(mask)
        return image, mask
  
trainloader = DataLoader(trainset, batch_size= BATCH_SIZE, shuffle=True)
validloader = DataLoader(validset, batch_size= BATCH_SIZE)
testloader = DataLoader(testset, batch_size= BATCH_SIZE)

The model code

The model we will use is a U-net imported from the segmentation library of PyTorch.

import segmentation_models_pytorch as smp

ENCODER = 'timm-efficientnet-b0'
WEIGTHS = 'imagenet'
DEVICE = 'cuda'
class SegmentationModel(nn.Module):
    def __init__(self):
        super(SegmentationModel, self).__init__()

        self.backbone = smp.Unet(
            encoder_name = ENCODER,
            encoder_weights= WEIGTHS,
            in_channels = 3,
            classes = 9,
            activation = 'softmax',
         
        )

    def forward(self, images):
        logits = self.backbone(images)

        return logits

model = SegmentationModel()
model.to(DEVICE)

Now we need to define a train and validation function to be used in the training loop. Note how the validation function does not update the model. This is important to avoid overfitting.

def train_f( dataloader, model, optimizer, loss_f):
    model.train()

    total_loss = 0.0

    for images, masks in tqdm(dataloader):
        images = images.to(DEVICE)
        masks = masks.to(DEVICE)

        optimizer.zero_grad()

        logits = model(images)

        loss = loss_f(masks, logits)
    
        loss.backward()

        optimizer.step()

        total_loss += loss.item()

    return (total_loss/ len(dataloader))

def eval_f( dataloader, model, loss_f):
    model.eval()

    total_loss = 0.0

    with torch.no_grad():
        for images, masks in tqdm(dataloader):
            images = images.to(DEVICE)
            masks = masks.to(DEVICE)

            logits = model(images)

            loss = loss_f(masks, logits)

            total_loss += loss.item()

        return (total_loss/ len(dataloader))

Now let’s look at the training loop.

opt = torch.optim.Adam(model.parameters(), lr = LR)
loss_f = nn.CrossEntropyLoss()
current_loss = np.inf

for i in range(EPOCHS):
    t_loss = train_f(trainloader, model, opt, loss_f)
    v_loss = eval_f(validloader, model, loss_f)

    if v_loss < current_loss:
        torch.save(model.state_dict(), "model.pytorch")
        print('MODEL_SAVED')
        current_loss = v_loss

    print(f'EPOCH: {i + 1}, Train loss: {t_loss}, valid loss: {v_loss}')

we loop on the epochs and save the best model.

Some results

Now that the training is complete we can look at some results.

image info

The model can clearly be improved, but this was just a proof of concept. The code can be found here

back home