【深度学习】序列生成模型(六):评价方法计算实例:计算ROUGE-N得分【理论到程序】

2024-07-30 11:29:08 浏览数 (1)

一、BLEU-N得分(Bilingual Evaluation Understudy)

【深度学习】序列生成模型(五):评价方法计算实例:计算BLEU-N得分

二、ROUGE-N得分(Recall-Oriented Understudy for Gisting Evaluation)

1. 定义

  设

mathbf{x}

为从模型分布

p_{theta}

中生成的一个候选序列,

mathbf{s^{(1)}}, ⋯ , mathbf{s^{(K)}}

为从真实数据分布中采样得到的一组参考序列,

mathcal{W}

为从参考序列中提取N元组合的集合,ROUGE-N算法的定义为:

text{ROUGE-N}(mathbf{x}) = frac{sum_{k=1}^{K} sum_{w in mathcal{W}} min(c_w(mathbf{x}), c_w(mathbf{s}^{(k)}))}{sum_{k=1}^{K} sum_{w in mathcal{W}} c_w(mathbf{s}^{(k))}}

其中

c_w(mathbf{x})

是N元组合

w

在生成序列

mathbf{x}

中出现的次数,

c_w(mathbf{s}^{(k))})

是N元组合

w

在参考序列

mathbf{s}^{(k)}

中出现的次数。

2. 计算

N=1
  • 生成序列
mathbf{x}=text{the cat sat on the mat}
  • 参考序列
    mathbf{s}^{(1)}=text{the cat is on the mat}
    mathbf{s}^{(2)}=text{the bird sat on the bush}
mathcal{W}=text{ {the, cat, is, on, mat, bird, sat, bush }}

w w w

c w ( x ) c_w(mathbf{x}) cw​(x)

c w ( s ( 1 ) ) c_w(mathbf{s^{(1)}}) cw​(s(1))

c w ( s ( 2 ) ) c_w(mathbf{s^{(2)}}) cw​(s(2))

min ⁡ ( c w ( x ) , c w ( s ( 1 ) ) min(c_w(mathbf{x}), c_w(mathbf{s}^{(1)}) min(cw​(x),cw​(s(1))

min ⁡ ( c w ( x ) , c w ( s ( 2 ) ) min(c_w(mathbf{x}), c_w(mathbf{s}^{(2)}) min(cw​(x),cw​(s(2))

the

2

2

2

2

2

cat

1

1

0

1

0

is

0

1

0

0

0

on

1

1

1

1

1

mat

1

1

0

1

0

bird

0

0

1

0

0

sat

1

0

1

0

1

bush

0

0

1

0

0

w
c_w(mathbf{x})
c_w(mathbf{s^{(1)}})
c_w(mathbf{s^{(2)}})
min(c_w(mathbf{x}), c_w(mathbf{s}^{(1)})
min(c_w(mathbf{x}), c_w(mathbf{s}^{(2)})

the22222cat11010is01000on11111mat11010bird00100sat10101bush00100

  • 分子
sum_{k=1}^{K} sum_{w in mathcal{W}} min(c_w(mathbf{x}), c_w(mathbf{s}^{(k)}))
  • 分母
sum_{k=1}^{K} sum_{w in mathcal{W}} c_w(mathbf{s}^{(k)})
text{ROUGE-N}(mathbf{x}) = frac{sum_{k=1}^{K} sum_{w in mathcal{W}} min(c_w(mathbf{x}), c_w(mathbf{s}^{(k)}))}{sum_{k=1}^{K} sum_{w in mathcal{W}} c_w(mathbf{s}^{(k))}}=frac{5 4}{6 6}=frac{9}{12}=0.75
N=2
  • 生成序列
mathbf{x}=text{the cat sat on the mat}
  • 参考序列
    mathbf{s}^{(1)}=text{the cat is on the mat}
    mathbf{s}^{(2)}=text{the bird sat on the bush}
mathcal{W}=text{ {the cat, cat is, is on, on the, the mat, the bird, bird sat, sat on, the bush }}

w w w

c w ( x ) c_w(mathbf{x}) cw​(x)

c w ( s ( 1 ) ) c_w(mathbf{s^{(1)}}) cw​(s(1))

c w ( s ( 2 ) ) c_w(mathbf{s^{(2)}}) cw​(s(2))

min ⁡ ( c w ( x ) , c w ( s ( 1 ) ) min(c_w(mathbf{x}), c_w(mathbf{s}^{(1)}) min(cw​(x),cw​(s(1))

min ⁡ ( c w ( x ) , c w ( s ( 2 ) ) min(c_w(mathbf{x}), c_w(mathbf{s}^{(2)}) min(cw​(x),cw​(s(2))

the cat

1

1

0

1

0

cat is

0

1

0

0

0

is on

0

1

0

0

0

on the

1

1

1

1

1

the mat

1

1

0

0

0

the bird

0

0

1

0

0

bird sat

0

0

1

0

0

sat on

1

0

1

1

1

the bush

0

0

1

0

0

w
c_w(mathbf{x})
c_w(mathbf{s^{(1)}})
c_w(mathbf{s^{(2)}})
min(c_w(mathbf{x}), c_w(mathbf{s}^{(1)})
min(c_w(mathbf{x}), c_w(mathbf{s}^{(2)})

the cat11010cat is01000is on01000on the11111the mat11000the bird00100bird sat00100sat on10111the bush00100

  • 分子
sum_{k=1}^{K} sum_{w in mathcal{W}} min(c_w(mathbf{x}), c_w(mathbf{s}^{(k)}))
  • 分母
sum_{k=1}^{K} sum_{w in mathcal{W}} c_w(mathbf{s}^{(k)})
text{ROUGE-N}(mathbf{x}) = frac{sum_{k=1}^{K} sum_{w in mathcal{W}} min(c_w(mathbf{x}), c_w(mathbf{s}^{(k)}))}{sum_{k=1}^{K} sum_{w in mathcal{W}} c_w(mathbf{s}^{(k))}}=frac{3 2}{5 5}=frac{5}{10}=0.5

3. 程序

代码语言:javascript复制
main_string = 'the cat sat on the mat'
string1 = 'the cat is on the mat'
string2 = 'the bird sat on the bush'

words = list(set(string1.split(' ') string2.split(' ')))  # 去除重复元素

total_occurrences, matching_occurrences = 0, 0
for word in words:
    matching_occurrences  = min(main_string.count(word), string1.count(word))   min(main_string.count(word), string2.count(word))
    total_occurrences  = string1.count(word)   string2.count(word)

print(matching_occurrences / total_occurrences)

bigrams = []
split1 = string1.split(' ')
for i in range(len(split1) - 1):
    bigrams.append(split1[i]   ' '   split1[i   1])

split2 = string2.split(' ')
for i in range(len(split2) - 1):
    bigrams.append(split2[i]   ' '   split2[i   1])

bigrams = list(set(bigrams))  # 去除重复元素

total_occurrences, matching_occurrences = 0, 0
for bigram in bigrams:
    matching_occurrences  = min(main_string.count(bigram), string1.count(bigram))   min(main_string.count(bigram), string2.count(bigram))
    total_occurrences  = string1.count(bigram)   string2.count(bigram)

print(matching_occurrences / total_occurrences)

输出:

代码语言:javascript复制
0.75
0.5

0 人点赞