未验证 提交 4628b6f8 编写于 作者: P pangyoki 提交者: GitHub

[NPU] refine update_loss_scaling npu kernel (#32580)

* refine update_loss_scaling npu kernel

* add mutable_data

* change Zerolike op to MemcpyAsync

* delete useless code

* add found_inf_vec

* add memcpy if not finite

* fix unittest
上级 8a42b1f8
......@@ -15,6 +15,7 @@ limitations under the License. */
#include "paddle/fluid/operators/amp/update_loss_scaling_op.h"
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
......@@ -145,16 +146,43 @@ class LazyZerosNPU {
const std::vector<bool> found_inf_vec,
const std::vector<const framework::Tensor*>& xs,
const std::vector<framework::Tensor*>& outs) const {
if (!xs.size()) {
return;
}
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
Tensor* zero_tensor;
void* zero_ptr;
if (found_inf_vec[0]) {
int max_num = -1;
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
if (found_inf_vec[0]) {
VLOG(4) << "-- UpdateLossScaling: Find infinite grads. --";
int num = out->numel();
if (max_num < num) {
max_num = num;
zero_tensor = out;
}
}
auto place = dev_ctx.GetPlace();
auto stream = dev_ctx.stream();
auto g = out->mutable_data<T>(place);
platform::NPUMemsetAsync(static_cast<void*>(g), 0,
out->numel() * sizeof(T), stream);
zero_tensor->mutable_data<T>(place);
auto runner_zeros =
NpuOpRunner("ZerosLike", {*zero_tensor}, {*zero_tensor});
runner_zeros.Run(stream);
zero_tensor->check_memory_size();
zero_ptr = zero_tensor->data<void>();
}
for (size_t i = 0; i < xs.size(); ++i) {
auto* out = outs[i];
auto* x = xs[i];
auto dst_ptr = out->mutable_data<T>(place);
if (!found_inf_vec[0]) {
framework::TensorCopy(*x, place, dev_ctx, out);
} else if (zero_ptr != dst_ptr) {
auto size = out->numel() * framework::SizeOfType(out->type());
memory::Copy(BOOST_GET_CONST(platform::NPUPlace, place), dst_ptr,
BOOST_GET_CONST(platform::NPUPlace, place), zero_ptr, size,
stream);
}
}
}
......
......@@ -71,8 +71,7 @@ class TestUpdateLossScalingOp(OpTest):
}
def test_check_output(self):
self.check_output_with_place(
self.place, check_dygraph=False, no_check_set=['Out'])
self.check_output_with_place(self.place, check_dygraph=False)
class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
......@@ -103,9 +102,6 @@ class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp):
'OutBadSteps': self.zero_steps
}
def test_check_output(self):
self.check_output_with_place(self.place, check_dygraph=False)
@unittest.skipIf(not paddle.is_compiled_with_npu(),
"core is not compiled with NPU")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册