kevin 2 gadi atpakaļ
vecāks
revīzija
f96ef83ec7
1 mainītis faili ar 116 papildinājumiem un 32 dzēšanām
  1. 116 32
      main.py

+ 116 - 32
main.py

@@ -12,6 +12,15 @@ from tensorflow.keras.layers import Layer, Conv2D, Dense, MaxPooling2D, Input, F
 from tensorflow.keras.metrics import Precision, Recall
 import tensorflow as tf
 
+train_mode = False
+retrain_mode = False
+MODEL_NAME = "siamesemodelv4.h5"
+
+# Avoid OOM errors by setting GPU Memory Consumption Growth
+gpus = tf.config.experimental.list_physical_devices('GPU')
+for gpu in gpus:
+    tf.config.experimental.set_memory_growth(gpu, True)
+
 # Setup paths
 POS_PATH = os.path.join('data', 'positive')
 NEG_PATH = os.path.join('data', 'negative')
@@ -34,6 +43,7 @@ def data_aug(img):
 
     return data
 
+
 def preprocess(file_path):
     # Read in image from file path
     byte_img = tf.io.read_file(file_path)
@@ -48,27 +58,30 @@ def preprocess(file_path):
     # Return image
     return img
 
+
 def preprocess_twin(input_img, validation_img, label):
-    return(preprocess(input_img), preprocess(validation_img), label)
+    return (preprocess(input_img), preprocess(validation_img), label)
 
-# img_path = os.path.join(ANC_PATH, '924e839c-135f-11ec-b54e-a0cec8d2d278.jpg')
-# img = cv2.imread(img_path)
-# augmented_images = data_aug(img)
-#
-# for image in augmented_images:
-#     cv2.imwrite(os.path.join(ANC_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())
+
+# for file_name in os.listdir(os.path.join(ANC_PATH)):
+#     img_path = os.path.join(ANC_PATH, file_name)
+#     img = cv2.imread(img_path)
+#     augmented_images = data_aug(img)
 #
+#     for image in augmented_images:
+#         cv2.imwrite(os.path.join(ANC_PATH, 'surya-{}.jpg'.format(uuid.uuid1())), image.numpy())
+# #
 # for file_name in os.listdir(os.path.join(POS_PATH)):
 #     img_path = os.path.join(POS_PATH, file_name)
 #     img = cv2.imread(img_path)
 #     augmented_images = data_aug(img)
 #
 #     for image in augmented_images:
-#         cv2.imwrite(os.path.join(POS_PATH, '{}.jpg'.format(uuid.uuid1())), image.numpy())
+#         cv2.imwrite(os.path.join(POS_PATH, 'surya-{}.jpg'.format(uuid.uuid1())), image.numpy())
 
-anchor = tf.data.Dataset.list_files(ANC_PATH+'/*.jpg').take(3000)
-positive = tf.data.Dataset.list_files(POS_PATH+'/*.jpg').take(3000)
-negative = tf.data.Dataset.list_files(NEG_PATH+'/*.jpg').take(3000)
+anchor = tf.data.Dataset.list_files(ANC_PATH + '/*.jpg').take(2000)
+positive = tf.data.Dataset.list_files(POS_PATH + '/*.jpg').take(2000)
+negative = tf.data.Dataset.list_files(NEG_PATH + '/*.jpg').take(2000)
 
 dir_test = anchor.as_numpy_iterator()
 
@@ -86,30 +99,48 @@ data = data.map(preprocess_twin)
 data = data.cache()
 data = data.shuffle(buffer_size=10000)
 
-train_data = data.take(round(len(data)*.7))
+train_data = data.take(round(len(data) * .7))
 train_data = train_data.batch(16)
 train_data = train_data.prefetch(8)
 
-test_data = data.skip(round(len(data)*.7))
-test_data = test_data.take(round(len(data)*.3))
+test_data = data.skip(round(len(data) * .7))
+test_data = test_data.take(round(len(data) * .3))
 test_data = test_data.batch(16)
 test_data = test_data.prefetch(8)
 
+
+def save_model(model):
+    model.save(MODEL_NAME)
+
+
+def load_model():
+    model = None
+    if retrain_mode:
+        return model
+    try:
+        model = tf.keras.models.load_model(MODEL_NAME, custom_objects={'L1Dist': L1Dist,
+                                                                                'BinaryCrossentropy': tf.losses.BinaryCrossentropy})
+    except (ImportError, IOError):
+        return None
+    return model
+
+
 # Build embedding layer
 
 def make_embedding():
