【目标检测】YOLOv5在Android上的部署

2022-11-12 11:09:50 浏览数 (1)

前言

本篇博文用来研究YOLOv5在Android上部署的例程 主要参考的是Pytorch官方提供的Demo:https://github.com/pytorch/android-demo-app/tree/master/PyTorchDemoApp

功能简述

App主页如下图所示:

主要功能:

  • 切换测试图片 在程序中直接指定三张(或任意张)图片,点击测试图片,可以切换图片
  • 选择图片 点击选择图片,可以在相册中选择一张图片,也可以直接进行拍照
  • 实时视频 点击实时视频,可以开启摄像头,直接在摄像预览中显示检测结果
  • 切换模型(我添加的功能) 点击切换模型,可以选择不同的模型进行检测

快速上手

首先来跑通官方Demo,首先下载官方提供的yolov5s.torchscript.ptl 下载链接:https://pytorch-mobile-demo-apps.s3.us-east-2.amazonaws.com/yolov5s.torchscript.ptl

下载完放到assets文件夹下

直接运行,从相册中选择图片时会报错:

Unable to decode stream: java.io.FileNotFoundException:/…/open failed: EACCES (Permission denied)

此时需要在AndroidManifest.xmlapplication标签中添加一句:

代码语言:javascript复制
android:requestLegacyExternalStorage="true"

然后就可以正常运行了

训练自己的模型

下面用YOLOv5-6.0版本训练自己的模型,怎么训练不做赘述,可以参考本专栏的往期博文。

然后修改export.py中的export_torchscript函数,主要添加三行代码,用以导出.torchscript.ptl后缀模型。

代码语言:javascript复制
def export_torchscript(model, im, file, optimize, prefix=colorstr('TorchScript:')):
    # YOLOv5 TorchScript model export
    try:
        print(f'n{prefix} starting export with torch {torch.__version__}...')
        f = file.with_suffix('.torchscript.pt')
        f = str(f)
        fl = file.with_suffix('.torchscript.ptl')

        ts = torch.jit.trace(model, im, strict=False)
        (optimize_for_mobile(ts) if optimize else ts).save(f)
        (optimize_for_mobile(ts) if optimize else ts)._save_for_lite_interpreter(str(fl))
        print(f'{prefix} export success, saved as {f} ({file_size(f):.1f} MB)')
    except Exception as e:
        print(f'{prefix} export failure: {e}')

然后在终端运行:

代码语言:javascript复制
python export.py --weights runs/train/exp/weights/best.pt --include torchscript

运行完得到best.torchscript.ptl模型

切换自己的模型

下面来添加一个切换模型的功能,并使用自己训练的模型。

首先修改pytorch依赖版本,修改build.gradle中的依赖:

代码语言:javascript复制
implementation 'org.pytorch:pytorch_android_lite:1.9.0'
implementation 'org.pytorch:pytorch_android_torchvision:1.9.0'

这里的版本尽量和后面训练用的pytorch版本对应,比如后面自己用的pytorch版本是1.9.0,这里就写1.9.0。

然后修改ObjectDetectionActivitys,java,这里将mOutputColumnprivate修饰符去掉,使其可以在外部访问:

接下来修改xml界面,在activity_main.xml中添加切换模型按钮,并调整布局

代码语言:javascript复制
<Button
    android:id="@ id/select"
    android:layout_width="100dp"
    android:layout_height="wrap_content"
    android:layout_marginTop="32dp"
    android:textAllCaps="false"
    android:text="@string/select_model"
    app:layout_constraintEnd_toEndOf="parent"
    app:layout_constraintEnd_toStartOf="@ id/selectButton"
    app:layout_constraintHorizontal_bias="0.5"
    app:layout_constraintStart_toStartOf="parent"
    app:layout_constraintTop_toBottomOf="@ id/detectButton"
    android:background="@drawable/button_selector"/>


<Button
    android:id="@ id/testButton"
    android:layout_width="wrap_content"
    android:layout_height="wrap_content"
    android:layout_marginTop="180dp"
    android:textAllCaps="false"
    app:layout_constraintEnd_toEndOf="parent"
    app:layout_constraintHorizontal_bias="0.5"
    app:layout_constraintStart_toStartOf="parent"
    app:layout_constraintTop_toBottomOf="@ id/imageView"
    android:background="@drawable/button_selector"/>

然后修改MainActivity.java,添加以下三个属性

代码语言:javascript复制
private String model_name = "yolov5s.torchscript.ptl";
private String model_class = "classes.txt";
private int num_class = 80;

添加选择模型按钮响应:

