0% found this document useful (0 votes)
8 views2 pages

Training CodeTensorflowLite

Uploaded by

20701025
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
0% found this document useful (0 votes)
8 views2 pages

Training CodeTensorflowLite

Uploaded by

20701025
Copyright
© © All Rights Reserved
We take content rights seriously. If you suspect this is your content, claim it here.
Available Formats
Download as TXT, PDF, TXT or read online on Scribd
You are on page 1/ 2

import numpy as np

import tensorflow as tf
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

# Expanded dataset
data = np.array([
[30, 70], [25, 50], [20, 90], [35, 40], [28, 65],
[22, 75], [32, 60], [26, 80], [29, 55], [31, 85],
[24, 45], [27, 70], [33, 50], [23, 85], [34, 65]
])

labels = np.array([80, 50, 90, 20, 70, 75, 60, 85, 55, 95, 45, 70, 50, 90, 65]) /
100.0

# Build the ANN model


model = Sequential([
Dense(16, activation='relu', input_shape=(2,)), # Increased number of neurons
Dense(16, activation='relu'),
Dense(1, activation='sigmoid') # Output rain likelihood (0-1)
])

# Compile the model


model.compile(optimizer='adam', loss='mse', metrics=['mae'])

# Train the model


model.fit(data, labels, epochs=500, verbose=1)

# Save the model


model.save("rain_prediction_model.h5")

from tensorflow.keras.losses import MeanSquaredError

# Explicitly map 'mse' to the MeanSquaredError function


custom_objects = {'mse': MeanSquaredError()}

# Load the model with the custom_objects mapping


model = tf.keras.models.load_model("rain_prediction_model.h5",
custom_objects=custom_objects)

# Convert the model to TensorFlow Lite


converter = tf.lite.TFLiteConverter.from_keras_model(model)
tflite_model = converter.convert()

# Save the TensorFlow Lite model


with open("rain_prediction_model.tflite", "wb") as f:
f.write(tflite_model)

# Read the TensorFlow Lite model file


with open("rain_prediction_model.tflite", "rb") as f:
tflite_model = f.read()

# Write the model to a C header file


with open("model.h", "w") as f:
f.write("#ifndef MODEL_H_\n")
f.write("#define MODEL_H_\n\n")
f.write(f"const unsigned char model[] = {{\n")

# Convert the binary model data to hexadecimal format


for i, byte in enumerate(tflite_model):
if i % 12 == 0: # 12 bytes per line for readability
f.write("\n ")
f.write(f"0x{byte:02x}, ")
f.write("\n};\n\n")
f.write(f"const unsigned int model_len = {len(tflite_model)};\n")
f.write("\n#endif // MODEL_H_\n")

You might also like