Colouring SAR Images Using Deep Learning
cover-image
In this project I explore the use of Deep Learning, specifically Convolutional Neural Networks (CNNs) to colour SAR images. The project is inspired by the paper titled "Colorful Image Colorization" by Richard Zhang, Phillip Isola and Alexei A. Efros. The paper proposes more of a general framework of colouring images, but I wanted to see how well it would work on SAR images. The results were quite impressive, and I was able to achieve a good level of colouring on the images.
The model in the paper is trained on the ImageNet dataset, which is a more general dataset than the one required for the task of colouring sar images. But thankfully the dataset required for the task of colouring is not hard to curate, only coloured images are required and features and labels can be extracted from them.
Creating Training Data:
The images required are provided by Sentinel-1&2 Image Pairs (SAR & Optical) dataset The images are then preprocessed and resized into a 256 X 256 square. They are further processed and converted to the LAB colour space. The L channel is used as the input to the model, and the AB channels are used as the labels.
Model Details:

The model is a CNN designed to The model is a Convolutional Neural Network designed to predict the ‘a’ and ‘b’ channels of the Lab color space given the ‘L’ channel. Previous works in this area treat the problem of colouring as a regression task and use the mean squared error loss function. They aim to predict the exact values of the ‘a’ and ‘b’ channels for every pixel.

However, the model in the paper treats the problem as a classification task and uses a custom loss function which is a variation of categorical crossentropy loss. They quantise the ‘a’ and ‘b’ channels into bins and predict the probability of each bin for every pixel. This approach is more robust and produces more realistic and visually appealing results.

I train a CNN to map from a grayscale input to a distribution over quantised colour value outputs using the architecture shown below.

