Transfering Style!

Original Notebook : https://www.kaggle.com/code/ohseokkim/transfering-style

Introduction

Picture Credit: https://miro.medium.com

Nearal Style Transfer

Neural Style Transfer(NST) refers to a class of software algorithms that manipulate digital images, or videos, in order to adopt the appearance or visual style of another image. NST algorithms are characterized by their use of deep neural networks for the sake of image transformation. Common uses for NST are the creation of artificial artwork from photographs, for example by transferring the appearance of famous paintings to user-supplied photographs. Several notable mobile apps use NST techniques for this purpose, including DeepArt and Prisma. This method has been used by artists and designers around the globe to develop new artwork based on existent style(s).

  Style transfer means that when a content image and a style image are given, the outline and shape of the image are similar to the content image, and the color or texture is changed to be similar to the style image.

  By separating content and style, you can mix content and style of different images.

  A pre-trained VGG19 Net is used as model to extract content and style. It then uses the losses of the content and style to iteractively update the target image until the desired result is achieved.

from PIL import Image
from io import BytesIO
import matplotlib.pyplot as plt
import numpy as np

import torch
import torch.optim as optim
import requests
from torchvision import transforms, models

1. Load in model

  VGG19 is divided into two parts.

  • vgg19.features: All convolutional layers and pooling layers
  • vgg19.classifier: The last three linaer layers are the classifier layer.

  We only need the features part. And “freeze” so that the weight is not updated.

# get the "features" portion of VGG19 (we will not need the "classifier" portion)
with HiddenPrints():
  vgg = models.vgg19(weights="IMAGENET1K_V1").features

  # freeze all VGG parameters since we're only optimizing the target image
  for param in vgg.parameters():
    param.requires_grad_(False)
# move the model to GPU, if available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

vgg.to(device)
#> Sequential(
#>   (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (1): ReLU(inplace=True)
#>   (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (3): ReLU(inplace=True)
#>   (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#>   (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (6): ReLU(inplace=True)
#>   (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (8): ReLU(inplace=True)
#>   (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#>   (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (11): ReLU(inplace=True)
#>   (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (13): ReLU(inplace=True)
#>   (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (15): ReLU(inplace=True)
#>   (16): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (17): ReLU(inplace=True)
#>   (18): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#>   (19): Conv2d(256, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (20): ReLU(inplace=True)
#>   (21): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (22): ReLU(inplace=True)
#>   (23): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (24): ReLU(inplace=True)
#>   (25): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (26): ReLU(inplace=True)
#>   (27): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#>   (28): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (29): ReLU(inplace=True)
#>   (30): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (31): ReLU(inplace=True)
#>   (32): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (33): ReLU(inplace=True)
#>   (34): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
#>   (35): ReLU(inplace=True)
#>   (36): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
#> )

2. Load in Content and Style Images

  Load the content image and style image to be used for style transfer. The load_image function transforms the image and loads it in the form of normalized Tensors.

def load_image(img_path, max_size = 128, shape = None):
  '''
  Load in and transform an image, making sure the image is <= 400 pixels in the x-y dims.
  '''

  if "http" in img_path:
    response = requests.get(img_path)
    image = Image.open(BytesIO(response.content)).convert('RGB')
  else:
    image = Image.open(img_path).convert('RGB')

  # large images will slow down processing
  if max(image.size) > max_size:
    size = max_size
  else:
    size = max(image.size)

  if shape is not None:
    size = shape

  in_transform = transforms.Compose([
    transforms.Resize(size),
    transforms.ToTensor(),
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225))
  ])

  # discard the transparent, alpha channel (that's the :3) and add the batch dimension
  image = in_transform(image)[:3, :, :].unsqueeze(0)

  return image
# load in content and style image
content = load_image("./data/856047.jpg").to(device)
# Resize style to match content, makes code easier
style = load_image("./data/starry_night.jpg").to(device)
# helper function for un-normalizing an image
# and converting it from a Tensor image to a NumPy image for dispaly
def im_convert(tensor):
  """
  Display a tensor as an image.
  """

  image = tensor.to("cpu").clone().detach()
  image = image.numpy().squeeze()
  image = image.transpose(1, 2, 0)
  image = image * np.array((0.229, 0.224, 0.225)) + np.array((0.485, 0.456, 0.406))
  image = image.clip(0, 1)

  return image
#display the images
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 10))
# content and style ims side-by-side
ax1.imshow(im_convert(content))
ax2.imshow(im_convert(style))

plt.show()

def get_features(image, model, layers = None):
  """
  Run an image forward through a model and get the features for a set of layers. Default layers are for
  VGGNet matching Gatys et al (2016)
  """

  if layers is None:
    layers = {'0' : 'conv1_1',
              '5' : 'conv2_1',
              '10': 'conv3_1',
              '19': 'conv4_1',
              '21': 'conv4_2', ## content representation
              '28': 'conv5_1'}

  features = {}
  x = image
  # model._modules is a dictionary holding each module in the model
  for name, layer in model._modules.items():
    x = layer(x)
    if name in layers:
      features[layers[name]] = x

  return features

3. Gram Matrix

Picture Credit: https://miro.medium.com

The matrix expressing the correlation of this Channel is called Gram Matrix. Loss is minimized by definding the difference between this Gram Matrix and the Gram Matrix of the newly created image as a Loss Function. Next, in order to reflect the content, the loss function is calculated in units of pixels from the feature map spit out from each pre-trained CNN. In this way, a new image is created that minimizes the Loss calculated from Style and Loss calculated from Content.

