From 4628b6f8d6da3303207eb8af9ffe4ff67f9ef2f4 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Sat, 8 May 2021 14:51:05 +0800 Subject: [PATCH] [NPU] refine update_loss_scaling npu kernel (#32580) * refine update_loss_scaling npu kernel * add mutable_data * change Zerolike op to MemcpyAsync * delete useless code * add found_inf_vec * add memcpy if not finite * fix unittest --- .../amp/update_loss_scaling_op_npu.cc | 44 +++++++++++++++---- .../npu/test_update_loss_scaling_op_npu.py | 6 +-- 2 files changed, 37 insertions(+), 13 deletions(-) diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc index 45b28bf61e5..820966addfc 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/operators/amp/update_loss_scaling_op.h" #include #include +#include "paddle/fluid/framework/data_type.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -145,16 +146,43 @@ class LazyZerosNPU { const std::vector found_inf_vec, const std::vector& xs, const std::vector& outs) const { + if (!xs.size()) { + return; + } + auto place = dev_ctx.GetPlace(); + auto stream = dev_ctx.stream(); + Tensor* zero_tensor; + void* zero_ptr; + if (found_inf_vec[0]) { + int max_num = -1; + for (size_t i = 0; i < xs.size(); ++i) { + auto* out = outs[i]; + int num = out->numel(); + if (max_num < num) { + max_num = num; + zero_tensor = out; + } + } + + zero_tensor->mutable_data(place); + auto runner_zeros = + NpuOpRunner("ZerosLike", {*zero_tensor}, {*zero_tensor}); + runner_zeros.Run(stream); + zero_tensor->check_memory_size(); + zero_ptr = zero_tensor->data(); + } + for (size_t i = 0; i < xs.size(); ++i) { auto* out = outs[i]; - if (found_inf_vec[0]) { - VLOG(4) << "-- UpdateLossScaling: Find infinite grads. --"; - - auto place = dev_ctx.GetPlace(); - auto stream = dev_ctx.stream(); - auto g = out->mutable_data(place); - platform::NPUMemsetAsync(static_cast(g), 0, - out->numel() * sizeof(T), stream); + auto* x = xs[i]; + auto dst_ptr = out->mutable_data(place); + if (!found_inf_vec[0]) { + framework::TensorCopy(*x, place, dev_ctx, out); + } else if (zero_ptr != dst_ptr) { + auto size = out->numel() * framework::SizeOfType(out->type()); + memory::Copy(BOOST_GET_CONST(platform::NPUPlace, place), dst_ptr, + BOOST_GET_CONST(platform::NPUPlace, place), zero_ptr, size, + stream); } } } diff --git a/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py index 1060e67078f..cae3239229f 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py @@ -71,8 +71,7 @@ class TestUpdateLossScalingOp(OpTest): } def test_check_output(self): - self.check_output_with_place( - self.place, check_dygraph=False, no_check_set=['Out']) + self.check_output_with_place(self.place, check_dygraph=False) class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): @@ -103,9 +102,6 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): 'OutBadSteps': self.zero_steps } - def test_check_output(self): - self.check_output_with_place(self.place, check_dygraph=False) - @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") -- GitLab