在Android上使用YOLOv8目标检测(步骤+代码)

2024-07-25 18:39:23 浏览数 (3)

视觉/图像重磅干货,第一时间送达!

前 言

Yolov8 是一种流行的物体检测 AI。Android是世界上用户最多的移动操作系统。

本文介绍如何在 Android 设备上执行 yolov8 物体检测。

步骤1:从Pytorch格式转换为tflite格式

YOLOv8 以pytorch格式构建。将其转换为tflite,以便在 android 上使用。

安装YOLOv8

安装一个名为Ultralytics的框架。Yolov8包含在此框架中。

代码语言:javascript复制
pip install ultralytics

转换为 tflite

使用转换代码进行转换。以下代码将下载预训练模型的权重。

如果您有使用自己的自定义数据训练的模型的权重检查点文件,请替换 yolov8s.pt 部分。

代码语言:javascript复制
from ultralytics import YOLO
model = YOLO('yolov8s.pt')
model.export(format="tflite")

将生成yolov8s_saved_model/yolov8s_float16.tflite,因此请使用它。

如果发生转换错误...

如果出现以下错误,则是由于tensorflow的版本问题,因此请安装兼容的版本。

ImportError:generic_type:无法初始化类型“StatusCode”:具有该名称的对象已定义

例如将tensorflow改为如下版本。

代码语言:javascript复制
pip install tensorflow==2.13.0

在 Android 上运行 tflite 文件

从这里开始,我们将在android studio项目中运行yolov8 tflite文件。

将 tflite 文件添加到项目中

在android studio项目的app目录下创建assets目录(File → New → Folder → Asset Folder),添加tflite文件(yolov8s_float32.tflite)和labels.txt,可以通过复制粘贴的方式添加。

labels.txt 是一个文本文件,其中描述了 YOLOv8 模型的类名,如下所示。

如果您设置了自定义类,请写入该类。

默认的 YOLOv8 预训练模型如下。

labels.txt内容如下:

代码语言:javascript复制
person
bicycle
car
motorcycle
airplane
bus
train
truck
boat
traffic light
fire hydrant
stop sign
parking meter
bench
bird
cat
dog
horse
sheep
cow
elephant
bear
zebra
giraffe
backpack
umbrella
handbag
tie
suitcase
frisbee
skis
snowboard
sports ball
kite
baseball bat
baseball glove
skateboard
surfboard
tennis racket
bottle
wine glass
cup
fork
knife
spoon
bowl
banana
apple
sandwich
orange
broccoli
carrot
hot dog
pizza
donut
cake
chair
couch
potted plant
bed
dining table
toilet
tv
laptop
mouse
remote
keyboard
cell phone
microwave
oven
toaster
sink
refrigerator
book
clock
vase
scissors
teddy bear
hair drier
toothbrush

安装 tflite

将以下内容添加到 app/build.gradle.kts 中的依赖项中以安装 tflite 框架。

应用程序/build.gradle.kts

代码语言:javascript复制
implementation("org.tensorflow:tensorflow-lite:2.14.0")
implementation("org.tensorflow:tensorflow-lite-support:0.4.4")

添加完以上内容后,按立即同步进行安装。

导入所需模块

代码语言:javascript复制
import org.tensorflow.lite.DataType
import org.tensorflow.lite.Interpreter
import org.tensorflow.lite.gpu.CompatibilityList
import org.tensorflow.lite.gpu.GpuDelegate
import org.tensorflow.lite.support.common.FileUtil
import org.tensorflow.lite.support.common.ops.CastOp
import org.tensorflow.lite.support.common.ops.NormalizeOp
import org.tensorflow.lite.support.image.ImageProcessor
import org.tensorflow.lite.support.image.TensorImage
import org.tensorflow.lite.support.tensorbuffer.TensorBuffer
import java.io.BufferedReader
import java.io.IOException
import java.io.InputStream
import java.io.InputStreamReader

必需的类属性

