未验证 提交 c96f06f2 编写于 作者: L Leo Chen 提交者: GitHub

add unary operator __neg__, test=develop (#21787)

adds unary operator __neg__ for VarBase in dygraph mode, and for Variable in static graph mode.
上级 40f09058
......@@ -35,11 +35,7 @@ def monkey_patch_math_varbase():
"""
def safe_get_dtype(var):
try:
dtype = var.dtype
except:
raise ValueError("Cannot get data type from %s", var.name)
return dtype
return var.dtype
@no_grad
def create_tensor(value, dtype, shape):
......@@ -117,6 +113,9 @@ def monkey_patch_math_varbase():
outs = core.ops.scale(inputs, attrs)
return outs['Out'][0]
def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0)
def _scalar_elementwise_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)
......@@ -217,6 +216,7 @@ def monkey_patch_math_varbase():
setattr(core.VarBase, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
scalar_method)),
# b = -a
core.VarBase.__neg__ = _neg_
core.VarBase.astype = astype
......@@ -157,6 +157,9 @@ def monkey_patch_variable():
"bias": bias})
return out
def _neg_(var):
return _scalar_elementwise_op_(var, -1.0, 0.0)
def _scalar_elementwise_add_(var, value):
return _scalar_elementwise_op_(var, 1.0, value)
......@@ -273,5 +276,6 @@ def monkey_patch_variable():
setattr(Variable, method_name,
_elemwise_method_creator_(method_name, op_type, reverse,
scalar_method))
# b = -a
Variable.__neg__ = _neg_
Variable.astype = astype
......@@ -200,6 +200,19 @@ class TestMathOpPatches(unittest.TestCase):
b_np_actual = (a_np / 7).astype('int64')
self.assertTrue(numpy.array_equal(b_np, b_np_actual))
@prog_scope()
def test_neg(self):
a = fluid.layers.data(name="a", shape=[10, 1])
b = -a
place = fluid.CPUPlace()
exe = fluid.Executor(place)
a_np = numpy.random.uniform(-1, 1, size=[10, 1]).astype('float32')
b_np = exe.run(fluid.default_main_program(),
feed={"a": a_np},
fetch_list=[b])
self.assertTrue(numpy.allclose(-a_np, b_np))
if __name__ == '__main__':
unittest.main()
......@@ -201,6 +201,13 @@ class TestMathOpPatchesVarBase(unittest.TestCase):
res = (a >= b)
self.assertTrue(np.array_equal(res.numpy(), a_np >= b_np))
def test_neg(self):
a_np = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
with fluid.dygraph.guard():
a = fluid.dygraph.to_variable(a_np)
res = -a
self.assertTrue(np.array_equal(res.numpy(), -a_np))
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册