提交 a8f118ca 编写于 作者: W wanghaoshuang

Add EditDistance to evaluator.py

上级 680aec21
......@@ -218,21 +218,23 @@ class EditDistance(Evaluator):
raise ValueError("You can only invoke Evaluator in root block")
self.total_error = self.create_state(
dtype='int64', shape=[1], suffix='total')
self.batch_num = 0
dtype='float32', shape=[1], suffix='total')
self.batch_num = self.create_state(
dtype='float32', shape=[1], suffix='total')
error = layers.edit_distance(input=input, label=label)
mean_error = layers.mean(input=error)
error = layers.cast(x=error, dtype='float32')
mean_error = layers.mean(x=error)
layers.sums(input=[self.total_error, mean_error], out=self.total_error)
const1 = layers.fill_constant(shape=[1], value=1.0, dtype="float32")
layers.sums(input=[self.batch_num, const1], out=self.batch_num)
self.metrics.append(mean_error)
def eval(self, executor, eval_program=None):
self.batch_num += 1
if eval_program is None:
eval_program = Program()
block = eval_program.current_block()
with program_guard(main_program=eval_program):
total_error = _clone_var_(block, self.total_error)
batch_num = layers.fill_constant(
shape=[1], value=self.batch_num, dtype="float32")
batch_num = _clone_var_(block, self.batch_num)
out = layers.elementwise_div(x=total_error, y=batch_num)
return np.array(executor.run(eval_program, fetch_list=[out])[0])
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册