代码语言:javascript复制
private void ShowChoise()
{
    AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
    // builder.setIcon(R.drawable.ic_launcher_foreground);
    builder.setTitle("选择一个模型");
    //    指定下拉列表的显示数据
    final String[] cities = {"YOLOv5s", "王者荣耀模型"};
    //    设置一个下拉的列表选择项
    builder.setItems(cities, new DialogInterface.OnClickListener()
    {
        @Override
        public void onClick(DialogInterface dialog, int which)
        {
            Toast.makeText(MainActivity.this, "选择的模型为:"   cities[which], Toast.LENGTH_SHORT).show();
            if (which==0){
                model_name = "yolov5s.torchscript.ptl";
                model_class = "classes.txt";
                num_class = 80;
            }
            else {
                model_name = "mymodel.ptl";
                model_class = "classes_wzry.txt";
                num_class = 10;
            }
            // 重新加载
            try {
                mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));
                BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));
                String line;
                List<String> classes = new ArrayList<>();
                while ((line = br.readLine()) != null) {
                    classes.add(line);
                }
                PrePostProcessor.mClasses = new String[classes.size()];
                PrePostProcessor.mOutputColumn = num_class   5;
                classes.toArray(PrePostProcessor.mClasses);
            } catch (IOException e) {
                Log.e("Object Detection", "Error reading assets", e);
                finish();
            }
        }
    });
    builder.show();
}

这里选择的模型数量添加if分支,model_class为模型对应的类别标签,需要仿照classes.txt单独创建,num_class为类别数量。

最后将之上一步得到的best.torchscript.ptl复制到assets文件夹下,注意需要手动修改文件名mymodel.ptl,这里不改名会发生文件找不到的报错,最后再运行即可。

完整代码

除了上面这部分,还对界面进行了汉化,图片加载做了微调,几个修改过的文件的完整源码如下:

activity_main.xml

代码语言:javascript复制
<?xml version="1.0" encoding="utf-8"?>
<androidx.constraintlayout.widget.ConstraintLayout xmlns:android="http://schemas.android.com/apk/res/android"
    xmlns:app="http://schemas.android.com/apk/res-auto"
    xmlns:tools="http://schemas.android.com/tools"
    android:layout_width="match_parent"
    android:layout_height="match_parent"
    tools:context="org.pytorch.demo.objectdetection.MainActivity">

    <ImageView
        android:id="@ id/imageView"
        android:layout_width="0dp"
        android:layout_height="0dp"
        android:layout_marginTop="0dp"
        android:background="#FFFFFF"
        android:contentDescription="@string/image_view"
        app:layout_constraintDimensionRatio="1:1"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

    <org.pytorch.demo.objectdetection.ResultView
        android:id="@ id/resultView"
        android:layout_width="0dp"
        android:layout_height="0dp"
        android:layout_marginTop="0dp"
        app:layout_constraintDimensionRatio="1:1"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toTopOf="parent" />

    <Button
        android:id="@ id/detectButton"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="20dp"
        android:text="@string/detect"
        android:textAllCaps="false"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.498"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@ id/imageView"
        android:background="@drawable/button_selector"/>


    <ProgressBar
        android:id="@ id/progressBar"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="20dp"
        android:visibility="invisible"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.498"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@ id/imageView" />


    <Button
        android:id="@ id/selectButton"
        android:layout_width="100dp"
        android:layout_height="wrap_content"
        android:text="@string/select"
        android:textAllCaps="false"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintEnd_toStartOf="@ id/liveButton"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toEndOf="@ id/select"
        app:layout_constraintTop_toTopOf="@ id/select"
        android:background="@drawable/button_selector"/>

    <Button
        android:id="@ id/liveButton"
        android:layout_width="100dp"
        android:layout_height="wrap_content"
        android:text="@string/live"
        android:textAllCaps="false"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toEndOf="@ id/selectButton"
        app:layout_constraintTop_toTopOf="@ id/selectButton"
        android:background="@drawable/button_selector"/>

    <Button
        android:id="@ id/select"
        android:layout_width="100dp"
        android:layout_height="wrap_content"
        android:layout_marginTop="32dp"
        android:textAllCaps="false"
        android:text="@string/select_model"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintEnd_toStartOf="@ id/selectButton"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@ id/detectButton"
        android:background="@drawable/button_selector"/>


    <Button
        android:id="@ id/testButton"
        android:layout_width="wrap_content"
        android:layout_height="wrap_content"
        android:layout_marginTop="180dp"
        android:textAllCaps="false"
        app:layout_constraintEnd_toEndOf="parent"
        app:layout_constraintHorizontal_bias="0.5"
        app:layout_constraintStart_toStartOf="parent"
        app:layout_constraintTop_toBottomOf="@ id/imageView"
        android:background="@drawable/button_selector"/>

