Matheus Schmitz
LinkedIn
Github Portfolio
Using an image-to-image GAN to convert satellite imagery to street maps.
Based on the Pix2Pix architecture: https://arxiv.org/abs/1611.07004
# File manipulation imports for Google Colab
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir("/content/drive/My Drive/Colab Notebooks/Satellite_to_Map_GAN/")
Mounted at /content/drive
import torch
from torch import nn
class DiscBlock(nn.Module):
    def __init__(self, in_channels, out_channels, stride):
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, bias=False, padding_mode='reflect'),
            nn.BatchNorm2d(out_channels),
            nn.PReLU()
        )
    def forward(self, x):
        return self.conv(x)
class Discriminator(nn.Module):
    def __init__(self, in_channels=3, features=[64, 128, 256, 512]): 
        super().__init__()
        
        # First block doesn't use batch norm
        self.initial = nn.Sequential(
            nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.PReLU()
        )
        # All remainder blocks have batch norm
        layers = []
        for idx in range(len(features)-1):
            layers.append(
                DiscBlock(in_channels=features[idx], out_channels=features[idx+1], 
                         stride = 1 if features[idx+1] == features[-1] else 2),
                )
        self.backbone = nn.Sequential(*layers)
        self.final = nn.Conv2d(in_channels=features[-1], out_channels=1, kernel_size=4,
                               stride=1, padding=1, padding_mode='reflect')
    def forward(self, x, y):
        x = torch.cat([x,y], dim=1)
        x = self.initial(x)
        x = self.backbone(x)
        return self.final(x)
def test():
    x = torch.randn((1, 3, 256, 256))
    y = torch.randn((1, 3, 256, 256))
    model = Discriminator()
    preds = model(x, y)
    print(preds.shape)
test()
torch.Size([1, 1, 26, 26])
import torch
from torch import nn
class GenBLock(nn.Module):
    def __init__(self, in_channels, out_channels, encoder=True, use_dropout=False):
        '''
        encoder = True → Encoder 
        encoder = False → Decoder
        '''
        super().__init__()
        self.conv = nn.Sequential(
            nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, padding_mode='reflect') 
            if encoder
            else nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
            nn.BatchNorm2d(out_channels),
            nn.PReLU()
        )
        self.use_dropout = use_dropout
        self.dropout = nn.Dropout(0.5)
    def forward(self, x):
        x = self.conv(x)
        return self.dropout(x) if self.use_dropout else x
