From df4a978cb27d1875e689fe287dd7b29e7cc061e2 Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Thu, 10 Jun 2021 14:28:20 +0800 Subject: [PATCH] [Debug] Add nan& inf check FLAG for dygraph (#32635) * add check nan of inf for dygraph * add unittest for dygraph * revert error change --- .../fluid/framework/details/nan_inf_utils.h | 20 ++++ .../framework/details/nan_inf_utils_detail.cc | 15 ++- paddle/fluid/imperative/CMakeLists.txt | 2 +- paddle/fluid/imperative/prepared_operator.cc | 8 ++ .../unittests/check_nan_inf_base_dygraph.py | 112 ++++++++++++++++++ .../fluid/tests/unittests/test_nan_inf.py | 11 +- 6 files changed, 161 insertions(+), 7 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py diff --git a/paddle/fluid/framework/details/nan_inf_utils.h b/paddle/fluid/framework/details/nan_inf_utils.h index 4d7d9afe701..cf64ccd60f4 100644 --- a/paddle/fluid/framework/details/nan_inf_utils.h +++ b/paddle/fluid/framework/details/nan_inf_utils.h @@ -19,6 +19,7 @@ #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/platform/place.h" namespace paddle { @@ -30,9 +31,28 @@ void CheckVarHasNanOrInf(const std::string& op_type, const std::string& var_name, const platform::Place& place); +void CheckVarHasNanOrInf(const std::string& op_type, + const std::string& var_name, + const framework::Variable* var, + const platform::Place& place); + void CheckOpHasNanOrInf(const framework::OperatorBase& op, const framework::Scope& scope, const platform::Place& place); + +template +void CheckOpHasNanOrInfInDygraph(const std::string& op_type, + const imperative::NameVarMap& op_outs, + platform::Place place) { + for (const auto& pair : op_outs) { + for (const auto& ivar : pair.second) { + auto* var = ivar->MutableVar(); + if (var == nullptr) continue; + CheckVarHasNanOrInf(op_type, ivar->Name(), var, place); + } + } +} + } // namespace details } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc index f9aa14bf7e8..30231a1799f 100644 --- a/paddle/fluid/framework/details/nan_inf_utils_detail.cc +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -297,13 +297,12 @@ void tensor_check(const std::string& op_type, } void CheckVarHasNanOrInf(const std::string& op_type, - const framework::Scope& scope, const std::string& var_name, + const framework::Variable* var, const platform::Place& place) { - auto* var = scope.FindVar(var_name); PADDLE_ENFORCE_NOT_NULL( - var, platform::errors::NotFound("In op=%s, can't find var:%s", op_type, - var_name)); + var, platform::errors::NotFound("Cannot find var: `%s` in op `%s`.", + var_name, op_type)); const Tensor* tensor{nullptr}; if (var->IsType()) { @@ -393,6 +392,14 @@ void CheckVarHasNanOrInf(const std::string& op_type, tensor_check(op_type, var_name, *tensor, place); } +void CheckVarHasNanOrInf(const std::string& op_type, + const framework::Scope& scope, + const std::string& var_name, + const platform::Place& place) { + auto* var = scope.FindVar(var_name); + CheckVarHasNanOrInf(op_type, var_name, var, place); +} + bool IsSkipOp(const framework::OperatorBase& op) { if (op_type_nan_inf_white_list().count(op.Type()) != 0) return true; diff --git a/paddle/fluid/imperative/CMakeLists.txt b/paddle/fluid/imperative/CMakeLists.txt index 6bee3d44b2e..c9dffe2d76a 100644 --- a/paddle/fluid/imperative/CMakeLists.txt +++ b/paddle/fluid/imperative/CMakeLists.txt @@ -1,6 +1,6 @@ cc_library(imperative_flag SRCS flags.cc DEPS gflags) -cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform) +cc_library(prepared_operator SRCS prepared_operator.cc DEPS proto_desc operator device_context lod_tensor selected_rows var_type_traits op_kernel_type data_transform nan_inf_utils) cc_library(layer SRCS layer.cc DEPS prepared_operator math_function imperative_flag variable_helper op_registry) add_subdirectory(jit) cc_library(amp SRCS amp_auto_cast.cc DEPS layer ) diff --git a/paddle/fluid/imperative/prepared_operator.cc b/paddle/fluid/imperative/prepared_operator.cc index 2a3b6424d4a..4a42751b1c4 100644 --- a/paddle/fluid/imperative/prepared_operator.cc +++ b/paddle/fluid/imperative/prepared_operator.cc @@ -15,8 +15,11 @@ #include "paddle/fluid/imperative/prepared_operator.h" #include "paddle/fluid/framework/data_type_transform.h" +#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/imperative/infer_shape_context.h" +DECLARE_bool(check_nan_inf); + namespace paddle { namespace imperative { @@ -175,6 +178,11 @@ static void PreparedOpRunImpl( func(DygraphExecutionContext(op, scope, *dev_ctx, ctx, ins, outs, attrs)); + if (FLAGS_check_nan_inf) { + framework::details::CheckOpHasNanOrInfInDygraph( + op.Type(), outs, dev_ctx->GetPlace()); + } + /** * [ Why need handle complex gradient to real gradient? ] * diff --git a/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py b/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py new file mode 100644 index 00000000000..08bab306df1 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/check_nan_inf_base_dygraph.py @@ -0,0 +1,112 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import unicode_literals +from __future__ import print_function + +import os +import sys +import time +import numpy as np + +os.environ[str("FLAGS_check_nan_inf")] = str("1") +os.environ[str("GLOG_vmodule")] = str("nan_inf_utils_detail=10") + +import paddle +import paddle.nn as nn + +np.random.seed(0) + + +def generator(): + batch_size = 5 + for i in range(5): + curr_train_x = np.random.randint( + batch_size, size=(batch_size, 3)).astype("float32") + if i >= 2: + curr_train_x[0, :] = np.nan + curr_train_x[-1, :] = np.inf + res = [] + for i in range(batch_size): + y = i % 3 + res.append([y]) + y_label = np.array(res).astype('int64') + yield [curr_train_x, y_label] + + +class TestLayer(nn.Layer): + def __init__(self): + super(TestLayer, self).__init__() + self.linear1 = nn.Linear(3, 400) + self.linear2 = nn.Linear(400, 400) + self.linear3 = nn.Linear(400, 3) + + def forward(self, x): + x = self.linear1(x) + x = nn.functional.sigmoid(x) + x = self.linear2(x) + x = nn.functional.sigmoid(x) + x = self.linear3(x) + x = nn.functional.softmax(x) + + return x + + +def check(use_cuda): + paddle.set_device('gpu' if use_cuda else 'cpu') + + net = TestLayer() + sgd = paddle.optimizer.SGD(learning_rate=0.05, parameters=net.parameters()) + + for step, (x, y) in enumerate(generator()): + x = paddle.to_tensor(x) + y = paddle.to_tensor(y) + + zero = paddle.zeros(shape=[1], dtype='int64') + fp16_zero = paddle.cast(zero, dtype='float16') + + y = y + zero + + y_pred = net(x) + + cost = nn.functional.cross_entropy(y_pred, y, use_softmax=False) + avg_cost = paddle.mean(cost) + + acc_top1 = paddle.metric.accuracy(input=y_pred, label=y, k=1) + + print('iter={:.0f}, cost={}, acc1={}'.format( + step, avg_cost.numpy(), acc_top1.numpy())) + + sgd.step() + sgd.clear_grad() + + +if __name__ == '__main__': + if paddle.is_compiled_with_cuda(): + try: + check(use_cuda=True) + assert False + except Exception as e: + print(e) + print(type(e)) + # Note. Enforce in cuda kernel may not catch in paddle, and + # Exception type will be RuntimeError + assert type(e) == OSError or type(e) == RuntimeError + try: + check(use_cuda=False) + assert False + except Exception as e: + print(e) + print(type(e)) + assert type(e) == RuntimeError diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py index 1673002cb79..cb7e673c6ca 100644 --- a/python/paddle/fluid/tests/unittests/test_nan_inf.py +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -29,11 +29,10 @@ class TestNanInf(unittest.TestCase): self._python_interp = sys.executable if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': self._python_interp += " -m coverage run --branch -p" - self._python_interp += " check_nan_inf_base.py" self.env = os.environ.copy() - def test_nan_inf(self): + def check_nan_inf(self): cmd = self._python_interp proc = subprocess.Popen( @@ -53,6 +52,14 @@ class TestNanInf(unittest.TestCase): assert (out + err ).find('There are `nan` or `inf` in tensor'.encode()) != -1 + def test_nan_inf_in_static_mode(self): + self._python_interp += " check_nan_inf_base.py" + self.check_nan_inf() + + def test_nan_inf_in_dynamic_mode(self): + self._python_interp += " check_nan_inf_base_dygraph.py" + self.check_nan_inf() + class TestNanInfEnv(TestNanInf): def setUp(self): -- GitLab