0% found this document useful (0 votes)
5 views2 pages

Stuff

Uploaded by

phongth779
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
5 views2 pages

Stuff

Uploaded by

phongth779
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 2

import tensorflow as tf

from tensorflow.keras.applications.efficientnet import EfficientNetB0,


preprocess_input
from tensorflow.keras.models import Model
from tensorflow.keras.layers import GlobalAveragePooling2D, Dense, Dropout
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import EarlyStopping, ModelCheckpoint
from tensorflow.keras.preprocessing.image import ImageDataGenerator
import matplotlib.pyplot as plt

# Set parameters
default_size = 224
batch_size = 32
base_path = '/content/drive/My Drive/lfw-deepfunneled'

# Data generators with on-the-fly data augmentation using this method allowed
#for more agumentation, this is used expecting the dataset of the source to be
#quite uniformed as each class is contain in a folder in the dataset which is
#like that for most of the online dataset
train_datagen = ImageDataGenerator(
preprocessing_function=preprocess_input,
rotation_range=20,
width_shift_range=0.2,
height_shift_range=0.2,
shear_range=0.2,
zoom_range=0.2,
horizontal_flip=True,
brightness_range=[0.8, 1.2],
channel_shift_range=0.2,
validation_split=0.2 # for validation split
)

# Train and validation generators


train_generator = train_datagen.flow_from_directory(
directory=base_path,
target_size=(default_size, default_size),
batch_size=batch_size,
class_mode='categorical', # use for classification, basically telling it that i
have a category of classes
subset='training' #as i assign this as a training dataset, it will take the 80%
of the dataset
)

validation_generator = train_datagen.flow_from_directory(
directory=base_path, #those r the parameters of the function
target_size=(default_size, default_size),
batch_size=batch_size,
class_mode='categorical',
subset='validation' #as i assign this as validation, it used the other 20% to
be used as validating dataset
)

# Model creation using EfficientNetB0, called transfer learning, for ImageNet


classifier, do not include top
base_model = EfficientNetB0(weights='imagenet', include_top=False,
input_shape=(default_size, default_size, 3))
x = GlobalAveragePooling2D()(base_model.output)
x = Dense(1024, activation='relu')(x)
x = Dropout(rate=0.2)(x) #help prevent overfitting
predictions = Dense(train_generator.num_classes, activation='softmax')(x)
model = Model(inputs=base_model.input, outputs=predictions)

# Freeze the base model layers


for layer in base_model.layers:
layer.trainable = False

# Compile the model


model.compile(optimizer=Adam(learning_rate=0.0001),
loss='categorical_crossentropy', metrics=['accuracy'])

# Early Stopping and Model Checkpointing


early_stopping = EarlyStopping(monitor='val_loss', patience=3)
model_checkpoint = ModelCheckpoint(
filepath='/content/drive/My Drive/saved_model/best_save.keras',
monitor='val_accuracy',
save_best_only=True
)

epochs = 10

# Train the model


history = model.fit(
train_generator,
epochs=epochs,
validation_data=validation_generator,
callbacks=[early_stopping, model_checkpoint]
)

# Save the entire model


model.save("/content/drive/My Drive/saved_model/my_model_011324.keras")

acc = history.history['accuracy']
val_acc = history.history['val_accuracy']

loss = history.history['loss']
val_loss = history.history['val_loss']

epochs_range = range(epochs)

plt.figure(figsize=(8, 8))
plt.subplot(1, 2, 1)
plt.plot(epochs_range, acc, label='Training Accuracy')
plt.plot(epochs_range, val_acc, label='Validation Accuracy')
plt.legend(loc='lower right')
plt.title('Training and Validation Accuracy')

plt.subplot(1, 2, 2)
plt.plot(epochs_range, loss, label='Training Loss')
plt.plot(epochs_range, val_loss, label='Validation Loss')
plt.legend(loc='upper right')
plt.title('Training and Validation Loss')
plt.show()

You might also like