티스토리 뷰
홈페이지에는 Tensorflow Lite로 구현할 수 있는 Image Classification이나 Object detection, Question Answering 같은 예제들이 소개되어 있다. 참고로 Raspberry PI같이 Embedded 환경에서 Test해볼 수 있는 Image Classification이랑 Object Detection 같은 것만 예제로 제공되고 있다. 아무튼 전반적인 Tensorflow Lite가 어떻게 돌아가는지를 확인해볼 수 있는 예제를 살펴보고자 한다.
우선 Keras로 간단한 Linear regression을 위한 model을 한번 만들어본다.
import tensorflow as tf
x = [-1, 0, 1, 2, 3, 4]
y = [-3, -1, 1, 3, 5, 7]
# Create a simple Keras Model
model = tf.keras.models.Sequential([
tf.keras.layers.Dense(units=1, input_shape=[1])
])
model.compile(optimizer='sgd', loss='mean_squared_error')
model.fit(x, y, epochs=500)
현재 만든 것은 node 1개짜리 1 dense layer로 구성된 neural network이고, 이때 optimizer는 sgd(Stochastic Gradient Descent), loss function은 mean_squared_error로 설정하고 500회동안 학습시키도록 했다. 그래서 x에 대해서 y를 맞출 수 있게끔 model을 만드는 것이 목적이다. 직관적으로 봤을때 x와 y간의 관계는 다음과 같은 관계식으로 나타낼 수 있다.
$$ y = 2x - 1 $$
그래서 이전 포스트에서 언급한 바와 같이 model을 tflite용 model로 바꿔주는 converting 과정이 필요하고 tflite에서 converter라는 class를 제공한다. 그래서 해당 model을 저장하는 다음의 코드가 수행되면 생성된다.
import pathlib
# Save model to export_dir
export_dir = '.'
tf.saved_model.save(model, export_dir)
# Convert the model to tflite model format
converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
tflite_model = converter.convert()
# Save the model
tflite_model_file = pathlib.Path('./a.tflite')
tflite_model_file.write_bytes(tflite_model)
(참고로 경우에 따라서는 saved_model에 save()함수가 없을 수 있다. 이건 tensorflow 1.14 이전 버전을 사용할 경우 save() 함수가 정의되어 있지 않아 발생하는 문제이므로 tensorflow version을 update해준다.)
(추가로 경우에 따라서는 "Converting unsupported operation: IdentityN" 이라는 오류가 발생하기도 하는데, 이건 tflite converter에서 해당 operation을 지원하지 않아서이다. 이전 포스트에서도 말한것처럼 몇몇 operation은 지원하지 않기 때문에 해당 operation을 다른 걸로 변경해줘야 한다. 링크를 참조해서 다음과 수정하면 정상적으로 수행된다.)
# Save model to export_dir
export_dir = '.'
tf.saved_model.save(model, export_dir)
# Convert the model to tflite model format
converter = tf.lite.TFLiteConverter.from_saved_model(export_dir)
converter.target_ops = [tf.lite.OpsSet.TFLITE_BUILTINS, tf.lite.OpsSet.SELECT_TF_OPS]
tflite_model = converter.convert()
# Save the model
tflite_model_file = pathlib.Path('./a.tflite')
tflite_model_file.write_bytes(tflite_model)
(Tensorflow 2.x 한정)다음으로 다룰 예제는 위처럼 network을 직접 만드는 것이 아니라 concrete function을 tflite model로 만드는 것이다. 참고로 concrete function이란 tensorflow에서 제공하는 tf.function을 명시적으로 표현하는 것으로, 사이트상에 설명된 바로는 이를 통해 graph 방식의 모델 구현이 가능해 성능을 높일 수 있다고 한다. 아무튼 예제는 다음과 같다.
import tensorflow as tf
# Load the MobileNet tf.keras model
model = tf.keras.applications.MobileNetV2(weights='imagenet', input_shape=(224, 224, 3))
# Get the concrete function from the Keras model
run_model = tf.function(lambda x: model(x))
# Save the concrete function
concrete_func = run_model.get_concrete_function(tf.TensorSpec(model.inputs[0].shape,
model.inputs[0].dtype))
# Save the model
converter = tf.lite.TFLiteConverter.from_concrete_functions([concrete_func])
tflite_model = converter.convert()
위의 예제에서는 pre-train된 model 중 MobileNetV2를 가져와서 concrete function으로 만들어 lite용 모델로 변환하는 과정을 거친것이다.
아니면 command line상에서도 기존에 저장해둔 tensorflow용 model을 tflite model로 변환할 수 있다.
#!/bin/bash
# Saving with the command-line from a SavedModel
tflite_converter --output_file=model.tflite --saved_model_dir=./saved_model
# Saving with the command-line from a Keras model
tflite_converter --output_file=model.tflite --keras_model_file=model.h5
이렇게 코드 상이나 command line으로도 기존에 저장해둔 model을 tflite용 model로 바꾸는 과정에 대한 예제를 다뤄보았다.
'Study > EmbeddedSystem' 카테고리의 다른 글
[Embedded][DL] TF Microcontroller Challenge (0) | 2021.06.19 |
---|---|
[Embedded][DL] Tensorflow Lite - Quantization (3) | 2019.11.26 |
[Embedded][DL] Tensorflow Lite - Introduction (0) | 2019.11.15 |
[Embedded] Performance Measures (1) (0) | 2017.07.10 |
[Embedded] Serial Communication (0) | 2017.07.10 |
[GPIO] GPIO Port Mode(Direction) Register (0) | 2017.06.20 |
[MOOC] Obstacle Avoidance Robot in CyberSim (0) | 2014.05.25 |
- Total
- Today
- Yesterday
- RL
- processing
- End-To-End
- 파이썬
- dynamic programming
- arduino
- Kinect SDK
- Windows Phone 7
- bias
- ColorStream
- SketchFlow
- Kinect
- Pipeline
- reward
- Expression Blend 4
- Off-policy
- Offline RL
- ai
- TensorFlow Lite
- Policy Gradient
- Gan
- 딥러닝
- DepthStream
- Kinect for windows
- Variance
- Distribution
- PowerPoint
- 한빛미디어
- 강화학습
- windows 8
일 | 월 | 화 | 수 | 목 | 금 | 토 |
---|---|---|---|---|---|---|
1 | 2 | |||||
3 | 4 | 5 | 6 | 7 | 8 | 9 |
10 | 11 | 12 | 13 | 14 | 15 | 16 |
17 | 18 | 19 | 20 | 21 | 22 | 23 |
24 | 25 | 26 | 27 | 28 | 29 | 30 |