From 3bf8a34c69a5e8a859d75699462e6e66ef27ddce Mon Sep 17 00:00:00 2001 From: xiayanming <41795079@qq.com> Date: Fri, 12 Mar 2021 15:33:45 +0800 Subject: [PATCH] [NPU] Support npu kernel for amp_check_finite_and_unscale_npu op (#31457) * Support npu kernel for amp_check_finite_and_unscale_npu op * support EnforceNotMet exception * fix exception bug * modify python unittest * precommit * update c++ unittest * fix review * fix review --- paddle/fluid/operators/amp/CMakeLists.txt | 4 + .../amp/check_finite_and_unscale_op_npu.cc | 119 ++++++++++++++++ .../check_finite_and_unscale_op_npu_test.cc | 131 ++++++++++++++++++ .../test_amp_check_finite_and_scale_op_npu.py | 123 ++++++++++++++++ 4 files changed, 377 insertions(+) create mode 100644 paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc create mode 100644 paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc create mode 100644 python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py diff --git a/paddle/fluid/operators/amp/CMakeLists.txt b/paddle/fluid/operators/amp/CMakeLists.txt index b3ff52a7ae1..2ea8bbcbc61 100644 --- a/paddle/fluid/operators/amp/CMakeLists.txt +++ b/paddle/fluid/operators/amp/CMakeLists.txt @@ -4,3 +4,7 @@ if(WITH_UNITY_BUILD) include(unity_build_rule.cmake) endif() register_operators() + +if(WITH_ASCEND_CL) + cc_test(check_finite_and_unscale_op_npu_test SRCS check_finite_and_unscale_op_npu_test.cc DEPS op_registry check_finite_and_unscale_op scope device_context enforce executor) +endif() diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc new file mode 100644 index 00000000000..46f9f7ff089 --- /dev/null +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc @@ -0,0 +1,119 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include + +#include "paddle/fluid/framework/tensor_util.h" +#include "paddle/fluid/operators/amp/check_finite_and_unscale_op.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const { + const auto xs = ctx.MultiInput("X"); + const auto* scale = ctx.Input("Scale"); + auto outs = ctx.MultiOutput("Out"); + auto* found_inf = ctx.Output("FoundInfinite"); + + found_inf->mutable_data(ctx.GetPlace()); + + bool found_inf_data = false; + + auto stream = + ctx.template device_context() + .stream(); + + // step1: inverse scale(RealDiv) + Tensor const_tensor; + const_tensor.mutable_data({1}, ctx.GetPlace()); + TensorFromVector(std::vector{static_cast(1.0)}, ctx.device_context(), + &const_tensor); + + ctx.template device_context().Wait(); + + // Inverse(1.0/scale) + Tensor* tmp_inverse_out = const_cast(scale); + Tensor inverse_out(scale->type()); + inverse_out.Resize(scale->dims()); + inverse_out.mutable_data(ctx.GetPlace()); + auto runner_inverse = + NpuOpRunner("Div", {const_tensor, *scale}, {inverse_out}, {}); + runner_inverse.Run(stream); + tmp_inverse_out = &inverse_out; + + size_t x_size = xs.size(); + for (size_t i = 0; i < x_size; ++i) { + found_inf_data = true; + const auto* x = xs[i]; + auto* out = outs[i]; + out->mutable_data(ctx.GetPlace()); + + // step2: CheckNumerics + // CheckNumerics runs on the Ascend AI CPU, which delivers poor + // performance. + Tensor check_xout(x->type()); + check_xout.Resize(x->dims()); + check_xout.mutable_data(ctx.GetPlace()); + try { + auto runner_checknumerics = + NpuOpRunner("CheckNumerics", {*x}, {check_xout}, + {{"message", std::string("check_nan_and_inf")}}); + runner_checknumerics.Run(stream); + } catch (platform::EnforceNotMet& exception) { + LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; + found_inf_data = true; + } catch (...) { + LOG(WARNING) << "[check_nan_and_inf] detected contains NaN or INF!!!"; + found_inf_data = true; + } + + if (!found_inf_data) { + // MatMul + auto runner_matmul = + NpuOpRunner("Mul", {*x, *tmp_inverse_out}, {*out}, {}); + runner_matmul.Run(stream); + } else { + // ZerosLike + auto runner_zeroslike = NpuOpRunner("ZerosLike", {*x}, {*out}, {}); + runner_zeroslike.Run(stream); + } // end if + } // end for + + // set found_inf to true + if (found_inf_data) { + Tensor found_inf_tensor; + found_inf_tensor.Resize({1}); + bool* is_found_inf = + found_inf_tensor.mutable_data(paddle::platform::CPUPlace()); + *is_found_inf = true; + framework::TensorCopySync(found_inf_tensor, ctx.GetPlace(), found_inf); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OP_NPU_KERNEL(check_finite_and_unscale, + ops::CheckFiniteAndUnscaleNPUKernel, + ops::CheckFiniteAndUnscaleNPUKernel); diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc new file mode 100644 index 00000000000..99e81a4757d --- /dev/null +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu_test.cc @@ -0,0 +1,131 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/platform/enforce.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +using Tensor = paddle::framework::Tensor; + +USE_OP(check_finite_and_unscale); +USE_OP_DEVICE_KERNEL(check_finite_and_unscale, NPU); + +struct InputVars { + std::string name; + f::LoDTensor *tensor; +}; + +template +void Compare(f::Scope *scope, const p::DeviceContext &ctx) { + const f::DDim dims = f::make_ddim({2, 2}); + auto place = ctx.GetPlace(); + + // init input + std::vector input_names = { + {"x", scope->Var("x")->GetMutable()}, + {"x1", scope->Var("x1")->GetMutable()}}; + + auto *scale = scope->Var("scale")->GetMutable(); + + // init output + auto *out = scope->Var("out")->GetMutable(); + auto *out1 = scope->Var("out1")->GetMutable(); + auto *found_inf = scope->Var("found_inf")->GetMutable(); + + // Initialize input data + const int num_inputs = input_names.size(); + size_t numel = static_cast(f::product(dims)); + + for (int i = 0; i < num_inputs; ++i) { + std::vector init_xs; + for (size_t j = 0; j < numel; ++j) { + if (j == 0) { + init_xs.push_back(static_cast(NAN)); + } else { + init_xs.push_back(static_cast(j + 1)); + } + } + f::TensorFromVector(init_xs, ctx, input_names[i].tensor); + input_names[i].tensor->Resize(dims); + } + + f::TensorFromVector(std::vector{static_cast(0.5)}, ctx, scale); + + ctx.Wait(); + + // run + f::AttributeMap attrs; + auto op = f::OpRegistry::CreateOp( + "check_finite_and_unscale", {{"X", {"x", "x1"}}, {"Scale", {"scale"}}}, + {{"Out", {"out", "out1"}}, {"FoundInfinite", {"found_inf"}}}, attrs); + op->Run(*scope, place); + ctx.Wait(); + + // out0 + std::vector out_vec; + f::TensorToVector(*out, ctx, &out_vec); + EXPECT_EQ(out_vec.size(), static_cast(4)); + for (size_t j = 0; j < out_vec.size(); ++j) { + VLOG(3) << "out_vec[" << j << "]:" << out_vec[j]; + } + + ctx.Wait(); + + // out0 + std::vector out1_vec; + f::TensorToVector(*out1, ctx, &out1_vec); + EXPECT_EQ(out1_vec.size(), static_cast(4)); + for (size_t j = 0; j < out1_vec.size(); ++j) { + VLOG(3) << "out1_vec[" << j << "]:" << out1_vec[j]; + } + + ctx.Wait(); + + // out found_inf + Tensor found_inf_tensor; + found_inf_tensor.Resize({1}); + bool *is_finite_data = + found_inf_tensor.mutable_data(paddle::platform::CPUPlace()); + f::TensorCopy(*found_inf, place, &found_inf_tensor); + EXPECT_FALSE(*is_finite_data); + + ctx.Wait(); +} + +TEST(check_finite_and_unscale, NPU_fp32) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx); +} + +TEST(check_finite_and_unscale, NPU_fp16) { + f::Scope scope; + p::NPUDeviceContext ctx(p::NPUPlace(0)); + Compare(&scope, ctx); +} diff --git a/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py new file mode 100644 index 00000000000..4cda0ceeccf --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_amp_check_finite_and_scale_op_npu.py @@ -0,0 +1,123 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.fluid as fluid + +paddle.enable_static() + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestCheckFiniteAndUnscaleOp(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "check_finite_and_unscale" + self.place = paddle.NPUPlace(0) + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([0]), + 'Out': [('out0', x / scale)], + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_kernel_type(self): + self.use_mkldnn = False + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + self.check_output_with_place(self.place, check_dygraph=False) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestCheckFiniteAndUnscaleOpWithNan(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "check_finite_and_unscale" + self.place = paddle.NPUPlace(0) + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + x[128][128] = np.nan + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x)], + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_kernel_type(self): + self.use_mkldnn = False + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + # When input contains nan, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output_with_place( + self.place, check_dygraph=False, no_check_set=['Out']) + + +@unittest.skipIf(not paddle.is_compiled_with_npu(), + "core is not compiled with NPU") +class TestCheckFiniteAndUnscaleOpWithInf(OpTest): + def setUp(self): + self.set_npu() + self.op_type = "check_finite_and_unscale" + self.place = paddle.NPUPlace(0) + self.init_dtype() + x = np.random.random((1024, 1024)).astype(self.dtype) + x[128][128] = np.inf + scale = np.random.random((1)).astype(self.dtype) + + self.inputs = {'X': [('x0', x)], 'Scale': scale} + self.outputs = { + 'FoundInfinite': np.array([1]), + 'Out': [('out0', x)], + } + + def set_npu(self): + self.__class__.use_npu = True + + def init_kernel_type(self): + self.use_mkldnn = False + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + # When input contains inf, do not check the output, + # since the output may be nondeterministic and will be discarded. + self.check_output_with_place( + self.place, check_dygraph=False, no_check_set=['Out']) + + +if __name__ == '__main__': + unittest.main() -- GitLab