class Generator(nn.Module):
    def __init__(self, in_channels=3, features=64):
        super().__init__()
        # Encoder 
        self.encoder1 = nn.Sequential(
            nn.Conv2d(in_channels, out_channels=features, kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.PReLU()
        ) # 128 x 128
        self.encoder2 = GenBLock(features  , features*2, encoder=True, use_dropout=False) # 64 x 64
        self.encoder3 = GenBLock(features*2, features*4, encoder=True, use_dropout=False) # 32 x 32
        self.encoder4 = GenBLock(features*4, features*8, encoder=True, use_dropout=False) # 16 x 16
        self.encoder5 = GenBLock(features*8, features*8, encoder=True, use_dropout=False) # 8 x 8
        self.encoder6 = GenBLock(features*8, features*8, encoder=True, use_dropout=False) # 4 x 4
        self.encoder7 = GenBLock(features*8, features*8, encoder=True, use_dropout=False) # 2 x 2
        # Blottleneck
        self.bottleneck = nn.Sequential(
            nn.Conv2d(features*8, features*8, kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
            nn.PReLU()
        ) # 1 x 1
        # Decoder
        '''
        Uses skip-connection from the matching encoder layer,
        hence input will be features*2
        '''
        self.decoder1 = GenBLock(features*8  , features*8, encoder=False, use_dropout=True) # 2 x 2
        self.decoder2 = GenBLock(features*8*2, features*8, encoder=False, use_dropout=True) # 4 x 4
        self.decoder3 = GenBLock(features*8*2, features*8, encoder=False, use_dropout=True) # 8 x 8
        self.decoder4 = GenBLock(features*8*2, features*8, encoder=False, use_dropout=False) # 16 x 16
        self.decoder5 = GenBLock(features*8*2, features*4, encoder=False, use_dropout=False) # 32 x 32
        self.decoder6 = GenBLock(features*4*2, features*2, encoder=False, use_dropout=False) # 64 x 64
        self.decoder7 = GenBLock(features*2*2, features  , encoder=False, use_dropout=False) # 128 x 128
        # Output 
        self.output = nn.Sequential(
            nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
            nn.Tanh()
        ) # 256 x 256
    def forward(self, x):
        e1 = self.encoder1(x)
        e2 = self.encoder2(e1)
        e3 = self.encoder3(e2)
        e4 = self.encoder4(e3)
        e5 = self.encoder5(e4)
        e6 = self.encoder6(e5)
        e7 = self.encoder7(e6)
        bottleneck = self.bottleneck(e7)
        d1 = self.decoder1(bottleneck)
        d2 = self.decoder2(torch.cat([d1, e7], dim=1))
        d3 = self.decoder3(torch.cat([d2, e6], dim=1))
        d4 = self.decoder4(torch.cat([d3, e5], dim=1))
        d5 = self.decoder5(torch.cat([d4, e4], dim=1))
        d6 = self.decoder6(torch.cat([d5, e3], dim=1))
        d7 = self.decoder7(torch.cat([d6, e2], dim=1))
        return self.output(torch.cat([d7, e1], dim=1))
def test():
    x = torch.randn((1, 3, 256, 256))
    model = Generator(in_channels=3, features=64)
    preds = model(x)
    print(preds.shape)
test()
torch.Size([1, 3, 256, 256])
!pip install --quiet Pillow albumentations --upgrade --force-reinstall
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.0.17 requires typing-extensions<4.2.0,>=3.7.4.1; python_version < "3.8", but you have typing-extensions 4.3.0 which is incompatible.
spacy 3.3.1 requires typing-extensions<4.2.0,>=3.7.4; python_version < "3.8", but you have typing-extensions 4.3.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/maps/train"
VAL_DIR = "data/maps/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 5
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"
both_transform = A.Compose(
    [A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5),], additional_targets={"image0": "image"},
)
transform_only_input = A.Compose(
    [
        A.ColorJitter(p=0.2),
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)
transform_only_mask = A.Compose(
    [
        A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
        ToTensorV2(),
    ]
)
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
class MapDataset(Dataset):
    def __init__(self, root_dir):
        self.root_dir = root_dir
        self.list_files = os.listdir(self.root_dir)
    def __len__(self):
        return len(self.list_files)
    def __getitem__(self, index):
        img_file = self.list_files[index]
        img_path = os.path.join(self.root_dir, img_file)
        image = np.array(Image.open(img_path))
        input_image = image[:, :600, :]
        target_image = image[:, 600:, :]
        augmentations = both_transform(image=input_image, image0=target_image)
        input_image = augmentations['image']
        target_image = augmentations['image0']
        input_image = transform_only_input(image=input_image)['image']
        target_image = transform_only_mask(image=target_image)['image']
        return input_image, target_image
import torch
from torchvision.utils import save_image
def save_some_examples(gen, val_loader, epoch, folder):
    x, y = next(iter(val_loader))
    x, y = x.to(DEVICE), y.to(DEVICE)
    gen.eval()
    with torch.no_grad():
        y_fake = gen(x)
        y_fake = y_fake * 0.5 + 0.5  # remove normalization#
        save_image(y_fake, folder + f"/y_gen_{epoch}.png")
        save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
        if epoch == 1:
            save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
    gen.train()
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
    print("=> Saving checkpoint")
    checkpoint = {
        "state_dict": model.state_dict(),
        "optimizer": optimizer.state_dict(),
    }
    torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
    print("=> Loading checkpoint")
    checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
    model.load_state_dict(checkpoint["state_dict"])
    optimizer.load_state_dict(checkpoint["optimizer"])
    # If we don't do this then it will just have learning rate of old checkpoint
    # and it will lead to many hours of debugging \:
    for param_group in optimizer.param_groups:
        param_group["lr"] = lr
!pip install --quiet adabelief-pytorch
from adabelief_pytorch import AdaBelief
import torch
#from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
# import config
# from dataset import MapDataset
# from generator_model import Generator
# from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
def train_model(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler):
    loop = tqdm(loader, leave=True)
    for idx, (x, y) in enumerate(loop):
        x, y = x.to(DEVICE), y.to(DEVICE)
        # Train Discriminator with automatic mixed precision
        with torch.cuda.amp.autocast():
            y_fake = gen(x)
            D_real = disc(x, y)
            D_fake = disc(x, y_fake.detach())
            D_real_loss = bce(D_real, torch.ones_like(D_real))
            D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
            D_loss = (D_real_loss + D_fake_loss) / 2
        opt_disc.zero_grad()
        d_scaler.scale(D_loss).backward()
        d_scaler.step(opt_disc)
        d_scaler.update()
        # Train Generator with automatic mixed precision
        with torch.cuda.amp.autocast():
            D_fake = disc(x, y_fake)
            G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
            L1 = l1(y_fake, y) * L1_LAMBDA
            G_loss = G_fake_loss + L1
        opt_gen.zero_grad()
        g_scaler.scale(G_loss).backward()
        g_scaler.step(opt_gen)
        g_scaler.update()
def main():
    disc = Discriminator(in_channels=3).to(DEVICE)
    gen = Generator(in_channels=3).to(DEVICE)
    opt_disc = AdaBelief(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999), print_change_log=False)
    opt_gen = AdaBelief(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999), print_change_log=False)
    BCE = nn.BCEWithLogitsLoss()
    L1_LOSS = nn.L1Loss()
    if LOAD_MODEL:
        load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
        load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)
    train_dataset = MapDataset(TRAIN_DIR)
    train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
    val_dataset = MapDataset(VAL_DIR)
    val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
    # Train with float16 instead of float32
    g_scaler = torch.cuda.amp.GradScaler()
    d_scaler = torch.cuda.amp.GradScaler()
    for epoch in range(NUM_EPOCHS):
        train_model(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)
        if SAVE_MODEL and epoch % 10 == 0:
            save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
            save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)
        save_some_examples(gen, val_loader, epoch, folder='evaluation')
