未验证 提交 1c0db09a 编写于 作者: N niuliling123 提交者: GitHub

Add enable/disable_model_check_nan_inf op (#54081)

上级 a15fec8b
......@@ -758,6 +758,28 @@
kernel :
func : triu_grad
- backward_op: disable_check_model_nan_inf_grad
forward: disable_check_model_nan_inf (Tensor x, int flag=0) -> Tensor(out)
args: (Tensor out_grad, int unsetflag = 1)
output : Tensor(x_grad)
infer_meta:
func: UnchangedInferMeta
param : [out_grad]
kernel:
func: check_model_nan_inf
data_type: out_grad
- backward_op: enable_check_model_nan_inf_grad
forward: enable_check_model_nan_inf (Tensor x, int flag=1) -> Tensor(out)
args: (Tensor out_grad, int unsetflag = 0)
output : Tensor(x_grad)
infer_meta:
func: UnchangedInferMeta
param : [out_grad]
kernel:
func: check_model_nan_inf
data_type: out_grad
- backward_op: unpool_grad
forward: unpool (Tensor x, Tensor indices, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format) -> Tensor(out)
args: (Tensor x, Tensor indices, Tensor out, Tensor out_grad, int[] ksize, int[] strides, int[] padding, IntArray output_size, str data_format)
......
......@@ -209,6 +209,17 @@
data_type : x
backward : depthwise_conv2d_transpose_grad
- op : disable_check_model_nan_inf
args: (Tensor x, int flag = 0)
output: Tensor(out)
infer_meta:
func: UnchangedInferMeta
param : [x]
kernel:
func: check_model_nan_inf
data_type: x
backward: disable_check_model_nan_inf_grad
- op : distribute_fpn_proposals
args : (Tensor fpn_rois, Tensor rois_num, int min_level, int max_level, int refer_level, int refer_scale, bool pixel_offset)
output : Tensor[](multi_fpn_rois){max_level - min_level + 1}, Tensor[](multi_level_rois_num){max_level - min_level + 1}, Tensor(restore_index)
......@@ -305,6 +316,17 @@
data_type : dtype > x
backend : place > x
- op : enable_check_model_nan_inf
args: (Tensor x, int flag = 1)
output: Tensor(out)
infer_meta:
func: UnchangedInferMeta
param : [x]
kernel:
func: check_model_nan_inf
data_type: x
backward: enable_check_model_nan_inf_grad
- op : equal
args : (Tensor x, Tensor y)
output : Tensor(out)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/impl/debug_tools_impl.h"
PD_REGISTER_KERNEL(check_model_nan_inf,
CPU,
ALL_LAYOUT,
phi::CheckModelNanInfKernel,
bool,
float,
double,
int32_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/impl/debug_tools_impl.h"
PD_REGISTER_KERNEL(check_model_nan_inf,
GPU,
ALL_LAYOUT,
phi::CheckModelNanInfKernel,
bool,
float,
double,
int32_t,
int64_t,
phi::dtype::float16,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "glog/logging.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/cast_kernel.h"
#include "paddle/phi/kernels/funcs/math_function.h"
PHI_DECLARE_bool(check_nan_inf);
namespace phi {
template <typename T, typename Context>
void CheckModelNanInfKernel(const Context& dev_ctx,
const DenseTensor& x,
int flag,
DenseTensor* out) {
phi::CastKernel<T>(dev_ctx, x, x.dtype(), out);
VLOG(6) << "model_check_nan_inf: Change FLAGS_check_nan_inf "
<< FLAGS_check_nan_inf << " to " << flag;
FLAGS_check_nan_inf = flag;
}
} // namespace phi
......@@ -35,6 +35,7 @@ __all__ = [
"enable_tensor_checker",
"disable_tensor_checker",
"compare_accuracy",
"check_layer_numerics",
]
......@@ -60,6 +61,77 @@ class DebugMode(Enum):
# DUMP_ALL = 5
def check_layer_numerics(func):
"""
This decorator is used to check the numerical values of the layer's input and output data.
Args:
func (callable): The function to be decorated.
Returns:
callable: The decorated function.
Raises:
None.
Example:
import paddle
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super().__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
self._b = self.create_parameter([2, 3], dtype=dtype)
@paddle.amp.debugging.check_layer_numerics
def forward(self, x):
# return 1/x * self._w + self._b open it you will see the error log
return x * self._w + self._b
dtype = 'float32'
x = paddle.rand([10, 2, 2], dtype=dtype)
model = MyLayer(dtype)
x[0] = float(0)
loss = model(x)
adam = paddle.optimizer.Adam(parameters=model.parameters())
loss.backward()
adam.step()
#error log
#[PRECISION] [ERROR] in [device=gpu:0, op=divide, tensor=, dtype=fp32], numel=40, num_nan=0, num_inf=4, num_zero=0, max=inf, min=1.048930e+00, mean=inf
#Traceback (most recent call last):
# File "tmp.py", line 16, in <module>
# loss = model(x)
# File "/paddle/nn/layer/layers.py", line 1254, in __call__
# return self.forward(*inputs, **kwargs)
# File "/paddle/amp/debugging.py", line 116, in wrapper
# out_data = func(self, *modified_args, **kwargs)
# File "test.py", line 10, in forward
# return 1/x * self._w+ self._b
#RuntimeError: (PreconditionNotMet) There are NAN or INF (num_nan=0, num_inf=4, num_zero=0) in [device=gpu:0, op=divide, tensor=, dtype=fp32].
"""
def wrapper(self, *args, **kwargs):
if args:
# Set temp data and temp.gradient = False
start_data = args[0]
start_data.stop_gradient = False
modified_args = list(args) # Convert args to a mutable list
# Set FLAGS_check_nan_inf = 1
modified_args[0] = _C_ops.enable_check_model_nan_inf(start_data, 1)
# Call the forward function
out_data = func(self, *modified_args, **kwargs)
# Set FLAGS_check_nan_inf = 0
out = _C_ops.disable_check_model_nan_inf(out_data, 0)
return out
else:
print("No elements found in args")
out = func(self, *args, **kwargs)
return out
return wrapper
def set_checked_op_list(checked_op_list):
# check checked_op_list
if checked_op_list is not None:
......
......@@ -118,5 +118,26 @@ class TestTensorChecker(unittest.TestCase):
_assert_flag(False)
class TestCheckLayerNumerics(unittest.TestCase):
def test_layer_checker(self):
class MyLayer(paddle.nn.Layer):
def __init__(self, dtype):
super().__init__()
self._w = self.create_parameter([2, 3], dtype=dtype)
self._b = self.create_parameter([2, 3], dtype=dtype)
@paddle.amp.debugging.check_layer_numerics
def forward(self, x):
return x * self._w + self._b
dtype = 'float32'
x = paddle.rand([10, 2, 3], dtype=dtype)
model = MyLayer(dtype)
loss = model(x)
adam = paddle.optimizer.Adam(parameters=model.parameters())
loss.backward()
adam.step()
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册