</androidx.constraintlayout.widget.ConstraintLayout>

MainActivity.java

代码语言:javascript复制
// Copyright (c) 2020 Facebook, Inc. and its affiliates.
// All rights reserved.
//
// This source code is licensed under the BSD-style license found in the
// LICENSE file in the root directory of this source tree.

package org.pytorch.demo.objectdetection;

import androidx.appcompat.app.AlertDialog;
import androidx.appcompat.app.AppCompatActivity;
import androidx.core.app.ActivityCompat;
import androidx.core.content.ContextCompat;

import android.Manifest;
import android.content.Context;
import android.content.DialogInterface;
import android.content.Intent;
import android.content.pm.PackageManager;
import android.database.Cursor;
import android.graphics.Bitmap;
import android.graphics.BitmapFactory;
import android.graphics.Matrix;
import android.net.Uri;
import android.os.Bundle;
import android.provider.MediaStore;
import android.util.Log;
import android.view.View;
import android.widget.Button;
import android.widget.ImageView;
import android.widget.ProgressBar;
import android.widget.Toast;

import org.pytorch.IValue;
import org.pytorch.LiteModuleLoader;
import org.pytorch.Module;
import org.pytorch.Tensor;
import org.pytorch.torchvision.TensorImageUtils;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.InputStreamReader;
import java.io.OutputStream;
import java.util.ArrayList;
import java.util.List;

public class MainActivity extends AppCompatActivity implements Runnable {
    private int mImageIndex = 0;
    private String[] mTestImages = {"test1.png", "test2.jpg", "test3.png"};

    private ImageView mImageView;
    private ResultView mResultView;
    private Button mButtonDetect;
    private Button mButtonSelect;
    private ProgressBar mProgressBar;
    private Bitmap mBitmap = null;
    private Module mModule = null;
    private float mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY;

    private String model_name = "yolov5s.torchscript.ptl";
    private String model_class = "classes.txt";
    private int num_class = 80;

    public static String assetFilePath(Context context, String assetName) throws IOException {
        File file = new File(context.getFilesDir(), assetName);
        if (file.exists() && file.length() > 0) {
            return file.getAbsolutePath();
        }

        try (InputStream is = context.getAssets().open(assetName)) {
            try (OutputStream os = new FileOutputStream(file)) {
                byte[] buffer = new byte[4 * 1024];
                int read;
                while ((read = is.read(buffer)) != -1) {
                    os.write(buffer, 0, read);
                }
                os.flush();
            }
            return file.getAbsolutePath();
        }
    }

