diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..1385a3182fd34bc6ea5a4f6521eaef95d854eb27 --- /dev/null +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc @@ -0,0 +1,219 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/amp/update_loss_scaling_op.h" +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +void Update(const platform::NPUDeviceContext& ctx, + const std::vector found_inf_vec, + const Tensor* pre_loss_scaling_tensor, const Tensor* good_in_tensor, + const Tensor* bad_in_tensor, const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, const float incr_ratio, + const float decr_ratio, Tensor* updated_loss_scaling_tensor, + Tensor* good_out_tensor, Tensor* bad_out_tensor) { + auto place = ctx.GetPlace(); + auto stream = ctx.stream(); + if (found_inf_vec[0]) { + // good_out_data = 0 + auto g = good_out_tensor->mutable_data(place); + platform::NPUMemsetAsync(static_cast(g), 0, + good_out_tensor->numel() * sizeof(int), stream); + // bad_out_data = bad_in_data + 1 + Tensor factor_tensor(bad_out_tensor->type()); + factor_tensor.mutable_data({1}, place); + TensorFromVector(std::vector{1}, ctx, &factor_tensor); + auto runner_p2 = NpuOpRunner("Add", {*bad_in_tensor, factor_tensor}, + {*bad_out_tensor}, {}); + runner_p2.Run(stream); + + std::vector bad_out_data; + TensorToVector(*bad_out_tensor, ctx, &bad_out_data); + if (bad_out_data[0] == decr_every_n_nan_or_inf) { + auto runner_p3 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, + {*updated_loss_scaling_tensor}, + {{"power", static_cast(1)}, + {"scale", decr_ratio}, + {"shift", static_cast(0)}}); + + runner_p3.Run(stream); + + std::vector new_loss_scaling; + TensorToVector(*updated_loss_scaling_tensor, ctx, &new_loss_scaling); + if (new_loss_scaling[0] < static_cast(1)) { + // updated_loss_scaling_data = 1 + auto runner_p4 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, + {*updated_loss_scaling_tensor}, + {{"power", static_cast(1)}, + {"scale", static_cast(0)}, + {"shift", static_cast(1)}}); + + runner_p4.Run(stream); + } + + // bad_out_data = 0 + auto b = bad_out_tensor->mutable_data(place); + platform::NPUMemsetAsync(static_cast(b), 0, + bad_out_tensor->numel() * sizeof(int), stream); + } + } else { + // bad_out_data = 0 + auto b = bad_out_tensor->mutable_data(place); + platform::NPUMemsetAsync(static_cast(b), 0, + bad_out_tensor->numel() * sizeof(int), stream); + + // good_out_data = good_in_data + 1 + Tensor factor_tensor(good_out_tensor->type()); + factor_tensor.mutable_data({1}, place); + TensorFromVector(std::vector{1}, ctx, &factor_tensor); + auto runner_p2 = NpuOpRunner("Add", {*good_in_tensor, factor_tensor}, + {*good_out_tensor}, {}); + runner_p2.Run(stream); + + std::vector good_out_data; + TensorToVector(*good_out_tensor, ctx, &good_out_data); + + if (good_out_data[0] == incr_every_n_steps) { + auto runner_p3 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, + {*updated_loss_scaling_tensor}, + {{"power", static_cast(1)}, + {"scale", incr_ratio}, + {"shift", static_cast(0)}}); + runner_p3.Run(stream); + + std::vector new_loss_scaling; + TensorToVector(*updated_loss_scaling_tensor, ctx, &new_loss_scaling); + if (!std::isfinite(new_loss_scaling[0])) { + // updated_loss_scaling_data = pre_loss_scaling_data + auto runner_p4 = NpuOpRunner("Power", {*pre_loss_scaling_tensor}, + {*updated_loss_scaling_tensor}, + {{"power", static_cast(1)}, + {"scale", static_cast(1)}, + {"shift", static_cast(0)}}); + + runner_p4.Run(stream); + } + // good_out_data = 0 + auto g = good_out_tensor->mutable_data(place); + platform::NPUMemsetAsync(static_cast(g), 0, + good_out_tensor->numel() * sizeof(int), stream); + } + } +} + +template +class UpdateLossScalingFunctor { + public: + void operator()(const platform::NPUDeviceContext& dev_ctx, + const std::vector found_inf_vec, + const Tensor* pre_loss_scaling_tensor, + const Tensor* good_in_tensor, const Tensor* bad_in_tensor, + const int incr_every_n_steps, + const int decr_every_n_nan_or_inf, const float incr_ratio, + const float decr_ratio, Tensor* updated_loss_scaling_tensor, + Tensor* good_out_tensor, Tensor* bad_out_tensor) const { + Update(dev_ctx, found_inf_vec, pre_loss_scaling_tensor, good_in_tensor, + bad_in_tensor, incr_every_n_steps, decr_every_n_nan_or_inf, + incr_ratio, decr_ratio, updated_loss_scaling_tensor, + good_out_tensor, bad_out_tensor); + } +}; + +template +class LazyZerosNPU { + public: + void operator()(const platform::NPUDeviceContext& dev_ctx, + const std::vector found_inf_vec, + const std::vector& xs, + const std::vector& outs) const { + for (size_t i = 0; i < xs.size(); ++i) { + auto* out = outs[i]; + if (found_inf_vec[0]) { + VLOG(1) << "-- UpdateLossScaling: Find infinite grads. --"; + + auto place = dev_ctx.GetPlace(); + auto stream = dev_ctx.stream(); + auto g = out->mutable_data(place); + platform::NPUMemsetAsync(static_cast(g), 0, + out->numel() * sizeof(int), stream); + } + } + } +}; + +template +class UpdateLossScalingNPUKernel : public framework::OpKernel { + using MPDType = typename details::MPTypeTrait::Type; + + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto& dev_ctx = ctx.template device_context(); + + const auto xs = ctx.MultiInput("X"); + auto outs = ctx.MultiOutput("Out"); + const auto* found_inf = ctx.Input("FoundInfinite"); + PADDLE_ENFORCE_EQ(found_inf->numel(), 1, + platform::errors::InvalidArgument( + "FoundInfinite must has only one element.")); + + std::vector found_inf_vec; + TensorToVector(*found_inf, ctx.device_context(), &found_inf_vec); + + LazyZerosNPU{}(dev_ctx, found_inf_vec, xs, outs); + const bool stop_update = ctx.Attr("stop_update"); + if (stop_update) { + return; + } + + const auto* pre_loss_scaling = ctx.Input("PrevLossScaling"); + const auto* good_in = ctx.Input("InGoodSteps"); + const auto* bad_in = ctx.Input("InBadSteps"); + auto* updated_loss_scaling = ctx.Output("LossScaling"); + auto* good_out = ctx.Output("OutGoodSteps"); + auto* bad_out = ctx.Output("OutBadSteps"); + + updated_loss_scaling->mutable_data(dev_ctx.GetPlace()); + good_out->mutable_data(dev_ctx.GetPlace()); + bad_out->mutable_data(dev_ctx.GetPlace()); + + const int incr_every_n_steps = ctx.Attr("incr_every_n_steps"); + const int decr_every_n_nan_or_inf = + ctx.Attr("decr_every_n_nan_or_inf"); + const float incr_ratio = ctx.Attr("incr_ratio"); + const float decr_ratio = ctx.Attr("decr_ratio"); + UpdateLossScalingFunctor{}( + dev_ctx, found_inf_vec, pre_loss_scaling, good_in, bad_in, + incr_every_n_steps, decr_every_n_nan_or_inf, incr_ratio, decr_ratio, + updated_loss_scaling, good_out, bad_out); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + update_loss_scaling, + ops::UpdateLossScalingNPUKernel, + ops::UpdateLossScalingNPUKernel); diff --git a/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..1060e67078f8d827618c782c8b413e861bf4d68a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_update_loss_scaling_op_npu.py @@ -0,0 +1,268 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid +import paddle.fluid.contrib.mixed_precision.amp_nn as amp_nn + +paddle.enable_static() +SEED = 2021 + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestUpdateLossScalingOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "update_loss_scaling" + self.place = paddle.NPUPlace(0) + + self.init() + found_inf = np.array([False], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', x)], + 'LossScaling': self.prev_loss_scaling * self.incr_ratio, + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def set_npu(self): + self.__class__.use_npu = True + + def init(self): + self.incr_ratio = 2.0 + self.decr_ratio = 0.8 + self.dtype = np.float32 + self.prev_loss_scaling = np.array([2048]).astype(self.dtype) + self.num_good_steps = np.array([999], dtype=np.int32) + self.num_bad_steps = np.array([1], dtype=np.int32) + self.zero_steps = np.array([0], dtype=np.int32) + self.attrs = { + 'incr_every_n_steps': 1000, + 'decr_every_n_nan_or_inf': 2, + 'incr_ratio': self.incr_ratio, + 'decr_ratio': self.decr_ratio, + } + + def test_check_output(self): + self.check_output_with_place( + self.place, check_dygraph=False, no_check_set=['Out']) + + +class TestUpdateLossScalingOpBad(TestUpdateLossScalingOp): + def setUp(self): + self.set_npu() + self.op_type = "update_loss_scaling" + self.place = paddle.NPUPlace(0) + + self.init() + found_inf = np.array([True], dtype=np.bool) + x = np.random.random((1024, 1024)).astype(self.dtype) + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + x[i[0]][j[0]] = np.inf + + self.inputs = { + 'X': [('x0', x)], + 'FoundInfinite': found_inf, + 'PrevLossScaling': self.prev_loss_scaling, + 'InGoodSteps': self.num_good_steps, + 'InBadSteps': self.num_bad_steps + } + + self.outputs = { + 'Out': [('out0', np.zeros_like(x))], + 'LossScaling': self.prev_loss_scaling * self.decr_ratio, + 'OutGoodSteps': self.zero_steps, + 'OutBadSteps': self.zero_steps + } + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestUpdateLossScalingLayer(unittest.TestCase): + def loss_scaling_check(self, use_npu=True, scope=fluid.Scope()): + a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') + b = fluid.data(name="b", shape=[512, 128], dtype='float32') + x = [a, b] + found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') + prev_loss_scaling = fluid.data( + name="prev_loss_scaling", shape=[1], dtype='float32') + num_good_steps = fluid.data( + name="num_good_steps", shape=[1], dtype='int32') + num_bad_steps = fluid.data( + name="num_bad_steps", shape=[1], dtype='int32') + + a_v = np.random.random([1024, 1024]).astype('float32') + b_v = np.random.random([512, 128]).astype('float32') + found_inf_v = np.array([False]).astype('bool') + prev_loss_scaling_v = np.array([2048]).astype('float32') + num_good_steps_v = np.array([999], dtype=np.int32) + num_bad_steps_v = np.array([1], dtype=np.int32) + + incr_every_n_steps = 1000 + decr_every_n_nan_or_inf = 2 + incr_ratio = 2 + decr_ratio = 0.8 + + result = amp_nn.update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name="update_loss_scaling") + + place = paddle.NPUPlace(0) if use_npu else fluid.CPUPlace() + exe = fluid.Executor(place) + with fluid.scope_guard(scope): + exe.run(fluid.default_startup_program()) + result_v = exe.run(feed={ + 'a': a_v, + 'b': b_v, + 'found_inf': found_inf_v, + 'prev_loss_scaling': prev_loss_scaling_v, + 'num_good_steps': num_good_steps_v, + 'num_bad_steps': num_bad_steps_v + }, + fetch_list=[ + result, x, found_inf, prev_loss_scaling, + num_good_steps, num_bad_steps + ]) + assert np.array_equal(result_v[0], a_v) + assert np.array_equal(result_v[1], b_v) + assert np.array_equal(result_v[0], result_v[2]) + assert np.array_equal(result_v[1], result_v[3]) + assert np.array_equal(result_v[4], found_inf_v) + assert np.array_equal(result_v[5], prev_loss_scaling_v * incr_ratio) + assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) + assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) + + def loss_scaling_check_inf(self, use_npu=True, scope=fluid.Scope()): + a = fluid.data(name="a", shape=[1024, 1024], dtype='float32') + b = fluid.data(name="b", shape=[512, 128], dtype='float32') + x = [a, b] + found_inf = fluid.data(name="found_inf", shape=[1], dtype='bool') + prev_loss_scaling = fluid.data( + name="prev_loss_scaling", shape=[1], dtype='float32') + num_good_steps = fluid.data( + name="num_good_steps", shape=[1], dtype='int32') + num_bad_steps = fluid.data( + name="num_bad_steps", shape=[1], dtype='int32') + + a_v = np.random.random([1024, 1024]).astype('float32') + b_v = np.random.random([512, 128]).astype('float32') + i = np.random.randint(0, 1024, 1) + j = np.random.randint(0, 1024, 1) + a_v[i[0]][j[0]] = np.inf + found_inf_v = np.array([True]).astype('bool') + prev_loss_scaling_v = np.array([2048]).astype('float32') + num_good_steps_v = np.array([999], dtype=np.int32) + num_bad_steps_v = np.array([1], dtype=np.int32) + + incr_every_n_steps = 1000 + decr_every_n_nan_or_inf = 2 + incr_ratio = 2 + decr_ratio = 0.8 + + result = amp_nn.update_loss_scaling( + x, + found_inf, + prev_loss_scaling, + num_good_steps, + num_bad_steps, + incr_every_n_steps, + decr_every_n_nan_or_inf, + incr_ratio, + decr_ratio, + name="update_loss_scaling") + + place = paddle.NPUPlace(0) if use_npu else fluid.CPUPlace() + exe = fluid.Executor(place) + with fluid.scope_guard(scope): + exe.run(fluid.default_startup_program()) + result_v = exe.run(feed={ + 'a': a_v, + 'b': b_v, + 'found_inf': found_inf_v, + 'prev_loss_scaling': prev_loss_scaling_v, + 'num_good_steps': num_good_steps_v, + 'num_bad_steps': num_bad_steps_v + }, + fetch_list=[ + result, x, found_inf, prev_loss_scaling, + num_good_steps, num_bad_steps + ]) + assert np.array_equal(result_v[0], np.zeros_like(a_v)) + assert np.array_equal(result_v[1], np.zeros_like(b_v)) + assert np.array_equal(result_v[2], np.zeros_like(a_v)) + assert np.array_equal(result_v[3], np.zeros_like(b_v)) + assert np.array_equal(result_v[4], found_inf_v) + assert np.array_equal(result_v[5], prev_loss_scaling_v * decr_ratio) + assert np.array_equal(result_v[6], np.zeros_like(num_good_steps_v)) + assert np.array_equal(result_v[7], np.zeros_like(num_bad_steps_v)) + + def test_loss_scaling_cpu(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check(use_npu=False) + + def test_loss_scaling_cpu_inf(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check_inf(use_npu=False) + + def test_loss_scaling_npu(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check(use_npu=True) + + def test_loss_scaling_npu_inf(self): + main = fluid.Program() + startup = fluid.Program() + with fluid.unique_name.guard(): + with fluid.program_guard(main, startup): + self.loss_scaling_check_inf(use_npu=True) + + +if __name__ == '__main__': + unittest.main()