未验证 提交 c727ec4a 编写于 作者: W WangXi 提交者: GitHub

[NPU] add get_float_status op and refine NPU check_nan_inf (#35274)

上级 1635c02b
......@@ -64,7 +64,11 @@ elseif(WITH_ROCM)
hip_library(broadcast_op_handle SRCS broadcast_op_handle.cc DEPS op_handle_base scope ddim memory variable_visitor dynload_cuda)
hip_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else()
cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS framework_proto scope place)
if (WITH_ASCEND_CL)
cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS npu_op_runner framework_proto scope place)
else()
cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS framework_proto scope place)
endif()
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
......
......@@ -53,6 +53,12 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
}
}
#ifdef PADDLE_WITH_ASCEND_CL
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place);
#endif
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -15,6 +15,11 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/operators/npu_op_runner.h"
#endif
namespace paddle {
namespace framework {
namespace details {
......@@ -123,8 +128,10 @@ static void InitWhiteListFormEnv() {
template <typename T>
static void PrintNanInf(const T* value, const size_t numel, int print_num,
const std::string& op_type,
const std::string& var_name) {
const std::string& op_type, const std::string& var_name,
bool abort = true) {
T min_value = std::numeric_limits<T>::max();
T max_value = std::numeric_limits<T>::min();
size_t nan_count, inf_count, num_count;
nan_count = inf_count = num_count = 0;
......@@ -137,6 +144,8 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
count = inf_count++;
} else {
count = num_count++;
min_value = std::min(min_value, value[i]);
max_value = std::max(max_value, value[i]);
}
if (count < static_cast<size_t>(print_num)) {
......@@ -144,12 +153,17 @@ static void PrintNanInf(const T* value, const size_t numel, int print_num,
static_cast<uint64_t>(i), static_cast<float>(value[i]));
}
}
printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n",
static_cast<uint64_t>(nan_count), static_cast<uint64_t>(inf_count),
static_cast<uint64_t>(num_count));
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
printf(
"In cpu, there has %lu,%lu,%lu nan,inf,num. "
"And in num, min_value is %f, max_value is %f\n",
static_cast<uint64_t>(nan_count), static_cast<uint64_t>(inf_count),
static_cast<uint64_t>(num_count), static_cast<double>(min_value),
static_cast<double>(max_value));
if (abort) {
PADDLE_THROW(platform::errors::PreconditionNotMet(
"There are `nan` or `inf` in tensor (%s) of operator (%s).", var_name,
op_type));
}
}
// openmp 4.0, reduction with fp16
......@@ -415,6 +429,134 @@ bool IsSkipOp(const framework::OperatorBase& op) {
return false;
}
#ifdef PADDLE_WITH_ASCEND_CL
using NpuOpRunner = paddle::operators::NpuOpRunner;
constexpr int FLOAT_STATUS_SIZE = 8;
static framework::Tensor& npu_float_status() {
static framework::Tensor float_status;
return float_status;
}
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place) {
if (!platform::is_npu_place(place)) return;
std::call_once(white_list_init_flag, InitWhiteListFormEnv);
if (IsSkipOp(op)) return;
auto* dev_ctx = reinterpret_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream();
auto& flag = npu_float_status();
flag.mutable_data<float>({FLOAT_STATUS_SIZE}, place);
NpuOpRunner("NPUAllocFloatStatus", {}, {flag}).Run(stream);
framework::Tensor tmp;
tmp.mutable_data<float>({FLOAT_STATUS_SIZE}, place);
NpuOpRunner("NPUClearFloatStatus", {tmp}, {flag}).Run(stream);
}
void PrintNpuVarInfo(const std::string& op_type, const std::string& var_name,
const framework::Variable* var,
const platform::Place& place) {
const Tensor* tensor{nullptr};
if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<framework::SelectedRows>()) {
tensor = &var->Get<framework::SelectedRows>().value();
} else {
VLOG(10) << var_name << " var_name need not to check";
return;
}
if ((tensor->type() != proto::VarType::FP32) &&
(tensor->type() != proto::VarType::FP16)) {
return;
}
if (tensor->memory_size() == 0) {
VLOG(10) << var_name << " var_name need not to check, size == 0";
return;
}
VLOG(10) << "begin check " << op_type << " var_name:" << var_name
<< ", place:" << tensor->place() << ", numel:" << tensor->numel();
framework::Tensor cpu_tensor;
cpu_tensor.Resize(tensor->dims());
cpu_tensor.mutable_data(platform::CPUPlace(), tensor->type());
framework::TensorCopySync(*tensor, platform::CPUPlace(), &cpu_tensor);
LOG(WARNING) << "print [" << var_name << "] tensor info:";
// use env strategy control in future, -1=print_all.
int print_num = 3;
if (tensor->type() == proto::VarType::FP32) {
const float* value = cpu_tensor.data<float>();
PrintNanInf(value, tensor->numel(), print_num, op_type, var_name, false);
} else if (tensor->type() == proto::VarType::FP16) {
const paddle::platform::float16* value =
cpu_tensor.data<paddle::platform::float16>();
PrintNanInf(value, tensor->numel(), print_num, op_type, var_name, false);
}
}
void PrintNPUOpValueInfo(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place) {
LOG(WARNING) << "There are `nan` or `inf` in operator (" << op.Type()
<< "), here we print some tensor value info of this op.";
for (auto& vname : op.InputVars()) {
auto* var = scope.FindVar(vname);
if (var == nullptr) continue;
PrintNpuVarInfo(op.Type(), vname, var, place);
}
for (auto& vname : op.OutputVars(true)) {
auto* var = scope.FindVar(vname);
if (var == nullptr) continue;
PrintNpuVarInfo(op.Type(), vname, var, place);
}
}
static void NPUCheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place) {
if (!platform::is_npu_place(place)) return;
auto* dev_ctx = reinterpret_cast<platform::NPUDeviceContext*>(
platform::DeviceContextPool::Instance().Get(place));
auto stream = dev_ctx->stream();
auto& flag = npu_float_status();
Tensor tmp;
tmp.mutable_data<float>({FLOAT_STATUS_SIZE}, place);
// NPUGetFloatStatus updates data on input in-place.
// tmp is only placeholder.
NpuOpRunner("NPUGetFloatStatus", {flag}, {tmp}).Run(stream);
framework::Tensor cpu_tensor;
auto cpu_place = platform::CPUPlace();
float* cpu_data = static_cast<float*>(
cpu_tensor.mutable_data<float>({FLOAT_STATUS_SIZE}, cpu_place));
framework::TensorCopySync(flag, cpu_place, &cpu_tensor);
float sum = 0.0;
for (int i = 0; i < FLOAT_STATUS_SIZE; ++i) {
sum += cpu_data[i];
}
if (sum >= 1.0) PrintNPUOpValueInfo(op, scope, place);
PADDLE_ENFORCE_LT(
sum, 1.0, platform::errors::PreconditionNotMet(
"Operator %s contains Nan/Inf.", op.DebugStringEx(&scope)));
}
#endif
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& exec_scope,
const platform::Place& place) {
......@@ -422,6 +564,13 @@ void CheckOpHasNanOrInf(const framework::OperatorBase& op,
if (IsSkipOp(op)) return;
#ifdef PADDLE_WITH_ASCEND_CL
if (platform::is_npu_place(place)) {
NPUCheckOpHasNanOrInf(op, exec_scope, place);
return;
}
#endif
if (op_var_nan_inf_white_list().count(op.Type()) == 0) {
// NOTE. vname may destruct in the end of this func.
for (auto& vname : op.OutputVars(true)) {
......
......@@ -1105,6 +1105,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
auto* dev_ctx = pool.Get(place);
#ifdef PADDLE_WITH_ASCEND_CL
// NOTE(wangxi): nan/inf cannot be detected on NPU by checking the variable
// values, but only through special `float_status` to checks whether
// the operation is overflow. More about `float_status`, see:
// https://gitee.com/ascend/modelzoo/issues/I3NF8V?from=project-issue
if (FLAGS_check_nan_inf) {
framework::details::NPUAllocAndClearFloatStatus(*this, scope, place);
}
#endif
if (kernel_type_.get() == nullptr || kernel_func_.get() == nullptr) {
ChooseKernel(*runtime_ctx, scope, place);
}
......
......@@ -36,7 +36,7 @@ class ClearFloatStatusKernel : public framework::OpKernel<T> {
Tensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
const auto& runner =
NpuOpRunner("NPUClearFloatStatus", {*float_status}, {tmp});
NpuOpRunner("NPUClearFloatStatus", {tmp}, {*float_status_out});
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cstring>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
class GetFloatStatusOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasOutput("FloatStatusOut"), "Output", "FloatStatusOut",
"get_float_status");
ctx->SetOutputDim("FloatStatusOut", ctx->GetInputDim("FloatStatus"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(framework::proto::VarType::FP32,
ctx.GetPlace());
}
};
class GetFloatStatusMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("FloatStatus",
"(Tensor) of shape {8} that holds the float status.");
AddOutput("FloatStatusOut",
"(Tensor) of shape {8} that holds the get float status.");
AddComment(R"DOC(
Get the float status
)DOC");
}
};
template <typename DeviceContext, typename T>
class GetFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_THROW(platform::errors::Unimplemented(
"Operator get_float_status is not supported on CPU"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(
get_float_status, ops::GetFloatStatusOp, ops::GetFloatStatusMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(get_float_status, ops::GetFloatStatusKernel<CPU, float>);
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <cmath>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class GetFloatStatusKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto* float_status = ctx.Input<framework::Tensor>("FloatStatus");
auto* float_status_out = ctx.Output<framework::Tensor>("FloatStatusOut");
// GetClearFloatStatus modifies the input.
PADDLE_ENFORCE_EQ(float_status_out, float_status,
platform::errors::PreconditionNotMet(
"The input(FloatStatus) and Output(FloatStatusOut) "
"should be the same."));
Tensor tmp;
tmp.mutable_data<float>({8}, ctx.GetPlace());
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
// NPUGetFloatStatus updates data on input in-place.
// tmp is only placeholder.
NpuOpRunner("NPUGetFloatStatus", {*float_status}, {tmp}).Run(stream);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
get_float_status,
ops::GetFloatStatusKernel<paddle::platform::NPUDeviceContext, float>);
......@@ -157,6 +157,8 @@ std::map<std::string, std::set<std::string>> op_passing_outs_map = {
{"ParamOut", "Moment1Out", "Moment2Out", "Beta1PowOut", "Beta2PowOut"}},
{"rnn", {"DropoutState"}},
{"run_program", {"Out", "DOut", "OutScope"}},
{"clear_float_status", {"FloatStatusOut"}},
{"get_float_status", {"FloatStatusOut"}},
};
// NOTE(pangyoki): Tensor View Strategy.
......
......@@ -5,4 +5,12 @@ if (WITH_ASCEND_CL)
foreach(TEST_OP ${TEST_OPS})
py_test_modules(${TEST_OP} MODULES ${TEST_OP})
endforeach(TEST_OP)
# NOTE: NPU `get_float_status` read the value from register, During the test,
# it is found that this register will be overwritten by any program on the card.
# In order to prevent the interference of nan/inf in the other unittests, we
# need to set the unittests related to `float_status` to exclusive.
set_tests_properties(test_amp_check_finite_and_scale_op_npu PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_flags_check_nan_inf_npu PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
set_tests_properties(test_float_status_op_npu PROPERTIES LABELS "RUN_TYPE=EXCLUSIVE")
endif()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle.static as static
import paddle.fluid as fluid
from paddle.static import Program, program_guard
paddle.enable_static()
class TestCheckFiniteAndUnscale(unittest.TestCase):
def setUp(self):
fluid.set_flags({'FLAGS_check_nan_inf': True})
def get_prog(self):
main_program = Program()
with program_guard(main_program):
a = static.data(name="a", shape=[32, 32], dtype='float32')
b = static.data(name="b", shape=[32, 32], dtype='float32')
out = a / b
fp16_a = a.cast(paddle.float16)
fp16_b = b.cast(paddle.float16)
out = fp16_a + fp16_b
return main_program, out
def run_prog(self, a, b):
main_program, out = self.get_prog()
place = paddle.set_device('npu')
exe = static.Executor(place)
out_ = exe.run(main_program, feed={"a": a, "b": b}, fetch_list=[out])
return out_
def test_contains_nan(self):
a = np.zeros((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
with self.assertRaisesRegex(RuntimeError, "contains Nan/Inf"):
out = self.run_prog(a, b)
print(out)
def test_contains_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
with self.assertRaisesRegex(RuntimeError, "contains Nan/Inf"):
out = self.run_prog(a, b)
print(out)
def test_not_contains_nan_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.ones((32, 32)).astype('float32')
out = self.run_prog(a, b)
print(out)
def test_fp16_overflow(self):
a = np.ones((32, 32)).astype('float32')
b = np.ones((32, 32)).astype('float32')
a[0][0] = 50000
b[0][0] = 50000
with self.assertRaisesRegex(RuntimeError, "contains Nan/Inf"):
out = self.run_prog(a, b)
print(out)
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import sys
sys.path.append("..")
from op_test import OpTest, skip_check_grad_ci
import paddle
import paddle._C_ops as ops
class TestGetFloatStatusOp(unittest.TestCase):
def setUp(self):
device = paddle.set_device('npu')
def run_prog(self, a, b):
a = paddle.to_tensor(a)
b = paddle.to_tensor(b)
flag = ops.alloc_float_status()
ops.clear_float_status(flag, flag)
out = a / b
ops.get_float_status(flag, flag)
return out.numpy(), flag.numpy()
def test_contains_nan(self):
a = np.zeros((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
out, flag = self.run_prog(a, b)
print(out, flag)
self.assertGreaterEqual(np.sum(flag), 1.0)
def test_contains_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
out, flag = self.run_prog(a, b)
print(out, flag)
self.assertGreaterEqual(np.sum(flag), 1.0)
def test_not_contains_nan_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.ones((32, 32)).astype('float32')
out, flag = self.run_prog(a, b)
print(out, flag)
self.assertLess(np.sum(flag), 1.0)
class TestClearFloatStatusOp(unittest.TestCase):
def setUp(self):
device = paddle.set_device('npu')
def run_prog(self, a, b):
a = paddle.to_tensor(a)
b = paddle.to_tensor(b)
flag = ops.alloc_float_status()
ops.clear_float_status(flag, flag)
out = a / b
ops.get_float_status(flag, flag)
ops.clear_float_status(flag, flag)
out = a + b
ops.get_float_status(flag, flag)
return out.numpy(), flag.numpy()
def test_not_contains_nan_inf(self):
a = np.ones((32, 32)).astype('float32')
b = np.zeros((32, 32)).astype('float32')
out, flag = self.run_prog(a, b)
print(out, flag)
self.assertLess(np.sum(flag), 1.0)
def test_fp16_overflow(self):
a = np.ones((32, 32)).astype('float16')
b = np.zeros((32, 32)).astype('float16')
a[0][0] = 50000
b[0][0] = 50000
out, flag = self.run_prog(a, b)
print(out, flag)
self.assertGreaterEqual(np.sum(flag), 1.0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册