diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index bbb781c8664baff5a260ad9b2d8f8f3348ea089b..72f7f0e6011c1bdbf50482c8e35b6c1207f5aa73 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -64,7 +64,11 @@ elseif(WITH_ROCM) hip_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda) hip_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) else() - cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS framework_proto scope place) + if (WITH_ASCEND_CL) + cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS npu_op_runner framework_proto scope place) + else() + cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS framework_proto scope place) + endif() cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory diff --git a/paddle/fluid/framework/details/nan_inf_utils.h b/paddle/fluid/framework/details/nan_inf_utils.h index cf64ccd60f45a40b6c9ca83dcdd473686d03904f..5a592f22dc494e00c1bea0e42f22bfe9c6210a46 100644 --- a/paddle/fluid/framework/details/nan_inf_utils.h +++ b/paddle/fluid/framework/details/nan_inf_utils.h @@ -53,6 +53,12 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type, } } +#ifdef PADDLE_WITH_ASCEND_CL +void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, + const framework::Scope& scope, + const platform::Place& place); +#endif + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index 30231a1799fd3714646a81bba2afb5de03045850..f22f008c19896a3dba4626ed57016b3fc41e2059 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -15,6 +15,11 @@ #include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/details/nan_inf_utils_detail.h" #include "paddle/fluid/framework/op_proto_maker.h" + +#ifdef PADDLE_WITH_ASCEND_CL +#include "paddle/fluid/operators/npu_op_runner.h" +#endif + namespace paddle { namespace framework { namespace details { @@ -123,8 +128,10 @@ static void InitWhiteListFormEnv() { template static void PrintNanInf(const T* value, const size_t numel, int print_num, - const std::string& op_type, - const std::string& var_name) { + const std::string& op_type, const std::string& var_name, + bool abort = true) { + T min_value = std::numeric_limits::max(); + T max_value = std::numeric_limits::min(); size_t nan_count, inf_count, num_count; nan_count = inf_count = num_count = 0; @@ -137,6 +144,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, count = inf_count++; } else { count = num_count++; + min_value = std::min(min_value, value[i]); + max_value = std::max(max_value, value[i]); } if (count < static_cast(print_num)) { @@ -144,12 +153,17 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num, static_cast(i), static_cast(value[i])); } } - printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n", - static_cast(nan_count), static_cast(inf_count), - static_cast(num_count)); - PADDLE_THROW(platform::errors::PreconditionNotMet( - "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, - op_type)); + printf( + "In cpu, there has %lu,%lu,%lu nan,inf,num. " + "And in num, min_value is %f, max_value is %f\n", + static_cast(nan_count), static_cast(inf_count), + static_cast(num_count), static_cast(min_value), + static_cast(max_value)); + if (abort) { + PADDLE_THROW(platform::errors::PreconditionNotMet( + "There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name, + op_type)); + } } // openmp 4.0, reduction with fp16 @@ -415,6 +429,134 @@ bool IsSkipOp(const framework::OperatorBase& op) { return false; } +#ifdef PADDLE_WITH_ASCEND_CL +using NpuOpRunner = paddle::operators::NpuOpRunner; + +constexpr int FLOAT_STATUS_SIZE = 8; + +static framework::Tensor& npu_float_status() { + static framework::Tensor float_status; + return float_status; +} + +void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, + const framework::Scope& scope, + const platform::Place& place) { + if (!platform::is_npu_place(place)) return; + + std::call_once(white_list_init_flag, InitWhiteListFormEnv); + if (IsSkipOp(op)) return; + + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto stream = dev_ctx->stream(); + + auto& flag = npu_float_status(); + flag.mutable_data({FLOAT_STATUS_SIZE}, place); + NpuOpRunner("NPUAllocFloatStatus", {}, {flag}).Run(stream); + + framework::Tensor tmp; + tmp.mutable_data({FLOAT_STATUS_SIZE}, place); + NpuOpRunner("NPUClearFloatStatus", {tmp}, {flag}).Run(stream); +} + +void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name, + const framework::Variable* var, + const platform::Place& place) { + const Tensor* tensor{nullptr}; + if (var->IsType()) { + tensor = &var->Get(); + } else if (var->IsType()) { + tensor = &var->Get().value(); + } else { + VLOG(10) << var_name << " var_name need not to check"; + return; + } + + if ((tensor->type() != proto::VarType::FP32) && + (tensor->type() != proto::VarType::FP16)) { + return; + } + + if (tensor->memory_size() == 0) { + VLOG(10) << var_name << " var_name need not to check, size == 0"; + return; + } + + VLOG(10) << "begin check " << op_type << " var_name:" << var_name + << ", place:" << tensor->place() << ", numel:" << tensor->numel(); + + framework::Tensor cpu_tensor; + cpu_tensor.Resize(tensor->dims()); + cpu_tensor.mutable_data(platform::CPUPlace(), tensor->type()); + framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor); + + LOG(WARNING) << "print [" << var_name << "] tensor info:"; + // use env strategy control in future, -1=print_all. + int print_num = 3; + if (tensor->type() == proto::VarType::FP32) { + const float* value = cpu_tensor.data(); + PrintNanInf(value, tensor->numel(), print_num, op_type, var_name, false); + } else if (tensor->type() == proto::VarType::FP16) { + const paddle::platform::float16* value = + cpu_tensor.data(); + PrintNanInf(value, tensor->numel(), print_num, op_type, var_name, false); + } +} + +void PrintNPUOpValueInfo(const framework::OperatorBase& op, + const framework::Scope& scope, + const platform::Place& place) { + LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type() + << "), here we print some tensor value info of this op."; + for (auto& vname : op.InputVars()) { + auto* var = scope.FindVar(vname); + if (var == nullptr) continue; + PrintNpuVarInfo(op.Type(), vname, var, place); + } + + for (auto& vname : op.OutputVars(true)) { + auto* var = scope.FindVar(vname); + if (var == nullptr) continue; + PrintNpuVarInfo(op.Type(), vname, var, place); + } +} + +static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op, + const framework::Scope& scope, + const platform::Place& place) { + if (!platform::is_npu_place(place)) return; + + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(place)); + auto stream = dev_ctx->stream(); + + auto& flag = npu_float_status(); + Tensor tmp; + tmp.mutable_data({FLOAT_STATUS_SIZE}, place); + // NPUGetFloatStatus updates data on input in-place. + // tmp is only placeholder. + NpuOpRunner("NPUGetFloatStatus", {flag}, {tmp}).Run(stream); + + framework::Tensor cpu_tensor; + auto cpu_place = platform::CPUPlace(); + float* cpu_data = static_cast( + cpu_tensor.mutable_data({FLOAT_STATUS_SIZE}, cpu_place)); + + framework::TensorCopySync(flag, cpu_place, &cpu_tensor); + float sum = 0.0; + for (int i = 0; i < FLOAT_STATUS_SIZE; ++i) { + sum += cpu_data[i]; + } + + if (sum >= 1.0) PrintNPUOpValueInfo(op, scope, place); + + PADDLE_ENFORCE_LT( + sum, 1.0, platform::errors::PreconditionNotMet( + "Operator %s contains Nan/Inf.", op.DebugStringEx(&scope))); +} +#endif + void CheckOpHasNanOrInf(const framework::OperatorBase& op, const framework::Scope& exec_scope, const platform::Place& place) { @@ -422,6 +564,13 @@ void CheckOpHasNanOrInf(const framework::OperatorBase& op, if (IsSkipOp(op)) return; +#ifdef PADDLE_WITH_ASCEND_CL + if (platform::is_npu_place(place)) { + NPUCheckOpHasNanOrInf(op, exec_scope, place); + return; + } +#endif + if (op_var_nan_inf_white_list().count(op.Type()) == 0) { // NOTE. vname may destruct in the end of this func. for (auto& vname : op.OutputVars(true)) { diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 57f9d094ac80d788555f5fa47c0b7e98b0bdbac0..17003157b645a2cb7529543ef9016a26526d5ccb 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1105,6 +1105,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope, platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); auto* dev_ctx = pool.Get(place); +#ifdef PADDLE_WITH_ASCEND_CL + // NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable + // values, but only through special `float_status` to checks whether + // the operation is overflow. More about `float_status`, see: + // https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue + if (FLAGS_check_nan_inf) { + framework::details::NPUAllocAndClearFloatStatus(*this, scope, place); + } +#endif + if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) { ChooseKernel(*runtime_ctx, scope, place); } diff --git a/paddle/fluid/operators/amp/clear_float_status_op_npu.cc b/paddle/fluid/operators/amp/clear_float_status_op_npu.cc index d5bdcc37c2a8237bdaf54d3c3e8190175c9dec31..468b0f04cae3d543ff9da5661c1b90f55020bdab 100644 --- a/paddle/fluid/operators/amp/clear_float_status_op_npu.cc +++ b/paddle/fluid/operators/amp/clear_float_status_op_npu.cc @@ -36,7 +36,7 @@ class ClearFloatStatusKernel : public framework::OpKernel { Tensor tmp; tmp.mutable_data({8}, ctx.GetPlace()); const auto& runner = - NpuOpRunner("NPUClearFloatStatus", {*float_status}, {tmp}); + NpuOpRunner("NPUClearFloatStatus", {tmp}, {*float_status_out}); auto stream = ctx.template device_context() .stream(); diff --git a/paddle/fluid/operators/amp/get_float_status_op.cc b/paddle/fluid/operators/amp/get_float_status_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..a7cdfbf19360cb90615c4bdf881302b0d1f07da0 --- /dev/null +++ b/paddle/fluid/operators/amp/get_float_status_op.cc @@ -0,0 +1,75 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include + +#include "paddle/fluid/framework/op_registry.h" + +namespace paddle { +namespace operators { + +class GetFloatStatusOp : public framework::OperatorWithKernel { + public: + using framework::OperatorWithKernel::OperatorWithKernel; + + void InferShape(framework::InferShapeContext* ctx) const override { + OP_INOUT_CHECK(ctx->HasOutput("FloatStatusOut"), "Output", "FloatStatusOut", + "get_float_status"); + ctx->SetOutputDim("FloatStatusOut", ctx->GetInputDim("FloatStatus")); + } + + protected: + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext& ctx) const override { + return framework::OpKernelType(framework::proto::VarType::FP32, + ctx.GetPlace()); + } +}; + +class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("FloatStatus", + "(Tensor) of shape {8} that holds the float status."); + AddOutput("FloatStatusOut", + "(Tensor) of shape {8} that holds the get float status."); + AddComment(R"DOC( + Get the float status +)DOC"); + } +}; + +template +class GetFloatStatusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW(platform::errors::Unimplemented( + "Operator get_float_status is not supported on CPU")); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +using CPU = paddle::platform::CPUDeviceContext; + +REGISTER_OPERATOR( + get_float_status, ops::GetFloatStatusOp, ops::GetFloatStatusMaker, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL(get_float_status, ops::GetFloatStatusKernel); diff --git a/paddle/fluid/operators/amp/get_float_status_op_npu.cc b/paddle/fluid/operators/amp/get_float_status_op_npu.cc new file mode 100644 index 0000000000000000000000000000000000000000..f9c37137f32fae2a0317bacf938cfd9c2ffcc7d9 --- /dev/null +++ b/paddle/fluid/operators/amp/get_float_status_op_npu.cc @@ -0,0 +1,54 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/npu_op_runner.h" + +namespace paddle { +namespace operators { + +using Tensor = framework::Tensor; + +template +class GetFloatStatusKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + const auto* float_status = ctx.Input("FloatStatus"); + auto* float_status_out = ctx.Output("FloatStatusOut"); + // GetClearFloatStatus modifies the input. + PADDLE_ENFORCE_EQ(float_status_out, float_status, + platform::errors::PreconditionNotMet( + "The input(FloatStatus) and Output(FloatStatusOut) " + "should be the same.")); + Tensor tmp; + tmp.mutable_data({8}, ctx.GetPlace()); + auto stream = + ctx.template device_context() + .stream(); + // NPUGetFloatStatus updates data on input in-place. + // tmp is only placeholder. + NpuOpRunner("NPUGetFloatStatus", {*float_status}, {tmp}).Run(stream); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_NPU_KERNEL( + get_float_status, + ops::GetFloatStatusKernel); diff --git a/paddle/fluid/pybind/op_function_generator.cc b/paddle/fluid/pybind/op_function_generator.cc index a706cc49f5cf6edaab277d4be92d7bd5d7454bc7..fb6ec0b32d9916c938f8236ed101fe3afb83000d 100644 --- a/paddle/fluid/pybind/op_function_generator.cc +++ b/paddle/fluid/pybind/op_function_generator.cc @@ -157,6 +157,8 @@ std::map> op_passing_outs_map = { {"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}}, {"rnn", {"DropoutState"}}, {"run_program", {"Out", "DOut", "OutScope"}}, + {"clear_float_status", {"FloatStatusOut"}}, + {"get_float_status", {"FloatStatusOut"}}, }; // NOTE(pangyoki): Tensor View Strategy. diff --git a/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt b/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt index 4ab9262f248a21d454f9d32520fe084553cadfe7..c8fe050a5dc6b38cdd91387c1a5b46b40ced8b7e 100644 --- a/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/npu/CMakeLists.txt @@ -5,4 +5,12 @@ if (WITH_ASCEND_CL) foreach(TEST_OP ${TEST_OPS}) py_test_modules(${TEST_OP} MODULES ${TEST_OP}) endforeach(TEST_OP) + + # NOTE: NPU `get_float_status` read the value from register, During the test, + # it is found that this register will be overwritten by any program on the card. + # In order to prevent the interference of nan/inf in the other unittests, we + # need to set the unittests related to `float_status` to exclusive. + set_tests_properties(test_amp_check_finite_and_scale_op_npu PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_flags_check_nan_inf_npu PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") + set_tests_properties(test_float_status_op_npu PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE") endif() diff --git a/python/paddle/fluid/tests/unittests/npu/test_flags_check_nan_inf_npu.py b/python/paddle/fluid/tests/unittests/npu/test_flags_check_nan_inf_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..66c39062dc7ed64dee80cf2c8af7264293a90f35 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_flags_check_nan_inf_npu.py @@ -0,0 +1,86 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle.static as static +import paddle.fluid as fluid +from paddle.static import Program, program_guard + +paddle.enable_static() + + +class TestCheckFiniteAndUnscale(unittest.TestCase): + def setUp(self): + fluid.set_flags({'FLAGS_check_nan_inf': True}) + + def get_prog(self): + main_program = Program() + with program_guard(main_program): + a = static.data(name="a", shape=[32, 32], dtype='float32') + b = static.data(name="b", shape=[32, 32], dtype='float32') + out = a / b + fp16_a = a.cast(paddle.float16) + fp16_b = b.cast(paddle.float16) + out = fp16_a + fp16_b + return main_program, out + + def run_prog(self, a, b): + main_program, out = self.get_prog() + place = paddle.set_device('npu') + + exe = static.Executor(place) + out_ = exe.run(main_program, feed={"a": a, "b": b}, fetch_list=[out]) + return out_ + + def test_contains_nan(self): + a = np.zeros((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + + with self.assertRaisesRegex(RuntimeError, "contains Nan/Inf"): + out = self.run_prog(a, b) + print(out) + + def test_contains_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + + with self.assertRaisesRegex(RuntimeError, "contains Nan/Inf"): + out = self.run_prog(a, b) + print(out) + + def test_not_contains_nan_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.ones((32, 32)).astype('float32') + + out = self.run_prog(a, b) + print(out) + + def test_fp16_overflow(self): + a = np.ones((32, 32)).astype('float32') + b = np.ones((32, 32)).astype('float32') + a[0][0] = 50000 + b[0][0] = 50000 + + with self.assertRaisesRegex(RuntimeError, "contains Nan/Inf"): + out = self.run_prog(a, b) + print(out) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_float_status_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_float_status_op_npu.py new file mode 100644 index 0000000000000000000000000000000000000000..206641dab5c13a99f7aa3503a73d38994c5b48b1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_float_status_op_npu.py @@ -0,0 +1,103 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest, skip_check_grad_ci +import paddle +import paddle._C_ops as ops + + +class TestGetFloatStatusOp(unittest.TestCase): + def setUp(self): + device = paddle.set_device('npu') + + def run_prog(self, a, b): + a = paddle.to_tensor(a) + b = paddle.to_tensor(b) + + flag = ops.alloc_float_status() + ops.clear_float_status(flag, flag) + + out = a / b + ops.get_float_status(flag, flag) + return out.numpy(), flag.numpy() + + def test_contains_nan(self): + a = np.zeros((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + + out, flag = self.run_prog(a, b) + print(out, flag) + self.assertGreaterEqual(np.sum(flag), 1.0) + + def test_contains_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + + out, flag = self.run_prog(a, b) + print(out, flag) + self.assertGreaterEqual(np.sum(flag), 1.0) + + def test_not_contains_nan_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.ones((32, 32)).astype('float32') + + out, flag = self.run_prog(a, b) + print(out, flag) + self.assertLess(np.sum(flag), 1.0) + + +class TestClearFloatStatusOp(unittest.TestCase): + def setUp(self): + device = paddle.set_device('npu') + + def run_prog(self, a, b): + a = paddle.to_tensor(a) + b = paddle.to_tensor(b) + + flag = ops.alloc_float_status() + ops.clear_float_status(flag, flag) + + out = a / b + ops.get_float_status(flag, flag) + + ops.clear_float_status(flag, flag) + out = a + b + ops.get_float_status(flag, flag) + return out.numpy(), flag.numpy() + + def test_not_contains_nan_inf(self): + a = np.ones((32, 32)).astype('float32') + b = np.zeros((32, 32)).astype('float32') + + out, flag = self.run_prog(a, b) + print(out, flag) + self.assertLess(np.sum(flag), 1.0) + + def test_fp16_overflow(self): + a = np.ones((32, 32)).astype('float16') + b = np.zeros((32, 32)).astype('float16') + a[0][0] = 50000 + b[0][0] = 50000 + + out, flag = self.run_prog(a, b) + print(out, flag) + self.assertGreaterEqual(np.sum(flag), 1.0) + + +if __name__ == '__main__': + unittest.main()