Matheus Schmitz
LinkedIn
Github Portfolio
Using an image-to-image GAN to convert satellite imagery to street maps.
Based on the Pix2Pix architecture: https://arxiv.org/abs/1611.07004
# File manipulation imports for Google Colab
from google.colab import drive
drive.mount('/content/drive')
import os
os.chdir("/content/drive/My Drive/Colab Notebooks/Satellite_to_Map_GAN/")
Mounted at /content/drive
import torch
from torch import nn
class DiscBlock(nn.Module):
def __init__(self, in_channels, out_channels, stride):
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=stride, bias=False, padding_mode='reflect'),
nn.BatchNorm2d(out_channels),
nn.PReLU()
)
def forward(self, x):
return self.conv(x)
class Discriminator(nn.Module):
def __init__(self, in_channels=3, features=[64, 128, 256, 512]):
super().__init__()
# First block doesn't use batch norm
self.initial = nn.Sequential(
nn.Conv2d(in_channels*2, features[0], kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
nn.PReLU()
)
# All remainder blocks have batch norm
layers = []
for idx in range(len(features)-1):
layers.append(
DiscBlock(in_channels=features[idx], out_channels=features[idx+1],
stride = 1 if features[idx+1] == features[-1] else 2),
)
self.backbone = nn.Sequential(*layers)
self.final = nn.Conv2d(in_channels=features[-1], out_channels=1, kernel_size=4,
stride=1, padding=1, padding_mode='reflect')
def forward(self, x, y):
x = torch.cat([x,y], dim=1)
x = self.initial(x)
x = self.backbone(x)
return self.final(x)
def test():
x = torch.randn((1, 3, 256, 256))
y = torch.randn((1, 3, 256, 256))
model = Discriminator()
preds = model(x, y)
print(preds.shape)
test()
torch.Size([1, 1, 26, 26])
import torch
from torch import nn
class GenBLock(nn.Module):
def __init__(self, in_channels, out_channels, encoder=True, use_dropout=False):
'''
encoder = True → Encoder
encoder = False → Decoder
'''
super().__init__()
self.conv = nn.Sequential(
nn.Conv2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False, padding_mode='reflect')
if encoder
else nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=2, padding=1, bias=False),
nn.BatchNorm2d(out_channels),
nn.PReLU()
)
self.use_dropout = use_dropout
self.dropout = nn.Dropout(0.5)
def forward(self, x):
x = self.conv(x)
return self.dropout(x) if self.use_dropout else x
class Generator(nn.Module):
def __init__(self, in_channels=3, features=64):
super().__init__()
# Encoder
self.encoder1 = nn.Sequential(
nn.Conv2d(in_channels, out_channels=features, kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
nn.PReLU()
) # 128 x 128
self.encoder2 = GenBLock(features , features*2, encoder=True, use_dropout=False) # 64 x 64
self.encoder3 = GenBLock(features*2, features*4, encoder=True, use_dropout=False) # 32 x 32
self.encoder4 = GenBLock(features*4, features*8, encoder=True, use_dropout=False) # 16 x 16
self.encoder5 = GenBLock(features*8, features*8, encoder=True, use_dropout=False) # 8 x 8
self.encoder6 = GenBLock(features*8, features*8, encoder=True, use_dropout=False) # 4 x 4
self.encoder7 = GenBLock(features*8, features*8, encoder=True, use_dropout=False) # 2 x 2
# Blottleneck
self.bottleneck = nn.Sequential(
nn.Conv2d(features*8, features*8, kernel_size=4, stride=2, padding=1, padding_mode='reflect'),
nn.PReLU()
) # 1 x 1
# Decoder
'''
Uses skip-connection from the matching encoder layer,
hence input will be features*2
'''
self.decoder1 = GenBLock(features*8 , features*8, encoder=False, use_dropout=True) # 2 x 2
self.decoder2 = GenBLock(features*8*2, features*8, encoder=False, use_dropout=True) # 4 x 4
self.decoder3 = GenBLock(features*8*2, features*8, encoder=False, use_dropout=True) # 8 x 8
self.decoder4 = GenBLock(features*8*2, features*8, encoder=False, use_dropout=False) # 16 x 16
self.decoder5 = GenBLock(features*8*2, features*4, encoder=False, use_dropout=False) # 32 x 32
self.decoder6 = GenBLock(features*4*2, features*2, encoder=False, use_dropout=False) # 64 x 64
self.decoder7 = GenBLock(features*2*2, features , encoder=False, use_dropout=False) # 128 x 128
# Output
self.output = nn.Sequential(
nn.ConvTranspose2d(features*2, in_channels, kernel_size=4, stride=2, padding=1),
nn.Tanh()
) # 256 x 256
def forward(self, x):
e1 = self.encoder1(x)
e2 = self.encoder2(e1)
e3 = self.encoder3(e2)
e4 = self.encoder4(e3)
e5 = self.encoder5(e4)
e6 = self.encoder6(e5)
e7 = self.encoder7(e6)
bottleneck = self.bottleneck(e7)
d1 = self.decoder1(bottleneck)
d2 = self.decoder2(torch.cat([d1, e7], dim=1))
d3 = self.decoder3(torch.cat([d2, e6], dim=1))
d4 = self.decoder4(torch.cat([d3, e5], dim=1))
d5 = self.decoder5(torch.cat([d4, e4], dim=1))
d6 = self.decoder6(torch.cat([d5, e3], dim=1))
d7 = self.decoder7(torch.cat([d6, e2], dim=1))
return self.output(torch.cat([d7, e1], dim=1))
def test():
x = torch.randn((1, 3, 256, 256))
model = Generator(in_channels=3, features=64)
preds = model(x)
print(preds.shape)
test()
torch.Size([1, 3, 256, 256])
!pip install --quiet Pillow albumentations --upgrade --force-reinstall
ERROR: pip's dependency resolver does not currently take into account all the packages that are installed. This behaviour is the source of the following dependency conflicts.
thinc 8.0.17 requires typing-extensions<4.2.0,>=3.7.4.1; python_version < "3.8", but you have typing-extensions 4.3.0 which is incompatible.
spacy 3.3.1 requires typing-extensions<4.2.0,>=3.7.4; python_version < "3.8", but you have typing-extensions 4.3.0 which is incompatible.
datascience 0.10.6 requires folium==0.2.1, but you have folium 0.8.3 which is incompatible.
import torch
import albumentations as A
from albumentations.pytorch import ToTensorV2
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
TRAIN_DIR = "data/maps/train"
VAL_DIR = "data/maps/val"
LEARNING_RATE = 2e-4
BATCH_SIZE = 16
NUM_WORKERS = 2
IMAGE_SIZE = 256
CHANNELS_IMG = 3
L1_LAMBDA = 100
LAMBDA_GP = 10
NUM_EPOCHS = 5
LOAD_MODEL = True
SAVE_MODEL = True
CHECKPOINT_DISC = "disc.pth.tar"
CHECKPOINT_GEN = "gen.pth.tar"
both_transform = A.Compose(
[A.Resize(width=256, height=256), A.HorizontalFlip(p=0.5),], additional_targets={"image0": "image"},
)
transform_only_input = A.Compose(
[
A.ColorJitter(p=0.2),
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
ToTensorV2(),
]
)
transform_only_mask = A.Compose(
[
A.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5], max_pixel_value=255.0,),
ToTensorV2(),
]
)
from torch.utils.data import Dataset
from PIL import Image
import numpy as np
import os
class MapDataset(Dataset):
def __init__(self, root_dir):
self.root_dir = root_dir
self.list_files = os.listdir(self.root_dir)
def __len__(self):
return len(self.list_files)
def __getitem__(self, index):
img_file = self.list_files[index]
img_path = os.path.join(self.root_dir, img_file)
image = np.array(Image.open(img_path))
input_image = image[:, :600, :]
target_image = image[:, 600:, :]
augmentations = both_transform(image=input_image, image0=target_image)
input_image = augmentations['image']
target_image = augmentations['image0']
input_image = transform_only_input(image=input_image)['image']
target_image = transform_only_mask(image=target_image)['image']
return input_image, target_image
import torch
from torchvision.utils import save_image
def save_some_examples(gen, val_loader, epoch, folder):
x, y = next(iter(val_loader))
x, y = x.to(DEVICE), y.to(DEVICE)
gen.eval()
with torch.no_grad():
y_fake = gen(x)
y_fake = y_fake * 0.5 + 0.5 # remove normalization#
save_image(y_fake, folder + f"/y_gen_{epoch}.png")
save_image(x * 0.5 + 0.5, folder + f"/input_{epoch}.png")
if epoch == 1:
save_image(y * 0.5 + 0.5, folder + f"/label_{epoch}.png")
gen.train()
def save_checkpoint(model, optimizer, filename="my_checkpoint.pth.tar"):
print("=> Saving checkpoint")
checkpoint = {
"state_dict": model.state_dict(),
"optimizer": optimizer.state_dict(),
}
torch.save(checkpoint, filename)
def load_checkpoint(checkpoint_file, model, optimizer, lr):
print("=> Loading checkpoint")
checkpoint = torch.load(checkpoint_file, map_location=DEVICE)
model.load_state_dict(checkpoint["state_dict"])
optimizer.load_state_dict(checkpoint["optimizer"])
# If we don't do this then it will just have learning rate of old checkpoint
# and it will lead to many hours of debugging \:
for param_group in optimizer.param_groups:
param_group["lr"] = lr
!pip install --quiet adabelief-pytorch
from adabelief_pytorch import AdaBelief
import torch
#from utils import save_checkpoint, load_checkpoint, save_some_examples
import torch.nn as nn
import torch.optim as optim
# import config
# from dataset import MapDataset
# from generator_model import Generator
# from discriminator_model import Discriminator
from torch.utils.data import DataLoader
from tqdm import tqdm
from torchvision.utils import save_image
def train_model(disc, gen, loader, opt_disc, opt_gen, l1, bce, g_scaler, d_scaler):
loop = tqdm(loader, leave=True)
for idx, (x, y) in enumerate(loop):
x, y = x.to(DEVICE), y.to(DEVICE)
# Train Discriminator with automatic mixed precision
with torch.cuda.amp.autocast():
y_fake = gen(x)
D_real = disc(x, y)
D_fake = disc(x, y_fake.detach())
D_real_loss = bce(D_real, torch.ones_like(D_real))
D_fake_loss = bce(D_fake, torch.zeros_like(D_fake))
D_loss = (D_real_loss + D_fake_loss) / 2
opt_disc.zero_grad()
d_scaler.scale(D_loss).backward()
d_scaler.step(opt_disc)
d_scaler.update()
# Train Generator with automatic mixed precision
with torch.cuda.amp.autocast():
D_fake = disc(x, y_fake)
G_fake_loss = bce(D_fake, torch.ones_like(D_fake))
L1 = l1(y_fake, y) * L1_LAMBDA
G_loss = G_fake_loss + L1
opt_gen.zero_grad()
g_scaler.scale(G_loss).backward()
g_scaler.step(opt_gen)
g_scaler.update()
def main():
disc = Discriminator(in_channels=3).to(DEVICE)
gen = Generator(in_channels=3).to(DEVICE)
opt_disc = AdaBelief(disc.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999), print_change_log=False)
opt_gen = AdaBelief(gen.parameters(), lr=LEARNING_RATE, betas=(0.5, 0.999), print_change_log=False)
BCE = nn.BCEWithLogitsLoss()
L1_LOSS = nn.L1Loss()
if LOAD_MODEL:
load_checkpoint(CHECKPOINT_GEN, gen, opt_gen, LEARNING_RATE)
load_checkpoint(CHECKPOINT_DISC, disc, opt_disc, LEARNING_RATE)
train_dataset = MapDataset(TRAIN_DIR)
train_loader = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, num_workers=NUM_WORKERS)
val_dataset = MapDataset(VAL_DIR)
val_loader = DataLoader(val_dataset, batch_size=1, shuffle=True)
# Train with float16 instead of float32
g_scaler = torch.cuda.amp.GradScaler()
d_scaler = torch.cuda.amp.GradScaler()
for epoch in range(NUM_EPOCHS):
train_model(disc, gen, train_loader, opt_disc, opt_gen, L1_LOSS, BCE, g_scaler, d_scaler)
if SAVE_MODEL and epoch % 10 == 0:
save_checkpoint(gen, opt_gen, filename=CHECKPOINT_GEN)
save_checkpoint(disc, opt_disc, filename=CHECKPOINT_DISC)
save_some_examples(gen, val_loader, epoch, folder='evaluation')
main()
Weight decoupling enabled in AdaBelief Rectification enabled in AdaBelief Weight decoupling enabled in AdaBelief Rectification enabled in AdaBelief => Loading checkpoint => Loading checkpoint
import matplotlib.image as mpimg
import matplotlib.pyplot as plt
%matplotlib inline
# Plot original images, tetra-chrome images, and images colored by the model
num_images = 5
fig, axs = plt.subplots(ncols=2, nrows=num_images, figsize=(10,num_images*3))
first_images = 0
# Loop through axes and plot random images
for axs_row in range(axs.shape[0]):
# set image index
img_index = first_images
first_images += 1
# Plot original image
axs[axs_row][0].imshow(plt.imread(f'evaluation/input_{img_index}.png'))
axs[axs_row][0].set_xticks([], [])
axs[axs_row][0].set_yticks([], [])
axs[axs_row][0].set_title('Original')
# Plot tetra-chrome image
axs[axs_row][1].imshow(plt.imread(f'evaluation/y_gen_{img_index}.png'))
axs[axs_row][1].set_xticks([], [])
axs[axs_row][1].set_yticks([], [])
axs[axs_row][1].set_title('Prediction')
# # Plot grayscale image (the model input)
# axs[axs_row][2].imshow(plt.imread(f'evaluation/y_gen_{img_index}.png'))
# axs[axs_row][2].set_xticks([], [])
# axs[axs_row][2].set_yticks([], [])
# axs[axs_row][2].set_title('Prediction')
plt.tight_layout()
/usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:15: MatplotlibDeprecationWarning: Passing the minor parameter of set_xticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later. from ipykernel import kernelapp as app /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:16: MatplotlibDeprecationWarning: Passing the minor parameter of set_yticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later. app.launch_new_instance() /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:21: MatplotlibDeprecationWarning: Passing the minor parameter of set_xticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later. /usr/local/lib/python3.7/dist-packages/ipykernel_launcher.py:22: MatplotlibDeprecationWarning: Passing the minor parameter of set_yticks() positionally is deprecated since Matplotlib 3.2; the parameter will become keyword-only two minor releases later.
Matheus Schmitz
LinkedIn
Github Portfolio