# Train each model architecture # This cell can be left without running if the checkpoint files are available for model in arch_models: print("\nTraining %s..." % model["name"]) model_instance = create_cnn_model(**model["params"]) # Callbacks early_stopping = EarlyStopping(monitor='val_auc', patience=5, restore_best_weights=True, mode='max') reduce_lr = ReduceLROnPlateau(monitor='val_auc', factor=.5, patience=3, min_lr=1e-7, mode='max') model_checkpoint = ModelCheckpoint(model["checkpoint"], save_best_only=True, monitor='val_auc', mode='max') hist_logger = CSVLogger(model["history_file"]) # Reset generators to ensure consistent training across models train_generator.reset() val_generator.reset() model_instance.fit( train_generator, epochs=15, # Fewer epochs for tuning validation_data=val_generator, callbacks=[early_stopping, reduce_lr, model_checkpoint, hist_logger], verbose=1 )
# Train each regularization model # This cell can be left without running if the checkpoint files are available for model in reg_models: print("\nTraining %s..." % model["name"]) model_instance = create_cnn_model(**model["params"]) # Callbacks early_stopping = EarlyStopping(monitor='val_auc', patience=5, restore_best_weights=True, mode='max') reduce_lr = ReduceLROnPlateau(monitor='val_auc', factor=.5, patience=3, min_lr=1e-7, mode='max') model_checkpoint = ModelCheckpoint(model["checkpoint"], save_best_only=True, monitor='val_auc', mode='max') hist_logger = CSVLogger(model["history_file"]) # Reset generators to ensure consistent training train_generator.reset() val_generator.reset() model_instance.fit( train_generator, epochs=15, validation_data=val_generator, callbacks=[early_stopping, reduce_lr, model_checkpoint, hist_logger], verbose=1 )
# This cell can be left without running if the checkpoint files are available # Create the best regularized model final_model = create_cnn_model(**best_reg_model["params"]) # Train the final model final_model.fit( full_train_generator, epochs=EPOCHS, validation_data=full_val_generator, callbacks=[early_stopping, reduce_lr, final_checkpoint, final_csv_logger], verbose=1 )