微信小程序|调用tensorflow自定义模型

2020-06-17 15:36:38 浏览数 (1)

问题描述

在成功调用官网打包好的tensorflowjs模型后,怎么调用自己的模型呢?又需要做哪些处理呢?

解决方案

1)安装好python和tensorflow

2)安装tensorflowjs : pip install tensorflowjs

注:如果你的tensorflow版本是2.0的,在下载tfjs时可能会被更新为1.15版本的。可以考虑新建个python环境。

3)准备已经训练好的模型,并通过 model.save(“模型命名.h5”) 代码将模型保存为h5格式的文件。

下面是本文使用的mnist手写数字集的模型代码案例:

代码语言:javascript复制

import tensorflow as tf

mnist = tf.keras.datasets.mnist

(x_train, y_train),(x_test, y_test) = mnist.load_data()

x_train, x_test = x_train / 255.0, x_test / 255.0

model = tf.keras.models.Sequential([

  tf.keras.layers.Flatten(input_shape=(28, 28)),  

  tf.keras.layers.Dense(128, activation='relu'), 

  tf.keras.layers.Dropout(0.2),  tf.keras.layers.Dense(10, activation='softmax')

])

model.compile(optimizer='adam',

                              loss='sparse_categorical_crossentropy',  

                              metrics=['accuracy'])



model.fit(x_train, y_train, epochs=5)

model.evaluate(x_test, y_test)

model.save('D:\test/mnist.h5')

4)通过tensorflowjs_converter命令将h5格式的模型文件转换为json格式的文件。

1.打开pycharm的Terminal指令框

2. 输入转换指令:

tensorflowjs_converter--input_format=keras D:\test/mnist.h5 D:\test

注释:tensorflowjs_converter –模型格式 模型地址 保存地址

3.查看model.json是否生成

5)将模型放在服务器上,如果没有可以在本地创建,步骤如下 :

1.打开pycharm的Terminal的指令框

2.输入python3 -m http.server 8000

3.打开浏览器输入 localhost:8000 输出如下界面

如果出现localhost拒绝访问,可能是你的系统没开启iis服务,只能手动开启了。

未开启的建议依次按以下步骤来:

1 .百度:如何安装iss服务

2 .打开管理工具

3.进入管理工具界面,单击“Internet Information Services (IIS)管理器”。

4.右键单击“网站”,选择“添加网站”。

5.在弹出的界面中输入网站名称、选择物理路径(model.json所在的文件地址)、IP地址输入为127.0.0.1、端口为8000,然后点击确定。

6.打开目录展示功能:目录浏览—打开功能—启用

6) 在项目中安装相应的库

详细过程请参考之前发布的博客《微信小程序与tensorflow.js准备工作》在项目目录下使用npm安装对应包,安装代码如下:

代码语言:javascript复制

npm install fetch-wechat

npm install @tensorflow/tfjs-converter

npm install @tensorflow/tfjs-core

npm install @tensorflow/tfjs-layers

npm install regenerator-runtime

7) 效果

8) 代码部分

代码较为简单,说明以注释方式放在代码旁边(只展示主体代码部分,完成项目代码下载链接:

https://pan.baidu.com/s/18VcMiNaiEjC_Y_Yz1gJ_1g)

代码语言:javascript复制
const regeneratorRuntime = require('regenerator-runtime')

const tf = require('@tensorflow/tfjs-core')

const tfl = require('@tensorflow/tfjs-layers')

 

//index.js

Page({

  async onReady() {

    //加载相机

    const camera = wx.createCameraContext(this)

    // 加载模型

    const net = await this.loadModel()

    this.setData({result: 'Loading'})

    let count = 0

    //每隔10帧获取一张相机捕捉到的图片

    const listener = camera.onCameraFrame((frame) => {

      count  

      if (count === 10) {

        if (net) {

          //对图片内容进行预测

          this.predict(net, frame)

        }

        count = 0

      }

    })

    listener.start()

  },

  //加载模型

  async loadModel() {

    const net = await tfl.loadLayersModel('https://yuantao.store/model.json')

    net.summary()

    return net

  },

  async predict(net, frame){

    //图像预处理,API说明和用发可到tensorflow.google.cn查看

    const imgData = {data: new Uint8Array(frame.data), width: frame.width, height: frame.height}

    const x = tf.tidy(() => {

      const imgTensor = tf.browser.fromPixels(imgData, 4)

      //转换为Tensor,微信小程序相机获取的图片有4维

      const d = Math.floor((frame.height - frame.width) / 2)

      const imgSlice = imgTensor.slice([d, 0, 0], [frame.width, -1, 3])

      //截取正方形区域,并丢掉最后一个维度,只保留3个维度

      const imgResize = tf.image.resizeBilinear(imgSlice, [28, 28])

      return imgResize.mean(2)//对最后一个维度去均值,将三通道转换为单通道

    })

    // console.log(x)

    const y = await net.predict(x.expandDims(0)).argMax(1)

    //预测,并获取预测值最大的下标,及预测结果

    const res = y.dataSync()[0]//预测结果为一个对象,我们只需要值部分

    this.setData({result: res})

  }

  

})

END

主 编 | 王文星

责 编 | 马原涛

where2go 团队

0 人点赞