Tensorflow.js:我在浏览器中实现了迁移学习

2022-07-29 08:42:06 浏览数 (2)

⭐️ 本文首发自 前端修罗场,是一个由资深开发者独立运行的专业技术社区,我专注 Web 技术、答疑解惑、面试辅导以及职业发展。帮你评估知识点的掌握程度,获得更全面的学习指导意见,交个朋友,不走弯路,少吃亏!


迁移学习是将预训练模型与自定义训练数据相结合的能力。这意味着你可以利用模型的功能并添加自己的样本,而无需从头开始创建所有内容。

例如,一种算法已经用数千张图像进行了训练以创建图像分类模型,而不是创建自己的图像分类模型,迁移学习允许你将新的自定义图像样本与预先训练的模型相结合以创建新的图像分类器。这个特性使得拥有一个更加定制化的分类器变得非常快速和容易。

为了提供代码中的示例,让我们重新利用之前的示例并对其进行修改,以便我们可以对新图像进行分类。

以下是此设置最重要部分的一些代码示例,但如果你需要查看整个代码,可以在本文的最后找到它。

我们仍然需要从导入 Tensorflow.js 和 MobileNet 开始,但是这次我们还需要添加一个 KNN(k-nearest neighbor)分类器:

代码语言:javascript复制
<!-- 加载 TensorFlow.js -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs"></script>
<!-- 加载 MobileNet -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet"></script>
<!-- 加载 KNN 分类器 -->
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/knn-classifier"></script>

我们需要分类器的原因是(不仅仅是使用 MobileNet 模块)我们正在添加以前从未见过的自定义样本,因此 KNN 分类器将允许我们将所有内容组合在一起并对组合的数据进行预测

然后,我们可以用视频标签替换猫的图像,以使用来自摄像头的图像。

代码语言:javascript复制
<video autoplay id="webcam" width="227" height="227"></video>

最后,我们需要在页面上添加一些按钮,我们将用作标签来记录一些视频样本并开始预测。

代码语言:javascript复制
<section>
  <button class="button">Left</button>

  <button class="button">Right</button>

  <button class="test-predictions">Test</button>
</section>

现在,让我们转到 JavaScript 文件,我们将从设置几个重要变量开始:

代码语言:javascript复制
//要分类的数量
const NUM_CLASSES = ;
// 分类标签
const classes = ["Left", "Right"];
// 图像大小
const IMAGE_SIZE = ;
// KNN 的 K 值
const TOPK = ;

const video = document.getElementById("webcam");

在这个特定的示例中,我们希望能够在我们的头部向左或向右倾斜之间对网络摄像头输入进行分类,因此我们需要两个标记为 leftright 的类。

设置为 227 的图像大小是视频元素的大小(以像素为单位)。根据 Tensorflow.js 示例,该值需要设置为 227 以匹配用于训练 MobileNet 模型的数据格式。为了能够对我们的新数据进行分类,后者需要适应相同的格式。

如果你真的需要它更大,这是可能的,但你必须在将数据提供给 KNN 分类器之前转换和调整数据大小。

然后,我们将 K 的值设置为 10。KNN 算法中的 K 值很重要,因为它代表了我们在确定新输入的类别时考虑的实例数。

在这种情况下,10 意味着,在预测一些新数据的标签时,我们将查看训练数据中的 10 个最近邻,以确定如何对新输入进行分类。

最后,我们得到了视频元素。

对于逻辑,让我们从加载模型和分类器开始:

代码语言:javascript复制
async load() {
    const knn = knnClassifier.create();
    const mobilenetModule = await mobilenet.load();
    console.log("model loaded");
}

然后,让我们访问视频源:

代码语言:javascript复制
navigator.mediaDevices
  .getUserMedia({ video: true, audio: false })
  .then(stream => {
    video.srcObject = stream;
    video.width = IMAGE_SIZE;
    video.height = IMAGE_SIZE;
  });

接下来,让我们设置一些按钮事件来记录我们的示例数据:

代码语言:javascript复制
setupButtonEvents() {
    for (let i = ; i < NUM_CLASSES; i  ) {
      let button = document.getElementsByClassName("button")[i];

      button.onmousedown = () => {
        this.training = i;
        this.recordSamples = true;
      };
      button.onmouseup = () => (this.training = -1);
    }
  }

让我们编写我们的函数,它将获取网络摄像头图像样本,重新格式化它们并将它们与 MobileNet 模块结合起来:

代码语言:javascript复制
// 从视频元素中获取图像数据
const image = tf.browser.fromPixels(video);

let logits;
// 'conv_preds' 是 MobileNet 的 logits 激活
const infer = () => this.mobilenetModule.infer(image, "conv_preds");

// 如果按住其中一个按钮,则进行训练
if (this.training != -1) {
  logits = infer();

  // 将当前图像添加到分类器
  this.knn.addExample(logits, this.training);
}

最后,一旦我们收集了一些网络摄像头图像,我们就可以使用以下代码测试我们的预测:

代码语言:javascript复制
logits = infer();
const res = await this.knn.predictClass(logits, TOPK);
const prediction = classes[res.classIndex];

最后,您可以处理我们不再需要的网络摄像头数据:

代码语言:javascript复制
// 完成后处理图像
image.dispose();
if (logits != null) {
  logits.dispose();
}

0 人点赞