提交 436fbb0d 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into refine_op_test

......@@ -11,9 +11,9 @@ __all__ = ['get_numeric_gradient']
def create_op(op_type):
# TODO need to set attrs
kwargs = dict()
for in_name, _ in Operator.get_op_input_names(op_type):
for in_name in Operator.get_op_input_names(op_type):
kwargs[in_name] = in_name
for out_name, _ in Operator.get_op_output_names(op_type):
for out_name in Operator.get_op_output_names(op_type):
kwargs[out_name] = out_name
return Operator(op_type, **kwargs)
......
......@@ -27,9 +27,7 @@ class OpTestMeta(type):
places.append(core.GPUPlace(0))
for place in places:
for ins in Operator.get_op_input_names(self.type):
in_name = ins[0]
in_dup = ins[1]
for in_name, in_dup in Operator.get_op_inputs(self.type):
if hasattr(self, 'inputs') and in_name in self.inputs:
kwargs[in_name] = []
if in_dup:
......@@ -49,8 +47,7 @@ class OpTestMeta(type):
else:
kwargs[in_name] = "@EMPTY@"
for out_name, out_dup in Operator.get_op_output_names(
self.type):
for out_name, out_dup in Operator.get_op_outputs(self.type):
if not hasattr(self, "outputs"):
raise ValueError(
"The test op must set self.outputs dict.")
......@@ -73,8 +70,7 @@ class OpTestMeta(type):
ctx = core.DeviceContext.create(place)
op.run(scope, ctx)
for out_name, out_dup in Operator.get_op_output_names(
self.type):
for out_name, out_dup in Operator.get_op_outputs(self.type):
actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name]
self.assertTrue(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册