    private void ShowChoise()
    {
        AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
        // builder.setIcon(R.drawable.ic_launcher_foreground);
        builder.setTitle("选择一个模型");
        //    指定下拉列表的显示数据
        final String[] cities = {"YOLOv5s", "王者荣耀模型"};
        //    设置一个下拉的列表选择项
        builder.setItems(cities, new DialogInterface.OnClickListener()
        {
            @Override
            public void onClick(DialogInterface dialog, int which)
            {
                Toast.makeText(MainActivity.this, "选择的模型为:"   cities[which], Toast.LENGTH_SHORT).show();
                if (which==0){
                    model_name = "yolov5s.torchscript.ptl";
                    model_class = "classes.txt";
                    num_class = 80;
                }
                else {
                    model_name = "mymodel.ptl";
                    model_class = "classes_wzry.txt";
                    num_class = 10;
                }
                // 重新加载
                try {
                    mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));
                    BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));
                    String line;
                    List<String> classes = new ArrayList<>();
                    while ((line = br.readLine()) != null) {
                        classes.add(line);
                    }
                    PrePostProcessor.mClasses = new String[classes.size()];
                    PrePostProcessor.mOutputColumn = num_class   5;
                    classes.toArray(PrePostProcessor.mClasses);
                } catch (IOException e) {
                    Log.e("Object Detection", "Error reading assets", e);
                    finish();
                }
            }
        });
        builder.show();
    }


    @Override
    protected void onCreate(Bundle savedInstanceState) {
        super.onCreate(savedInstanceState);

        if (ContextCompat.checkSelfPermission(this, Manifest.permission.READ_EXTERNAL_STORAGE) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.READ_EXTERNAL_STORAGE}, 1);
        }

        if (ContextCompat.checkSelfPermission(this, Manifest.permission.CAMERA) != PackageManager.PERMISSION_GRANTED) {
            ActivityCompat.requestPermissions(this, new String[]{Manifest.permission.CAMERA}, 1);
        }

        setContentView(R.layout.activity_main);

        try {
            mBitmap = BitmapFactory.decodeStream(getAssets().open(mTestImages[mImageIndex]));
        } catch (IOException e) {
            Log.e("Object Detection", "Error reading assets", e);
            finish();
        }

        mImageView = findViewById(R.id.imageView);
        mImageView.setImageBitmap(mBitmap);
        mResultView = findViewById(R.id.resultView);
        mResultView.setVisibility(View.INVISIBLE);

        final Button buttonTest = findViewById(R.id.testButton);
        buttonTest.setText(("测试图片 1/3"));
        buttonTest.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                mResultView.setVisibility(View.INVISIBLE);
                mImageIndex = (mImageIndex   1) % mTestImages.length;
                buttonTest.setText(String.format("测试图片 %d/%d", mImageIndex   1, mTestImages.length));

                try {
                    mBitmap = BitmapFactory.decodeStream(getAssets().open(mTestImages[mImageIndex]));
                    mImageView.setImageBitmap(mBitmap);
                } catch (IOException e) {
                    Log.e("Object Detection", "Error reading assets", e);
                    finish();
                }
            }
        });


        final Button buttonSelect = findViewById(R.id.selectButton);
        buttonSelect.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                mResultView.setVisibility(View.INVISIBLE);

                final CharSequence[] options = { "从相册选择", "拍照", "取消" };
                AlertDialog.Builder builder = new AlertDialog.Builder(MainActivity.this);
                builder.setTitle("新测试图片");

                builder.setItems(options, new DialogInterface.OnClickListener() {
                    @Override
                    public void onClick(DialogInterface dialog, int item) {
                        if (options[item].equals("拍照")) {
                            Intent takePicture = new Intent(android.provider.MediaStore.ACTION_IMAGE_CAPTURE);
                            startActivityForResult(takePicture, 0);
                        }
                        else if (options[item].equals("从相册选择")) {
                            Intent pickPhoto = new Intent(Intent.ACTION_PICK, android.provider.MediaStore.Images.Media.INTERNAL_CONTENT_URI);
                            startActivityForResult(pickPhoto , 1);
                        }
                        else if (options[item].equals("取消")) {
                            dialog.dismiss();
                        }
                    }
                });
                builder.show();
            }
        });

        final Button buttonLive = findViewById(R.id.liveButton);
        buttonLive.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
              final Intent intent = new Intent(MainActivity.this, ObjectDetectionActivity.class);
              startActivity(intent);
            }
        });

        mButtonDetect = findViewById(R.id.detectButton);
        mProgressBar = (ProgressBar) findViewById(R.id.progressBar);
        mButtonDetect.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                mButtonDetect.setEnabled(false);
                mProgressBar.setVisibility(ProgressBar.VISIBLE);
                mButtonDetect.setText(getString(R.string.run_model));

                mImgScaleX = (float)mBitmap.getWidth() / PrePostProcessor.mInputWidth;
                mImgScaleY = (float)mBitmap.getHeight() / PrePostProcessor.mInputHeight;

                mIvScaleX = (mBitmap.getWidth() > mBitmap.getHeight() ? (float)mImageView.getWidth() / mBitmap.getWidth() : (float)mImageView.getHeight() / mBitmap.getHeight());
                mIvScaleY  = (mBitmap.getHeight() > mBitmap.getWidth() ? (float)mImageView.getHeight() / mBitmap.getHeight() : (float)mImageView.getWidth() / mBitmap.getWidth());

                mStartX = (mImageView.getWidth() - mIvScaleX * mBitmap.getWidth())/2;
                mStartY = (mImageView.getHeight() -  mIvScaleY * mBitmap.getHeight())/2;

                Thread thread = new Thread(MainActivity.this);
                thread.start();
            }
        });

        // 新增选择模型按钮
        mButtonSelect = findViewById(R.id.select);
        mButtonSelect.setOnClickListener(new View.OnClickListener() {
            public void onClick(View v) {
                ShowChoise();
            }
        });


        try {
            mModule = LiteModuleLoader.load(MainActivity.assetFilePath(getApplicationContext(), model_name));
            BufferedReader br = new BufferedReader(new InputStreamReader(getAssets().open(model_class)));
            String line;
            List<String> classes = new ArrayList<>();
            while ((line = br.readLine()) != null) {
                classes.add(line);
            }
            PrePostProcessor.mClasses = new String[classes.size()];
            PrePostProcessor.mOutputColumn = num_class;
            classes.toArray(PrePostProcessor.mClasses);
        } catch (IOException e) {
            Log.e("Object Detection", "Error reading assets", e);
            finish();
        }
    }

    @Override
    protected void onActivityResult(int requestCode, int resultCode, Intent data) {
        super.onActivityResult(requestCode, resultCode, data);
        if (resultCode != RESULT_CANCELED) {
            switch (requestCode) {
                case 0:
                    if (resultCode == RESULT_OK && data != null) {
                        mBitmap = (Bitmap) data.getExtras().get("data");
                        Matrix matrix = new Matrix();
                        //matrix.postRotate(90.0f);
                        matrix.postRotate(0);
                        mBitmap = Bitmap.createBitmap(mBitmap, 0, 0, mBitmap.getWidth(), mBitmap.getHeight(), matrix, true);
                        mImageView.setImageBitmap(mBitmap);
                    }
                    break;
                case 1:
                    if (resultCode == RESULT_OK && data != null) {
                        Uri selectedImage = data.getData();
                        String[] filePathColumn = {MediaStore.Images.Media.DATA};
                        if (selectedImage != null) {
                            Cursor cursor = getContentResolver().query(selectedImage,
                                    filePathColumn, null, null, null);
                            if (cursor != null) {
                                cursor.moveToFirst();
                                int columnIndex = cursor.getColumnIndex(filePathColumn[0]);
                                String picturePath = cursor.getString(columnIndex);
                                mBitmap = BitmapFactory.decodeFile(picturePath);
                                Matrix matrix = new Matrix();
                                //matrix.postRotate(90.0f);
                                matrix.postRotate(0);
                                mBitmap = Bitmap.createBitmap(mBitmap, 0, 0, mBitmap.getWidth(), mBitmap.getHeight(), matrix, true);
                                mImageView.setImageBitmap(mBitmap);
                                cursor.close();
                            }
                        }
                    }
                    break;
            }
        }
    }

    @Override
    public void run() {
        Bitmap resizedBitmap = Bitmap.createScaledBitmap(mBitmap, PrePostProcessor.mInputWidth, PrePostProcessor.mInputHeight, true);
        final Tensor inputTensor = TensorImageUtils.bitmapToFloat32Tensor(resizedBitmap, PrePostProcessor.NO_MEAN_RGB, PrePostProcessor.NO_STD_RGB);
        IValue[] outputTuple = mModule.forward(IValue.from(inputTensor)).toTuple();
        final Tensor outputTensor = outputTuple[0].toTensor();
        final float[] outputs = outputTensor.getDataAsFloatArray();
        final ArrayList<Result> results =  PrePostProcessor.outputsToNMSPredictions(outputs, mImgScaleX, mImgScaleY, mIvScaleX, mIvScaleY, mStartX, mStartY);

        runOnUiThread(() -> {
            mButtonDetect.setEnabled(true);
            mButtonDetect.setText(getString(R.string.detect));
            mProgressBar.setVisibility(ProgressBar.INVISIBLE);
            mResultView.setResults(results);
            mResultView.invalidate();
            mResultView.setVisibility(View.VISIBLE);
        });
    }
}

