main.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200
  1. # Import standard dependencies
  2. import cv2
  3. import os
  4. import uuid
  5. import random
  6. import numpy as np
  7. from matplotlib import pyplot as plt
  8. # Import tensorflow dependencies - Functional API
  9. from tensorflow.keras.models import Model
  10. from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten
  11. from tensorflow.keras.metrics import Precision, Recall
  12. import tensorflow as tf
  13. # Setup paths
  14. POS_PATH = os.path.join('data', 'positive')
  15. NEG_PATH = os.path.join('data', 'negative')
  16. ANC_PATH = os.path.join('data', 'anchor')
  17. def data_aug(img):
  18. data = []
  19. for i in range(9):
  20. img = tf.image.stateless_random_brightness(img, max_delta=0.02, seed=(1, 2))
  21. img = tf.image.stateless_random_contrast(img, lower=0.6, upper=1, seed=(1, 3))
  22. # img = tf.image.stateless_random_crop(img, size=(20,20,3), seed=(1,2))
  23. img = tf.image.stateless_random_flip_left_right(img, seed=(np.random.randint(100), np.random.randint(100)))
  24. img = tf.image.stateless_random_jpeg_quality(img, min_jpeg_quality=90, max_jpeg_quality=100,
  25. seed=(np.random.randint(100), np.random.randint(100)))
  26. img = tf.image.stateless_random_saturation(img, lower=0.9, upper=1,
  27. seed=(np.random.randint(100), np.random.randint(100)))
  28. data.append(img)
  29. return data
  30. def preprocess(file_path):
  31. # Read in image from file path
  32. byte_img = tf.io.read_file(file_path)
  33. # Load in the image
  34. img = tf.io.decode_jpeg(byte_img)
  35. # Preprocessing steps - resizing the image to be 100x100x3
  36. img = tf.image.resize(img, (100, 100))
  37. # Scale image to be between 0 and 1
  38. img = img / 255.0
  39. # Return image
  40. return img
  41. def preprocess_twin(input_img, validation_img, label):
  42. return(preprocess(input_img), preprocess(validation_img), label)
  43. # img_path = os.path.join(ANC_PATH, '924e839c-135f-11ec-b54e-a0cec8d2d278.jpg')
  44. # img = cv2.imread(img_path)
  45. # augmented_images = data_aug(img)
  46. #
  47. # for image in augmented_images:
  48. # cv2.imwrite(os.path.join(ANC_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())
  49. #
  50. # for file_name in os.listdir(os.path.join(POS_PATH)):
  51. # img_path = os.path.join(POS_PATH, file_name)
  52. # img = cv2.imread(img_path)
  53. # augmented_images = data_aug(img)
  54. #
  55. # for image in augmented_images:
  56. # cv2.imwrite(os.path.join(POS_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())
  57. anchor = tf.data.Dataset.list_files(ANC_PATH+'/*.jpg').take(3000)
  58. positive = tf.data.Dataset.list_files(POS_PATH+'/*.jpg').take(3000)
  59. negative = tf.data.Dataset.list_files(NEG_PATH+'/*.jpg').take(3000)
  60. dir_test = anchor.as_numpy_iterator()
  61. positives = tf.data.Dataset.zip((anchor, positive, tf.data.Dataset.from_tensor_slices(tf.ones(len(anchor)))))
  62. negatives = tf.data.Dataset.zip((anchor, negative, tf.data.Dataset.from_tensor_slices(tf.zeros(len(anchor)))))
  63. data = positives.concatenate(negatives)
  64. samples = data.as_numpy_iterator()
  65. example = samples.next()
  66. res = preprocess_twin(*example)
  67. data = data.map(preprocess_twin)
  68. data = data.cache()
  69. data = data.shuffle(buffer_size=10000)
  70. train_data = data.take(round(len(data)*.7))
  71. train_data = train_data.batch(16)
  72. train_data = train_data.prefetch(8)
  73. test_data = data.skip(round(len(data)*.7))
  74. test_data = test_data.take(round(len(data)*.3))
  75. test_data = test_data.batch(16)
  76. test_data = test_data.prefetch(8)
  77. # Build embedding layer
  78. def make_embedding():
  79. inp = Input(shape=(100,100,3), name='input_image')
  80. c1 = Conv2D(64, (10,10), activation='relu')(inp)
  81. m1 = MaxPooling2D(64, (2,2), padding='same')(c1)
  82. c2 = Conv2D(128, (7,7), activation='relu')(m1)
  83. m2 = MaxPooling2D(64, (2,2), padding='same')(c2)
  84. c3 = Conv2D(128, (4,4), activation='relu')(m2)
  85. m3 = MaxPooling2D(64, (2,2), padding='same')(c3)
  86. c4 = Conv2D(256, (4,4), activation='relu')(m3)
  87. f1 = Flatten()(c4)
  88. d1 = Dense(4096, activation='sigmoid')(f1)
  89. return Model(inputs=[inp], outputs=[d1], name='embedding')
  90. embedding = make_embedding()
  91. # Siamese L1 Distance class
  92. class L1Dist(Layer):
  93. # Init method - inheritance
  94. def __init__(self, **kwargs):
  95. super().__init__()
  96. # Magic happens here - similarity calculation
  97. def call(self, input_embedding, validation_embedding):
  98. return tf.math.abs(input_embedding - validation_embedding)
  99. def make_siamese():
  100. input_image = Input(name='input_img', shape=(100,100,3))
  101. validation_image = Input(name='validation_img', shape=(100,100,3))
  102. inp_embedding = embedding(input_image)
  103. val_embedding = embedding(validation_image)
  104. siamese_layer = L1Dist()
  105. distances = siamese_layer(inp_embedding, val_embedding)
  106. classifier = Dense(1, activation='sigmoid')(distances)
  107. return Model(inputs=[input_image, validation_image], outputs=classifier, name='SiameseNetwork')
  108. # Training
  109. siamese_model = make_siamese()
  110. binary_cross_loss = tf.losses.BinaryCrossentropy()
  111. opt = tf.keras.optimizers.Adam(1e-4)
  112. checkpoint_dir = './training_checkpoints'
  113. checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
  114. checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model)
  115. @tf.function
  116. def train_step(batch):
  117. # Record all of our operations
  118. with tf.GradientTape() as tape:
  119. # Get anchor and positive/negative image
  120. X = batch[:2]
  121. # Get label
  122. y = batch[2]
  123. # Forward pass
  124. yhat = siamese_model(X, training=True)
  125. # Calculate loss
  126. loss = binary_cross_loss(y, yhat)
  127. print(loss)
  128. # Calculate gradients
  129. grad = tape.gradient(loss, siamese_model.trainable_variables)
  130. # Calculate updated weights and apply to siamese model
  131. opt.apply_gradients(zip(grad, siamese_model.trainable_variables))
  132. # Return loss
  133. return loss
  134. def train(data, EPOCHS):
  135. # Loop through epochs
  136. for epoch in range(1, EPOCHS + 1):
  137. print('\n Epoch {}/{}'.format(epoch, EPOCHS))
  138. progbar = tf.keras.utils.Progbar(len(data))
  139. # Creating a metric object
  140. r = Recall()
  141. p = Precision()
  142. # Loop through each batch
  143. for idx, batch in enumerate(data):
  144. # Run train step here
  145. loss = train_step(batch)
  146. yhat = siamese_model.predict(batch[:2])
  147. r.update_state(batch[2], yhat)
  148. p.update_state(batch[2], yhat)
  149. progbar.update(idx + 1)
  150. print(loss.numpy(), r.result().numpy(), p.result().numpy())
  151. # Save checkpoints
  152. if epoch % 10 == 0:
  153. checkpoint.save(file_prefix=checkpoint_prefix)
  154. EPOCHS = 50
  155. if __name__ == '__main__':
  156. train(train_data, EPOCHS)