未验证 提交 ae93d9c2 编写于 作者: C chentianyu03 提交者: GitHub

change '/' method from scale Op to elementwise_div Op (#33279)

* fix the bug of div operation result using scale method do not exactly equal the result of elementwise_div method

* remove __div__ , __rdiv__ methods which do not define in python3

* modify the note

* add test case

* add test case
上级 44054bad
...@@ -46,9 +46,7 @@ _supported_promote_complex_types_ = [ ...@@ -46,9 +46,7 @@ _supported_promote_complex_types_ = [
'__rsub__', '__rsub__',
'__mul__', '__mul__',
'__rmul__', '__rmul__',
'__div__',
'__truediv__', '__truediv__',
'__rdiv__',
'__rtruediv__', '__rtruediv__',
'__matmul__', '__matmul__',
] ]
...@@ -168,9 +166,6 @@ def monkey_patch_math_varbase(): ...@@ -168,9 +166,6 @@ def monkey_patch_math_varbase():
def _scalar_mul_(var, value): def _scalar_mul_(var, value):
return _scalar_elementwise_op_(var, value, 0.0) return _scalar_elementwise_op_(var, value, 0.0)
def _scalar_div_(var, value):
return _scalar_elementwise_op_(var, 1.0 / value, 0.0)
# for binary operator such as elementwise, compare # for binary operator such as elementwise, compare
def _binary_creator_(method_name, def _binary_creator_(method_name,
op_type, op_type,
...@@ -201,7 +196,10 @@ def monkey_patch_math_varbase(): ...@@ -201,7 +196,10 @@ def monkey_patch_math_varbase():
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32') self = astype(self, 'float32')
# here use `scale` replace `elementwise` to get better performance # here use `scale` replace `elementwise` to get better performance
# but only +, -, *, / can use this method # but only +, -, * can use this method
# NOTE(chentianyu03): / can not use `scale` method,because the result of
# `scale` method (self*(1/other_var)) do not exactly equal with the result
# of `elementwise_div` method.
if scalar_method is not None: if scalar_method is not None:
return scalar_method(self, other_var) return scalar_method(self, other_var)
else: else:
...@@ -288,12 +286,8 @@ def monkey_patch_math_varbase(): ...@@ -288,12 +286,8 @@ def monkey_patch_math_varbase():
## a*b == b*a. Do not need to reverse explicitly ## a*b == b*a. Do not need to reverse explicitly
('__rmul__', ('__rmul__',
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)), _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
_scalar_div_)),
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
False, _scalar_div_)), False, None)),
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
None)),
('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True, ('__rtruediv__', _binary_creator_('rtruediv__', 'elementwise_div', True,
None)), None)),
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
......
...@@ -39,9 +39,7 @@ EXPRESSION_MAP = { ...@@ -39,9 +39,7 @@ EXPRESSION_MAP = {
"__rsub__": "A -= B", "__rsub__": "A -= B",
"__mul__": "A * B", "__mul__": "A * B",
"__rmul__": "A *= B", "__rmul__": "A *= B",
"__div__": "A / B",
"__truediv__": "A / B", "__truediv__": "A / B",
"__rdiv__": "A /= B",
"__rtruediv__": "A /= B", "__rtruediv__": "A /= B",
"__pow__": "A ** B", "__pow__": "A ** B",
"__rpow__": "A **= B", "__rpow__": "A **= B",
...@@ -209,9 +207,6 @@ def monkey_patch_variable(): ...@@ -209,9 +207,6 @@ def monkey_patch_variable():
def _scalar_mul_(var, value): def _scalar_mul_(var, value):
return _scalar_op_(var, value, 0.0) return _scalar_op_(var, value, 0.0)
def _scalar_div_(var, value):
return _scalar_op_(var, 1.0 / value, 0.0)
def _binary_creator_(method_name, def _binary_creator_(method_name,
op_type, op_type,
reverse=False, reverse=False,
...@@ -241,7 +236,10 @@ def monkey_patch_variable(): ...@@ -241,7 +236,10 @@ def monkey_patch_variable():
if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_: if op_type == 'elementwise_div' and self.dtype in _supported_int_dtype_:
self = astype(self, 'float32') self = astype(self, 'float32')
# here use `scale` replace `elementwise` to get better performance # here use `scale` replace `elementwise` to get better performance
# but only +, -, *, / can use this method # but only +, -, * can use this method
# NOTE(chentianyu03): / can not use `scale` method,because the result of
# `scale` method (self*(1/other_var)) do not exactly equal with the result
# of `elementwise_div` method.
if scalar_method is not None: if scalar_method is not None:
return scalar_method(self, other_var) return scalar_method(self, other_var)
else: else:
...@@ -337,12 +335,8 @@ def monkey_patch_variable(): ...@@ -337,12 +335,8 @@ def monkey_patch_variable():
# a*b == b*a. Do not need to reverse explicitly # a*b == b*a. Do not need to reverse explicitly
('__rmul__', ('__rmul__',
_binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)), _binary_creator_('__rmul__', 'elementwise_mul', False, _scalar_mul_)),
('__div__', _binary_creator_('__div__', 'elementwise_div', False,
_scalar_div_)),
('__truediv__', _binary_creator_('__truediv__', 'elementwise_div', ('__truediv__', _binary_creator_('__truediv__', 'elementwise_div',
False, _scalar_div_)), False, None)),
('__rdiv__', _binary_creator_('__rdiv__', 'elementwise_div', True,
None)),
('__rtruediv__', _binary_creator_('__rtruediv__', 'elementwise_div', ('__rtruediv__', _binary_creator_('__rtruediv__', 'elementwise_div',
True, None)), True, None)),
('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False, ('__pow__', _binary_creator_('__pow__', 'elementwise_pow', False,
......
...@@ -187,6 +187,13 @@ class TestTensorScalarTypePromotionDynamic(unittest.TestCase): ...@@ -187,6 +187,13 @@ class TestTensorScalarTypePromotionDynamic(unittest.TestCase):
c = paddle.full([2, 2, 2], 0.5, dtype="float32") c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/') self.check_operation(a, b, c, '/')
# tensor(float32) / scalar(int)
# this behavior should be equal to elementwise_div Op
a = paddle.to_tensor([99, 99, 99], dtype='float32')
b = 100
c = a / paddle.to_tensor([100, 100, 100], dtype='float32')
self.check_operation(a, b, c, '/')
# tensor(int64) / scalar(float, .0) # tensor(int64) / scalar(float, .0)
a = paddle.ones([2, 2, 2], dtype='int64') a = paddle.ones([2, 2, 2], dtype='int64')
b = 2.0 b = 2.0
......
...@@ -218,6 +218,12 @@ class TestTensorScalarTypePromotionStatic(unittest.TestCase): ...@@ -218,6 +218,12 @@ class TestTensorScalarTypePromotionStatic(unittest.TestCase):
c = paddle.full([2, 2, 2], 0.5, dtype="float32") c = paddle.full([2, 2, 2], 0.5, dtype="float32")
self.check_operation(a, b, c, '/') self.check_operation(a, b, c, '/')
# this behavior should be equal to elementwise_div Op
a = paddle.full([2, 2, 2], 99, dtype="float32")
b = 100
c = a / paddle.full([2, 2, 2], 100, dtype="float32")
self.check_operation(a, b, c, '/')
# tensor(int64) / scalar(float, .0) # tensor(int64) / scalar(float, .0)
with program_guard(Program()): with program_guard(Program()):
a = paddle.ones([2, 2, 2], dtype='int64') a = paddle.ones([2, 2, 2], dtype='int64')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册