diff --git a/paddle/fluid/operators/amp/alloc_float_status_op.cc b/paddle/fluid/operators/amp/alloc_float_status_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..181dd6eabe22d7d0c82b7c8f17625d787008f00b --- /dev/null +++ b/paddle/fluid/operators/amp/alloc_float_status_op.cc @@ -0,0 +1,74 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class AllocFloatStatusOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutput("FloatStatus"), "Output", "FloatStatus", + "alloc_float_status"); + ctx->SetOutputDim("FloatStatus", {8}); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } +}; + +class AllocFloatStatusMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddOutput("FloatStatus", + "(Tensor) of shape {8} that holds the float status."); + AddComment(R"DOC( + Produces a float Tensor that holds the float status +)DOC"); + } +}; + +template +class AllocFloatStatusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Operator alloc_float_status is not supported on CPU")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + alloc_float_status, ops::AllocFloatStatusOp, ops::AllocFloatStatusMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(alloc_float_status, + ops::AllocFloatStatusKernel); diff --git a/paddle/fluid/operators/amp/alloc_float_status_op_npu.cc b/paddle/fluid/operators/amp/alloc_float_status_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..fe5b08af52a624b29100635ee34cfac7c2d2a859 --- /dev/null +++ b/paddle/fluid/operators/amp/alloc_float_status_op_npu.cc @@ -0,0 +1,47 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class AllocFloatStatusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + auto* float_status = ctx.Output("FloatStatus"); + float_status->mutable_data(ctx.GetPlace()); + + auto runner = NpuOpRunner("NPUAllocFloatStatus", {}, {*float_status}); + auto stream = + ctx.template device_context() + .stream(); + runner.Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + alloc_float_status, + ops::AllocFloatStatusKernel); diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc index 9d78936ad5f7f2618eb766d84de2c631fc0cf8c5..c7520dbd34f6a92afb5c2fe320197fdad8e95379 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op.cc @@ -60,6 +60,12 @@ class CheckFiniteAndUnscaleOpMaker : public framework::OpProtoAndCheckerMaker { AddInput("Scale", "(Tensor) 1-dim tensor, the scale of check_finite_and_unscale " "operator."); +#ifdef PADDLE_WITH_ASCEND_CL + AddInput("FloatStatus", + "(Tensor) 1-dim tensor of shape [8], allocated by " + "alloc_float_status op") + .AsDispensable(); +#endif AddOutput("Out", "(Tensors) The scaled output tensor of " "check_finite_and_unscale operator.") diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc index 21968dcb05dd122d1c5705a3f4f1abb085db8fd0..8fd45326e4ec6134cf4b98be12212ce8d7d74541 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc @@ -24,12 +24,19 @@ namespace operators { using Tensor = framework::Tensor; +// NOTE(zhiqiu): The CheckFiniteAndUnscaleNPUKernel is different from CUDA. +// On NPU, we do not really check the data of input tensors, +// but use NPUGetFloatStatus to check whether the nan/inf occurs on device, +// and clear it after this op. +// Which may leads to wrong result if the input tensors is not calculated +// on NPU device, but got from other way, for example, feeding. template class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const { const auto xs = ctx.MultiInput("X"); const auto* scale = ctx.Input("Scale"); + const auto* float_status = ctx.Input("FloatStatus"); auto outs = ctx.MultiOutput("Out"); auto* found_inf = ctx.Output("FoundInfinite"); @@ -56,58 +63,60 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { runner_inverse.Run(stream); tmp_inverse_out = &inverse_out; - size_t x_size = xs.size(); - for (size_t i = 0; i < x_size; ++i) { + // NOTE(zhiqiu): + Tensor tmp; + tmp.mutable_data({8}, ctx.GetPlace()); + + // NOTE(zhiqiu): NPUGetFloatStatus updates data on input in-place. + // tmp is only placeholder. + auto runner_float_status = + NpuOpRunner("NPUGetFloatStatus", {*float_status}, {tmp}, + {{"message", std::string("check_nan_and_inf")}}); + runner_float_status.Run(stream); + + Tensor sum; + sum.mutable_data({1}, ctx.GetPlace()); + auto runner_reduce_sum = + NpuOpRunner("ReduceSumD", {*float_status}, {sum}, + {{"axes", std::vector{0}}, {"keep_dims", true}}); + runner_reduce_sum.Run(stream); + + std::vector sum_vec; + TensorToVector( + sum, ctx.template device_context(), + &sum_vec); + found_inf_data = (sum_vec[0] > 1); + + VLOG(4) << "found_inf_data:" << found_inf_data; + + for (size_t i = 0; i < xs.size(); ++i) { const auto* x = xs[i]; auto* out = outs[i]; out->mutable_data(ctx.GetPlace()); - - // step2: CheckNumerics - // CheckNumerics runs on the Ascend AI CPU, which delivers poor - // performance. - Tensor check_xout(x->type()); - check_xout.Resize(x->dims()); - check_xout.mutable_data(ctx.GetPlace()); - try { - auto runner_checknumerics = - NpuOpRunner("CheckNumerics", {*x}, {check_xout}, - {{"message", std::string("check_nan_and_inf")}}); - runner_checknumerics.Run(stream); - ctx.template device_context() - .Wait(); - } catch (platform::EnforceNotMet& exception) { - LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; - found_inf_data = true; - } catch (...) { - LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; - found_inf_data = true; - } - if (!found_inf_data) { // MatMul auto runner_matmul = NpuOpRunner("Mul", {*x, *tmp_inverse_out}, {*out}, {}); runner_matmul.Run(stream); - } else { - // ZerosLike - auto runner_zeroslike = NpuOpRunner("ZerosLike", {*x}, {*out}, {}); - runner_zeroslike.Run(stream); - } // end if - } // end for + } + } // set found_inf to true - if (found_inf_data) { - Tensor found_inf_tensor; - found_inf_tensor.Resize({1}); - bool* is_found_inf = - found_inf_tensor.mutable_data(paddle::platform::CPUPlace()); - *is_found_inf = true; - - framework::TensorCopy( - found_inf_tensor, ctx.GetPlace(), - ctx.template device_context(), found_inf); - ctx.template device_context().Wait(); - } + VLOG(4) << "found overflow:" << found_inf_data; + Tensor found_inf_tensor; + found_inf_tensor.Resize({1}); + bool* is_found_inf = + found_inf_tensor.mutable_data(paddle::platform::CPUPlace()); + *is_found_inf = found_inf_data; + + framework::TensorCopy( + found_inf_tensor, ctx.GetPlace(), + ctx.template device_context(), found_inf); + ctx.template device_context().Wait(); + + auto runner_clear_status = + NpuOpRunner("NPUClearFloatStatus", {*float_status}, {tmp}); + runner_clear_status.Run(stream); } }; diff --git a/paddle/fluid/operators/scale_op_npu.cc b/paddle/fluid/operators/scale_op_npu.cc index ee7210a7784d72e1cec297ad8ba194b36fae8fba..cbfd11834ae47710bc8b80df15400689a50af6bc 100644 --- a/paddle/fluid/operators/scale_op_npu.cc +++ b/paddle/fluid/operators/scale_op_npu.cc @@ -34,6 +34,8 @@ class ScaleNPUKernel : public framework::OpKernel { ctx.template device_context() .stream(); float _power = 1.0; + VLOG(4) << "scale:" << scale << ", bias:" << bias + << " ,bias_after_scale:" << bias_after_scale; if (bias_after_scale) { out->mutable_data(ctx.GetPlace()); auto runner = diff --git a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py index 3bfc078971d7a4f18fcd37ff9de2740edb9778e0..588eb2a29f555a09a7c1bf5c7512198b999eeccd 100644 --- a/python/paddle/fluid/contrib/mixed_precision/amp_nn.py +++ b/python/paddle/fluid/contrib/mixed_precision/amp_nn.py @@ -20,7 +20,7 @@ from paddle.fluid import core __all__ = ['check_finite_and_unscale', 'update_loss_scaling'] -def check_finite_and_unscale(x, scale, name=None): +def check_finite_and_unscale(x, scale, name=None, float_status=None): """ Check if input X contains all finite data, if yes, scale it by input Scale. @@ -30,9 +30,11 @@ def check_finite_and_unscale(x, scale, name=None): FoundInfinite will be 1 (True), and Out will not be scaled. In this case, the data of Out should not be used, and its data may not be deterministic. Otherwise, FoundInfinite will be 0 (False). + Args: x(list|tuple): The input tensors of check_finite_and_unscale operator. scale: The scale of check_finite_and_unscale operator. + float_status(Tensor): (Only used on NPU) The float status to check overflow. """ check_type(x, 'x', (tuple, list), 'check_finite_and_unscale') for e in x: @@ -43,6 +45,11 @@ def check_finite_and_unscale(x, scale, name=None): found_inf = helper.create_variable_for_type_inference(dtype='bool') inputs = {'X': x, 'Scale': scale} + if core.is_compiled_with_npu(): + check_variable_and_dtype(float_status, "float_status", + ['float16', 'float32'], + 'check_finite_and_unscale') + inputs['FloatStatus'] = float_status outputs = {'Out': x, 'FoundInfinite': found_inf} helper.append_op( type='check_finite_and_unscale', inputs=inputs, outputs=outputs) diff --git a/python/paddle/fluid/contrib/mixed_precision/decorator.py b/python/paddle/fluid/contrib/mixed_precision/decorator.py index 16cba2ce36b20ef2d8b97c046b52a8df64fe0d49..3cb9fe75559b1615f2ed1a01bd31742c2996e090 100644 --- a/python/paddle/fluid/contrib/mixed_precision/decorator.py +++ b/python/paddle/fluid/contrib/mixed_precision/decorator.py @@ -29,6 +29,7 @@ from .amp_nn import check_finite_and_unscale from .amp_nn import update_loss_scaling import types import warnings +import paddle __all__ = ["decorate"] @@ -165,6 +166,17 @@ class OptimizerWithMixedPrecision(object): train_program = loss.block.program self._train_program = train_program + # NOTE(zhiqiu): _float_status is only used for NPU. + if core.is_compiled_with_npu(): + float_status = paddle.static.data( + name="float_status", shape=[8], dtype='float32') + self._train_program.global_block().append_op( + type="alloc_float_status", + outputs={"FloatStatus": float_status}, ) + self._float_status = float_status + else: + self._float_status = None + with program_guard(self._train_program, startup_program): self._init_amp_var() @@ -294,7 +306,10 @@ class OptimizerWithMixedPrecision(object): for p, g in params_grads: with self._train_program._optimized_guard([p, g]): _, found_inf = check_finite_and_unscale( - [g, ], self._loss_scaling, name="find_infinite_scale") + [g, ], + self._loss_scaling, + name="find_infinite_scale", + float_status=self._float_status) found_infs.append(found_inf) elif self._use_pure_fp16: if fp32_grads: @@ -302,19 +317,24 @@ class OptimizerWithMixedPrecision(object): _, fp32_found_inf = check_finite_and_unscale( fp32_grads, self._loss_scaling, - name="find_infinite_scale_fp32") + name="find_infinite_scale_fp32", + float_status=self._float_status) found_infs.append(fp32_found_inf) if fp16_grads: with self._train_program._optimized_guard(fp16_grads): _, fp16_found_inf = check_finite_and_unscale( fp16_grads, self._loss_scaling, - name="find_infinite_scale_fp16") + name="find_infinite_scale_fp16", + float_status=self._float_status) found_infs.append(fp16_found_inf) else: with self._train_program._optimized_guard(grads): _, found_inf = check_finite_and_unscale( - grads, self._loss_scaling, name="find_infinite_scale") + grads, + self._loss_scaling, + name="find_infinite_scale", + float_status=self._float_status) if self._use_dynamic_loss_scaling: if self._is_distributed or self._use_pure_fp16: @@ -394,6 +414,7 @@ class OptimizerWithMixedPrecision(object): The scaled loss by scaling factor, the list of optimize ops, and a list of scaled parameters and gradients. """ + opt_dict = self._optimizer.__class__.__dict__ if 'minimize' in opt_dict and isinstance(opt_dict['minimize'], types.FunctionType): diff --git a/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py index ac80ea4c62cbfef0e313beac24391893d306d2f7..8828892dca3ccc04dcf926cf5462a282fd442c51 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py @@ -19,106 +19,128 @@ sys.path.append("..") from op_test import OpTest, skip_check_grad_ci import paddle import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.contrib.mixed_precision.amp_nn import check_finite_and_unscale paddle.enable_static() @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") -class TestCheckFiniteAndUnscaleOp(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "check_finite_and_unscale" - self.place = paddle.NPUPlace(0) - self.init_dtype() - x = np.random.random((1024, 1024)).astype(self.dtype) - scale = np.random.random((1)).astype(self.dtype) - - self.inputs = {'X': [('x0', x)], 'Scale': scale} - self.outputs = { - 'FoundInfinite': np.array([0]), - 'Out': [('out0', x / scale)], - } - - def set_npu(self): - self.__class__.use_npu = True - - def init_kernel_type(self): - self.use_mkldnn = False - - def init_dtype(self): - self.dtype = np.float32 - - def test_check_output(self): - self.check_output_with_place(self.place, check_dygraph=False) - - -@unittest.skipIf(not paddle.is_compiled_with_npu(), - "core is not compiled with NPU") -class TestCheckFiniteAndUnscaleOpWithNan(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "check_finite_and_unscale" - self.place = paddle.NPUPlace(0) - self.init_dtype() - x = np.random.random((1024, 1024)).astype(self.dtype) - x[128][128] = np.nan - scale = np.random.random((1)).astype(self.dtype) - - self.inputs = {'X': [('x0', x)], 'Scale': scale} - self.outputs = { - 'FoundInfinite': np.array([1]), - 'Out': [('out0', x)], - } - - def set_npu(self): - self.__class__.use_npu = True - - def init_kernel_type(self): - self.use_mkldnn = False - - def init_dtype(self): - self.dtype = np.float32 - - def test_check_output(self): - # When input contains nan, do not check the output, - # since the output may be nondeterministic and will be discarded. - self.check_output_with_place( - self.place, check_dygraph=False, no_check_set=['Out']) +class TestCheckFiniteAndUnscale(unittest.TestCase): + def get_prog(self): + paddle.enable_static() + main_program = paddle.static.Program() + with program_guard(main_program): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + scale = paddle.static.data(name="scale", shape=[1], dtype='float32') + float_status = paddle.static.data( + name="status", shape=[8], dtype='float32') + main_program.global_block().append_op( + type="alloc_float_status", + outputs={"FloatStatus": float_status}, ) + c = paddle.fluid.layers.elementwise_div(a, b) + out, found_inf = check_finite_and_unscale( + [c], scale, float_status=float_status) + + return main_program, out, found_inf, float_status + + def run_prog(self, a, b, scale): + main_program, out, found_inf, float_status = self.get_prog() + place = fluid.NPUPlace(0) + exe = fluid.Executor(place) + out_, founf_inf_, float_status_ = exe.run( + main_program, + feed={"a": a, + "b": b, + "scale": scale}, + fetch_list=[out, found_inf, float_status]) + print(float_status_) + return out_, founf_inf_ + + def test_contains_nan(self): + a = np.zeros((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + scale = np.array([2.0]).astype('float32') + + out, found_inf = self.run_prog(a, b, scale) + print(out, found_inf) + + self.assertTrue(found_inf[0]) + + def test_contains_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + scale = np.array([2.0]).astype('float32') + + out, found_inf = self.run_prog(a, b, scale) + print(out, found_inf) + + self.assertTrue(found_inf[0]) + + def test_not_contains_nan_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.ones((32, 32)).astype('float32') + scale = np.array([2.0]).astype('float32') + + out, found_inf = self.run_prog(a, b, scale) + print(out, found_inf) + + self.assertTrue(np.allclose(out, (a / b) / scale[0])) + self.assertFalse(found_inf[0]) @unittest.skipIf(not paddle.is_compiled_with_npu(), "core is not compiled with NPU") -class TestCheckFiniteAndUnscaleOpWithInf(OpTest): - def setUp(self): - self.set_npu() - self.op_type = "check_finite_and_unscale" - self.place = paddle.NPUPlace(0) - self.init_dtype() - x = np.random.random((1024, 1024)).astype(self.dtype) - x[128][128] = np.inf - scale = np.random.random((1)).astype(self.dtype) - - self.inputs = {'X': [('x0', x)], 'Scale': scale} - self.outputs = { - 'FoundInfinite': np.array([1]), - 'Out': [('out0', x)], - } - - def set_npu(self): - self.__class__.use_npu = True - - def init_kernel_type(self): - self.use_mkldnn = False - - def init_dtype(self): - self.dtype = np.float32 - - def test_check_output(self): - # When input contains inf, do not check the output, - # since the output may be nondeterministic and will be discarded. - self.check_output_with_place( - self.place, check_dygraph=False, no_check_set=['Out']) +class TestCheckFiniteAndUnscaleClearFloatStatus(unittest.TestCase): + def get_prog(self): + paddle.enable_static() + main_program = paddle.static.Program() + with program_guard(main_program): + a = paddle.static.data(name="a", shape=[32, 32], dtype='float32') + b = paddle.static.data(name="b", shape=[32, 32], dtype='float32') + scale = paddle.static.data(name="scale", shape=[1], dtype='float32') + float_status = paddle.static.data( + name="status", shape=[8], dtype='float32') + main_program.global_block().append_op( + type="alloc_float_status", + outputs={"FloatStatus": float_status}, ) + c = paddle.fluid.layers.elementwise_div(a, b) + out, found_inf = check_finite_and_unscale( + [c], scale, float_status=float_status) + main_program.global_block().append_op( + type="alloc_float_status", + outputs={"FloatStatus": float_status}, ) + d = paddle.fluid.layers.elementwise_add(a, b) + out, found_inf = check_finite_and_unscale( + [d], scale, float_status=float_status) + + return main_program, out, found_inf, float_status + + def run_prog(self, a, b, scale): + main_program, out, found_inf, float_status = self.get_prog() + place = fluid.NPUPlace(0) + exe = fluid.Executor(place) + out_, founf_inf_, float_status_ = exe.run( + main_program, + feed={"a": a, + "b": b, + "scale": scale}, + fetch_list=[out, found_inf, float_status]) + print(float_status_) + return out_, founf_inf_ + + def test_not_contains_nan_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + scale = np.array([2.0]).astype('float32') + + out, found_inf = self.run_prog(a, b, scale) + print(out, found_inf) + + self.assertTrue(np.allclose(out, (a + b) / scale[0])) + self.assertFalse(found_inf[0]) if __name__ == '__main__':