未验证 提交 3abe8450 编写于 作者: W whs 提交者: GitHub

Fix data type of variable in edit distance evaluator (#19618)

* Fix data type of variable in edit distance evaluator.
test=develop

* Add unitest for edit_distance python API.
test=develop
上级 42b5bec6
......@@ -263,7 +263,7 @@ class EditDistance(Evaluator):
zero = layers.fill_constant(shape=[1], value=0.0, dtype='float32')
compare_result = layers.equal(distances, zero)
compare_result_int = layers.cast(x=compare_result, dtype='int')
compare_result_int = layers.cast(x=compare_result, dtype='int64')
seq_right_count = layers.reduce_sum(compare_result_int)
instance_error_count = layers.elementwise_sub(
x=seq_num, y=seq_right_count)
......
......@@ -2345,6 +2345,15 @@ class TestBook(LayerTest):
label_length=label_length)
return (output)
def test_edit_distance(self):
with self.static_graph():
predict = layers.data(
name='predict', shape=[-1, 1], dtype='int64', lod_level=1)
label = layers.data(
name='label', shape=[-1, 1], dtype='int64', lod_level=1)
evaluator = fluid.evaluator.EditDistance(predict, label)
return evaluator.metrics
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册