未验证 提交 660f781b 编写于 作者: N niuliling123 提交者: GitHub

Print the forward's stack when backward op has nan/inf and FLAGS_check_nan_inf_level = 0 (#52639)

上级 7b5065ab
......@@ -55,9 +55,12 @@ paddle::Tensor add_n_ad_func(const std::vector<paddle::Tensor>& x) {
VLOG(3) << "Final State Running: "
<< "add_n_ad_func";
auto api_result = paddle::experimental::add_n(x);
std::string forward_trace = "";
// Check NaN and Inf if needed
if (FLAGS_check_nan_inf) {
egr::CheckTensorHasNanOrInf("add_n", api_result);
forward_trace = egr::Controller::Instance().GetPythonStack();
}
// Get Outputs
......@@ -83,6 +86,12 @@ paddle::Tensor add_n_ad_func(const std::vector<paddle::Tensor>& x) {
// Node Construction
auto grad_node =
std::shared_ptr<AddNGradNodeFinal>(new AddNGradNodeFinal(1, 1));
// Set forward's stack
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(forward_trace);
}
// SetAttributes if needed
// Set TensorWrappers for Forward Inputs if needed
......
......@@ -110,9 +110,11 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,
dilations,
groups,
data_format);
std::string forward_trace = "";
// Check NaN and Inf if needed
if (FLAGS_check_nan_inf) {
egr::CheckTensorHasNanOrInf("conv2d", api_result);
forward_trace = egr::Controller::Instance().GetPythonStack();
}
// Get Outputs
......@@ -138,6 +140,12 @@ paddle::Tensor conv2d_ad_func(const paddle::Tensor& input,
// Node Construction
auto grad_node =
std::shared_ptr<Conv2dGradNodeFinal>(new Conv2dGradNodeFinal(1, 2));
// Set forward's stack
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(forward_trace);
}
// SetAttributes if needed
grad_node->SetAttributestrides(strides);
grad_node->SetAttributepaddings(paddings);
......
......@@ -172,9 +172,11 @@ sync_batch_norm__ad_func(const paddle::Tensor& x,
data_layout,
use_global_stats,
trainable_statistics);
std::string forward_trace = "";
// Check NaN and Inf if needed
if (FLAGS_check_nan_inf) {
egr::CheckTensorHasNanOrInf("sync_batch_norm_", api_result);
forward_trace = egr::Controller::Instance().GetPythonStack();
}
// Get Outputs
......@@ -226,6 +228,12 @@ sync_batch_norm__ad_func(const paddle::Tensor& x,
// Node Construction
auto grad_node =
std::shared_ptr<SyncBatchNormGradNode>(new SyncBatchNormGradNode(6, 5));
// Set forward's stack
if (FLAGS_check_nan_inf) {
grad_node->SetForwardTrace(forward_trace);
}
egr::Controller::Instance().PushBackForceSequentialNodes(grad_node.get());
// SetAttributes if needed
grad_node->SetAttributemomentum(momentum);
......
......@@ -74,6 +74,12 @@ class Controller {
void EnableLayoutAutoTune() { tracer_->EnableLayoutAutoTune(); }
void SetPythonStack(std::string stack_str) {
tracer_->SetPythonStack(stack_str);
}
std::string GetPythonStack() { return tracer_->GetPythonStack(); }
bool HasGrad() const { return tracer_->HasGrad(); }
void SetHasGrad(bool has_grad) { tracer_->SetHasGrad(has_grad); }
std::string GenerateUniqueName(std::string key = "eager_in_tmp") {
......
......@@ -297,6 +297,10 @@ FORWARD_BODY_TEMPLATE = """ if(require_any_grad) {{
// Node Construction
{}
// Set for forward trace
if (FLAGS_check_nan_inf) {{
{}
}}
// SetAttributes if needed
{}
// Set TensorWrappers for Forward Inputs if needed
......@@ -485,7 +489,25 @@ CHECK_BACKWARD_INPLACE_TEMPLATE = """
}}
}}"""
CHECK_NAN_AND_INF_TEMPLATE = """ if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
CHECK_NAN_AND_INF_TEMPLATE_FORWARD = """
std::string forward_trace ="";
if (FLAGS_check_nan_inf) {{
egr::CheckTensorHasNanOrInf("{}", {});
forward_trace = egr::Controller::Instance().GetPythonStack();
}}
"""
CHECK_NAN_AND_INF_TEMPLATE_BACKWARD = """
if (FLAGS_check_nan_inf) {{
try{{
egr::CheckTensorHasNanOrInf("{}", {});
}} catch(...) {{
LOG(WARNING) << "There are nan/inf in ({})";
auto forward_trace = GetForwardTrace();
std::cout<<forward_trace<<std::endl;
std::rethrow_exception(std::current_exception());
}}
}}
"""
inplace_optional_out_type_map = {
......@@ -1048,11 +1070,15 @@ class DygraphFunctionGeneratorBase(FunctionGeneratorBase):
node_event_name = forward_api_name + " node_creation"
node_creation_event_str = f"{indent}paddle::platform::RecordEvent node_creation_record_event(\"{node_event_name}\", paddle::platform::TracerEventType::OperatorInner, 1);\n"
set_forward_trace = (
f"{indent} grad_node->SetForwardTrace(forward_trace);"
)
if not for_backward:
self.node_creation_str = FORWARD_BODY_TEMPLATE.format(
node_creation_event_str,
pass_stop_gradient_args_str,
node_construction_str,
set_forward_trace,
set_attributes_str,
set_input_tensor_wrappers_str,
set_grad_out_meta_str,
......@@ -1427,7 +1453,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
)
# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE_FORWARD.format(
function_name, "api_result"
)
......@@ -2322,8 +2348,8 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"""
# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(
backward_api_name, "returns"
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE_BACKWARD.format(
backward_api_name, "returns", backward_api_name
)
# Prepare for Node Creation if Necessary
......
......@@ -121,7 +121,10 @@ static PyObject * eager_api_{}(PyObject *self, PyObject *args, PyObject *kwargs)
NOAMP_DYGRAPH_FUNCTION_TEMPLATE = "decltype({}({})) out = {}({});"
FUNCTION_SET_DEVICE_TEMPLATE = """{} if (paddle::platform::is_gpu_place(place)) {{
FUNCTION_SET_DEVICE_TEMPLATE = """{}
LOG(INFO)<<"this is SetPythonStack";
SetPythonStack();
if (paddle::platform::is_gpu_place(place)) {{
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
phi::backends::gpu::SetDeviceId(place.device);
VLOG(4) <<"CurrentDeviceId: " << phi::backends::gpu::GetCurrentDeviceId() << " from " << (int)place.device;
......@@ -170,7 +173,6 @@ PYTHON_C_WRAPPER_TEMPLATE = """
#include "paddle/fluid/pybind/eager.h"
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
namespace paddle {{
namespace pybind {{
......
......@@ -292,6 +292,10 @@ class GradNodeBase {
is_tensor_wrappers_cleared_ = is_tensor_wrappers_cleared;
}
void SetForwardTrace(std::string trace) { forward_trace_ = trace; }
std::string GetForwardTrace() { return forward_trace_; }
private:
// bwd_out_meta_ is used to record Grad output info for backward
paddle::small_vector<std::vector<GradSlotMeta>, kSlotSmallVectorSize>
......@@ -317,6 +321,8 @@ class GradNodeBase {
bool need_complex_to_real_ = false;
bool is_tensor_wrappers_cleared_ = false;
// The trace of forward function
std::string forward_trace_ = "";
};
} // namespace egr
......@@ -37,6 +37,7 @@ DECLARE_string(tracer_mkldnn_ops_off);
namespace paddle {
namespace imperative {
thread_local std::string Tracer::python_stack_ = "";
thread_local bool Tracer::enable_program_desc_tracing_ = false;
......
......@@ -199,7 +199,8 @@ class Tracer {
use_layout_autotune_ = false;
return false;
}
void SetPythonStack(std::string stack_str) { python_stack_ = stack_str; }
std::string GetPythonStack() { return python_stack_; }
phi::KernelSignature GetExpectedKernelSignature(
const std::string& type,
const NameTensorMap& ins,
......@@ -215,6 +216,7 @@ class Tracer {
std::unique_ptr<UniqueNameGenerator> generator_;
platform::Place expected_place_;
GarbageCollectorMap gcs_;
static thread_local std::string python_stack_;
static thread_local bool enable_program_desc_tracing_;
static thread_local bool use_layout_autotune_;
static thread_local bool has_grad_;
......
......@@ -38,6 +38,7 @@ limitations under the License. */
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/dense_tensor.h"
DECLARE_bool(check_nan_inf);
namespace paddle {
namespace pybind {
......@@ -215,6 +216,21 @@ std::shared_ptr<imperative::VarBase> CastPyArg2VarBase(PyObject* obj,
return py::cast<std::shared_ptr<imperative::VarBase>>(obj);
}
void SetPythonStack() {
if (FLAGS_check_nan_inf) {
pybind11::gil_scoped_acquire gil;
PyObject* mod = PyImport_ImportModule("traceback");
PyObject* traceback_list = PyObject_CallMethod(mod, "format_stack", "");
std::string str = "";
for (Py_ssize_t i = 0; i < PyList_Size(traceback_list); i++) {
PyObject* line = PyList_GetItem(traceback_list, i);
str += py::str(PyUnicode_AsUTF8(line));
}
std::string last = str + egr::Controller::Instance().GetPythonStack();
egr::Controller::Instance().SetPythonStack(last);
}
}
std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj,
ssize_t arg_pos) {
if (PyObject_IsInstance(obj,
......
......@@ -78,6 +78,7 @@ std::vector<std::string> CastPyArg2VectorOfString(PyObject* obj,
ssize_t arg_pos);
std::shared_ptr<jit::Function> CastPyArg2JitFunction(PyObject* obj,
ssize_t arg_pos);
void SetPythonStack();
PyObject* ToPyObject(int value);
PyObject* ToPyObject(uint32_t value);
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import paddle
def main():
paddle.set_flags({"FLAGS_check_nan_inf": 1, "FLAGS_check_nan_inf_level": 0})
cpu_place = paddle.CPUPlace()
x = paddle.to_tensor([1, 0.0, 3], stop_gradient=False, place=cpu_place)
y = paddle.to_tensor([0.2, 0.0, 0.5], place=cpu_place)
z = paddle.pow(x, y)
paddle.autograd.backward([z])
if __name__ == "__main__":
main()
......@@ -78,6 +78,13 @@ class TestCheckSkipEnv(TestNanInf):
class TestNanInfCheckResult(unittest.TestCase):
def setUp(self):
self._python_interp = sys.executable
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
self._python_interp += " -m coverage run --branch -p"
self.env = os.environ.copy()
def generate_inputs(self, shape, dtype="float32"):
data = np.random.random(size=shape).astype(dtype)
# [-10, 10)
......@@ -141,6 +148,25 @@ class TestNanInfCheckResult(unittest.TestCase):
if paddle.fluid.core.is_compiled_with_cuda():
_check_num_nan_inf(use_cuda=True)
def test_check_stack(self):
self._python_interp += " check_nan_inf_backward_stack.py"
cmd = self._python_interp
proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.env,
)
out, err = proc.communicate()
returncode = proc.returncode
print(out)
print(err)
# in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find(b' z = paddle.pow(x, y)') != -1
def check_nan_inf_level(self, use_cuda, dtype):
shape = [8, 8]
x_np, y_np = self.generate_inputs(shape, dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册