Colouring Manga Using Deep Learning
cover-image
In this project I explore the use of Deep Learning, specifically Convolutional Neural Networks (CNNs) to colour manga images. The project is inspired by the paper titled "Colorful Image Colorization" by Richard Zhang, Phillip Isola and Alexei A. Efros. The paper proposes more of a general framework of colouring images, but I wanted to see how well it would work on manga images. The results were quite impressive, and I was able to achieve a good level of colouring on the images.
The model in the paper is trained on the ImageNet dataset, which is a more general dataset than the one required for the task of colouring manga images. But thankfully the dataset required for the task of colouring is not hard to curate, only coloured images are required and features and labels can be extracted from them.
Creating Training Data:
The images required are obtained by writing a web-scraping python script that automatically searches and downloads tagged images from the imaging hosting website, Danbooru. Danbooru is ideal for this purpose as it has a vast collection of tagged images, and the images are of high quality. I downloaded approximately 10,000 images for the task. The images are then preprocessed and resized into a 256 X 256 square. They are further processed and converted to the LAB colour space. The L channel is used as the input to the model, and the AB channels are used as the labels.
Model Details:

The model is a CNN designed to The model is a Convolutional Neural Network designed to predict the ‘a’ and ‘b’ channels of the Lab color space given the ‘L’ channel. Previous works in this area treat the problem of colouring as a regression task and use the mean squared error loss function. They aim to predict the exact values of the ‘a’ and ‘b’ channels for every pixel.

However, the model in the paper treats the problem as a classification task and uses a custom loss function which is a variation of categorical crossentropy loss. They quantise the ‘a’ and ‘b’ channels into bins and predict the probability of each bin for every pixel. This approach is more robust and produces more realistic and visually appealing results.

I train a CNN to map from a grayscale input to a distribution over quantised colour value outputs using the architecture shown below.