代码语言:javascript复制
private val modelPath = "yolov8s_float32.tflite"
private val labelPath = "labels.txt"
private var interpreter: Interpreter? = null
private var tensorWidth = 0
private var tensorHeight = 0
private var numChannel = 0
private var numElements = 0
private var labels = mutableListOf<String>()
private val imageProcessor = ImageProcessor.Builder()
    .add(NormalizeOp(INPUT_MEAN, INPUT_STANDARD_DEVIATION))
    .add(CastOp(INPUT_IMAGE_TYPE))
    .build() // preprocess input

companion object {
    private const val INPUT_MEAN = 0f
    private const val INPUT_STANDARD_DEVIATION = 255f
    private val INPUT_IMAGE_TYPE = DataType.FLOAT32
    private val OUTPUT_IMAGE_TYPE = DataType.FLOAT32
    private const val CONFIDENCE_THRESHOLD = 0.3F
    private const val IOU_THRESHOLD = 0.5F
}

初始化模型

初始化 tflite 模型。获取模型文件并将其传递给 tflite 的Interpreter。可选地传递要使用的线程数。

如果您在 Activity 以外的类中使用它,则需要将上下文传递给该类。

代码语言:javascript复制
val model = FileUtil.loadMappedFile(context, modelPath)
val options = Interpreter.Options()
options.numThreads = 4
interpreter = Interpreter(model, options)

从解释器获取 yolov8s 输入和输出shape。

代码语言:javascript复制
val inputShape = interpreter.getInputTensor(0).shape()
val outputShape = interpreter.getOutputTensor(0).shape()

tensorWidth = inputShape[1]
tensorHeight = inputShape[2]
numChannel = outputShape[1]
numElements = outputShape[2]

从 label.txt 文件中读取类名。

必须明确关闭 InputStream 和 InputStreamReader。

代码语言:javascript复制
try {
    val inputStream: InputStream = context.assets.open(labelPath)
    val reader = BufferedReader(InputStreamReader(inputStream))
    var line: String? = reader.readLine()
    while (line != null && line != "") {
        labels.add(line)
        line = reader.readLine()
    }
    reader.close()
    inputStream.close()
} catch (e: IOException) {
    e.printStackTrace()
}

输入图像并执行

输入是位图,但根据模型的输入格式进行下面的预处理。

1. 调整大小以匹配模型的输入形状

2. 使其成为张量

3. 通过将像素值除以 255 来标准化像素值(使其成为 0 到 1 范围内的值)

4. 转换为模型的输入类型

5. 输入获取 imageBuffer 以

代码语言:javascript复制
val resizedBitmap = Bitmap.createScaledBitmap(bitmap, tensorWidth, tensorHeight, false)
val tensorImage = TensorImage(DataType.FLOAT32)
tensorImage.load(resizedBitmap)
val processedImage = imageProcessor.process(tensorImage)
val imageBuffer = processedImage.buffer

创建一个与模型输出形状相匹配的输出张量缓冲区,并将其与上面的输入 imageBuffer 一起传递给解释器进行执行。

代码语言:javascript复制
val output = TensorBuffer.createFixedSize(intArrayOf(1 , numChannel, numElements), OUTPUT_IMAGE_TYPE)
interpreter.run(imageBuffer, output.buffer)

对输出进行后处理

输出框被视为 BoudingBox 类。

它是一个具有类、框和置信度的类。

x1, y1 是起点。x2, y2 是终点。cx, cy 是中心。w是宽度, h是高度。

代码语言:javascript复制
data class BoundingBox(
    val x1: Float,
    val y1: Float,
    val x2: Float,
    val y2: Float,
    val cx: Float,
    val cy: Float,
    val w: Float,
    val h: Float,
    val cnf: Float,
    val cls: Int,
    val clsName: String
)

接下来的过程是从众多输出框候选中选择一个可靠性较高的框。

1. 提取置信度高于置信度阈值的框。

2. 在重叠框中,保留可靠性最高的框。(nms)