main()
Weight decoupling enabled in AdaBelief Rectification enabled in AdaBelief Weight decoupling enabled in AdaBelief Rectification enabled in AdaBelief => Loading checkpoint => Loading checkpoint
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
%matplotlib inline
# Plot original images, tetra-chrome images, and images colored by the model
num_images = 5
fig, axs = plt.subplots(ncols=2, nrows=num_images, figsize=(10,num_images*3))
first_images = 0
# Loop through axes and plot random images
for axs_row in range(axs.shape[0]):
    
    # set image index
    img_index = first_images
    first_images += 1
    
    # Plot original image
    axs[axs_row][0].imshow(plt.imread(f'evaluation/input_{img_index}.png'))
    axs[axs_row][0].set_xticks([], [])  
    axs[axs_row][0].set_yticks([], [])
    axs[axs_row][0].set_title('Original')
    
    # Plot tetra-chrome image
    axs[axs_row][1].imshow(plt.imread(f'evaluation/y_gen_{img_index}.png'))
    axs[axs_row][1].set_xticks([], [])  
    axs[axs_row][1].set_yticks([], [])
    axs[axs_row][1].set_title('Prediction')
    # # Plot grayscale image (the model input)
    # axs[axs_row][2].imshow(plt.imread(f'evaluation/y_gen_{img_index}.png'))
    # axs[axs_row][2].set_xticks([], [])  
    # axs[axs_row][2].set_yticks([], [])
    # axs[axs_row][2].set_title('Prediction')
plt.tight_layout()
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:15: MatplotlibDeprecationWarning: Passing the minor parameter of set_xticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later. from ipykernel import kernelapp as app /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:16: MatplotlibDeprecationWarning: Passing the minor parameter of set_yticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later. app.launch_new_instance() /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:21: MatplotlibDeprecationWarning: Passing the minor parameter of set_xticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later. /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:22: MatplotlibDeprecationWarning: Passing the minor parameter of set_yticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later.
Matheus Schmitz
LinkedIn
Github Portfolio