How to Create Custom Model For Android Using TensorFlow?
Last Updated :
05 Oct, 2021
Tensorflow is an open-source library for machine learning. In android, we have limited computing power as well as resources. So we are using TensorFlow light which is specifically designed to operate on devices with limited power. In this post, we going to see a classification example called the iris dataset. The dataset contains 3 classes of 50 instances each, where each class refers to the type of iris plant.
Attribute information:
- sepal length in cm
- sepal width in cm
- petal length in cm
- petal width in cm
Based on the information given in the input, we will predict whether the plant is Iris Setosa, Iris Versicolour, or Iris Virginica. You can refer to this link for more information.
Step by step Implementation
Step 1:
Download the iris data set (file name: iris.data) from this (https://archive.ics.uci.edu/ml/machine-learning-databases/iris/) link.
Step 2:
Create a new python file with a name iris in the Jupyter notebook. Put the iris.data file in the same directory where iris.ipynb resides. Copy the following code in the Jupyter notebook file.
iris.ipynb
Python
import pandas as pd
import numpy as np
import tensorflow as tf
from sklearn.preprocessing import LabelEncoder
from keras.utils import to_categorical
# reading the csb into data frame
df = pd.read_csv('iris.data')
# specifying the columns values into x and y variable
# iloc range based selecting 0 to 4 (4) values
X = df.iloc[:, :4].values
y = df.iloc[:, 4].values
# normalizing labels
le = LabelEncoder()
# performing fit and transform data on y
y = le.fit_transform(y)
y = to_categorical(y)
from tensorflow.keras.layers import Dense
from tensorflow.keras.models import Sequential
model = Sequential()
# input layer
# passing number neurons =64
# relu activation
# shape of neuron 4
model.add(Dense(64, activation='relu', input_shape=[4]))
# processing layer
# adding another denser layer of size 64
model.add(Dense(64))
# creating 3 output neuron
model.add(Dense(3, activation='softmax'))
# compiling model
model.compile(optimizer='sgd', loss='categorical_crossentropy',
metrics=['acc'])
# training the model for fixed number of iterations (epoches)
model.fit(X, y, epochs=200)
from tensorflow import lite
converter = lite.TFLiteConverter.from_keras_model(model)
tfmodel = converter.convert()
open('iris.tflite', 'wb').write(tfmodel)
Step 3:
After executing the line open('iris.tflite','wb').write(tfmodel) a new file named iris.tflite will get created in the same directory where iris.data resides.Â
A) Open Android Studio. Create a new kotlin-android project. (You can refer here for creating a project).Â
B) Right-click on app > New > Other >TensorFlow Lite ModelÂ
C) Click on the folder icon.Â
D) Navigate to iris.tflite fileÂ
E) Click on OK
F) Your model will look like this after clicking on the finish. (It may take some time to load).Â
Copy the code and paste it in the click listener of a button in MainActivity.kt.(It is shown below).
Step 5: Create XML layout for prediction
Navigate to the app > res > layout > activity_main.xml and add the below code to that file. Below is the code for the activity_main.xml file. Â
XML
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout
xmlns:android="http://schemas.android.com/apk/res/android"
xmlns:app="http://schemas.android.com/apk/res-auto"
xmlns:tools="http://schemas.android.com/tools"
android:layout_width="match_parent"
android:layout_height="match_parent"
tools:context=".MainActivity">
<ScrollView
android:layout_width="match_parent"
android:layout_height="match_parent"
android:layout_marginBottom="50dp">
<LinearLayout
android:layout_width="match_parent"
android:layout_height="match_parent"
android:orientation="vertical">
<!-- creating edittexts for input-->
<EditText
android:id="@+id/tf1"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="70dp"
android:ems="10"
android:inputType="numberDecimal" />
<EditText
android:id="@+id/tf2"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="20dp"
android:ems="10"
android:inputType="numberDecimal" />
<EditText
android:id="@+id/tf3"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="20dp"
android:ems="10"
android:inputType="numberDecimal" />
<EditText
android:id="@+id/tf4"
android:layout_width="175dp"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="20dp"
android:ems="10"
android:inputType="numberDecimal" />
<!-- creating Button for input-->
<!-- after clicking on button we will see prediction-->
<Button
android:id="@+id/button"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="100dp"
android:text="Button"
app:layout_constraintBottom_toTopOf="@+id/textView"
app:layout_constraintEnd_toEndOf="parent"
app:layout_constraintHorizontal_bias="0.0"
app:layout_constraintStart_toStartOf="parent" />
<!-- creating textview on which we will see prediction-->
<TextView
android:id="@+id/textView"
android:layout_width="wrap_content"
android:layout_height="wrap_content"
android:layout_gravity="center"
android:layout_marginTop="50dp"
android:text="TextView"
android:textSize="20dp"
app:layout_constraintEnd_toEndOf="parent" />
</LinearLayout>
</ScrollView>
</androidx.constraintlayout.widget.ConstraintLayout>
Â
Step 6: Working with the MainActivity.kt file
Go to the MainActivity.kt file and refer to the following code. Below is the code for the MainActivity.kt file. Comments are added inside the code to understand the code in more detail.Â
Kotlin
import androidx.appcompat.app.AppCompatActivity
import android.os.Bundle
import android.view.View
import android.widget.Button
import android.widget.EditText
import android.widget.TextView
import com.example.gfgtfdemo.ml.Iris
import org.tensorflow.lite.DataType
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.nio.ByteBuffer
class MainActivity : AppCompatActivity() {
override fun onCreate(savedInstanceState: Bundle?) {
super.onCreate(savedInstanceState)
setContentView(R.layout.activity_main)
// getting the object edit texts
var ed1: EditText = findViewById(R.id.tf1);
var ed2: EditText = findViewById(R.id.tf2);
var ed3: EditText = findViewById(R.id.tf3);
var ed4: EditText = findViewById(R.id.tf4);
// getting the object of result textview
var txtView: TextView = findViewById(R.id.textView);
var b: Button = findViewById<Button>(R.id.button);
// registering listener
b.setOnClickListener(View.OnClickListener {
val model = Iris.newInstance(this)
// getting values from edit text and converting to float
var v1: Float = ed1.text.toString().toFloat();
var v2: Float = ed2.text.toString().toFloat();
var v3: Float = ed3.text.toString().toFloat();
var v4: Float = ed4.text.toString().toFloat();
/*************************ML MODEL CODE STARTS HERE******************/
// creating byte buffer which will act as input for model
var byte_buffer: ByteBuffer = ByteBuffer.allocateDirect(4 * 4)
byte_buffer.putFloat(v1)
byte_buffer.putFloat(v2)
byte_buffer.putFloat(v3)
byte_buffer.putFloat(v4)
// Creates inputs for reference.
val inputFeature0 = TensorBuffer.createFixedSize(intArrayOf(1, 4), DataType.FLOAT32)
inputFeature0.loadBuffer(byte_buffer)
// Runs model inference and gets result.
val outputs = model.process(inputFeature0)
val outputFeature0 = outputs.outputFeature0AsTensorBuffer.floatArray
// setting the result to the output textview
txtView.setText(
"Iris-setosa : =" + outputFeature0[0].toString() + "\n" +
"Iris-versicolor : =" + outputFeature0[1].toString() + "\n" +
"Iris-virginica: =" + outputFeature0[2].toString()
)
// Releases model resources if no longer used.
model.close()
})
}
}
Â
Output:Â
You can download this project from here.
Â
Similar Reads
Create Model using Custom Module in Pytorch
Custom module in Pytorch A custom module in PyTorch is a user-defined module that is built using the PyTorch library's built-in neural network module, torch.nn.Module. It's a way of creating new modules by combining and extending the functionality provided by existing PyTorch modules. The torch.nn.M
8 min read
Save and Load Models using TensorFlow in Json?
If you are looking to explore Machine Learning with TensorFlow, you are at the right place. This comprehensive article explains how to save and load the models in TensorFlow along with its brief overview. If you read this article till the end, you will not need to look for further guides on how to s
6 min read
How can Tensorflow be used with the flower dataset to compile and fit the model?
In this article, we will learn how can we compile a model and fit the flower dataset to it. TO fit a dataset on a model we need to first create a data pipeline, create the model's architecture using TensorFlow high-level API, and then before fitting the model on the data using data pipelines we need
6 min read
Real-Time Object Detection Using TensorFlow
In November 2015, Google's deep artificial intelligence research division introduced TensorFlow, a cutting-edge machine learning library initially designed for internal purposes. This open-source library revolutionized the field, which helped researchers and developers in building, training, and dep
11 min read
Using the SavedModel format in Tensorflow
TensorFlow is a popular deep-learning framework that provides a variety of tools to help users build, train, and deploy machine-learning models. One of the most important aspects of deploying a machine learning model is saving and exporting it to a format that can be easily used by other programs an
4 min read
How can Tensorflow be used to configure the dataset for performance?
Tensorflow is a popular open-source platform for building and training machine learning models. It provides several techniques for loading and preparing the dataset to get the best performance out of the model. The correct configuration of the dataset is crucial for the overall performance of the mo
8 min read
How to Create a Custom Loss Function in Keras
Creating a custom loss function in Keras is crucial for optimizing deep learning models. The article aims to learn how to create a custom loss function. Need to create Custom Loss Functions Loss function is considered as a fundamental component of deep learning as it is helpful in error minimization
3 min read
TensorFlow - How to create one hot tensor
TensorFlow is open-source Python library designed by Google to develop Machine Learning models and deep learning  neural networks. One hot tensor is a Tensor in which all the values at indices where i =j and i!=j is same. Method Used: one_hot: This method accepts a Tensor of indices, a scalar defin
2 min read
How to Check if Tensorflow is Using GPU
In this article, we are going to see how to check whether TensorFlow is using GPU or not. GPUs are the new norm for deep learning. GPUs have a higher number of logical cores through which they can attain a higher level of parallelization and can provide better and fast results to computation as comp
3 min read
How to Convert a TensorFlow Model to PyTorch?
The landscape of deep learning is rapidly evolving. While TensorFlow and PyTorch stand as two of the most prominent frameworks, each boasts its unique advantages and ecosystems. However, transitioning between these frameworks can be daunting, often requiring tedious reimplementation and adaptation o
6 min read