architecture
Architecture as described in the paper. Each conv layer refers to a block of 2 or 3 repeated conv and ReLU layers, followed by a BatchNorm [30] layer. The net has no pool layers. All changes in resolution are achieved through spatial downsampling or upsampling between conv blocks.
Key components of implementation:
1. Quantisation of the AB channels:
The AB channels are quantised into 274 bins. The bins are determined by clustering the AB channels of the training data into 274 clusters using K-means clustering. The model then predicts the probability of each bin for every pixel.
quantisation.py
1 2def quantize(ab_img): 3 return np.floor_divide(ab_img, 10) * 10 4 5def set_possible_colors(): 6 height, width = 256, 256 7 ab_to_possible_color = {} 8 all_images = os.listdir("images") 9 color_index = 0 10 11 for image in all_images: 12 rgb_image = np.array( 13 Image.open(f"images/{image}") 14 .convert("RGB")) 15 ab_img = extract_ab(rgb_image) 16 quantized_image = quantize_(ab_img) 17 18 for y in range(height): 19 for x in range(width): 20 a = quantized_ab_image[y][x][0] 21 b = quantized_ab_image[y][x][1] 22 ck = get_color_key(a, b) 23 24 if ck not in ab_to_possible_color: 25 ab_to_possible_color[ck] = color_index 26 color_index += 1 27 if save_files: 28 pickle.dump(ab_to_possible_color, 29 open("ab_to_q_index_dict.p", "wb")) 30
2. Converting AB channels to probability distribution:
The AB channels are converted to a probability distribution.
ab_to_z.py
1 2def ab_to_z(ab_img, Q, sigma=5): 3 w, h = ab_img.shape[0], ab_img.shape[1] 4 pts = w * h 5 ab_img_flat = np.reshape(ab_img, (pts, 2)) 6 pts_flat = np.empty((pts, Q)) 7 pts_ind = np.arange(0, pts) 8 9 nbrs = NearestNeighbors( 10 n_neighbors=5, algorithm='ball_tree') 11 .fit(get_ab_domain()) 12 distances, _ = nbrs.kneighbors(ab_img_flat) 13 14 gaussian_kernel = np.exp( 15 -distances**2 / (2 * sigma**2) 16 ) 17 gaussian_kernel /= np.sum(gaussian_kernel,1) 18 19 pts_flat[pts_ind, ind] = gaussian_kernel 20 z = np.reshape(pts_flat, (w, h, Q)) 21 22 return z 23
3. Loss function:
The loss function is a variation of categorical crossentropy loss which includes components like class-rebalancing to account for the imbalance of colours in the dataset.
loss_function.py
1def multinomial_crossentropy_loss(Z, Z_hat): 2 Q = len(get_ab_to_q_dict()) 3 p = pickle.load(open("p_10000.p", "rb")) 4 p_tilde = gaussian_filter(p, sigma=5) 5 eps = 0.0001 6 weights = get_loss_weights(Z, Q, p_tilde) 7 log = tf.math.log(Z_hat + eps) 8 mul = log * Z 9 summ = tf.reduce_sum(mul, 1) 10 loss = -tf.reduce_sum(weights * summ) 11 12 return loss
4. Model architecture:
The model is a CNN as described above. The model is trained on the training data for 1000 epochs.
model.py
1def ColourNet(): 2 input_tensor = Input(shape=(256, 256, 1)) 3 x = Conv2D(64, (3, 3),'relu', 'same', 4 "he_normal", l2_reg)(input_tensor) 5 x = Conv2D(64, (3, 3),'relu', 'same', 6 "he_normal", l2_reg, (2, 2))(x) 7 x = BatchNormalization()(x) 8 9 x = Conv2D(128, (3, 3),'relu', 'same', 10 "he_normal", l2_reg)(x) 11 x = Conv2D(128, (3, 3),'relu', 'same', 12 "he_normal", l2_reg, (2, 2))(x) 13 x = BatchNormalization()(x) 14 15 x = Conv2D(256, (3, 3),'relu', 'same', 16 "he_normal", l2_reg)(x) 17 x = Conv2D(256, (3, 3),'relu', 'same', 18 "he_normal", l2_reg)(x) 19 x = Conv2D(256, (3, 3),'relu', 'same', 20 "he_normal", (2, 2))(x) 21 x = BatchNormalization()(x) 22 23 x = Conv2D(512, (3, 3),'relu', 'same', 24 "he_normal", l2_reg)(x) 25 x = Conv2D(512, (3, 3),'relu', 'same', 26 "he_normal", l2_reg)(x) 27 x = Conv2D(512, (3, 3),'relu', 'same', 28 "he_normal", l2_reg)(x) 29 x = BatchNormalization()(x) 30 31 x = Conv2D(512, (3, 3),'relu', 'same', 32 dilation_rate=2, "he_normal", l2_reg)(x) 33 x = Conv2D(512, (3, 3),'relu', 'same', 34 dilation_rate=2, "he_normal", l2_reg)(x) 35 x = Conv2D(512, (3, 3),'relu', 'same', 36 dilation_rate=2, "he_normal", l2_reg)(x) 37 x = BatchNormalization()(x) 38 39 x = Conv2D(512, (3, 3),'relu', 'same', 40 dilation_rate=2, "he_normal", l2_reg)(x) 41 x = Conv2D(512, (3, 3),'relu', 'same', 42 dilation_rate=2, "he_normal", l2_reg)(x) 43 x = Conv2D(512, (3, 3),'relu', 'same', 44 dilation_rate=2, "he_normal", l2_reg)(x) 45 x = BatchNormalization()(x) 46 47 x = Conv2D(256, (3, 3),'relu', 'same', 48 "he_normal", l2_reg)(x) 49 x = Conv2D(256, (3, 3),'relu', 'same', 50 "he_normal", l2_reg)(x) 51 x = Conv2D(256, (3, 3),'relu', 'same', 52 "he_normal", l2_reg)(x) 53 x = BatchNormalization()(x) 54 55 x = UpSampling2D(size=(2, 2))(x) 56 x = Conv2D(128, (3, 3),'relu', 'same', 57 "he_normal", l2_reg)(x) 58 x = Conv2D(128, (3, 3),'relu', 'same', 59 "he_normal", l2_reg)(x) 60 x = Conv2D(128, (3, 3),'relu', 'same', 61 "he_normal", l2_reg)(x) 62 x = BatchNormalization()(x) 63 64 x = Conv2D(274, (1, 1),'softmax', 'same')(x) 65 outputs = UpSampling2D(size=(4, 4), 66 interpolation='bilinear')(x) 67 model = Model(input_tensor, outputs) 68 69 return model
5. Converting probability distribution to AB channels:
The probability distribution output by the model is converted back to AB channels after applying temperature scaling.
convert_to_ab.py
1def temperature_scaling(Z, T): 2 Z = np.exp(np.log(Z) / T) 3 Z /= np.sum(Z, axis=-1, keepdims=True) 4 5 return Z 6 7def convert_Z_to_ab(Z, ab_domain, T=0.38): 8 Z = temperature_scaling(Z, T) 9 10 ab_output = np.dot(Z, ab_domain) 11 12 return ab_output
For more details on the implementation, you can check out the code on my github repo here.
Results and Model Evaluation:

In the paper, the team has employed a process similar to Turing Test to evaluate their model, but in this process I try to evaluate my model on some quantitative metrics like Mean Squared Error (MSE), Mean Absolute Error (MAE) and Peak Signal to Noise Ratio (PSNR).

The model performs fairly well with mse of 0.0012, mae of 0.0034 and psnr of 37. Below are some results that are indicative of how the model performs on anime inspired art.

result1
result2
result3
Improvements and Future Work:

The model performs fairly well on the task of colouring sar images, but there is still room for improvement. It is correctly predicting the areas to colour but is unable to accurately predict the exact colour, this may be due to lack of diversity of colour in the dataset. The model could be trained on a larger and more diverse dataset to improve its generalisation.

References:

Community

GitHubKaggle

Social Media

TwitterLinkedin