代码语言:javascript复制
private fun bestBox(array: FloatArray) : List<BoundingBox>? {

    val boundingBoxes = mutableListOf<BoundingBox>()

    for (c in 0 until numElements) {
        var maxConf = -1.0f
        var maxIdx = -1
        var j = 4
        var arrayIdx = c   numElements * j
        while (j < numChannel){
            if (array[arrayIdx] > maxConf) {
                maxConf = array[arrayIdx]
                maxIdx = j - 4
            }
            j  
            arrayIdx  = numElements
        }

        if (maxConf > CONFIDENCE_THRESHOLD) {
            val clsName = labels[maxIdx]
            val cx = array[c] // 0
            val cy = array[c   numElements] // 1
            val w = array[c   numElements * 2]
            val h = array[c   numElements * 3]
            val x1 = cx - (w/2F)
            val y1 = cy - (h/2F)
            val x2 = cx   (w/2F)
            val y2 = cy   (h/2F)
            if (x1 < 0F || x1 > 1F) continue
            if (y1 < 0F || y1 > 1F) continue
            if (x2 < 0F || x2 > 1F) continue
            if (y2 < 0F || y2 > 1F) continue

            boundingBoxes.add(
                BoundingBox(
                    x1 = x1, y1 = y1, x2 = x2, y2 = y2,
                    cx = cx, cy = cy, w = w, h = h,
                    cnf = maxConf, cls = maxIdx, clsName = clsName
                )
            )
        }
    }

    if (boundingBoxes.isEmpty()) return null

    return applyNMS(boundingBoxes)
}

private fun applyNMS(boxes: List<BoundingBox>) : MutableList<BoundingBox> {
    val sortedBoxes = boxes.sortedByDescending { it.cnf }.toMutableList()
    val selectedBoxes = mutableListOf<BoundingBox>()

    while(sortedBoxes.isNotEmpty()) {
        val first = sortedBoxes.first()
        selectedBoxes.add(first)
        sortedBoxes.remove(first)

        val iterator = sortedBoxes.iterator()
        while (iterator.hasNext()) {
            val nextBox = iterator.next()
            val iou = calculateIoU(first, nextBox)
            if (iou >= IOU_THRESHOLD) {
                iterator.remove()
            }
        }
    }

    return selectedBoxes
}

private fun calculateIoU(box1: BoundingBox, box2: BoundingBox): Float {
    val x1 = maxOf(box1.x1, box2.x1)
    val y1 = maxOf(box1.y1, box2.y1)
    val x2 = minOf(box1.x2, box2.x2)
    val y2 = minOf(box1.y2, box2.y2)
    val intersectionArea = maxOf(0F, x2 - x1) * maxOf(0F, y2 - y1)
    val box1Area = box1.w * box1.h
    val box2Area = box2.w * box2.h
    return intersectionArea / (box1Area   box2Area - intersectionArea)
}

此时你会得到yolov8的输出。

代码语言:javascript复制
val bestBoxes = bestBox(output.floatArray)

将输出框绘制到图像上

代码语言:javascript复制
fun drawBoundingBoxes(bitmap: Bitmap, boxes: List<BoundingBox>): Bitmap {
    val mutableBitmap = bitmap.copy(Bitmap.Config.ARGB_8888, true)
    val canvas = Canvas(mutableBitmap)
    val paint = Paint().apply {
        color = Color.RED
        style = Paint.Style.STROKE
        strokeWidth = 8f
    }
    val textPaint = Paint().apply {
        color = Color.WHITE
        textSize = 40f
        typeface = Typeface.DEFAULT_BOLD
    }

    for (box in boxes) {
        val rect = RectF(
            box.x1 * mutableBitmap.width,
            box.y1 * mutableBitmap.height,
            box.x2 * mutableBitmap.width,
            box.y2 * mutableBitmap.height
        )
        canvas.drawRect(rect, paint)
        canvas.drawText(box.clsName, rect.left, rect.bottom, textPaint)
    }

    return mutableBitmap
}

在一些情况下,解释器为空时需要模型路径是否正确。

1 人点赞