The model is a CNN designed to The model is a Convolutional Neural Network designed to predict the ‘a’ and ‘b’ channels of the Lab color space given the ‘L’ channel. Previous works in this area treat the problem of colouring as a regression task and use the mean squared error loss function. They aim to predict the exact values of the ‘a’ and ‘b’ channels for every pixel.
However, the model in the paper treats the problem as a classification task and uses a custom loss function which is a variation of categorical crossentropy loss. They quantise the ‘a’ and ‘b’ channels into bins and predict the probability of each bin for every pixel. This approach is more robust and produces more realistic and visually appealing results.
I train a CNN to map from a grayscale input to a distribution over quantised colour value outputs using the architecture shown below.
def quantize(ab_img):
return np.floor_divide(ab_img, 10) * 10
def set_possible_colors():
height, width = 256, 256
ab_to_possible_color = {}
all_images = os.listdir("images")
color_index = 0
for image in all_images:
rgb_image = np.array(
Image.open(f"images/{image}")
.convert("RGB"))
ab_img = extract_ab(rgb_image)
quantized_image = quantize_(ab_img)
for y in range(height):
for x in range(width):
a = quantized_ab_image[y][x][0]
b = quantized_ab_image[y][x][1]
ck = get_color_key(a, b)
if ck not in ab_to_possible_color:
ab_to_possible_color[ck] = color_index
color_index += 1
if save_files:
pickle.dump(ab_to_possible_color,
open("ab_to_q_index_dict.p", "wb"))
def ab_to_z(ab_img, Q, sigma=5):
w, h = ab_img.shape[0], ab_img.shape[1]
pts = w * h
ab_img_flat = np.reshape(ab_img, (pts, 2))
pts_flat = np.empty((pts, Q))
pts_ind = np.arange(0, pts)
nbrs = NearestNeighbors(
n_neighbors=5, algorithm='ball_tree')
.fit(get_ab_domain())
distances, _ = nbrs.kneighbors(ab_img_flat)
gaussian_kernel = np.exp(
-distances**2 / (2 * sigma**2)
)
gaussian_kernel /= np.sum(gaussian_kernel,1)
pts_flat[pts_ind, ind] = gaussian_kernel
z = np.reshape(pts_flat, (w, h, Q))
return z
def multinomial_crossentropy_loss(Z, Z_hat):
Q = len(get_ab_to_q_dict())
p = pickle.load(open("p_10000.p", "rb"))
p_tilde = gaussian_filter(p, sigma=5)
eps = 0.0001
weights = get_loss_weights(Z, Q, p_tilde)
log = tf.math.log(Z_hat + eps)
mul = log * Z
summ = tf.reduce_sum(mul, 1)
loss = -tf.reduce_sum(weights * summ)
return loss
def ColourNet():
input_tensor = Input(shape=(256, 256, 1))
x = Conv2D(64, (3, 3),'relu', 'same',
"he_normal", l2_reg)(input_tensor)
x = Conv2D(64, (3, 3),'relu', 'same',
"he_normal", l2_reg, (2, 2))(x)
x = BatchNormalization()(x)
x = Conv2D(128, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(128, (3, 3),'relu', 'same',
"he_normal", l2_reg, (2, 2))(x)
x = BatchNormalization()(x)
x = Conv2D(256, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(256, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(256, (3, 3),'relu', 'same',
"he_normal", (2, 2))(x)
x = BatchNormalization()(x)
x = Conv2D(512, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(512, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(512, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = BatchNormalization()(x)
x = Conv2D(512, (3, 3),'relu', 'same',
dilation_rate=2, "he_normal", l2_reg)(x)
x = Conv2D(512, (3, 3),'relu', 'same',
dilation_rate=2, "he_normal", l2_reg)(x)
x = Conv2D(512, (3, 3),'relu', 'same',
dilation_rate=2, "he_normal", l2_reg)(x)
x = BatchNormalization()(x)
x = Conv2D(512, (3, 3),'relu', 'same',
dilation_rate=2, "he_normal", l2_reg)(x)
x = Conv2D(512, (3, 3),'relu', 'same',
dilation_rate=2, "he_normal", l2_reg)(x)
x = Conv2D(512, (3, 3),'relu', 'same',
dilation_rate=2, "he_normal", l2_reg)(x)
x = BatchNormalization()(x)
x = Conv2D(256, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(256, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(256, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = BatchNormalization()(x)
x = UpSampling2D(size=(2, 2))(x)
x = Conv2D(128, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(128, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = Conv2D(128, (3, 3),'relu', 'same',
"he_normal", l2_reg)(x)
x = BatchNormalization()(x)
x = Conv2D(274, (1, 1),'softmax', 'same')(x)
outputs = UpSampling2D(size=(4, 4),
interpolation='bilinear')(x)
model = Model(input_tensor, outputs)
return model
def temperature_scaling(Z, T):
Z = np.exp(np.log(Z) / T)
Z /= np.sum(Z, axis=-1, keepdims=True)
return Z
def convert_Z_to_ab(Z, ab_domain, T=0.38):
Z = temperature_scaling(Z, T)
ab_output = np.dot(Z, ab_domain)
return ab_output
In the paper, the team has employed a process similar to Turing Test to evaluate their model, but in this process I try to evaluate my model on some quantitative metrics like Mean Squared Error (MSE), Mean Absolute Error (MAE) and Peak Signal to Noise Ratio (PSNR).
The model performs fairly well with mse of 0.0012, mae of 0.0034 and psnr of 37. Below are some results that are indicative of how the model performs on anime inspired art.
The model performs fairly well on the task of colouring manga images, but there is still room for improvement. It is correctly predicting the areas to colour but is unable to accurately predict the exact colour, this may be due to lack of diversity of colour in the dataset. The model could be trained on a larger and more diverse dataset to improve its generalisation.