From 3abe84500ef62aff063e0a74c9bd907ef3f404c6 Mon Sep 17 00:00:00 2001 From: whs Date: Thu, 5 Sep 2019 15:22:02 +0800 Subject: [PATCH] Fix data type of variable in edit distance evaluator (#19618) * Fix data type of variable in edit distance evaluator. test=develop * Add unitest for edit_distance python API. test=develop --- python/paddle/fluid/evaluator.py | 2 +- python/paddle/fluid/tests/unittests/test_layers.py | 9 +++++++++ 2 files changed, 10 insertions(+), 1 deletion(-) diff --git a/python/paddle/fluid/evaluator.py b/python/paddle/fluid/evaluator.py index 5a4c838f75b..80ac91575f6 100644 --- a/python/paddle/fluid/evaluator.py +++ b/python/paddle/fluid/evaluator.py @@ -263,7 +263,7 @@ class EditDistance(Evaluator): zero = layers.fill_constant(shape=[1], value=0.0, dtype='float32') compare_result = layers.equal(distances, zero) - compare_result_int = layers.cast(x=compare_result, dtype='int') + compare_result_int = layers.cast(x=compare_result, dtype='int64') seq_right_count = layers.reduce_sum(compare_result_int) instance_error_count = layers.elementwise_sub( x=seq_num, y=seq_right_count) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 4a5e6b1a78d..cf3821b78fd 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -2345,6 +2345,15 @@ class TestBook(LayerTest): label_length=label_length) return (output) + def test_edit_distance(self): + with self.static_graph(): + predict = layers.data( + name='predict', shape=[-1, 1], dtype='int64', lod_level=1) + label = layers.data( + name='label', shape=[-1, 1], dtype='int64', lod_level=1) + evaluator = fluid.evaluator.EditDistance(predict, label) + return evaluator.metrics + if __name__ == '__main__': unittest.main() -- GitLab