strings.xml

代码语言:javascript复制
<resources>
    <string name="app_name">YOLOv5</string>
    <string name="image_view">Image View</string>
    <string name="detect">检测</string>
    <string name="run_model">正在运行,请稍后</string>
    <string name="restart">Restart</string>
    <string name="select">选择图片</string>
    <string name="live">实时视频</string>
    <string name="select_model">切换模型</string>
</resources>

button_selector.xml

代码语言:javascript复制
<?xml version="1.0" encoding="utf-8"?>
<selector xmlns:android="http://schemas.android.com/apk/res/android">
    <item android:state_pressed="true">
        <shape>
            <solid android:color="#64AFFA"/>
            <corners android:radius="10dp"/>
            <padding
                android:bottom="2dp"
                android:left="3dp"
                android:right="3dp"
                android:top="2dp">
            </padding>
        </shape>
    </item>
    <item android:state_pressed="false">
        <shape>
            <solid android:color="#99CCFF"/>
            <corners android:radius="10dp"/>
            <padding
                android:bottom="2dp"
                android:left="3dp"
                android:right="3dp"
                android:top="2dp">
            </padding>
        </shape>
    </item>
</selector>

总结

经过实测,整个APK文件打包出来有1点多G,由此可见pytorch框架一加进去体积就会变得很大,后续轻量化还有研究空间。同时,视频实时检测,帧率很低,基本卡成PPT,可能是受限于手机的算力不足,后续也有待研究优化。

0 人点赞