未验证 提交 0bbff93c 编写于 作者: P pangyoki 提交者: GitHub

fix elementwise_div npu op (#35700)

上级 9861af7a
...@@ -63,15 +63,22 @@ class ElementwiseDivGradNPUKernel : public framework::OpKernel<T> { ...@@ -63,15 +63,22 @@ class ElementwiseDivGradNPUKernel : public framework::OpKernel<T> {
ctx.template device_context<paddle::platform::NPUDeviceContext>() ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream(); .stream();
Tensor y_power(y->type());
y_power.mutable_data<T>(y->dims(), place);
const auto& runner_y_power = NpuOpRunner(
"Power", {*y}, {y_power}, {{"power", static_cast<float>(-1)}});
runner_y_power.Run(stream);
if (dx) { if (dx) {
dx->mutable_data<T>(place); dx->mutable_data<T>(place);
Tensor tensor_one(y->type());
tensor_one.mutable_data<float>({1}, place);
FillNpuTensorWithConstant<float>(&tensor_one, static_cast<float>(1.0));
// Use `Div` CANN OP to achieve `1/y` instead of `Power` CANN OP.
// Because `Power` will cause precision overflow, that is, `float_status`
// will be set to 1.
Tensor y_div(y->type());
y_div.mutable_data<T>(y->dims(), place);
const auto& runner_one_div_y =
NpuOpRunner("Div", {tensor_one, *y}, {y_div}, {});
runner_one_div_y.Run(stream);
Tensor tensor_zeros(x->type()); Tensor tensor_zeros(x->type());
tensor_zeros.mutable_data<T>(x->dims(), place); tensor_zeros.mutable_data<T>(x->dims(), place);
const auto& runner_tensor_zeros = const auto& runner_tensor_zeros =
...@@ -100,7 +107,7 @@ class ElementwiseDivGradNPUKernel : public framework::OpKernel<T> { ...@@ -100,7 +107,7 @@ class ElementwiseDivGradNPUKernel : public framework::OpKernel<T> {
Tensor x_grad_w(x->type()); Tensor x_grad_w(x->type());
x_grad_w.mutable_data<T>(x->dims(), place); x_grad_w.mutable_data<T>(x->dims(), place);
const auto& runner_x_grad_w = const auto& runner_x_grad_w =
NpuOpRunner("Mul", {x_nozero_f, y_power}, {x_grad_w}, {}); NpuOpRunner("Mul", {x_nozero_f, y_div}, {x_grad_w}, {});
runner_x_grad_w.Run(stream); runner_x_grad_w.Run(stream);
const auto& runner_x_grad = const auto& runner_x_grad =
......
...@@ -21,6 +21,7 @@ sys.path.append("..") ...@@ -21,6 +21,7 @@ sys.path.append("..")
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from paddle.fluid.core import ops
paddle.enable_static() paddle.enable_static()
SEED = 2021 SEED = 2021
...@@ -173,5 +174,30 @@ class TestElementwiseDivNet(unittest.TestCase): ...@@ -173,5 +174,30 @@ class TestElementwiseDivNet(unittest.TestCase):
self.assertTrue(np.allclose(npu_loss, cpu_loss)) self.assertTrue(np.allclose(npu_loss, cpu_loss))
class TestFloatStatus(unittest.TestCase):
def test_overflow(self):
paddle.disable_static()
paddle.set_device('npu')
flag = paddle.zeros([8])
ops.clear_float_status(flag, flag)
self.assertEqual(flag.numpy().sum(), 0.0)
x = paddle.to_tensor([12.564], stop_gradient=False)
y = paddle.to_tensor([2.], stop_gradient=False)
z = x / y
out = 32768. * z
ops.get_float_status(flag, flag)
self.assertEqual(flag.numpy().sum(), 0.0)
out.sum().backward()
ops.get_float_status(flag, flag)
self.assertEqual(flag.numpy().sum(), 0.0)
paddle.enable_static()
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册