Datawhale干货
作者:王浩,结行科技算法工程师
参加了“世界人工智能创新大赛”——手写体 OCR 识别竞赛(任务一),取得了Top1的成绩。队伍随机组的,有人找我我就加了进来,这是我第一次做OCR相关的项目,所以随意起了个名字。下面通过这篇文章来详细介绍我们的方案。
实践背景
赛题背景
银行日常业务中涉及到各类凭证的识别录入,例如身份证录入、支票录入、对账单录入等。以往的录入方式主要是以人工录入为主,效率较低,人力成本较高。近几年来,OCR相关技术以其自动执行、人为干预较少等特点正逐步替代传统的人工录入方式。但OCR技术在实际应用中也存在一些问题,在各类凭证字段的识别中,手写体由于其字体差异性大、字数不固定、语义关联性较低、凭证背景干扰等原因,导致OCR识别率准确率不高,需要大量人工校正,对日常的银行录入业务造成了一定的影响。
赛题地址:http://ailab.aiwin.org.cn/competitions/65
赛题任务
本次赛题将提供手写体图像切片数据集,数据集从真实业务场景中,经过切片脱敏得到,参赛队伍通过识别技术,获得对应的识别结果。即:
- 输入:手写体图像切片数据集
- 输出:对应的识别结果
本任务提供开放可下载的训练集及测试集,允许线下建模或线上提供 Notebook 环境及 Terminal 容器环境(脱网)建模,输出识别结果完成赛题。
赛题数据
A. 数据规模和内容覆盖
B.数据示例
原始手写体图像共分为三类,分别涉及银行名称、年月日、金额三大类,分别示意如下:
相应图片切片中可能混杂有一定量的干扰信息,分别示例如下:
识别结果 JSON 在训练集中的格式如下:
代码语言:javascript复制json 文件内容规范:
{
"image1": "陆万捌千零贰拾伍元整",
"image2": "付经管院工资",
"image3": "",
...
}
实践方案
通过在网上查阅资料,得知OCR比赛最常用的模型是CRNN CTC。所以我最开始也是采用这个方案。
上图是我找到的资料,有好多个版本。因为是第一次做OCR的项目,所以我优先选择有数据集的项目,这样可以快速的了解模型的输入输出。
所以我选择的第一个Attention_ocr.pytorch-master.zip,从名字上可以看出这个是加入注意力机制,感觉效果会好一些。
构建数据集
下图是Attention_ocr.pytorch-master.zip自带的数据集截图,从截图上可以看出,数据的格式:“图片路径 空格 标签”。我们也需要按照这样的格式构建数据集。
新建makedata.py文件,插入下面的代码。
代码语言:javascript复制import os
import json
#官方给的数据集
image_path_amount = "./data/train/amount/images"
image_path_date = "./data/train/date/images"
#增强数据集
image_path_test='./data/gan_test_15000/images/0'
image_path_train='./data/gan_train_15500_0/images/0'
amount_list = os.listdir(image_path_amount)
amount_list = os.listdir(image_path_amount)
new_amount_list = []
for filename in amount_list:
new_amount_list.append(image_path_amount "/" filename)
date_list = os.listdir(image_path_date)
new_date_list = []
for filename in date_list:
new_date_list.append(image_path_date "/" filename)
new_test_list = []
for filename in amount_list:
new_test_list.append(image_path_amount "/" filename)
new_train_list = []
for filename in amount_list:
new_train_list.append(image_path_amount "/" filename)
image_path_amount和image_path_date是官方给定的数据集路径。
image_path_test和image_path_train是增强的数据集(在后面会讲如何做增强)
创建建立list,保存图片的路径。
代码语言:javascript复制amount_json = "./data/train/amount/gt.json"
date_json = "./data/train/date/gt.json"
train_json = "train_data.json"
test_json = "test_data.json"
with open(amount_json, "r", encoding='utf-8') as f:
load_dict_amount = json.load(f)
with open(date_json, "r", encoding='utf-8') as f:
load_dict_date = json.load(f)
with open(train_json, "r", encoding='utf-8') as f:
load_dict_train = json.load(f)
with open(test_json, "r", encoding='utf-8') as f:
load_dict_test = json.load(f)
四个json文件对应上面的四个list,json文件存储的是图片的名字和图片的标签,把json解析出来存到字典中。
代码语言:javascript复制#聚合list
all_list = new_amount_list new_date_list new_test_list new_train_list
from sklearn.model_selection import train_test_split
#切分训练集合和验证集
train_list, test_list = train_test_split(all_list, test_size=0.15, random_state=42)
#聚合字典
all_dic = {}
all_dic.update(load_dict_amount)
all_dic.update(load_dict_date)
all_dic.update(load_dict_train)
all_dic.update(load_dict_test)
with open('train.txt', 'w') as f:
for line in train_list:
f.write(line " " all_dic[line.split('/')[-1]] "n")
with open('val.txt', 'w') as f:
for line in test_list:
f.write(line " " all_dic[line.split('/')[-1]] "n")
将四个list聚合为一个list。
使用train_test_split切分训练集和验证集。
聚合字典。
然后分别遍历trainlist和testlist,将其写入train.txt和val.txt。
到这里数据集就制作完成了。得到train.txt和val.txt
查看train.txt
数据集和自带的数据集格式一样了,然后我们就可以开始训练了。
获取class
新建getclass.py文件夹,加入以下代码:
代码语言:javascript复制import json
amount_json = "./data/train/amount/gt.json"
date_json = "./data/train/date/gt.json"
with open(amount_json, "r", encoding='utf-8') as f:
load_dict_amount = json.load(f)
with open(date_json, "r", encoding='utf-8') as f:
load_dict_date = json.load(f)
all_dic = {}
all_dic.update(load_dict_amount)
all_dic.update(load_dict_date)
list_key=[]
for keyline in all_dic.values():
for key in keyline:
if key not in list_key:
list_key.append(key)
with open('data/char_std_5990.txt', 'w') as f:
for line in list_key:
f.write(line "n")
执行完就可以得到存储class的txt文件。打开char_std_5990.txt,看到有21个类。
模型改进
crnn的卷积部分类似VGG,我对模型的改进主要有一下几个方面:
1、加入激活函数Swish。
2、加入BatchNorm。
3、加入SE注意力机制。
4、适当加深模型。
代码如下:
代码语言:javascript复制self.cnn = nn.Sequential(
nn.Conv2d(nc, 64, 3, 1, 1), Swish(), nn.BatchNorm2d(64),
nn.MaxPool2d(2, 2), # 64x16x50
nn.Conv2d(64, 128, 3, 1, 1), Swish(), nn.BatchNorm2d(128),
nn.MaxPool2d(2, 2), # 128x8x25
nn.Conv2d(128, 256, 3, 1, 1), nn.BatchNorm2d(256), Swish(), # 256x8x25
nn.Conv2d(256, 256, 3, 1, 1), nn.BatchNorm2d(256), Swish(), # 256x8x25
SELayer(256, 16),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)), # 256x4x25
nn.Conv2d(256, 512, 3, 1, 1), nn.BatchNorm2d(512), Swish(), # 512x4x25
nn.Conv2d(512, 512, 1), nn.BatchNorm2d(512), Swish(),
nn.Conv2d(512, 512, 3, 1, 1), nn.BatchNorm2d(512), Swish(), # 512x4x25
SELayer(512, 16),
nn.MaxPool2d((2, 2), (2, 1), (0, 1)), # 512x2x25
nn.Conv2d(512, 512, 2, 1, 0), nn.BatchNorm2d(512), Swish()) # 512x1x25
SE和Swish
代码语言:javascript复制class SELayer(nn.Module):
def __init__(self, channel, reduction=16):
super(SELayer, self).__init__()
self.avg_pool = nn.AdaptiveAvgPool2d(1)
self.fc = nn.Sequential(
nn.Linear(channel, channel // reduction, bias=True),
nn.LeakyReLU(inplace=True),
nn.Linear(channel // reduction, channel, bias=True),
nn.Sigmoid()
)
def forward(self, x):
b, c, _, _ = x.size()
y = self.avg_pool(x).view(b, c)
y = self.fc(y).view(b, c, 1, 1)
return x * y.expand_as(x)
class Swish(nn.Module):
def forward(self, x):
return x * torch.sigmoid(x)
模型训练
打开train.py ,在训练之前,我们还要调节一下参数。
代码语言:javascript复制parser = argparse.ArgumentParser()
parser.add_argument('--trainlist', default='train.txt')
parser.add_argument('--vallist', default='val.txt')
parser.add_argument('--workers', type=int, help='number of data loading workers', default=0)
parser.add_argument('--batchSize', type=int, default=4, help='input batch size')
parser.add_argument('--imgH', type=int, default=32, help='the height of the input image to network')
parser.add_argument('--imgW', type=int, default=512, help='the width of the input image to network')
parser.add_argument('--nh', type=int, default=512, help='size of the lstm hidden state')
parser.add_argument('--niter', type=int, default=300, help='number of epochs to train for')
parser.add_argument('--lr', type=float, default=0.00005, help='learning rate for Critic, default=0.00005')
parser.add_argument('--beta1', type=float, default=0.5, help='beta1 for adam. default=0.5')
parser.add_argument('--cuda', action='store_true', help='enables cuda', default=True)
parser.add_argument('--ngpu', type=int, default=1, help='number of GPUs to use')
parser.add_argument('--encoder', type=str, default='', help="path to encoder (to continue training)")
parser.add_argument('--decoder', type=str, default='', help='path to decoder (to continue training)')
parser.add_argument('--experiment', default='./expr/attentioncnn', help='Where to store samples and models')
parser.add_argument('--displayInterval', type=int, default=100, help='Interval to be displayed')
parser.add_argument('--valInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--saveInterval', type=int, default=1, help='Interval to be displayed')
parser.add_argument('--adam', default=True, action='store_true', help='Whether to use adam (default is rmsprop)')
parser.add_argument('--adadelta', action='store_true', help='Whether to use adadelta (default is rmsprop)')
parser.add_argument('--keep_ratio',default=True, action='store_true', help='whether to keep ratio for image resize')
parser.add_argument('--random_sample', default=True, action='store_true', help='whether to sample the dataset with random sampler')
parser.add_argument('--teaching_forcing_prob', type=float, default=0.5, help='where to use teach forcing')
parser.add_argument('--max_width', type=int, default=129, help='the width of the featuremap out from cnn')
parser.add_argument("--output_file", default='deep_model.log', type=str, required=False)
opt = parser.parse_args()
- trainlist:训练集,默认是train.txt。
- vallist:验证集路径,默认是val.txt。
- batchSize:批大小,根据显存大小设置。
- imgH:图片的高度,crnn模型默认为32,这里不需要修改。
- imgW:图片宽度,我在这里设置为512。
- keep_ratio:设置为True,设置为True后,程序会保持图片的比率,然后在一个batch内统一尺寸,这样训练的模型精度更高。
- lr:学习率,设置为0.00005,这里要注意,不要太大,否则不收敛。
其他的参数就不一一介绍了,大家可以自行尝试。
运行结果:
运行结果
训练完成后,可以在expr文件夹下面找到模型。
训练的模型
结果预测
在推理之前,我们还需要确认最长的字符串,新建getmax.py,添加如下代码:
代码语言:javascript复制import os
import json
image_path_amount = "./data/train/amount/images"
image_path_date = "./data/train/date/images"
amount_list = os.listdir(image_path_amount)
new_amount_list = []
for filename in amount_list:
new_amount_list.append(image_path_amount "/" filename)
date_list = os.listdir(image_path_date)
new_date_list = []
for filename in date_list:
new_date_list.append(image_path_date "/" filename)
amount_json = "./data/train/amount/gt.json"
date_json = "./data/train/date/gt.json"
with open(amount_json, "r", encoding='utf-8') as f:
load_dict_amount = json.load(f)
with open(date_json, "r", encoding='utf-8') as f:
load_dict_date = json.load(f)
all_list = new_amount_list new_date_list
from sklearn.model_selection import train_test_split
all_dic = {}
all_dic.update(load_dict_amount)
all_dic.update(load_dict_date)
maxLen = 0
for i in all_dic.values():
if (len(i) > maxLen):
maxLen = len(i)
print(maxLen)
运行结果:28
将test.py中的max_length设置为28。
修改模型的路径,包括encoder_path和decoder_path。
代码语言:javascript复制 encoder_path = './expr/attentioncnn/encoder_22.pth'
decoder_path = './expr/attentioncnn/decoder_22.pth'
修改测试集的路径:
代码语言:javascript复制 for path in tqdm(glob.glob('./data/测试集/date/images/*.jpg')):
text, prob = test(path)
if prob<0.8:
count =1
result_dict[os.path.basename(path)] = {
'result': text,
'confidence': prob
}
for path in tqdm(glob.glob('./data/测试集/amount/images/*.jpg')):
text, prob = test(path)
if prob<0.8:
count =1
result_dict[os.path.basename(path)] = {
'result': text,
'confidence': prob
}
写到最后
作者第一次参加OCR相关的赛事,在任务一中取得Top1的好成绩,背后的付出和努力通过方案分享也能看到。近期接触到很多在比赛中拿到不错成绩的小伙伴,不少是第一次尝试。所以,努力后还是可以得到自己满意的结果的。
整理不易,点赞三连↓