圆月山庄资源网 Design By www.vgjia.com
目标是想把在服务器上用pytorch训练好的模型转换为可以在移动端运行的tflite模型。
最直接的思路是想把pytorch模型转换为tensorflow的模型,然后转换为tflite。但是这个转换目前没有发现比较靠谱的方法。
经过调研发现最新的tflite已经支持直接从keras模型的转换,所以可以采用keras作为中间转换的桥梁,这样就能充分利用keras高层API的便利性。
转换的基本思想就是用pytorch中的各层网络的权重取出来后直接赋值给keras网络中的对应layer层的权重。
转换为Keras模型后,再通过tf.contrib.lite.TocoConverter把模型直接转为tflite.
下面是一个例子,假设转换的是一个两层的CNN网络。
import tensorflow as tf from tensorflow import keras import numpy as np import torch from torchvision import models import torch.nn as nn # import torch.nn.functional as F from torch.autograd import Variable class PytorchNet(nn.Module): def __init__(self): super(PytorchNet, self).__init__() conv1 = nn.Sequential( nn.Conv2d(3, 32, 3, 2), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)) conv2 = nn.Sequential( nn.Conv2d(32, 64, 3, 1, groups=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2)) self.feature = nn.Sequential(conv1, conv2) self.init_weights() def forward(self, x): return self.feature(x) def init_weights(self): for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_( m.weight.data, mode='fan_out', nonlinearity='relu') if m.bias is not None: m.bias.data.zero_() if isinstance(m, nn.BatchNorm2d): m.weight.data.fill_(1) m.bias.data.zero_() def KerasNet(input_shape=(224, 224, 3)): image_input = keras.layers.Input(shape=input_shape) # conv1 network = keras.layers.Conv2D( 32, (3, 3), strides=(2, 2), padding="valid")(image_input) network = keras.layers.BatchNormalization( trainable=False, fused=False)(network) network = keras.layers.Activation("relu")(network) network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network) # conv2 network = keras.layers.Conv2D( 64, (3, 3), strides=(1, 1), padding="valid")(network) network = keras.layers.BatchNormalization( trainable=False, fused=True)(network) network = keras.layers.Activation("relu")(network) network = keras.layers.MaxPool2D(pool_size=(2, 2), strides=(2, 2))(network) model = keras.Model(inputs=image_input, outputs=network) return model class PytorchToKeras(object): def __init__(self, pModel, kModel): super(PytorchToKeras, self) self.__source_layers = [] self.__target_layers = [] self.pModel = pModel self.kModel = kModel tf.keras.backend.set_learning_phase(0) def __retrieve_k_layers(self): for i, layer in enumerate(self.kModel.layers): if len(layer.weights) > 0: self.__target_layers.append(i) def __retrieve_p_layers(self, input_size): input = torch.randn(input_size) input = Variable(input.unsqueeze(0)) hooks = [] def add_hooks(module): def hook(module, input, output): if hasattr(module, "weight"): # print(module) self.__source_layers.append(module) if not isinstance(module, nn.ModuleList) and not isinstance(module, nn.Sequential) and module != self.pModel: hooks.append(module.register_forward_hook(hook)) self.pModel.apply(add_hooks) self.pModel(input) for hook in hooks: hook.remove() def convert(self, input_size): self.__retrieve_k_layers() self.__retrieve_p_layers(input_size) for i, (source_layer, target_layer) in enumerate(zip(self.__source_layers, self.__target_layers)): print(source_layer) weight_size = len(source_layer.weight.data.size()) transpose_dims = [] for i in range(weight_size): transpose_dims.append(weight_size - i - 1) if isinstance(source_layer, nn.Conv2d): transpose_dims = [2,3,1,0] self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy( ).transpose(transpose_dims), source_layer.bias.data.numpy()]) elif isinstance(source_layer, nn.BatchNorm2d): self.kModel.layers[target_layer].set_weights([source_layer.weight.data.numpy(), source_layer.bias.data.numpy(), source_layer.running_mean.data.numpy(), source_layer.running_var.data.numpy()]) def save_model(self, output_file): self.kModel.save(output_file) def save_weights(self, output_file): self.kModel.save_weights(output_file, save_format='h5') pytorch_model = PytorchNet() keras_model = KerasNet(input_shape=(224, 224, 3)) torch.save(pytorch_model, 'test.pth') #Load the pretrained model pytorch_model = torch.load('test.pth') # #Time to transfer weights converter = PytorchToKeras(pytorch_model, keras_model) converter.convert((3, 224, 224)) # #Save the converted keras model for later use # converter.save_weights("keras.h5") converter.save_model("keras_model.h5") # convert keras model to tflite model converter = tf.contrib.lite.TocoConverter.from_keras_model_file( "keras_model.h5") tflite_model = converter.convert() open("convert_model.tflite", "wb").write(tflite_model)
补充知识:tensorflow模型转换成tensorflow lite模型
1.把graph和网络模型打包在一个文件中
bazel build tensorflow/python/tools:freeze_graph && bazel-bin/tensorflow/python/tools/freeze_graph --input_graph=eval_graph_def.pb --input_checkpoint=checkpoint --output_graph=frozen_eval_graph.pb --output_node_names=outputs
For example:
bazel-bin/tensorflow/python/tools/freeze_graph \ --input_graph=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224_eval.pbtxt --input_checkpoint=./mobilenet_v1_1.0_224/mobilenet_v1_1.0_224.ckpt --output_graph=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb --output_node_names=MobilenetV1/Predictions/Reshape_1
2.把第一步中生成的tensorflow pb模型转换为tf lite模型
转换前需要先编译转换工具
bazel build tensorflow/contrib/lite/toco:toco
转换分两种,一种的转换为float的tf lite,另一种可以转换为对模型进行unit8的量化版本的模型。两种方式如下:
非量化的转换:
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco \ 官网给的这个路径不对 ./bazel-bin/tensorflow/contrib/lite/toco/toco \ —input_file=./mobilenet_v1_1.0_224/frozen_eval_graph_test.pb \ —output_file=./mobilenet_v1_1.0_224/tflite_model_test.tflite \ --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \ --inference_type=FLOAT \ --input_shape="1,224, 224,3" \ --input_array=input \ --output_array=MobilenetV1/Predictions/Reshape_1
量化方式的转换(注意,只有量化训练的模型才能进行量化的tf_lite转换):
./bazel-bin/third_party/tensorflow/contrib/lite/toco/toco ./bazel-bin/tensorflow/contrib/lite/toco/toco --input_file=frozen_eval_graph.pb --output_file=tflite_model.tflite --input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE --inference_type=QUANTIZED_UINT8 --input_shape="1,224, 224,3" --input_array=input --output_array=outputs --std_value=127.5 --mean_value=127.5
以上这篇Pytorch转tflite方式就是小编分享给大家的全部内容了,希望能给大家一个参考,也希望大家多多支持。
标签:
Pytorch,tflite
圆月山庄资源网 Design By www.vgjia.com
广告合作:本站广告合作请联系QQ:858582 申请时备注:广告合作(否则不回)
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
免责声明:本站文章均来自网站采集或用户投稿,网站不提供任何软件下载或自行开发的软件! 如有用户或公司发现本站内容信息存在侵权行为,请邮件告知! 858582#qq.com
圆月山庄资源网 Design By www.vgjia.com
暂无评论...
更新日志
2025年01月24日
2025年01月24日
- 小骆驼-《草原狼2(蓝光CD)》[原抓WAV+CUE]
- 群星《欢迎来到我身边 电影原声专辑》[320K/MP3][105.02MB]
- 群星《欢迎来到我身边 电影原声专辑》[FLAC/分轨][480.9MB]
- 雷婷《梦里蓝天HQⅡ》 2023头版限量编号低速原抓[WAV+CUE][463M]
- 群星《2024好听新歌42》AI调整音效【WAV分轨】
- 王思雨-《思念陪着鸿雁飞》WAV
- 王思雨《喜马拉雅HQ》头版限量编号[WAV+CUE]
- 李健《无时无刻》[WAV+CUE][590M]
- 陈奕迅《酝酿》[WAV分轨][502M]
- 卓依婷《化蝶》2CD[WAV+CUE][1.1G]
- 群星《吉他王(黑胶CD)》[WAV+CUE]
- 齐秦《穿乐(穿越)》[WAV+CUE]
- 发烧珍品《数位CD音响测试-动向效果(九)》【WAV+CUE】
- 邝美云《邝美云精装歌集》[DSF][1.6G]
- 吕方《爱一回伤一回》[WAV+CUE][454M]