architecture
Architecture as described in the paper. Each conv layer refers to a block of 2 or 3 repeated conv and ReLU layers, followed by a BatchNorm [30] layer. The net has no pool layers. All changes in resolution are achieved through spatial downsampling or upsampling between conv blocks.
Key components of implementation:
1. Quantisation of the AB channels:
The AB channels are quantised into 274 bins. The bins are determined by clustering the AB channels of the training data into 274 clusters using K-means clustering. The model then predicts the probability of each bin for every pixel.
def quantize(ab_img): return np.floor_divide(ab_img, 10) * 10 def set_possible_colors(): height, width = 256, 256 ab_to_possible_color = {} all_images = os.listdir("images") color_index = 0 for image in all_images: rgb_image = np.array( Image.open(f"images/{image}") .convert("RGB")) ab_img = extract_ab(rgb_image) quantized_image = quantize_(ab_img) for y in range(height): for x in range(width): a = quantized_ab_image[y][x][0] b = quantized_ab_image[y][x][1] ck = get_color_key(a, b) if ck not in ab_to_possible_color: ab_to_possible_color[ck] = color_index color_index += 1 if save_files: pickle.dump(ab_to_possible_color, open("ab_to_q_index_dict.p", "wb"))
2. Converting AB channels to probability distribution:
The AB channels are converted to a probability distribution.
def ab_to_z(ab_img, Q, sigma=5): w, h = ab_img.shape[0], ab_img.shape[1] pts = w * h ab_img_flat = np.reshape(ab_img, (pts, 2)) pts_flat = np.empty((pts, Q)) pts_ind = np.arange(0, pts) nbrs = NearestNeighbors( n_neighbors=5, algorithm='ball_tree') .fit(get_ab_domain()) distances, _ = nbrs.kneighbors(ab_img_flat) gaussian_kernel = np.exp( -distances**2 / (2 * sigma**2) ) gaussian_kernel /= np.sum(gaussian_kernel,1) pts_flat[pts_ind, ind] = gaussian_kernel z = np.reshape(pts_flat, (w, h, Q)) return z
3. Loss function:
The loss function is a variation of categorical crossentropy loss which includes components like class-rebalancing to account for the imbalance of colours in the dataset.
def multinomial_crossentropy_loss(Z, Z_hat): Q = len(get_ab_to_q_dict()) p = pickle.load(open("p_10000.p", "rb")) p_tilde = gaussian_filter(p, sigma=5) eps = 0.0001 weights = get_loss_weights(Z, Q, p_tilde) log = tf.math.log(Z_hat + eps) mul = log * Z summ = tf.reduce_sum(mul, 1) loss = -tf.reduce_sum(weights * summ) return loss
4. Model architecture:
The model is a CNN as described above. The model is trained on the training data for 1000 epochs.
def ColourNet(): input_tensor = Input(shape=(256, 256, 1)) x = Conv2D(64, (3, 3),'relu', 'same', "he_normal", l2_reg)(input_tensor) x = Conv2D(64, (3, 3),'relu', 'same', "he_normal", l2_reg, (2, 2))(x) x = BatchNormalization()(x) x = Conv2D(128, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(128, (3, 3),'relu', 'same', "he_normal", l2_reg, (2, 2))(x) x = BatchNormalization()(x) x = Conv2D(256, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(256, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(256, (3, 3),'relu', 'same', "he_normal", (2, 2))(x) x = BatchNormalization()(x) x = Conv2D(512, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(512, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(512, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = BatchNormalization()(x) x = Conv2D(512, (3, 3),'relu', 'same', dilation_rate=2, "he_normal", l2_reg)(x) x = Conv2D(512, (3, 3),'relu', 'same', dilation_rate=2, "he_normal", l2_reg)(x) x = Conv2D(512, (3, 3),'relu', 'same', dilation_rate=2, "he_normal", l2_reg)(x) x = BatchNormalization()(x) x = Conv2D(512, (3, 3),'relu', 'same', dilation_rate=2, "he_normal", l2_reg)(x) x = Conv2D(512, (3, 3),'relu', 'same', dilation_rate=2, "he_normal", l2_reg)(x) x = Conv2D(512, (3, 3),'relu', 'same', dilation_rate=2, "he_normal", l2_reg)(x) x = BatchNormalization()(x) x = Conv2D(256, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(256, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(256, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = BatchNormalization()(x) x = UpSampling2D(size=(2, 2))(x) x = Conv2D(128, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(128, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = Conv2D(128, (3, 3),'relu', 'same', "he_normal", l2_reg)(x) x = BatchNormalization()(x) x = Conv2D(274, (1, 1),'softmax', 'same')(x) outputs = UpSampling2D(size=(4, 4), interpolation='bilinear')(x) model = Model(input_tensor, outputs) return model
5. Converting probability distribution to AB channels:
The probability distribution output by the model is converted back to AB channels after applying temperature scaling.
def temperature_scaling(Z, T): Z = np.exp(np.log(Z) / T) Z /= np.sum(Z, axis=-1, keepdims=True) return Z def convert_Z_to_ab(Z, ab_domain, T=0.38): Z = temperature_scaling(Z, T) ab_output = np.dot(Z, ab_domain) return ab_output
For more details on the implementation, you can check out the code on my github repo here.
Results and Model Evaluation:

In the paper, the team has employed a process similar to Turing Test to evaluate their model, but in this process I try to evaluate my model on some quantitative metrics like Mean Squared Error (MSE), Mean Absolute Error (MAE) and Peak Signal to Noise Ratio (PSNR).

The model performs fairly well with mse of 0.0012, mae of 0.0034 and psnr of 37. Below are some results that are indicative of how the model performs on anime inspired art.

result1
result2
result3
Improvements and Future Work:

The model performs fairly well on the task of colouring manga images, but there is still room for improvement. It is correctly predicting the areas to colour but is unable to accurately predict the exact colour, this may be due to lack of diversity of colour in the dataset. The model could be trained on a larger and more diverse dataset to improve its generalisation.

References:

Community

GitHubKaggle

Social Media

TwitterLinkedin