How to Create Custom Model For Android Using TensorFlow?
Last Updated :
05 Oct, 2021
Tensorflow is an open-source library for machine learning. In android, we have limited computing power as well as resources. So we are using TensorFlow light which is specifically designed to operate on devices with limited power. In this post, we going to see a classification example called the iris dataset. The dataset contains 3 classes of 50 instances each, where each class refers to the type of iris plant.
Attribute information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
Based on the information given in the input, we will predict whether the plant is Iris Setosa, Iris Versicolour, or Iris Virginica. You can refer to this link for more information.
Step by step Implementation
Step 1:
Download the iris data set (file name: iris.data) from this (https://archive.ics.uci.edu/ml/machine-learning-databases/iris/) link.
Step 2:
Create a new python file with a name iris in the Jupyter notebook. Put the iris.data file in the same directory where iris.ipynb resides. Copy the following code in the Jupyter notebook file.
iris.ipynb
Python
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical
# reading the csb into data frame
df = pd.read_csv('iris.data')
# specifying the columns values into x and y variable
# iloc range based selecting 0 to 4 (4) values
X = df.iloc[:, :4].values
y = df.iloc[:, 4].values
# normalizing labels
le = LabelEncoder()
# performing fit and transform data on y
y = le.fit_transform(y)
y = to_categorical(y)
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
model = Sequential()
# input layer
# passing number neurons =64
# relu activation
# shape of neuron 4
model.add(Dense(64, activation='relu', input_shape=[4]))
# processing layer
# adding another denser layer of size 64
model.add(Dense(64))
# creating 3 output neuron
model.add(Dense(3, activation='softmax'))
# compiling model
model.compile(optimizer='sgd', loss='categorical_crossentropy',
metrics=['acc'])
# training the model for fixed number of iterations (epoches)
model.fit(X, y, epochs=200)
from tensorflow import lite
converter = lite.TFLiteConverter.from_keras_model(model)
tfmodel = converter.convert()
open('iris.tflite', 'wb').write(tfmodel)
Step 3:
After executing the line open('iris.tflite','wb').write(tfmodel) a new file named iris.tflite will get created in the same directory where iris.data resides.Â
A) Open Android Studio. Create a new kotlin-android project. (You can refer here for creating a project).Â
B) Right-click on app > New > Other >TensorFlow Lite ModelÂ
C) Click on the folder icon.Â
D) Navigate to iris.tflite fileÂ
E) Click on OK
F) Your model will look like this after clicking on the finish. (It may take some time to load).Â
Copy the code and paste it in the click listener of a button in MainActivity.kt.(It is shown below).
Step 5: Create XML layout for prediction
Navigate to the app > res > layout > activity_main.xml and add the below code to that file. Below is the code for the activity_main.xml file. Â
XML
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ScrollView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginBottom="50dp">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<!-- creating edittexts for input-->
<EditText
android:id="@+id/tf1"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="70dp"
android:ems="10"
android:inputType="numberDecimal" />
<EditText
android:id="@+id/tf2"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="20dp"
android:ems="10"
android:inputType="numberDecimal" />
<EditText
android:id="@+id/tf3"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="20dp"
android:ems="10"
android:inputType="numberDecimal" />
<EditText
android:id="@+id/tf4"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="20dp"
android:ems="10"
android:inputType="numberDecimal" />
<!-- creating Button for input-->
<!-- after clicking on button we will see prediction-->
<Button
android:id="@+id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="100dp"
android:text="Button"
app:layout_constraintBottom_toTopOf="@+id/textView"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.0"
app:layout_constraintStart_toStartOf="parent" />
<!-- creating textview on which we will see prediction-->
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="50dp"
android:text="TextView"
android:textSize="20dp"
app:layout_constraintEnd_toEndOf="parent" />
</LinearLayout>
</ScrollView>
</androidx.constraintlayout.widget.ConstraintLayout>
Â
Step 6: Working with the MainActivity.kt file
Go to the MainActivity.kt file and refer to the following code. Below is the code for the MainActivity.kt file. Comments are added inside the code to understand the code in more detail.Â
Kotlin
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.view.View
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import com.example.gfgtfdemo.ml.Iris
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.ByteBuffer
class MainActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
// getting the object edit texts
var ed1: EditText = findViewById(R.id.tf1);
var ed2: EditText = findViewById(R.id.tf2);
var ed3: EditText = findViewById(R.id.tf3);
var ed4: EditText = findViewById(R.id.tf4);
// getting the object of result textview
var txtView: TextView = findViewById(R.id.textView);
var b: Button = findViewById<Button>(R.id.button);
// registering listener
b.setOnClickListener(View.OnClickListener {
val model = Iris.newInstance(this)
// getting values from edit text and converting to float
var v1: Float = ed1.text.toString().toFloat();
var v2: Float = ed2.text.toString().toFloat();
var v3: Float = ed3.text.toString().toFloat();
var v4: Float = ed4.text.toString().toFloat();
/*************************ML MODEL CODE STARTS HERE******************/
// creating byte buffer which will act as input for model
var byte_buffer: ByteBuffer = ByteBuffer.allocateDirect(4 * 4)
byte_buffer.putFloat(v1)
byte_buffer.putFloat(v2)
byte_buffer.putFloat(v3)
byte_buffer.putFloat(v4)
// Creates inputs for reference.
val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 4), DataType.FLOAT32)
inputFeature0.loadBuffer(byte_buffer)
// Runs model inference and gets result.
val outputs = model.process(inputFeature0)
val outputFeature0 = outputs.outputFeature0AsTensorBuffer.floatArray
// setting the result to the output textview
txtView.setText(
"Iris-setosa : =" + outputFeature0[0].toString() + "\n" +
"Iris-versicolor : =" + outputFeature0[1].toString() + "\n" +
"Iris-virginica: =" + outputFeature0[2].toString()
)
// Releases model resources if no longer used.
model.close()
})
}
}
Â
Output:Â
You can download this project from here.
Â
Similar Reads
Non-linear Components In electrical circuits, Non-linear Components are electronic devices that need an external power source to operate actively. Non-Linear Components are those that are changed with respect to the voltage and current. Elements that do not follow ohm's law are called Non-linear Components. Non-linear Co
11 min read
Spring Boot Tutorial Spring Boot is a Java framework that makes it easier to create and run Java applications. It simplifies the configuration and setup process, allowing developers to focus more on writing code for their applications. This Spring Boot Tutorial is a comprehensive guide that covers both basic and advance
10 min read
Class Diagram | Unified Modeling Language (UML) A UML class diagram is a visual tool that represents the structure of a system by showing its classes, attributes, methods, and the relationships between them. It helps everyone involved in a projectâlike developers and designersâunderstand how the system is organized and how its components interact
12 min read
Backpropagation in Neural Network Back Propagation is also known as "Backward Propagation of Errors" is a method used to train neural network . Its goal is to reduce the difference between the modelâs predicted output and the actual output by adjusting the weights and biases in the network.It works iteratively to adjust weights and
9 min read
3-Phase Inverter An inverter is a fundamental electrical device designed primarily for the conversion of direct current into alternating current . This versatile device , also known as a variable frequency drive , plays a vital role in a wide range of applications , including variable frequency drives and high power
13 min read
Polymorphism in Java Polymorphism in Java is one of the core concepts in object-oriented programming (OOP) that allows objects to behave differently based on their specific class type. The word polymorphism means having many forms, and it comes from the Greek words poly (many) and morph (forms), this means one entity ca
7 min read
CTE in SQL In SQL, a Common Table Expression (CTE) is an essential tool for simplifying complex queries and making them more readable. By defining temporary result sets that can be referenced multiple times, a CTE in SQL allows developers to break down complicated logic into manageable parts. CTEs help with hi
6 min read
What is Vacuum Circuit Breaker? A vacuum circuit breaker is a type of breaker that utilizes a vacuum as the medium to extinguish electrical arcs. Within this circuit breaker, there is a vacuum interrupter that houses the stationary and mobile contacts in a permanently sealed enclosure. When the contacts are separated in a high vac
13 min read
Python Variables In Python, variables are used to store data that can be referenced and manipulated during program execution. A variable is essentially a name that is assigned to a value. Unlike many other programming languages, Python variables do not require explicit declaration of type. The type of the variable i
6 min read
Spring Boot Interview Questions and Answers Spring Boot is a Java-based framework used to develop stand-alone, production-ready applications with minimal configuration. Introduced by Pivotal in 2014, it simplifies the development of Spring applications by offering embedded servers, auto-configuration, and fast startup. Many top companies, inc
15+ min read