Sketch Image Translation
Sketch Image Translation
1
the discriminator. Our discriminator now has an output of corresponds to the realness of different patches of the input
2N classes: N classes for real images, and N classes for gen- image. This out-of-the-box implementation was used as a
erated images. Our hypothesis is that by utilizing this ex- baseline network to compare against our own models.
tra information, our network will calculate a more nuanced We modify the “Image-to-Image” [2] discriminator de-
loss, and then be able to more efficiently improve both the scribed above by adding a fully-connected layer to the end
generator and the discriminator, leading to a more effective of the network, which outputs a 2N-dimensional vector of
GAN. logits. By including N fake classes in our output, rather
The “Image-to-Image” paper [2] presents a novel way to than a single fake class as described in “Improved Tech-
train a conditional GAN as a solution for image-to-image niques” [6], we increase the discriminator’s power to learn
translation. Specifically, the network uses a generator to lower level features that differentiate between fake images
first encode an image to a high-level representation, and of different objects. In contrast, having only a single class
subsequently decode the representation into a generated im- that represents fake images of all object categories forces
age. By training the cGAN on input-output image pairs, one the discriminator to look for high level features shared by
can train the generator to create images that are, in theory, all generated images that indicate an image is fake. Our
indistinguishable from the given output population. Our discriminator has the following architecture:
novel approach relies heavily on the U-Net architecture of
the generator and the convolutional layer architecture of the C128-C256-C512-C1-FC125
discriminator proposed in this paper. The 2N-dimensional output vector has the form:
3.1. Architecture where R denotes a real object class, F denotes a fake ob-
ject class, and N is the number of classes in our dataset. In
We use a pre-existing GAN implementation provided this formulation, a class represents a real or fake photo of a
by the authors of “Image-to-Image” [2] as the basis for particular type of object.
our model. The generator has two components: an en- These logits can be turned into class probabilities using
coder component, which takes the given sketch s and down- a softmax:
samples it to create a lower dimensional representation
φ(s), and a decoder layer, which takes a vector containing exp(lj )
pmodel (y = j|x) = P2N
φ(s) and produces an image. The generator contains skip i=1 exp(li )
connections between the ith layer of the decoder and layer
8 - i of the encoder. The architecture for the generator is as We use these class probabilities to calculate our 2N loss and
follows: penalty loss, which we describe in the following sections.
3.2. 2N Cross Entropy Loss
• Encoder:
C64-C128-C256-C512-C512-C512-C512-C512 We will first discuss the 2N cross entropy loss function,
which is the simpler of the two loss functions used in our ex-
• Decoder: periments. This 2N cross entropy loss function is inspired
CD512-CD512-CD512-C512-C512-C256-C128-C64 by the supervised component of the N+1 loss function out-
lined in the introduction and proposed in “Improved Tech-
where C stands for convolution and CD stands for a de-
niques” [6]. The discriminator loss LD contains two terms.
convolution. All of the convolutions use 4x4 spatial filters
The first term is a cross entropy loss for a real image x and
applied with stride 2. Convolutions in the encoder down-
sketch s pair taken from our training data distribution pdata
sample by a factor of 2 and in the decoder convolutions up-
with ground truth class y and target class y. The second
sample by a factor of 2. Leaky ReLU activation functions
term is a cross entropy loss for the image G(s) generated
with a leak of 0.2 are used between layers in the encoder
from a sketch s with ground truth class y and target class y.
and standard ReLU activations are used between layers in
The loss LD is described by the following equation:
the decoder.
All of the convolutions in the discriminator use 4x4 spa- LD = −(Ex,s,y∼pdata (x,s,y) [log pmodel (y|x, s, y ≤ N )]
tial filters with a stride of 2 except for the final layer, which + Es,y∼pdata (s,y) [log pmodel (y|G(s), s, N < y ≤ 2N )])
uses a stride of 1. Leaky ReLU activations with a leak of (1)
0.2 are used in between the convolutional layers. And both
the generator and discriminator are trained using the Adam Similarly, the generator loss LG contains two terms. The
update rule [3] with a learning rate of 0.0002 and momen- first term is a cross entropy loss for the image G(s) gen-
tum of 0.5. The discriminator produces a 30x30 output that erated from a sketch s with ground truth class y and target
2
class y − N . The target class is y − N in this case because
the generator wants the image G(s) to be classified as a real
image of the object depicted in sketch s and y − N is the
index of that class in the output vector. The second term
in this loss is the L1 distance between the generated image
G(S) and the ground truth image x weighted by a hyper
parameter λ. This L1 term encourages the generator to pro-
duce images that are close to the ground truth photo. The
loss LG is given by the equation:
LG = −Es,y∼pdata (s,y) [log pmodel (y − N |G(s), s, N < y ≤ 2N )]
+ λLL1 (G) Figure 1: Examples of sketch-photo pairs. The bottom row
(2) displays examples of photos cropped using the segmenta-
tion mask.
3.3. Penalty Loss
The 2N cross entropy loss makes use of our 2N-
dimensional output; however, it does not take into account 4. Experiment
much of the additional information provided by the 2N rep-
4.1. Dataset
resentation. For instance, it doesn’t differentiate between a
misclassification of the object category from a misclassi- We used the Sketchy Database 1 , a large-scale collection
fication of realness. The penalty loss aims to make use of of sketch-photo pairs created by Georgia Tech to perform
this additional information by weighting the cross entropy image retrieval using deep learning. This database contains
terms used in the 2N losses by constant penalty values, 12,500 images from a subset of 125 categories from Ima-
which vary depending on the type of misclassification. For genet. The creators asked participants on Amazon Mechan-
a class prediction ŷ with target class y our penalty function ical Turk to sketch the target object in the images, so that
pen(y, ŷ) is as follows: each image ended up with about 5 hand-drawn sketches for
a total of 75,471 sketches in the final dataset. We eliminated
a; obj(y) = obj(ŷ), is-fake(y) 6= is-fake(ŷ)
10,918 sketches that the creators had marked as ambigu-
pen(y, ŷ) = b; obj(y) 6= obj(ŷ), is-fake(y) = is-fake(ŷ) ous, erroneous, having an incorrect pose, or including envi-
c; obj(y) 6= obj(ŷ), is-fake(y) 6= is-fake(ŷ) ronment details. Our final training size was 43,020 sketch-
photo pairs.
where obj() returns the type of object represented by
the given class, is-fake() determines if the given class 4.2. Image Segmentation
represents a fake image, and a, b, c are hyper parameters
During preliminary testing of our cGAN sketch-to-photo
that can be chosen in cross validation. Using this penalty
network, we noticed a consistent issue with our output im-
function we define our discriminator loss LD to be:
ages. As our image output population is comprised en-
LD = −(Ex,s,y∼pdata (x,s,y) [log pmodel (y|x, s, y ≤ N )] tirely of photographs, the images often have cluttered back-
× pen(y, ŷ) grounds. We surmised that often, our generator is learning
+ Es,y∼pdata (s,y) [log pmodel (y|G(s), s, N < y ≤ 2N )] to emulate the background instead of focusing on the re-
quested object. In the class of airplane, this background
× pen(y, ŷ))
emulation is not a problem, as here, most photo back-
(3) grounds are blue and uniform. The background becomes
Our generator also weights the cross entropy term by the a greater issue in classes such as eyeglasses, where the im-
output of the penalty function and is given by the equation age is cluttered with faces, hair, and other distracting ele-
below. Note that we pass y − N into the penalty function ments. We hypothesized that by cropping our image set to
as the target class for the generated image G(s) because the only include the key object, we will see a much higher qual-
generator wants the image to be classified as a real image of ity in the generated images.
the object depicted in sketch. In order to create a segmentation mask for our dataset,
LG = we adapted the findings proposed in “Fully Convolutional
Networks for Semantic Segmentation” [4], using models
− Es,y∼pdata (s,y) [log pmodel (y − N |G(s), s, N < y ≤ 2N )] created in “Deep Residual Learning for Instrument Segmen-
× pen(y − N, ŷ) + λLL1 (G) tation in Robotic Surgery” [5]. This allowed us to utilize
(4)
1 http://sketchy.eye.gatech.edu/
3
Figure 2: Images generated during training of penalty loss model. Successful generations on the left, unsuccessful generations
on the right.
4
Model Accuracy Model Top 1 Top 5
2N Loss (50k steps) 26.98% Ground Truth Photos 71.90% 79.04%
Penalty Loss (50k steps) 29.09% Baseline Model 0.83% 2.36%
Penalty Loss (134k steps) 10.26% Class Conditional Generator 1.05% 3.13%
2N Loss Model 0.48% 1.90%
Penalty Model 0.85% 2.44%
Table 1: Accuracy when classifying validation photos using Segmented Photos 40.58% 60.51%
the standalone discriminator. Trained on Segmented Photos 1.99% 4.42%
5
[6] T. Salimans, I. Goodfellow, W. Zaremba, V. Cheung, A. Rad-
ford, and X. Chen. Improved techniques for training gans.
arXiv preprint arXiv:1606.03498, 2016.
[7] C. Szegedy, V. Vanhoucke, S. Ioffe, J. Shlens, and Z. Wojna.
Rethinking the inception architecture for computer vision. In
Proceedings of the IEEE Conference on Computer Vision and
Pattern Recognition, pages 2818–2826, 2016.
[8] D. Warde-Farley and Y. Bengio. Improving generative adver-
sarial networks with denoising feature matching. In Proceed-
ings of the International Conference on Learning Representa-
tions (ICLR), 2017.
[9] J.-Y. Zhu, P. Krähenbühl, E. Shechtman, and A. A. Efros.
Generative visual manipulation on the natural image mani-
fold. In European Conference on Computer Vision, pages
597–613. Springer, 2016.
6
Figure 3: Examples of inputs and outputs of the various models. Each row corresponds to a sketch. The columns, from left
to right, correspond to: 1. Input Sketches; 2. Target Photos; 3. Segmented Target Photos; 4. Baseline Model; 5. Class
Conditional Generator; 6. 2N Loss Model; 7. Penalty Loss Model; 8. Trained on Segmented Photos.