未验证 提交 c96f06f2 编写于 作者: L Leo Chen 提交者: GitHub

add unary operator __neg__, test=develop (#21787)

adds unary operator __neg__ for VarBase in dygraph mode, and for Variable in static graph mode.
上级 40f09058
...@@ -35,11 +35,7 @@ def monkey_patch_math_varbase(): ...@@ -35,11 +35,7 @@ def monkey_patch_math_varbase():
""" """
def safe_get_dtype(var): def safe_get_dtype(var):
try: return var.dtype
dtype = var.dtype
except:
raise ValueError("Cannot get data type from %s", var.name)
return dtype
@no_grad @no_grad
def create_tensor(value, dtype, shape): def create_tensor(value, dtype, shape):
...@@ -117,6 +113,9 @@ def monkey_patch_math_varbase(): ...@@ -117,6 +113,9 @@ def monkey_patch_math_varbase():
outs = core.ops.scale(inputs, attrs) outs = core.ops.scale(inputs, attrs)
return outs['Out'][0] return outs['Out'][0]
def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0)
def _scalar_elementwise_add_(var, value): def _scalar_elementwise_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value) return _scalar_elementwise_op_(var, 1.0, value)
...@@ -217,6 +216,7 @@ def monkey_patch_math_varbase(): ...@@ -217,6 +216,7 @@ def monkey_patch_math_varbase():
setattr(core.VarBase, method_name, setattr(core.VarBase, method_name,
_elemwise_method_creator_(method_name, op_type, reverse, _elemwise_method_creator_(method_name, op_type, reverse,
scalar_method)) scalar_method)),
# b = -a
core.VarBase.__neg__ = _neg_
core.VarBase.astype = astype core.VarBase.astype = astype
...@@ -157,6 +157,9 @@ def monkey_patch_variable(): ...@@ -157,6 +157,9 @@ def monkey_patch_variable():
"bias": bias}) "bias": bias})
return out return out
def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0)
def _scalar_elementwise_add_(var, value): def _scalar_elementwise_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value) return _scalar_elementwise_op_(var, 1.0, value)
...@@ -273,5 +276,6 @@ def monkey_patch_variable(): ...@@ -273,5 +276,6 @@ def monkey_patch_variable():
setattr(Variable, method_name, setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse, _elemwise_method_creator_(method_name, op_type, reverse,
scalar_method)) scalar_method))
# b = -a
Variable.__neg__ = _neg_
Variable.astype = astype Variable.astype = astype
...@@ -200,6 +200,19 @@ class TestMathOpPatches(unittest.TestCase): ...@@ -200,6 +200,19 @@ class TestMathOpPatches(unittest.TestCase):
b_np_actual = (a_np / 7).astype('int64') b_np_actual = (a_np / 7).astype('int64')
self.assertTrue(numpy.array_equal(b_np, b_np_actual)) self.assertTrue(numpy.array_equal(b_np, b_np_actual))
@prog_scope()
def test_neg(self):
a = fluid.layers.data(name="a", shape=[10, 1])
b = -a
place = fluid.CPUPlace()
exe = fluid.Executor(place)
a_np = numpy.random.uniform(-1, 1, size=[10, 1]).astype('float32')
b_np = exe.run(fluid.default_main_program(),
feed={"a": a_np},
fetch_list=[b])
self.assertTrue(numpy.allclose(-a_np, b_np))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -201,6 +201,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase): ...@@ -201,6 +201,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
res = (a >= b) res = (a >= b)
self.assertTrue(np.array_equal(res.numpy(), a_np >= b_np)) self.assertTrue(np.array_equal(res.numpy(), a_np >= b_np))
def test_neg(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
res = -a
self.assertTrue(np.array_equal(res.numpy(), -a_np))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册