keras支持模型多输入多输出,本文记录多输出时loss、loss weight和metrics的设置方式。
模型输出
假设模型具有多个输出
- classify: 二维数组,分类softmax输出,需要配置交叉熵损失
- segmentation:与输入同尺寸map,sigmoid输出,需要配置二分类损失
- others:自定义其他输出,需要自定义损失
具体配置
model
- 变量均为模型中网络层
inputs = [input_1 , input_2]
outputs = [classify, segmentation, others]
model = keras.models.Model(inputs, outputs)
loss
代码语言:javascript复制my_loss = {
'classify': 'categorical_crossentropy',
'segmentation':'binary_crossentropy',
'others':my_loss_fun}
loss weight
代码语言:javascript复制my_loss_weights = {
'classify':1,
'segmentation':1,
'others':10}
metrics
代码语言:javascript复制my_metrics ={
'classify':'acc',
'segmentation':[mean_iou,'acc'],
'others':['mse','acc']
}
编译
代码语言:javascript复制model.compile(optimizer=Adam(lr=config.LEARNING_RATE), loss=my_loss, loss_weights= my_loss_weights, metrics= my_metrics)