%%time
local_zip = '/datasets/chestxray/TFRecords.zip'
zip_ref = zipfile.ZipFile(local_zip, 'r')
zip_ref.extractall('/datasets/chestxray/')
zip_ref.close()
CPU times: user 1.18 s, sys: 297 ms, total: 1.47 s
Wall time: 15 s
BATCH_SIZE = 64
IMAGE_SIZE = [512,512]
base_path = "/datasets/chestxray/TFRecords/"
TRAIN_FILENAMES = tf.io.gfile.glob(base_path + "/train/*.tfrecord")
VALID_FILENAMES = tf.io.gfile.glob(base_path + "/valid/*.tfrecord")
TEST_FILENAMES = tf.io.gfile.glob(base_path + "/test/*.tfrecord")
def get_model():
# Use EfficientNetB0 as the base
base_model = tf.keras.applications.EfficientNetB0(
input_shape=(*IMAGE_SIZE, 3), include_top=False, weights="imagenet"
)
# Just fine-tune the layers
base_model.trainable = False
# Add a couple of layers
x = base_model.output
x = tf.keras.layers.GlobalAveragePooling2D()(x)
x = tf.keras.layers.Dense(8, activation="relu")(x)
x = tf.keras.layers.Dropout(0.5)(x)
outputs = tf.keras.layers.Dense(1, activation="sigmoid")(x)
# Create a model with EfficientNetB0 as the base
model = tf.keras.Model(inputs=base_model.input, outputs=outputs)
return model
initial_learning_rate = 0.01
lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
initial_learning_rate, decay_steps=20, decay_rate=0.96, staircase=True
)
checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(
"pneumonia_model.h5", save_best_only=True
)
early_stopping_cb = tf.keras.callbacks.EarlyStopping(
patience=10, restore_best_weights=True
)
# All you need to run on GPU !!
strategy = tf.distribute.MirroredStrategy()
with strategy.scope():
model = get_model()
model.compile(
optimizer=tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9),
loss="binary_crossentropy",
metrics=tf.keras.metrics.AUC(name="auc"),
)
history = model.fit(
train_dataset,
epochs=5,
validation_data=valid_dataset,
callbacks=[checkpoint_cb, early_stopping_cb, WandbCallback()],
verbose = 1
)
# Save it as model artifact on W&B
model.save('model.h5')
artifact = wandb.Artifact(name="model", type="weights")
artifact.add_file("model.h5")
wandb.log_artifact(artifact)
Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
16711680/16705208 [==============================] - 0s 0us/step
Epoch 1/5
82/82 [==============================] - 121s 1s/step - loss: 0.5185 - auc: 0.8050 - val_loss: 1.2056 - val_auc: 0.7812
Epoch 2/5
82/82 [==============================] - 105s 1s/step - loss: 0.4009 - auc: 0.8965 - val_loss: 1.2381 - val_auc: 0.8125
Epoch 3/5
82/82 [==============================] - 117s 1s/step - loss: 0.3517 - auc: 0.9210 - val_loss: 1.1426 - val_auc: 0.8594
Epoch 4/5
82/82 [==============================] - 102s 1s/step - loss: 0.3278 - auc: 0.9304 - val_loss: 1.2295 - val_auc: 0.8750
Epoch 5/5
82/82 [==============================] - 109s 1s/step - loss: 0.3220 - auc: 0.9341 - val_loss: 1.2585 - val_auc: 0.8594
model.load_weights('pneumonia_model.h5')
converter = tf.lite.TFLiteConverter.from_keras_model(model)
converter.optimizations = [tf.lite.Optimize.DEFAULT]
quantized_model = converter.convert()
file = open('quantized_model.tflite', 'wb')
file.write(quantized_model)
# Save it as model artifact on W&B
quantized_artifact = wandb.Artifact(name="quantized_model", type="weights")
quantized_artifact.add_file("quantized_model.tflite")
wandb.log_artifact(quantized_artifact)