From f8b1a98a8b5b9c1d2c16a9dc88916e85e316cf97 Mon Sep 17 00:00:00 2001 From: A Farzat Date: Mon, 3 Nov 2025 19:42:41 +0300 Subject: Add a blog post about the disaster tweets project --- content/blog/csca5642-w4/notebook.html | 9537 ++++++++++++++++++++++++++++++++ 1 file changed, 9537 insertions(+) create mode 100644 content/blog/csca5642-w4/notebook.html (limited to 'content/blog/csca5642-w4/notebook.html') diff --git a/content/blog/csca5642-w4/notebook.html b/content/blog/csca5642-w4/notebook.html new file mode 100644 index 0000000..9c6a455 --- /dev/null +++ b/content/blog/csca5642-w4/notebook.html @@ -0,0 +1,9537 @@ + + + + + +cours3w4submission + + + + + + + + + + + + +
+ + + + + + + + + + + + + + + +# This cell can be left without running if the checkpoint files are available +class F1ScoreCallback(tf.keras.callbacks.Callback): + def __init__(self, validation_data, patience=5): + super().__init__() + self.X_val, self.y_val = validation_data + self.patience = patience + self.best_f1 = 0 + self.best_weights = None + self.wait = 0 + + def on_epoch_end(self, epoch, logs=None): + y_pred = self.model.predict(self.X_val, verbose=0) + y_pred_binary = np.round(y_pred).flatten() + f1 = f1_score(self.y_val.flatten(), y_pred_binary) + + logs['val_f1'] = f1 # This adds val_f1 to the logs + + print(" - val_f1: %.4f" % f1, end='') + + # Early stopping logic based on F1 + if f1 > self.best_f1: + self.best_f1 = f1 + self.best_weights = self.model.get_weights() + self.wait = 0 + else: + self.wait += 1 + if self.wait >= self.patience: + self.model.stop_training = True + self.model.set_weights(self.best_weights) + +for i, model in enumerate(rnn_models): + f1_callback = F1ScoreCallback(validation_data=(X_cv_rnn, y_cv_rnn), patience=10) + model_checkpoint = ModelCheckpoint(model["checkpoint"], save_best_only=True, monitor='val_f1', mode='max') + hist_logger = CSVLogger(model["history_file"]) + print("Training model number", i + 1, "of", len(rnn_models)) + print(model["name"]) + model["instance"].fit( + X_train_rnn, y_train_rnn, + batch_size=32, epochs=50, + validation_data=(X_cv_rnn, y_cv_rnn), + callbacks=[f1_callback, model_checkpoint, hist_logger], + verbose=1 + ) + + + + + + +
+ + -- cgit v1.2.3-70-g09d2