# This cell can be left without running if the checkpoint files are available class F1ScoreCallback(tf.keras.callbacks.Callback): def __init__(self, validation_data, patience=5): super().__init__() self.X_val, self.y_val = validation_data self.patience = patience self.best_f1 = 0 self.best_weights = None self.wait = 0 def on_epoch_end(self, epoch, logs=None): y_pred = self.model.predict(self.X_val, verbose=0) y_pred_binary = np.round(y_pred).flatten() f1 = f1_score(self.y_val.flatten(), y_pred_binary) logs['val_f1'] = f1 # This adds val_f1 to the logs print(" - val_f1: %.4f" % f1, end='') # Early stopping logic based on F1 if f1 > self.best_f1: self.best_f1 = f1 self.best_weights = self.model.get_weights() self.wait = 0 else: self.wait += 1 if self.wait >= self.patience: self.model.stop_training = True self.model.set_weights(self.best_weights) for i, model in enumerate(rnn_models): f1_callback = F1ScoreCallback(validation_data=(X_cv_rnn, y_cv_rnn), patience=10) model_checkpoint = ModelCheckpoint(model["checkpoint"], save_best_only=True, monitor='val_f1', mode='max') hist_logger = CSVLogger(model["history_file"]) print("Training model number", i + 1, "of", len(rnn_models)) print(model["name"]) model["instance"].fit( X_train_rnn, y_train_rnn, batch_size=32, epochs=50, validation_data=(X_cv_rnn, y_cv_rnn), callbacks=[f1_callback, model_checkpoint, hist_logger], verbose=1 )