【机器学习】Tensorflow.js:在浏览器中使用机器学习实现图像分类

2023-10-07 19:30:25 浏览数 (3)

使用 JavaScript 和 Tensorflow.js 等框架是入门和了解更多机器学习的好方法。 在本文中,我会介绍当前使用 Tensorflow.js 可用的三个主要功能,并阐明在前端使用机器学习的局限性。

机器学习通常感觉它属于数据科学家和 Python 开发人员的领域。 然而,在过去的几年中,已经创建了开源框架,以使其更易于在不同的编程语言中访问,包括 JavaScript。 在本文中,我们将使用 Tensorflow.js 通过几个示例项目来探索在浏览器中使用机器学习的不同可能性。

机器学习

对于机器学习,一个常见的定义是:计算机无需明确编程即可从数据中学习的能力。

如果我们将其与传统编程进行比较,这意味着我们让计算机识别数据中的模式并生成预测,而无需我们确切地告诉它要寻找什么。

让我们以欺诈检测为例。 没有确定的标准可以知道是什么使交易具有欺诈性; 欺诈可以在任何国家、任何账户、针对任何客户、任何时间等进行。 手动跟踪所有这些几乎是不可能的。

然而,使用多年来收集的有关欺诈费用的先前数据,我们可以训练机器学习算法来理解这些数据中的模式,从而生成一个模型,该模型可以给出任何新交易并预测它是否为欺诈的可能性,而无需 准确地告诉它要寻找什么。

几个核心概念

如果你是一个初学者,要理解以下我们代码中的示例,需要先了解下面的一些常用术语。

Model

当你使用数据集训练机器学习算法时,模型是此训练过程的输出。 它有点像一个将新数据作为输入并产生预测作为输出的函数。

标签和特征

标签和特征与你在训练过程中提供给算法的数据相关。

标签表示你将如何对数据集中的每个条目进行分类以及如何标记它。 例如,如果我们的数据集是一个描述不同动物的 CSV 文件,我们的标签可以是“猫”、“狗”或“蛇”之类的词(取决于每种动物代表什么)。

另一方面,特征是数据集中每个条目的特征。 对于我们的动物示例,它可能是“胡须、喵喵”、“顽皮、吠叫”、“爬行动物、猖獗”等。

使用这一点,机器学习算法将能够找到特征与其标签之间的某种相关性,并将用于未来的预测。

神经网络

神经网络是一组机器学习算法,它试图通过使用人工神经元层来模仿大脑的工作方式。

我们不需要在本文中深入了解它们的工作原理,但是如果您想了解更多信息,这里有一个非常好的视频:

现在我们已经定义了一些机器学习中常用的术语,让我们来谈谈使用 JavaScript 和 Tensorflow.js 框架可以做什么。

1. 使用预训练模型

根据你尝试解决的问题,可能已经有一个模型已经使用特定数据集和用于特定目的进行了训练,你可以在代码中加以利用和导入。

例如,假设我们正在构建一个网站来预测一张图片是否是一张猫的图片。 一种流行的图像分类模型称为 MobileNet,可作为带有 Tensorflow.js 的预训练模型使用。

代码如下所示:

代码语言:javascript复制
<html lang="en">
  <head>
    <meta charset="UTF-8">
    <meta name="viewport" content="width=device-width, initial-scale=1.0">
    <meta http-equiv="X-UA-Compatible" content="ie=edge">
    <title>Cat detection</title>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow/tfjs@1.0.1"> </script>
    <script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>
  </head>
  <body>
    <img id="image" alt="cat laying down" src="cat.jpeg"/>

    <script>
      const img = document.getElementById('image');

      const predictImage = async () => {
        console.log("Model loading...");
        const model = await mobilenet.load();
        console.log("Model is loaded!")

        const predictions = await model.classify(img);
        console.log('Predictions: ', predictions);
      }
      predictImage();
    </script>
  </body>
</html>

解释:

上面代码中,我们首先在 HTML 的头部导入 Tensorflow.js 和 MobileNet 模型:

代码语言:javascript复制
<script src="https://cdnjs.cloudflare.com/ajax/libs/tensorflow/1.0.1/tf.js"> </script>
<script src="https://cdn.jsdelivr.net/npm/@tensorflow-models/mobilenet@1.0.0"> </script>

然后,在 body 内部,我们有一个用于预测的图像元素:

代码语言:javascript复制
<img id="image" alt="cat laying down" src="cat.jpeg"/>

最后,在 script 标签内,我们有 JavaScript 代码,它加载预训练的 MobileNet 模型并对在图像标签中找到的图像进行分类。 它返回一个由 3 个预测组成的数组,这些预测按概率分数排序(第一个元素是最佳预测)。

代码语言:javascript复制
const predictImage = async () => {
  console.log("Model loading...");
  const model = await mobilenet.load();
  console.log("Model is loaded!")
  const predictions = await model.classify(img);
  console.log('Predictions: ', predictions);
}

predictImage();

上面这个示例,就是你可以在浏览器中通过 Tensorflow.js 使用预训练模型的方式!

注意:如果你想看看 MobileNet 模型还能分类什么,你可以在 Github 上找到可用的不同类的列表。

需要了解的重要一点是,在浏览器中加载预训练模型可能需要一些时间(有时长达 10 秒),因此你可能需要预加载或调整界面,以免影响用户的体验。

如果你更喜欢使用 Tensorflow.js 作为 NPM 模块,您可以通过以下方式导入模块:

代码语言:javascript复制
import * as mobilenet from '@tensorflow-models/mobilenet';

本文我们讲解了如何使用 TensorFlow.js 在浏览器中实现对图像的分类,并介绍了什么是机器学习。下一篇中,我还会为大家介绍更多 TensorFlow.js 在浏览器端的应用案例,关注我,少走弯路,不吃亏~

1 人点赞