def unet_model(input_shape=(256, 256, 3), num_classes=1):
inputs = tf.keras.layers.Input(shape=input_shape)
# Contracting Path (Encoder)
s1 = encoder_block(inputs, 64)
s2 = encoder_block(s1, 128)
s3 = encoder_block(s2, 256)
s4 = encoder_block(s3, 512)
# Bottleneck
b1 = tf.keras.layers.Conv2D(1024, 3, padding='valid')(s4)
b1 = tf.keras.layers.Activation('relu')(b1)
b1 = tf.keras.layers.Conv2D(1024, 3, padding='valid')(b1)
b1 = tf.keras.layers.Activation('relu')(b1)
# Expansive Path (Decoder)
d1 = decoder_block(b1, s4, 512)
d2 = decoder_block(d1, s3, 256)
d3 = decoder_block(d2, s2, 128)
d4 = decoder_block(d3, s1, 64)
outputs = tf.keras.layers.Conv2D(num_classes, 1, padding='valid', activation='sigmoid')(d4)
model = tf.keras.models.Model(inputs=inputs, outputs=outputs, name='U-Net')
return model
if __name__ == '__main__':
model = unet_model(input_shape=(572, 572, 3), num_classes=2)
model.summary()