|
@@ -0,0 +1,200 @@
|
|
|
|
+# Import standard dependencies
|
|
|
|
+import cv2
|
|
|
|
+import os
|
|
|
|
+import uuid
|
|
|
|
+import random
|
|
|
|
+import numpy as np
|
|
|
|
+from matplotlib import pyplot as plt
|
|
|
|
+
|
|
|
|
+# Import tensorflow dependencies - Functional API
|
|
|
|
+from tensorflow.keras.models import Model
|
|
|
|
+from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten
|
|
|
|
+from tensorflow.keras.metrics import Precision, Recall
|
|
|
|
+import tensorflow as tf
|
|
|
|
+
|
|
|
|
+# Setup paths
|
|
|
|
+POS_PATH = os.path.join('data', 'positive')
|
|
|
|
+NEG_PATH = os.path.join('data', 'negative')
|
|
|
|
+ANC_PATH = os.path.join('data', 'anchor')
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def data_aug(img):
|
|
|
|
+ data = []
|
|
|
|
+ for i in range(9):
|
|
|
|
+ img = tf.image.stateless_random_brightness(img, max_delta=0.02, seed=(1, 2))
|
|
|
|
+ img = tf.image.stateless_random_contrast(img, lower=0.6, upper=1, seed=(1, 3))
|
|
|
|
+ # img = tf.image.stateless_random_crop(img, size=(20,20,3), seed=(1,2))
|
|
|
|
+ img = tf.image.stateless_random_flip_left_right(img, seed=(np.random.randint(100), np.random.randint(100)))
|
|
|
|
+ img = tf.image.stateless_random_jpeg_quality(img, min_jpeg_quality=90, max_jpeg_quality=100,
|
|
|
|
+ seed=(np.random.randint(100), np.random.randint(100)))
|
|
|
|
+ img = tf.image.stateless_random_saturation(img, lower=0.9, upper=1,
|
|
|
|
+ seed=(np.random.randint(100), np.random.randint(100)))
|
|
|
|
+
|
|
|
|
+ data.append(img)
|
|
|
|
+
|
|
|
|
+ return data
|
|
|
|
+
|
|
|
|
+def preprocess(file_path):
|
|
|
|
+ # Read in image from file path
|
|
|
|
+ byte_img = tf.io.read_file(file_path)
|
|
|
|
+ # Load in the image
|
|
|
|
+ img = tf.io.decode_jpeg(byte_img)
|
|
|
|
+
|
|
|
|
+ # Preprocessing steps - resizing the image to be 100x100x3
|
|
|
|
+ img = tf.image.resize(img, (100, 100))
|
|
|
|
+ # Scale image to be between 0 and 1
|
|
|
|
+ img = img / 255.0
|
|
|
|
+
|
|
|
|
+ # Return image
|
|
|
|
+ return img
|
|
|
|
+
|
|
|
|
+def preprocess_twin(input_img, validation_img, label):
|
|
|
|
+ return(preprocess(input_img), preprocess(validation_img), label)
|
|
|
|
+
|
|
|
|
+# img_path = os.path.join(ANC_PATH, '924e839c-135f-11ec-b54e-a0cec8d2d278.jpg')
|
|
|
|
+# img = cv2.imread(img_path)
|
|
|
|
+# augmented_images = data_aug(img)
|
|
|
|
+#
|
|
|
|
+# for image in augmented_images:
|
|
|
|
+# cv2.imwrite(os.path.join(ANC_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())
|
|
|
|
+#
|
|
|
|
+# for file_name in os.listdir(os.path.join(POS_PATH)):
|
|
|
|
+# img_path = os.path.join(POS_PATH, file_name)
|
|
|
|
+# img = cv2.imread(img_path)
|
|
|
|
+# augmented_images = data_aug(img)
|
|
|
|
+#
|
|
|
|
+# for image in augmented_images:
|
|
|
|
+# cv2.imwrite(os.path.join(POS_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())
|
|
|
|
+
|
|
|
|
+anchor = tf.data.Dataset.list_files(ANC_PATH+'/*.jpg').take(3000)
|
|
|
|
+positive = tf.data.Dataset.list_files(POS_PATH+'/*.jpg').take(3000)
|
|
|
|
+negative = tf.data.Dataset.list_files(NEG_PATH+'/*.jpg').take(3000)
|
|
|
|
+
|
|
|
|
+dir_test = anchor.as_numpy_iterator()
|
|
|
|
+
|
|
|
|
+positives = tf.data.Dataset.zip((anchor, positive, tf.data.Dataset.from_tensor_slices(tf.ones(len(anchor)))))
|
|
|
|
+negatives = tf.data.Dataset.zip((anchor, negative, tf.data.Dataset.from_tensor_slices(tf.zeros(len(anchor)))))
|
|
|
|
+data = positives.concatenate(negatives)
|
|
|
|
+
|
|
|
|
+samples = data.as_numpy_iterator()
|
|
|
|
+
|
|
|
|
+example = samples.next()
|
|
|
|
+
|
|
|
|
+res = preprocess_twin(*example)
|
|
|
|
+
|
|
|
|
+data = data.map(preprocess_twin)
|
|
|
|
+data = data.cache()
|
|
|
|
+data = data.shuffle(buffer_size=10000)
|
|
|
|
+
|
|
|
|
+train_data = data.take(round(len(data)*.7))
|
|
|
|
+train_data = train_data.batch(16)
|
|
|
|
+train_data = train_data.prefetch(8)
|
|
|
|
+
|
|
|
|
+test_data = data.skip(round(len(data)*.7))
|
|
|
|
+test_data = test_data.take(round(len(data)*.3))
|
|
|
|
+test_data = test_data.batch(16)
|
|
|
|
+test_data = test_data.prefetch(8)
|
|
|
|
+
|
|
|
|
+# Build embedding layer
|
|
|
|
+
|
|
|
|
+def make_embedding():
|
|
|
|
+ inp = Input(shape=(100,100,3), name='input_image')
|
|
|
|
+ c1 = Conv2D(64, (10,10), activation='relu')(inp)
|
|
|
|
+ m1 = MaxPooling2D(64, (2,2), padding='same')(c1)
|
|
|
|
+ c2 = Conv2D(128, (7,7), activation='relu')(m1)
|
|
|
|
+ m2 = MaxPooling2D(64, (2,2), padding='same')(c2)
|
|
|
|
+ c3 = Conv2D(128, (4,4), activation='relu')(m2)
|
|
|
|
+ m3 = MaxPooling2D(64, (2,2), padding='same')(c3)
|
|
|
|
+ c4 = Conv2D(256, (4,4), activation='relu')(m3)
|
|
|
|
+ f1 = Flatten()(c4)
|
|
|
|
+ d1 = Dense(4096, activation='sigmoid')(f1)
|
|
|
|
+ return Model(inputs=[inp], outputs=[d1], name='embedding')
|
|
|
|
+
|
|
|
|
+embedding = make_embedding()
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+# Siamese L1 Distance class
|
|
|
|
+class L1Dist(Layer):
|
|
|
|
+
|
|
|
|
+ # Init method - inheritance
|
|
|
|
+ def __init__(self, **kwargs):
|
|
|
|
+ super().__init__()
|
|
|
|
+
|
|
|
|
+ # Magic happens here - similarity calculation
|
|
|
|
+ def call(self, input_embedding, validation_embedding):
|
|
|
|
+ return tf.math.abs(input_embedding - validation_embedding)
|
|
|
|
+
|
|
|
|
+def make_siamese():
|
|
|
|
+ input_image = Input(name='input_img', shape=(100,100,3))
|
|
|
|
+ validation_image = Input(name='validation_img', shape=(100,100,3))
|
|
|
|
+ inp_embedding = embedding(input_image)
|
|
|
|
+ val_embedding = embedding(validation_image)
|
|
|
|
+ siamese_layer = L1Dist()
|
|
|
|
+ distances = siamese_layer(inp_embedding, val_embedding)
|
|
|
|
+ classifier = Dense(1, activation='sigmoid')(distances)
|
|
|
|
+ return Model(inputs=[input_image, validation_image], outputs=classifier, name='SiameseNetwork')
|
|
|
|
+
|
|
|
|
+# Training
|
|
|
|
+
|
|
|
|
+siamese_model = make_siamese()
|
|
|
|
+
|
|
|
|
+binary_cross_loss = tf.losses.BinaryCrossentropy()
|
|
|
|
+opt = tf.keras.optimizers.Adam(1e-4)
|
|
|
|
+
|
|
|
|
+checkpoint_dir = './training_checkpoints'
|
|
|
|
+checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
|
|
|
|
+checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model)
|
|
|
|
+
|
|
|
|
+@tf.function
|
|
|
|
+def train_step(batch):
|
|
|
|
+ # Record all of our operations
|
|
|
|
+ with tf.GradientTape() as tape:
|
|
|
|
+ # Get anchor and positive/negative image
|
|
|
|
+ X = batch[:2]
|
|
|
|
+ # Get label
|
|
|
|
+ y = batch[2]
|
|
|
|
+
|
|
|
|
+ # Forward pass
|
|
|
|
+ yhat = siamese_model(X, training=True)
|
|
|
|
+ # Calculate loss
|
|
|
|
+ loss = binary_cross_loss(y, yhat)
|
|
|
|
+ print(loss)
|
|
|
|
+
|
|
|
|
+ # Calculate gradients
|
|
|
|
+ grad = tape.gradient(loss, siamese_model.trainable_variables)
|
|
|
|
+
|
|
|
|
+ # Calculate updated weights and apply to siamese model
|
|
|
|
+ opt.apply_gradients(zip(grad, siamese_model.trainable_variables))
|
|
|
|
+
|
|
|
|
+ # Return loss
|
|
|
|
+ return loss
|
|
|
|
+
|
|
|
|
+
|
|
|
|
+def train(data, EPOCHS):
|
|
|
|
+ # Loop through epochs
|
|
|
|
+ for epoch in range(1, EPOCHS + 1):
|
|
|
|
+ print('\n Epoch {}/{}'.format(epoch, EPOCHS))
|
|
|
|
+ progbar = tf.keras.utils.Progbar(len(data))
|
|
|
|
+
|
|
|
|
+ # Creating a metric object
|
|
|
|
+ r = Recall()
|
|
|
|
+ p = Precision()
|
|
|
|
+
|
|
|
|
+ # Loop through each batch
|
|
|
|
+ for idx, batch in enumerate(data):
|
|
|
|
+ # Run train step here
|
|
|
|
+ loss = train_step(batch)
|
|
|
|
+ yhat = siamese_model.predict(batch[:2])
|
|
|
|
+ r.update_state(batch[2], yhat)
|
|
|
|
+ p.update_state(batch[2], yhat)
|
|
|
|
+ progbar.update(idx + 1)
|
|
|
|
+ print(loss.numpy(), r.result().numpy(), p.result().numpy())
|
|
|
|
+
|
|
|
|
+ # Save checkpoints
|
|
|
|
+ if epoch % 10 == 0:
|
|
|
|
+ checkpoint.save(file_prefix=checkpoint_prefix)
|
|
|
|
+
|
|
|
|
+EPOCHS = 50
|
|
|
|
+
|
|
|
|
+if __name__ == '__main__':
|
|
|
|
+ train(train_data, EPOCHS)
|