提交 436fbb0d 编写于 作者: Q qijun

Merge remote-tracking branch 'baidu/develop' into refine_op_test

...@@ -11,9 +11,9 @@ __all__ = ['get_numeric_gradient'] ...@@ -11,9 +11,9 @@ __all__ = ['get_numeric_gradient']
def create_op(op_type): def create_op(op_type):
# TODO need to set attrs # TODO need to set attrs
kwargs = dict() kwargs = dict()
for in_name, _ in Operator.get_op_input_names(op_type): for in_name in Operator.get_op_input_names(op_type):
kwargs[in_name] = in_name kwargs[in_name] = in_name
for out_name, _ in Operator.get_op_output_names(op_type): for out_name in Operator.get_op_output_names(op_type):
kwargs[out_name] = out_name kwargs[out_name] = out_name
return Operator(op_type, **kwargs) return Operator(op_type, **kwargs)
......
...@@ -27,9 +27,7 @@ class OpTestMeta(type): ...@@ -27,9 +27,7 @@ class OpTestMeta(type):
places.append(core.GPUPlace(0)) places.append(core.GPUPlace(0))
for place in places: for place in places:
for ins in Operator.get_op_input_names(self.type): for in_name, in_dup in Operator.get_op_inputs(self.type):
in_name = ins[0]
in_dup = ins[1]
if hasattr(self, 'inputs') and in_name in self.inputs: if hasattr(self, 'inputs') and in_name in self.inputs:
kwargs[in_name] = [] kwargs[in_name] = []
if in_dup: if in_dup:
...@@ -49,8 +47,7 @@ class OpTestMeta(type): ...@@ -49,8 +47,7 @@ class OpTestMeta(type):
else: else:
kwargs[in_name] = "@EMPTY@" kwargs[in_name] = "@EMPTY@"
for out_name, out_dup in Operator.get_op_output_names( for out_name, out_dup in Operator.get_op_outputs(self.type):
self.type):
if not hasattr(self, "outputs"): if not hasattr(self, "outputs"):
raise ValueError( raise ValueError(
"The test op must set self.outputs dict.") "The test op must set self.outputs dict.")
...@@ -73,8 +70,7 @@ class OpTestMeta(type): ...@@ -73,8 +70,7 @@ class OpTestMeta(type):
ctx = core.DeviceContext.create(place) ctx = core.DeviceContext.create(place)
op.run(scope, ctx) op.run(scope, ctx)
for out_name, out_dup in Operator.get_op_output_names( for out_name, out_dup in Operator.get_op_outputs(self.type):
self.type):
actual = numpy.array(scope.find_var(out_name).get_tensor()) actual = numpy.array(scope.find_var(out_name).get_tensor())
expect = self.outputs[out_name] expect = self.outputs[out_name]
self.assertTrue( self.assertTrue(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册