main.py 9.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284
  1. # Import standard dependencies
  2. import cv2
  3. import os
  4. import uuid
  5. import random
  6. import numpy as np
  7. from matplotlib import pyplot as plt
  8. # Import tensorflow dependencies - Functional API
  9. from tensorflow.keras.models import Model
  10. from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, Flatten
  11. from tensorflow.keras.metrics import Precision, Recall
  12. import tensorflow as tf
  13. train_mode = False
  14. retrain_mode = False
  15. MODEL_NAME = "siamesemodelv4.h5"
  16. # Avoid OOM errors by setting GPU Memory Consumption Growth
  17. gpus = tf.config.experimental.list_physical_devices('GPU')
  18. for gpu in gpus:
  19. tf.config.experimental.set_memory_growth(gpu, True)
  20. # Setup paths
  21. POS_PATH = os.path.join('data', 'positive')
  22. NEG_PATH = os.path.join('data', 'negative')
  23. ANC_PATH = os.path.join('data', 'anchor')
  24. def data_aug(img):
  25. data = []
  26. for i in range(9):
  27. img = tf.image.stateless_random_brightness(img, max_delta=0.02, seed=(1, 2))
  28. img = tf.image.stateless_random_contrast(img, lower=0.6, upper=1, seed=(1, 3))
  29. # img = tf.image.stateless_random_crop(img, size=(20,20,3), seed=(1,2))
  30. img = tf.image.stateless_random_flip_left_right(img, seed=(np.random.randint(100), np.random.randint(100)))
  31. img = tf.image.stateless_random_jpeg_quality(img, min_jpeg_quality=90, max_jpeg_quality=100,
  32. seed=(np.random.randint(100), np.random.randint(100)))
  33. img = tf.image.stateless_random_saturation(img, lower=0.9, upper=1,
  34. seed=(np.random.randint(100), np.random.randint(100)))
  35. data.append(img)
  36. return data
  37. def preprocess(file_path):
  38. # Read in image from file path
  39. byte_img = tf.io.read_file(file_path)
  40. # Load in the image
  41. img = tf.io.decode_jpeg(byte_img)
  42. # Preprocessing steps - resizing the image to be 100x100x3
  43. img = tf.image.resize(img, (100, 100))
  44. # Scale image to be between 0 and 1
  45. img = img / 255.0
  46. # Return image
  47. return img
  48. def preprocess_twin(input_img, validation_img, label):
  49. return (preprocess(input_img), preprocess(validation_img), label)
  50. # for file_name in os.listdir(os.path.join(ANC_PATH)):
  51. # img_path = os.path.join(ANC_PATH, file_name)
  52. # img = cv2.imread(img_path)
  53. # augmented_images = data_aug(img)
  54. #
  55. # for image in augmented_images:
  56. # cv2.imwrite(os.path.join(ANC_PATH, 'surya-{}.jpg'.format(uuid.uuid1())), image.numpy())
  57. # #
  58. # for file_name in os.listdir(os.path.join(POS_PATH)):
  59. # img_path = os.path.join(POS_PATH, file_name)
  60. # img = cv2.imread(img_path)
  61. # augmented_images = data_aug(img)
  62. #
  63. # for image in augmented_images:
  64. # cv2.imwrite(os.path.join(POS_PATH, 'surya-{}.jpg'.format(uuid.uuid1())), image.numpy())
  65. anchor = tf.data.Dataset.list_files(ANC_PATH + '/*.jpg').take(2000)
  66. positive = tf.data.Dataset.list_files(POS_PATH + '/*.jpg').take(2000)
  67. negative = tf.data.Dataset.list_files(NEG_PATH + '/*.jpg').take(2000)
  68. dir_test = anchor.as_numpy_iterator()
  69. positives = tf.data.Dataset.zip((anchor, positive, tf.data.Dataset.from_tensor_slices(tf.ones(len(anchor)))))
  70. negatives = tf.data.Dataset.zip((anchor, negative, tf.data.Dataset.from_tensor_slices(tf.zeros(len(anchor)))))
  71. data = positives.concatenate(negatives)
  72. samples = data.as_numpy_iterator()
  73. example = samples.next()
  74. res = preprocess_twin(*example)
  75. data = data.map(preprocess_twin)
  76. data = data.cache()
  77. data = data.shuffle(buffer_size=10000)
  78. train_data = data.take(round(len(data) * .7))
  79. train_data = train_data.batch(16)
  80. train_data = train_data.prefetch(8)
  81. test_data = data.skip(round(len(data) * .7))
  82. test_data = test_data.take(round(len(data) * .3))
  83. test_data = test_data.batch(16)
  84. test_data = test_data.prefetch(8)
  85. def save_model(model):
  86. model.save(MODEL_NAME)
  87. def load_model():
  88. model = None
  89. if retrain_mode:
  90. return model
  91. try:
  92. model = tf.keras.models.load_model(MODEL_NAME, custom_objects={'L1Dist': L1Dist,
  93. 'BinaryCrossentropy': tf.losses.BinaryCrossentropy})
  94. except (ImportError, IOError):
  95. return None
  96. return model
  97. # Build embedding layer
  98. def make_embedding():
  99. inp = Input(shape=(100, 100, 3), name='input_image')
  100. c1 = Conv2D(64, (10, 10), activation='relu')(inp)
  101. m1 = MaxPooling2D(64, (2, 2), padding='same')(c1)
  102. c2 = Conv2D(128, (7, 7), activation='relu')(m1)
  103. m2 = MaxPooling2D(64, (2, 2), padding='same')(c2)
  104. c3 = Conv2D(128, (4, 4), activation='relu')(m2)
  105. m3 = MaxPooling2D(64, (2, 2), padding='same')(c3)
  106. c4 = Conv2D(256, (4, 4), activation='relu')(m3)
  107. f1 = Flatten()(c4)
  108. d1 = Dense(4096, activation='sigmoid')(f1)
  109. return Model(inputs=[inp], outputs=[d1], name='embedding')
  110. embedding = make_embedding()
  111. # Siamese L1 Distance class
  112. class L1Dist(Layer):
  113. # Init method - inheritance
  114. def __init__(self, **kwargs):
  115. super().__init__()
  116. # Magic happens here - similarity calculation
  117. def call(self, input_embedding, validation_embedding):
  118. return tf.math.abs(input_embedding - validation_embedding)
  119. def make_siamese():
  120. input_image = Input(name='input_img', shape=(100, 100, 3))
  121. validation_image = Input(name='validation_img', shape=(100, 100, 3))
  122. inp_embedding = embedding(input_image)
  123. val_embedding = embedding(validation_image)
  124. siamese_layer = L1Dist()
  125. distances = siamese_layer(inp_embedding, val_embedding)
  126. classifier = Dense(1, activation='sigmoid')(distances)
  127. return Model(inputs=[input_image, validation_image], outputs=classifier, name='SiameseNetwork')
  128. # Training
  129. siamese_model = load_model()
  130. if siamese_model is None:
  131. siamese_model = make_siamese()
  132. retrain_mode = True
  133. binary_cross_loss = tf.losses.BinaryCrossentropy()
  134. opt = tf.keras.optimizers.Adam(1e-4)
  135. # checkpoint_dir = './training_checkpoints'
  136. # checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
  137. # checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model)
  138. @tf.function
  139. def train_step(batch):
  140. # Record all of our operations
  141. with tf.GradientTape() as tape:
  142. # Get anchor and positive/negative image
  143. X = batch[:2]
  144. # Get label
  145. y = batch[2]
  146. # Forward pass
  147. yhat = siamese_model(X, training=True)
  148. # Calculate loss
  149. loss = binary_cross_loss(y, yhat)
  150. print(loss)
  151. # Calculate gradients
  152. grad = tape.gradient(loss, siamese_model.trainable_variables)
  153. # Calculate updated weights and apply to siamese model
  154. opt.apply_gradients(zip(grad, siamese_model.trainable_variables))
  155. # Return loss
  156. return loss
  157. def train(data, EPOCHS):
  158. # Loop through epochs
  159. for epoch in range(1, EPOCHS + 1):
  160. print('\n Epoch {}/{}'.format(epoch, EPOCHS))
  161. progbar = tf.keras.utils.Progbar(len(data))
  162. # Creating a metric object
  163. r = Recall()
  164. p = Precision()
  165. # Loop through each batch
  166. for idx, batch in enumerate(data):
  167. # Run train step here
  168. loss = train_step(batch)
  169. yhat = siamese_model.predict(batch[:2])
  170. r.update_state(batch[2], yhat)
  171. p.update_state(batch[2], yhat)
  172. progbar.update(idx + 1)
  173. print(loss.numpy(), r.result().numpy(), p.result().numpy())
  174. # Save checkpoints
  175. # if epoch % 10 == 0:
  176. # checkpoint.save(file_prefix=checkpoint_prefix)
  177. def predict():
  178. # test_input, test_val, y_true = test_data.as_numpy_iterator().next()
  179. # y_hat = siamese_model.predict([test_input, test_val])
  180. r = Recall()
  181. p = Precision()
  182. for test_input, test_val, y_true in test_data.as_numpy_iterator():
  183. yhat = siamese_model.predict([test_input, test_val])
  184. r.update_state(y_true, yhat)
  185. p.update_state(y_true, yhat)
  186. print(r.result().numpy(), p.result().numpy())
  187. def verify(model, verification_threshold):
  188. # Build results array
  189. results = []
  190. paths = os.listdir(os.path.join('application_data', 'verification_images'))
  191. print(paths)
  192. for image in paths:
  193. input_img = preprocess(os.path.join('application_data', 'input_image', 'input_image2.jpg'))
  194. validation_img = preprocess(os.path.join('application_data', 'verification_images', image))
  195. result = model.predict(list(np.expand_dims([input_img, validation_img], axis=1)))
  196. print(image, result[0][0])
  197. results.append(result[0][0])
  198. # Detection Threshold: Metric above which a prediciton is considered positive
  199. pos_array = np.array(results)
  200. index = np.argmax(pos_array)
  201. detection = pos_array[index]
  202. filename = paths[index]
  203. # Verification Threshold: Proportion of positive predictions / total positive samples
  204. # verification = detection / len(os.listdir(os.path.join('application_data', 'verification_images')))
  205. verification = detection
  206. verified = verification > verification_threshold
  207. return results, verified, detection, filename
  208. EPOCHS = 50
  209. if __name__ == '__main__':
  210. if train_mode or retrain_mode:
  211. train(train_data, EPOCHS)
  212. save_model(siamese_model)
  213. # predict()
  214. r, v, d, f = verify(siamese_model, 0.5)
  215. print(f"Results: {r}\nVerified: {v}\nDetection: {d}\nFilename: {f}")