3D Detection: 3D Box和点云 快速可视化

2022-09-07 16:59:33 浏览数 (1)

3D检测,用于3D box,点云快速可视化,辅助debug和分析: (Nuscenes,mmdet3d,OpenPCDet等适用)

注意:代码适用于多种模型,但是注意BEVDet系列和FCOS3D(front View)系列要选用不同的box转换。

3D box投影到bev可视化:

代码语言:javascript复制
import matplotlib.pyplot as plt
import torch

def box3d2x0y0wh(boxes_3d):
	# BEVDet/CenterPoints
    import numpy as np
    n = boxes_3d.shape[0]
    box2d = np.zeros((n,4))
    # 3dbox --> xywh
    box2d[:,:2] = boxes_3d[:,:2]
    box2d[:,2] = boxes_3d[:,3]
    box2d[:,3] = boxes_3d[:,4] # 2xywh
    # 
    # xyxy = np.ones_like(box2d)
    box2d[:,0] = box2d[:, 0] - box2d[:, 2] / 2 
    box2d[:,1] = box2d[:, 1]   box2d[:, 3] / 2 # NOTE: 左下角点
    
    return box2d

def box3d2x0y0wh_2(boxes_3d):
    # FCOS3D: front view--> BEV
    import numpy as np
    n = boxes_3d.shape[0]
    box2d = np.zeros((n,4))
    # 3dbox --> xywh 左下角点
    box2d[:, 0] = boxes_3d[:, 0]
    box2d[:, 1] = boxes_3d[:, 2]
    box2d[:, 2] = boxes_3d[:, 4]
    box2d[:, 3] = boxes_3d[:, 5] # 2xywh
    # 
    # xyxy = np.ones_like(box2d)
    box2d[:,0] = box2d[:, 0] - box2d[:, 2] / 2 
    box2d[:,1] = box2d[:, 1]   box2d[:, 3] / 2 # NOTE
    
    return box2d

# 根据坐标作图
def draw_boxes(pred_boxes_3d, target_boxes_3d, path):
    # pred_boxes xywh
    
    import matplotlib.pyplot as plt
    import matplotlib.patches as patches
    
    pred_boxes = box3d2x0y0wh(pred_boxes_3d)
    target_boxes = box3d2x0y0wh(target_boxes_3d)

    fig, ax = plt.subplots()
    ax.plot()
    # ax.add_patch(patches.Rectangle((1, 1),0.5,0.5,edgecolor = 'blue',facecolor = 'red',fill=True) )
    
    #
    for index, coord in enumerate(pred_boxes):
        rect = patches.Rectangle((coord[0], coord[1]), coord[2], coord[3], 
        						linewidth=1, edgecolor='r',facecolor='none')
        ax.add_patch(rect)
    for index, coord in enumerate(target_boxes):
        rect = patches.Rectangle((coord[0], coord[1]), coord[2], coord[3], 
        						linewidth=1, edgecolor='g',facecolor='none')
        ax.add_patch(rect)
    
    # plt.legend(loc='best',edgecolor='g')
	if os.path.exists(path):
        os.remove(path)
    fig.savefig(path, dpi=90, bbox_inches='tight')
    # print(0)
    plt.close(fig)
    print('Successfully saved')

点云快速可视化:

代码语言:javascript复制
def draw_pts(points, save_path, show=False):
    '''
    points: [N,3 c]
    '''
    assert len(points.shape) == 2
    if isinstance(points, torch.Tensor):
        points = points.cpu().numpy()

    points = points.copy()

    fig = plt.figure()
    ax = fig.add_subplot(111, projection='3d')
    # point_range = range(0, points.shape[0], skip) # skip points to prevent crash
    point_range = range(0, points.shape[0])
    ax.scatter(points[point_range, 0],   # x
            points[point_range, 1],   # y
            points[point_range, 2],   # z
            c=points[point_range, 2], # height data for color
            cmap=plt.get_cmap("Spectral"),
            marker="x")
    ax.axis('auto')  # {equal, scaled}
    if show:
        plt.show()

    if save_path is not None:
        fig.savefig(save_path, dpi=90, bbox_inches='tight')
    plt.close(fig)

0 人点赞