# Import standard dependencies import cv2 import os import uuid import random import numpy as np from matplotlib import pyplot as plt # Import tensorflow dependencies - Functional API from tensorflow.keras.models import Model from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten from tensorflow.keras.metrics import Precision, Recall import tensorflow as tf train_mode = False retrain_mode = False MODEL_NAME = "siamesemodelv4.h5" # Avoid OOM errors by setting GPU Memory Consumption Growth gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) # Setup paths POS_PATH = os.path.join('data', 'positive') NEG_PATH = os.path.join('data', 'negative') ANC_PATH = os.path.join('data', 'anchor') def data_aug(img): data = [] for i in range(9): img = tf.image.stateless_random_brightness(img, max_delta=0.02, seed=(1, 2)) img = tf.image.stateless_random_contrast(img, lower=0.6, upper=1, seed=(1, 3)) # img = tf.image.stateless_random_crop(img, size=(20,20,3), seed=(1,2)) img = tf.image.stateless_random_flip_left_right(img, seed=(np.random.randint(100), np.random.randint(100))) img = tf.image.stateless_random_jpeg_quality(img, min_jpeg_quality=90, max_jpeg_quality=100, seed=(np.random.randint(100), np.random.randint(100))) img = tf.image.stateless_random_saturation(img, lower=0.9, upper=1, seed=(np.random.randint(100), np.random.randint(100))) data.append(img) return data def preprocess(file_path): # Read in image from file path byte_img = tf.io.read_file(file_path) # Load in the image img = tf.io.decode_jpeg(byte_img) # Preprocessing steps - resizing the image to be 100x100x3 img = tf.image.resize(img, (100, 100)) # Scale image to be between 0 and 1 img = img / 255.0 # Return image return img def preprocess_twin(input_img, validation_img, label): return (preprocess(input_img), preprocess(validation_img), label) # for file_name in os.listdir(os.path.join(ANC_PATH)): # img_path = os.path.join(ANC_PATH, file_name) # img = cv2.imread(img_path) # augmented_images = data_aug(img) # # for image in augmented_images: # cv2.imwrite(os.path.join(ANC_PATH, 'surya-{}.jpg'.format(uuid.uuid1())), image.numpy()) # # # for file_name in os.listdir(os.path.join(POS_PATH)): # img_path = os.path.join(POS_PATH, file_name) # img = cv2.imread(img_path) # augmented_images = data_aug(img) # # for image in augmented_images: # cv2.imwrite(os.path.join(POS_PATH, 'surya-{}.jpg'.format(uuid.uuid1())), image.numpy()) anchor = tf.data.Dataset.list_files(ANC_PATH + '/*.jpg').take(2000) positive = tf.data.Dataset.list_files(POS_PATH + '/*.jpg').take(2000) negative = tf.data.Dataset.list_files(NEG_PATH + '/*.jpg').take(2000) dir_test = anchor.as_numpy_iterator() positives = tf.data.Dataset.zip((anchor, positive, tf.data.Dataset.from_tensor_slices(tf.ones(len(anchor))))) negatives = tf.data.Dataset.zip((anchor, negative, tf.data.Dataset.from_tensor_slices(tf.zeros(len(anchor))))) data = positives.concatenate(negatives) samples = data.as_numpy_iterator() example = samples.next() res = preprocess_twin(*example) data = data.map(preprocess_twin) data = data.cache() data = data.shuffle(buffer_size=10000) train_data = data.take(round(len(data) * .7)) train_data = train_data.batch(16) train_data = train_data.prefetch(8) test_data = data.skip(round(len(data) * .7)) test_data = test_data.take(round(len(data) * .3)) test_data = test_data.batch(16) test_data = test_data.prefetch(8) def save_model(model): model.save(MODEL_NAME) def load_model(): model = None if retrain_mode: return model try: model = tf.keras.models.load_model(MODEL_NAME, custom_objects={'L1Dist': L1Dist, 'BinaryCrossentropy': tf.losses.BinaryCrossentropy}) except (ImportError, IOError): return None return model # Build embedding layer def make_embedding(): inp = Input(shape=(100, 100, 3), name='input_image') c1 = Conv2D(64, (10, 10), activation='relu')(inp) m1 = MaxPooling2D(64, (2, 2), padding='same')(c1) c2 = Conv2D(128, (7, 7), activation='relu')(m1) m2 = MaxPooling2D(64, (2, 2), padding='same')(c2) c3 = Conv2D(128, (4, 4), activation='relu')(m2) m3 = MaxPooling2D(64, (2, 2), padding='same')(c3) c4 = Conv2D(256, (4, 4), activation='relu')(m3) f1 = Flatten()(c4) d1 = Dense(4096, activation='sigmoid')(f1) return Model(inputs=[inp], outputs=[d1], name='embedding') embedding = make_embedding() # Siamese L1 Distance class class L1Dist(Layer): # Init method - inheritance def __init__(self, **kwargs): super().__init__() # Magic happens here - similarity calculation def call(self, input_embedding, validation_embedding): return tf.math.abs(input_embedding - validation_embedding) def make_siamese(): input_image = Input(name='input_img', shape=(100, 100, 3)) validation_image = Input(name='validation_img', shape=(100, 100, 3)) inp_embedding = embedding(input_image) val_embedding = embedding(validation_image) siamese_layer = L1Dist() distances = siamese_layer(inp_embedding, val_embedding) classifier = Dense(1, activation='sigmoid')(distances) return Model(inputs=[input_image, validation_image], outputs=classifier, name='SiameseNetwork') # Training siamese_model = load_model() if siamese_model is None: siamese_model = make_siamese() retrain_mode = True binary_cross_loss = tf.losses.BinaryCrossentropy() opt = tf.keras.optimizers.Adam(1e-4) # checkpoint_dir = './training_checkpoints' # checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt') # checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model) @tf.function def train_step(batch): # Record all of our operations with tf.GradientTape() as tape: # Get anchor and positive/negative image X = batch[:2] # Get label y = batch[2] # Forward pass yhat = siamese_model(X, training=True) # Calculate loss loss = binary_cross_loss(y, yhat) print(loss) # Calculate gradients grad = tape.gradient(loss, siamese_model.trainable_variables) # Calculate updated weights and apply to siamese model opt.apply_gradients(zip(grad, siamese_model.trainable_variables)) # Return loss return loss def train(data, EPOCHS): # Loop through epochs for epoch in range(1, EPOCHS + 1): print('\n Epoch {}/{}'.format(epoch, EPOCHS)) progbar = tf.keras.utils.Progbar(len(data)) # Creating a metric object r = Recall() p = Precision() # Loop through each batch for idx, batch in enumerate(data): # Run train step here loss = train_step(batch) yhat = siamese_model.predict(batch[:2]) r.update_state(batch[2], yhat) p.update_state(batch[2], yhat) progbar.update(idx + 1) print(loss.numpy(), r.result().numpy(), p.result().numpy()) # Save checkpoints # if epoch % 10 == 0: # checkpoint.save(file_prefix=checkpoint_prefix) def predict(): # test_input, test_val, y_true = test_data.as_numpy_iterator().next() # y_hat = siamese_model.predict([test_input, test_val]) r = Recall() p = Precision() for test_input, test_val, y_true in test_data.as_numpy_iterator(): yhat = siamese_model.predict([test_input, test_val]) r.update_state(y_true, yhat) p.update_state(y_true, yhat) print(r.result().numpy(), p.result().numpy()) def verify(model, verification_threshold): # Build results array results = [] paths = os.listdir(os.path.join('application_data', 'verification_images')) print(paths) for image in paths: input_img = preprocess(os.path.join('application_data', 'input_image', 'input_image2.jpg')) validation_img = preprocess(os.path.join('application_data', 'verification_images', image)) result = model.predict(list(np.expand_dims([input_img, validation_img], axis=1))) print(image, result[0][0]) results.append(result[0][0]) # Detection Threshold: Metric above which a prediciton is considered positive pos_array = np.array(results) index = np.argmax(pos_array) detection = pos_array[index] filename = paths[index] # Verification Threshold: Proportion of positive predictions / total positive samples # verification = detection / len(os.listdir(os.path.join('application_data', 'verification_images'))) verification = detection verified = verification > verification_threshold return results, verified, detection, filename EPOCHS = 50 if __name__ == '__main__': if train_mode or retrain_mode: train(train_data, EPOCHS) save_model(siamese_model) # predict() r, v, d, f = verify(siamese_model, 0.5) print(f"Results: {r}\nVerified: {v}\nDetection: {d}\nFilename: {f}")