https://en.wikipedia.org/wiki/Gram_matrix

def gram_matrix(tensor):
  """
  Calculate the Gram Matrix of a given tensor
  """

  # get the batch_size, depth, height, and width of the Tensor
  _, d, h, w = tensor.size()

  # reshape so we're multiplying the features for each channel
  tensor = tensor.view(d, h * w)

  # calculate the gram matrix
  gram = torch.mm(tensor, tensor.t())

  return gram

  The function that extracts the features of a given convolutional layer and computes the Gram Matrix is made. Putting it all together, we extract the features from the image and compute the Gram Matrix for each layer from the style representation.

# get content and style features only once before training
content_features = get_features(content, vgg)
style_features = get_features(style, vgg)

# calculate the gram matrices for each layer of our style representation
style_grams = {layer: gram_matrix(style_features[layer]) for layer in style_features}

# create a third "target"image and prep it for change
# it is a good idea to start off with the target as a copy of our *content* image
# then iteratively change its style
target = content.clone().requires_grad_(True).to(device)

4. Define Losses and Weights

Individual Layer Style Weights

  You can give the option to weight the style expression in each relevant layer. It is recommended that the layer weight range from 0 to 1. By giving more weight to conv1_1 and conv2_1, more style artifacts can be reflected in the final target image.

Content and Style Weight

  Define alpha (content_weight) and beta (style_weight). This ratio affaects the style of the final image. It is recommended to leave content_weight = 1 and set style_weight to achieve the desired ratio.

# weights for each style layer
# weighting ealier layers more will result in *larger* style artifacts
# notice we are excluding `conv4_2` our content representation
style_weights = {'conv1_1' : 1,
                 'conv2_1' : 0.75,
                 'conv3_1' : 0.2,
                 'conv4_1' : 0.2,
                 'conv5_1' : 0.2}

content_weight = 1 # alpha
style_weight = 1e3 # beta

5. Update Target and Calculate Losses

Content Loss

  The Content loss is calculated as the MSE between the target and the content feature in the ‘conv4_2’ layer.

Style Loss

  The style loss is the loss between the target image and the style image. That is, it refers to the difference between the gram matrix of the style image and the gram matrix of the target image. Loss is calculated using MSE

Total Loss

Fianlly, the total loss is calculated by summing the style and content losses and weighting them with the specified alpha and beta.

# for displaying the target image, intermittently
show_every = 500

# iteration gyperparameters
optimizer = optim.Adam([target], lr = 0.003)
steps = 5001 # decide how many iterations to update your image (5000)

for ii in range(1, steps+1):

  # get the features from your target image
  target_features = get_features(target, vgg)

  # the content loss
  content_loss = torch.mean((target_features['conv4_2'] - content_features['conv4_2'])**2)

  # the style loss
  # initialize the style loss to 0
  style_loss = 0
  # then add to it for each layer's gram matrix loss
  for layer in style_weights:
    # get the "target" style representation for the layer
    target_feature = target_features[layer]
    target_gram    = gram_matrix(target_feature)
    _, d, h, w = target_feature.shape
    # get the "style" style representation
    style_gram = style_grams[layer]
    # the style loss for one layer, weighted appropriately
    layer_style_loss = style_weights[layer] * torch.mean((target_gram - style_gram) ** 2)
    # add to the style loss
    style_loss += layer_style_loss / (d * h * w)

  # calculate the *total* loss
  total_loss = content_weight * content_loss + style_weight * style_loss

  # update your target image
  optimizer.zero_grad()
  total_loss.backward()
  optimizer.step()

  # display intermediate images and print the loss
  if ii % show_every == 0:
    print('Total loss: ', total_loss.item())
    plt.imshow(im_convert(target))
    plt.show()
#> Total loss:  812.097900390625
#> <matplotlib.image.AxesImage object at 0x7fe0cdeabd90>
#> Total loss:  498.1385498046875
#> <matplotlib.image.AxesImage object at 0x7fe0cd537e80>
#> Total loss:  365.70355224609375
#> <matplotlib.image.AxesImage object at 0x7fe0cded2620>
#> Total loss:  290.7293395996094
#> <matplotlib.image.AxesImage object at 0x7fe0cd5ceb00>
#> Total loss:  243.23556518554688
#> <matplotlib.image.AxesImage object at 0x7fe0cd4494e0>
#> Total loss:  210.29429626464844
#> <matplotlib.image.AxesImage object at 0x7fe0cd4d0700>
#> Total loss:  186.3564910888672
#> <matplotlib.image.AxesImage object at 0x7fe0cd34f0a0>
#> Total loss:  168.50942993164062
#> <matplotlib.image.AxesImage object at 0x7fe0cd3afa00>
#> Total loss:  154.83030700683594
#> <matplotlib.image.AxesImage object at 0x7fe0cd49f9a0>
#> Total loss:  143.84054565429688
#> <matplotlib.image.AxesImage object at 0x7fe0cd34c1c0>

6. Check the last result

# display content and final, target image
fig, (ax1, ax2) = plt.subplots(1, 2, figsize = (20, 10))
ax1.imshow(im_convert(content))
#> <matplotlib.image.AxesImage object at 0x7fe0cd47a4d0>
ax2.imshow(im_convert(target))
#> <matplotlib.image.AxesImage object at 0x7fe0cd47bd30>
plt.show()

이전