-    inp = Input(shape=(100,100,3), name='input_image')
-    c1 = Conv2D(64, (10,10), activation='relu')(inp)
-    m1 = MaxPooling2D(64, (2,2), padding='same')(c1)
-    c2 = Conv2D(128, (7,7), activation='relu')(m1)
-    m2 = MaxPooling2D(64, (2,2), padding='same')(c2)
-    c3 = Conv2D(128, (4,4), activation='relu')(m2)
-    m3 = MaxPooling2D(64, (2,2), padding='same')(c3)
-    c4 = Conv2D(256, (4,4), activation='relu')(m3)
+    inp = Input(shape=(100, 100, 3), name='input_image')
+    c1 = Conv2D(64, (10, 10), activation='relu')(inp)
+    m1 = MaxPooling2D(64, (2, 2), padding='same')(c1)
+    c2 = Conv2D(128, (7, 7), activation='relu')(m1)
+    m2 = MaxPooling2D(64, (2, 2), padding='same')(c2)
+    c3 = Conv2D(128, (4, 4), activation='relu')(m2)
+    m3 = MaxPooling2D(64, (2, 2), padding='same')(c3)
+    c4 = Conv2D(256, (4, 4), activation='relu')(m3)
     f1 = Flatten()(c4)
     d1 = Dense(4096, activation='sigmoid')(f1)
     return Model(inputs=[inp], outputs=[d1], name='embedding')
 
+
 embedding = make_embedding()
 
 
@@ -124,9 +155,10 @@ class L1Dist(Layer):
     def call(self, input_embedding, validation_embedding):
         return tf.math.abs(input_embedding - validation_embedding)
 
+
 def make_siamese():
-    input_image = Input(name='input_img', shape=(100,100,3))
-    validation_image = Input(name='validation_img', shape=(100,100,3))
+    input_image = Input(name='input_img', shape=(100, 100, 3))
+    validation_image = Input(name='validation_img', shape=(100, 100, 3))
     inp_embedding = embedding(input_image)
     val_embedding = embedding(validation_image)
     siamese_layer = L1Dist()
@@ -134,16 +166,21 @@ def make_siamese():
     classifier = Dense(1, activation='sigmoid')(distances)
     return Model(inputs=[input_image, validation_image], outputs=classifier, name='SiameseNetwork')
 
-# Training
 
-siamese_model = make_siamese()
+# Training
+siamese_model = load_model()
+if siamese_model is None:
+    siamese_model = make_siamese()
+    retrain_mode = True
 
 binary_cross_loss = tf.losses.BinaryCrossentropy()
 opt = tf.keras.optimizers.Adam(1e-4)
 
-checkpoint_dir = './training_checkpoints'
-checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
-checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model)
+
+# checkpoint_dir = './training_checkpoints'
+# checkpoint_prefix = os.path.join(checkpoint_dir, 'ckpt')
+# checkpoint = tf.train.Checkpoint(opt=opt, siamese_model=siamese_model)
+
 
 @tf.function
 def train_step(batch):
@@ -191,10 +228,57 @@ def train(data, EPOCHS):
         print(loss.numpy(), r.result().numpy(), p.result().numpy())
 
         # Save checkpoints
-        if epoch % 10 == 0:
-            checkpoint.save(file_prefix=checkpoint_prefix)
+        # if epoch % 10 == 0:
+        #     checkpoint.save(file_prefix=checkpoint_prefix)
+
+
+def predict():
+    # test_input, test_val, y_true = test_data.as_numpy_iterator().next()
+    # y_hat = siamese_model.predict([test_input, test_val])
+
+    r = Recall()
+    p = Precision()
+
+    for test_input, test_val, y_true in test_data.as_numpy_iterator():
+        yhat = siamese_model.predict([test_input, test_val])
+        r.update_state(y_true, yhat)
+        p.update_state(y_true, yhat)
+
+    print(r.result().numpy(), p.result().numpy())
+
+
+def verify(model, verification_threshold):
+    # Build results array
+    results = []
+    paths = os.listdir(os.path.join('application_data', 'verification_images'))
+    print(paths)
+    for image in paths:
+        input_img = preprocess(os.path.join('application_data', 'input_image', 'input_image2.jpg'))
+        validation_img = preprocess(os.path.join('application_data', 'verification_images', image))
+        result = model.predict(list(np.expand_dims([input_img, validation_img], axis=1)))
+        print(image, result[0][0])
+        results.append(result[0][0])
+
+    # Detection Threshold: Metric above which a prediciton is considered positive
+    pos_array = np.array(results)
+    index = np.argmax(pos_array)
+    detection = pos_array[index]
+    filename = paths[index]
+
+    # Verification Threshold: Proportion of positive predictions / total positive samples
+    # verification = detection / len(os.listdir(os.path.join('application_data', 'verification_images')))
+    verification = detection
+    verified = verification > verification_threshold
+
+    return results, verified, detection, filename
+
 
 EPOCHS = 50
 
 if __name__ == '__main__':
-    train(train_data, EPOCHS)
+    if train_mode or retrain_mode:
+        train(train_data, EPOCHS)
+        save_model(siamese_model)
+    # predict()
+    r, v, d, f = verify(siamese_model, 0.5)
+    print(f"Results: {r}\nVerified: {v}\nDetection: {d}\nFilename: {f}")