From 8a0f611b642d06b97097aac0a5ac0a955b440b5d Mon Sep 17 00:00:00 2001 From: WangXi Date: Thu, 12 Dec 2019 16:37:20 +0800 Subject: [PATCH] Rewrite check nan inf tools (#21076) --- paddle/fluid/framework/CMakeLists.txt | 2 +- paddle/fluid/framework/details/CMakeLists.txt | 2 + .../fluid/framework/details/nan_inf_utils.h | 38 +++ .../framework/details/nan_inf_utils_detail.cc | 320 ++++++++++++++++++ .../framework/details/nan_inf_utils_detail.cu | 189 +++++++++++ .../framework/details/nan_inf_utils_detail.h | 59 ++++ paddle/fluid/framework/operator.cc | 12 +- .../tests/unittests/check_nan_inf_base.py | 116 +++++++ .../fluid/tests/unittests/test_nan_inf.py | 65 ++++ 9 files changed, 792 insertions(+), 11 deletions(-) create mode 100644 paddle/fluid/framework/details/nan_inf_utils.h create mode 100644 paddle/fluid/framework/details/nan_inf_utils_detail.cc create mode 100644 paddle/fluid/framework/details/nan_inf_utils_detail.cu create mode 100644 paddle/fluid/framework/details/nan_inf_utils_detail.h create mode 100644 python/paddle/fluid/tests/unittests/check_nan_inf_base.py create mode 100644 python/paddle/fluid/tests/unittests/test_nan_inf.py diff --git a/paddle/fluid/framework/CMakeLists.txt b/paddle/fluid/framework/CMakeLists.txt index 97c183a7d2..6aba7c685a 100644 --- a/paddle/fluid/framework/CMakeLists.txt +++ b/paddle/fluid/framework/CMakeLists.txt @@ -133,7 +133,7 @@ cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place) cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_vars_inference) cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto - shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check) + shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils) cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context) cc_test(operator_exception_test SRCS operator_exception_test.cc DEPS operator op_registry device_context) diff --git a/paddle/fluid/framework/details/CMakeLists.txt b/paddle/fluid/framework/details/CMakeLists.txt index 6789c54210..6366a2b3e5 100644 --- a/paddle/fluid/framework/details/CMakeLists.txt +++ b/paddle/fluid/framework/details/CMakeLists.txt @@ -22,6 +22,7 @@ endif() if(WITH_GPU) + nv_library(nan_inf_utils SRCS nan_inf_utils_detail.cc nan_inf_utils_detail.cu DEPS framework_proto scope place) nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory dynload_cuda variable_visitor) nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory @@ -43,6 +44,7 @@ if(WITH_GPU) nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle) else() + cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS framework_proto scope place) cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory variable_visitor) cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory diff --git a/paddle/fluid/framework/details/nan_inf_utils.h b/paddle/fluid/framework/details/nan_inf_utils.h new file mode 100644 index 0000000000..4d7d9afe70 --- /dev/null +++ b/paddle/fluid/framework/details/nan_inf_utils.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include + +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace details { +// assert false when meets NAN or inf +void CheckVarHasNanOrInf(const std::string& op_type, + const framework::Scope& scope, + const std::string& var_name, + const platform::Place& place); + +void CheckOpHasNanOrInf(const framework::OperatorBase& op, + const framework::Scope& scope, + const platform::Place& place); +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cc b/paddle/fluid/framework/details/nan_inf_utils_detail.cc new file mode 100644 index 0000000000..956b099e88 --- /dev/null +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cc @@ -0,0 +1,320 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/nan_inf_utils.h" +#include "paddle/fluid/framework/details/nan_inf_utils_detail.h" + +#include +#include +#include +#include + +#include "paddle/fluid/framework/op_proto_maker.h" +#include "paddle/fluid/framework/selected_rows.h" + +namespace paddle { +namespace framework { +namespace details { + +static std::once_flag white_list_init_flag; + +static int op_role_nan_inf_white_list = 0; + +static constexpr int FORWARD = 0x10000; + +// lazy init +static const std::unordered_map& role_str2int() { + /* In op_proto_maker.h + * framework::OpRole::kForward = 0x0000, + * framework::OpRole::kBackward = 0x0001, + * framework::OpRole::kOptimize = 0x0002, + * framework::OpRole::kRPC = 0x0004, + * framework::OpRole::kDist = 0x0008, + * framework::OpRole::kLRSched = 0x0010, + * framework::OpRole::kLoss = 0x0100, + * framework::OpRole::kNotSpecified = 0x1000, + */ + static const std::unordered_map _role_str2int = { + {"forward", FORWARD}, /* kForward=0, can't filter */ + {"backward", static_cast(framework::OpRole::kBackward)}, + {"optimize", static_cast(framework::OpRole::kOptimize)}, + {"rpc", static_cast(framework::OpRole::kRPC)}, + {"dist", static_cast(framework::OpRole::kDist)}, + {"lrsched", static_cast(framework::OpRole::kLRSched)}, + {"loss", static_cast(framework::OpRole::kLoss)}, + {"default", static_cast(framework::OpRole::kNotSpecified)}, + }; + return _role_str2int; +} + +static std::unordered_set& op_type_nan_inf_white_list() { + static std::unordered_set _op_type_nan_inf_white_list = { + "coalesce_tensor", /* This Op will alloc tensor, and may not init space */ + }; + return _op_type_nan_inf_white_list; +} + +static std::unordered_map>& +op_var_nan_inf_white_list() { + static std::unordered_map> + _op_var_nan_inf_white_list = { + /* encoded & gather var consist of idx&val, can't judge directly */ + {"dgc", {"__dgc_encoded__", "__dgc_gather__"}}, + }; + return _op_var_nan_inf_white_list; +} + +static void InitWhiteListFormEnv() { + // op_type_skip and op_var_skip may be NULL. + // So need init static value in there, prevent thread competition. + // NOTE. role_str2int needn't do this for it only used in this func. + op_type_nan_inf_white_list(); + op_var_nan_inf_white_list(); + + // export PADDLE_INF_NAN_SKIP_OP="op0,op1,op2" + // export PADDLE_INF_NAN_SKIP_ROLE="role1,role2,role3" + // export PADDLE_INF_NAN_SKIP_VAR="op0:var0,op0:var1,op1:var0" + const char* op_type_skip = std::getenv("PADDLE_INF_NAN_SKIP_OP"); + const char* op_role_skip = std::getenv("PADDLE_INF_NAN_SKIP_ROLE"); + const char* op_var_skip = std::getenv("PADDLE_INF_NAN_SKIP_VAR"); + + if (op_type_skip != NULL) { + std::stringstream ss(op_type_skip); + std::string op_type; + while (std::getline(ss, op_type, ',')) { + op_type_nan_inf_white_list().emplace(op_type); + } + } + + if (op_role_skip != NULL) { + std::stringstream ss(op_role_skip); + std::string op_role; + while (std::getline(ss, op_role, ',')) { + PADDLE_ENFORCE_EQ(role_str2int().find(op_role) != role_str2int().end(), + true, + platform::errors::InvalidArgument( + "Skip role must be one of " + "{forward,backward,optimize,rpc,dist,lrsched,loss," + "default}, instead of %s", + op_role)); + op_role_nan_inf_white_list |= role_str2int().at(op_role); + } + } + + if (op_var_skip != NULL) { + std::stringstream ss(op_var_skip); + std::string op_var; + while (std::getline(ss, op_var, ',')) { + auto pos = op_var.find(":"); + PADDLE_ENFORCE_EQ( + pos != std::string::npos, true, + platform::errors::InvalidArgument( + "Skip var format must be op:var, instead of %s", op_var)); + std::string op = op_var.substr(0, pos); + std::string var = op_var.substr(pos + 1); + + op_var_nan_inf_white_list()[op].emplace_back(var); + } + } +} + +template +static void PrintNanInf(const T* value, const size_t numel, int print_num, + const std::string& op_type, + const std::string& var_name) { + size_t nan_count, inf_count, num_count; + nan_count = inf_count = num_count = 0; + + // CPU print num value + for (size_t i = 0; i < numel; ++i) { + size_t count = 0; + if (std::isnan(value[i])) { + count = nan_count++; + } else if (std::isinf(value[i])) { + count = inf_count++; + } else { + count = num_count++; + } + + if (count < static_cast(print_num)) { + printf("numel:%lu index:%lu value:%f\n", static_cast(numel), + static_cast(i), static_cast(value[i])); + } + } + bool has_nan_inf = true; + printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n", + static_cast(nan_count), static_cast(inf_count), + static_cast(num_count)); + PADDLE_ENFORCE_EQ(has_nan_inf, false, + platform::errors::PreconditionNotMet( + "===ERROR: in [op=%s] [tensor=%s] find nan or inf===", + op_type, var_name)); +} + +// openmp 4.0, reduction with fp16 +#if defined _OPENMP && _OPENMP >= 201307 +// more detail see: 180 page of +// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf +#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in) +#endif + +template +static void CheckNanInf(const T* value, const size_t numel, int print_num, + const std::string& op_type, + const std::string& var_name) { + T sum = static_cast(0.0); +#if defined _OPENMP && _OPENMP >= 201307 +#pragma omp parallel for simd reduction(+ : sum) +#elif defined _OPENMP +#pragma omp parallel for reduction(+ : sum) +#endif + for (size_t i = 0; i < numel; ++i) { + sum += (value[i] - value[i]); + } + + if (std::isnan(sum) || std::isinf(sum)) { + PrintNanInf(value, numel, print_num, op_type, var_name); + } +} + +#if defined _OPENMP && _OPENMP >= 201307 +// openmp4.0 not need to specialization fp16 +#elif defined _OPENMP +template <> +void CheckNanInf( + const paddle::platform::float16* value, const size_t numel, int print_num, + const std::string& op_type, const std::string& var_name) { + float sum = 0.0f; +#pragma omp parallel for reduction(+ : sum) + for (size_t i = 0; i < numel; ++i) { + sum += static_cast(value[i] - value[i]); + } + + if (std::isnan(sum) || std::isinf(sum)) { + PrintNanInf(value, numel, print_num, op_type, var_name); + } +} +#endif + +template <> +template +void TensorCheckerVisitor::apply( + typename std::enable_if::value>::type*) const { + // use env strategy control in future, -1=print_all. + int print_num = 3; + CheckNanInf(tensor_.data(), tensor_.numel(), print_num, op_type_, + var_name_); +} + +template <> +void tensor_check(const std::string& op_type, + const std::string& var_name, + const framework::Tensor& tensor, + const platform::Place& place) { + TensorCheckerVisitor vistor(op_type, var_name, + tensor, place); + VisitDataType(tensor.type(), vistor); +} + +void CheckVarHasNanOrInf(const std::string& op_type, + const framework::Scope& scope, + const std::string& var_name, + const platform::Place& place) { + auto* var = scope.FindVar(var_name); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::NotFound("In op=%s, can't find var:%s", op_type, + var_name)); + + const Tensor* tensor{nullptr}; + if (var->IsType()) { + tensor = &var->Get(); + } else if (var->IsType()) { + tensor = &var->Get().value(); + } else { + VLOG(10) << var_name << " var_name need not to check"; + return; + } + + if (tensor->memory_size() == 0) { + VLOG(10) << var_name << " var_name need not to check, size == 0"; + return; + } + + VLOG(10) << "begin check " << op_type << " var_name:" << var_name + << ", place:" << tensor->place() << ", numel:" << tensor->numel(); + + if (platform::is_gpu_place(tensor->place())) { +#ifdef PADDLE_WITH_CUDA + tensor_check(op_type, var_name, *tensor, + place); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.", + var_name)); +#endif + return; + } + + tensor_check(op_type, var_name, *tensor, place); +} + +bool IsSkipOp(const framework::OperatorBase& op) { + if (op_type_nan_inf_white_list().count(op.Type()) != 0) return true; + + int op_role = op.template Attr( + framework::OpProtoAndCheckerMaker::OpRoleAttrName()); + + // kForward=0, can't filter + if (op_role == static_cast(framework::OpRole::kForward)) { + op_role = FORWARD; + } + if (op_role_nan_inf_white_list & op_role) return true; + + return false; +} + +void CheckOpHasNanOrInf(const framework::OperatorBase& op, + const framework::Scope& exec_scope, + const platform::Place& place) { + std::call_once(white_list_init_flag, InitWhiteListFormEnv); + + if (IsSkipOp(op)) return; + + if (op_var_nan_inf_white_list().count(op.Type()) == 0) { + // NOTE. vname may destruct in the end of this func. + for (auto& vname : op.OutputVars(true)) { + auto* var = exec_scope.FindVar(vname); + if (var == nullptr) continue; + CheckVarHasNanOrInf(op.Type(), exec_scope, vname, place); + } + } else { + for (auto& vname : op.OutputVars(true)) { + bool need_check = true; + for (auto& white_vname : op_var_nan_inf_white_list().at(op.Type())) { + if (vname.find(white_vname) != std::string::npos) { + need_check = false; + break; + } + } + if (!need_check) continue; + auto* var = exec_scope.FindVar(vname); + if (var == nullptr) continue; + CheckVarHasNanOrInf(op.Type(), exec_scope, vname, place); + } + } +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.cu b/paddle/fluid/framework/details/nan_inf_utils_detail.cu new file mode 100644 index 0000000000..0317e55909 --- /dev/null +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.cu @@ -0,0 +1,189 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/details/nan_inf_utils.h" +#include "paddle/fluid/framework/details/nan_inf_utils_detail.h" + +#include +#include +#include +#include + +namespace paddle { +namespace framework { +namespace details { + +static std::once_flag init_multi_gpu_op_var_map_flag; + +// lazy init +static std::vector>& +multi_op_var2gpu_str() { + static std::vector> + _multi_op_var2gpu_str; + return _multi_op_var2gpu_str; +} + +static std::vector& multi_op_var2gpu_str_mutex() { + static std::vector _multi_op_var2gpu_str_mutex; + return _multi_op_var2gpu_str_mutex; +} + +static void InitMultiGPUOpVarMap() { + int dev_count = platform::GetCUDADeviceCount(); + PADDLE_ENFORCE_GT(dev_count, 0, + platform::errors::NotFound( + "cuda device must > 0, now dev_count=%d", dev_count)); + + // https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex + std::vector> tmp_multi( + dev_count); + std::vector tmp_multi_mutex(dev_count); + + multi_op_var2gpu_str().swap(tmp_multi); + multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex); +} + +template +__device__ __forceinline__ void PrintNanInfKernel(const T* value, + const size_t numel, + int print_num, + char* debug_info) { + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + + __shared__ unsigned int nan_count, inf_count, num_count; + if (threadIdx.x == 0) nan_count = inf_count = num_count = 0; + __syncthreads; + + for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) { + unsigned int count = 0; + if (isnan(value[i])) { + count = atomicAdd(&nan_count, 1); + } else if (isinf(value[i])) { + count = atomicAdd(&inf_count, 1); + } else { + count = atomicAdd(&num_count, 1); + } + // for cuda, print in every block + if (count < print_num) { + printf("numel:%lu idx:%lu value:%f\n", static_cast(numel), + static_cast(i), static_cast(value[i])); + } + } + __syncthreads; + + if (true && threadIdx.x == 0) { + printf("In block %d, there has %u,%u,%u nan,inf,num\n", blockIdx.x, + nan_count, inf_count, num_count); + PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info); + } +} + +// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s +template +__global__ void CheckNanInfKernel(const T* value, const size_t numel, + int print_num, char* debug_info) { + /// step 1, judge wheater has nan or inf + __shared__ volatile int has_nan_inf; + if (threadIdx.x == 0) has_nan_inf = false; + __syncthreads(); + + const size_t tid = threadIdx.x + blockIdx.x * blockDim.x; + T sum = static_cast(0.0); + // Todo(wangxi). simd speed up + for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) { + sum += (value[i] - value[i]); + } + + if (isnan(sum) || isinf(sum)) has_nan_inf = true; + __syncthreads(); + + /// Note. different blocks may behave differently + if (!has_nan_inf) return; + + PrintNanInfKernel(value, numel, print_num, debug_info); +} + +template <> +template +void TensorCheckerVisitor::apply( + typename std::enable_if::value>::type*) const { + int print_num = 3; + + auto* dev_ctx = reinterpret_cast( + platform::DeviceContextPool::Instance().Get(tensor_.place())); + int dev_id = boost::get(tensor_.place()).device; + PADDLE_ENFORCE_EQ( + (dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true, + platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d", + multi_op_var2gpu_str_mutex().size())); + + std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + "]"; + char* gpu_str_ptr = NULL; + + { + auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id); + auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id); + + std::lock_guard guard(op_var2gpu_str_mutex); + if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert + auto gpu_str_tensor = + paddle::memory::Alloc(*dev_ctx, op_var.length() + 1); + gpu_str_ptr = reinterpret_cast(gpu_str_tensor->ptr()); + + op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor)); + + auto iter = op_var2gpu_str.find(op_var); + PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true, + platform::errors::PreconditionNotMet( + "op_var=%s should successed insert into " + "op_var2gpu_str, but now failed", + op_var)); + + PADDLE_ENFORCE_CUDA_SUCCESS( + cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1, + cudaMemcpyHostToDevice, dev_ctx->stream()), + platform::errors::External( + "Async cudaMemcpy op_var info to gpu failed.")); + } else { // get + auto iter = op_var2gpu_str.find(op_var); + PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true, + platform::errors::PreconditionNotMet( + "op_var=%s should be in the op_var2gpu_str, but " + "now can't find it", + op_var)); + gpu_str_ptr = reinterpret_cast(iter->second->ptr()); + } + } + + const size_t threads = 1024; + size_t blocks = std::min(128ul, (tensor_.numel() + threads - 1) / threads); + CheckNanInfKernel<<stream()>>>( + tensor_.data(), tensor_.numel(), print_num, gpu_str_ptr); +} + +template <> +void tensor_check(const std::string& op_type, + const std::string& var_name, + const framework::Tensor& tensor, + const platform::Place& place) { + std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap); + + TensorCheckerVisitor vistor(op_type, var_name, + tensor, place); + VisitDataType(tensor.type(), vistor); +} + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/details/nan_inf_utils_detail.h b/paddle/fluid/framework/details/nan_inf_utils_detail.h new file mode 100644 index 0000000000..15d00932f1 --- /dev/null +++ b/paddle/fluid/framework/details/nan_inf_utils_detail.h @@ -0,0 +1,59 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include + +#include "paddle/fluid/framework/scope.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace framework { +namespace details { + +template +struct TensorCheckerVisitor { + TensorCheckerVisitor(const std::string& op_type, const std::string& var_name, + const framework::Tensor& tensor, + const platform::Place& place) + : op_type_(op_type), + var_name_(var_name), + tensor_(tensor), + place_(place) {} + + template + void apply( + typename std::enable_if::value>::type* = 0) const { + VLOG(10) << var_name_ << " need not to check, it's type is not float point"; + } + + template + void apply(typename std::enable_if::value>::type* = + 0) const; + + std::string op_type_; + std::string var_name_; + const framework::Tensor& tensor_; + const platform::Place& place_; +}; + +template +void tensor_check(const std::string& op_type, const std::string& var_name, + const framework::Tensor& tensor, + const platform::Place& place); + +} // namespace details +} // namespace framework +} // namespace paddle diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 74b1f78a0a..fca7b32c99 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -21,6 +21,7 @@ limitations under the License. */ #include #include #include "paddle/fluid/framework/data_transform.h" +#include "paddle/fluid/framework/details/nan_inf_utils.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/op_call_stack.h" @@ -1012,16 +1013,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope, } if (FLAGS_check_nan_inf) { - for (auto& vname : OutputVars(true)) { - auto* var = exec_scope.FindVar(vname); - if (var == nullptr) continue; - if (var->IsType()) { - CheckTensorNANOrInf(type_, vname, var->Get()); - } else if (var->IsType()) { - CheckTensorNANOrInf(type_, vname, - var->Get().value()); - } - } + framework::details::CheckOpHasNanOrInf(*this, exec_scope, place); } // To solve issue #15032, have a discussion with @Luotao for cpu inference, diff --git a/python/paddle/fluid/tests/unittests/check_nan_inf_base.py b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py new file mode 100644 index 0000000000..6486a4d236 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/check_nan_inf_base.py @@ -0,0 +1,116 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import unicode_literals +from __future__ import print_function + +import os +import sys +import time +import numpy as np + +os.environ[str("FLAGS_check_nan_inf")] = str("1") +os.environ[str("GLOG_vmodule")] = str("nan_inf_utils_detail=10") + +import paddle.fluid.core as core +import paddle +import paddle.fluid as fluid +import paddle.compat as cpt + +np.random.seed(0) + + +def generator(): + batch_size = 5 + for i in range(5): + curr_train_x = np.random.randint( + batch_size, size=(batch_size, 3)).astype("float32") + if i >= 2: + curr_train_x[0, :] = np.nan + curr_train_x[-1, :] = np.inf + res = [] + for i in range(batch_size): + y = i % 3 + res.append([y]) + y_label = np.array(res).astype('int64') + yield [curr_train_x, y_label] + + +def net(): + x = fluid.layers.data(name="x", shape=[3], dtype='float32') + y = fluid.layers.data(name="y", shape=[1], dtype='int64') + + # test int64 value + zero = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0) + + # test float16 value + fp16_zero = fluid.layers.cast(zero, dtype='float16') + + y = y + zero + + hidden = x + + for i in range(2): + hidden = fluid.layers.fc(input=hidden, size=400, act="sigmoid") + + hidden = fluid.layers.fc(input=hidden, size=3, act=None) + cost, y_predict = fluid.layers.softmax_with_cross_entropy( + hidden, y, return_softmax=True) + acc_top1 = fluid.layers.accuracy(input=y_predict, label=y, k=1) + avg_cost = fluid.layers.mean(cost) + + sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.05) + sgd_optimizer.minimize(avg_cost) + return y_predict, avg_cost, acc_top1 + + +def check(use_cuda): + main = fluid.Program() + startup = fluid.Program() + scope = fluid.core.Scope() + + with fluid.scope_guard(scope): + with fluid.program_guard(main, startup): + y_predict, avg_cost, acc_top1 = net() + + place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace() + exe = fluid.Executor(place) + exe.run(startup) + + step = 0.0 + for train_data, y_label in generator(): + outs = exe.run( + main, + feed={'x': train_data, + 'y': y_label}, + fetch_list=[y_predict.name, avg_cost.name, acc_top1.name]) + step += 1 + print('iter={:.0f},cost={},acc1={}'.format(step, outs[1][0], + outs[2][0])) + + +if __name__ == '__main__': + if core.is_compiled_with_cuda(): + try: + check(use_cuda=True) + assert False + except Exception as e: + print(e) + assert type(e) == core.EnforceNotMet + try: + check(use_cuda=False) + assert False + except Exception as e: + print(e) + assert type(e) == core.EnforceNotMet diff --git a/python/paddle/fluid/tests/unittests/test_nan_inf.py b/python/paddle/fluid/tests/unittests/test_nan_inf.py new file mode 100644 index 0000000000..d4a971d25b --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_nan_inf.py @@ -0,0 +1,65 @@ +# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import unicode_literals +from __future__ import print_function + +import unittest +import os +import sys +import subprocess + + +class TestNanInf(unittest.TestCase): + def setUp(self): + self._python_interp = sys.executable + if os.getenv('WITH_COVERAGE', 'OFF') == 'ON': + self._python_interp += " -m coverage run --branch -p" + self._python_interp += " check_nan_inf_base.py" + + self.env = os.environ.copy() + + def test_nan_inf(self): + cmd = self._python_interp + + proc = subprocess.Popen( + cmd.split(" "), + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + env=self.env) + + out, err = proc.communicate() + returncode = proc.returncode + + print(out) + print(err) + + assert returncode == 0 + # in python3, type(out+err) is 'bytes', need use encode + assert (out + err).find('find nan or inf'.encode()) != -1 + + +class TestNanInfEnv(TestNanInf): + def setUp(self): + super(TestNanInfEnv, self).setUp() + # windows python have some bug with env, so need use str to pass ci + # otherwise, "TypeError: environment can only contain strings" + self.env[str("PADDLE_INF_NAN_SKIP_OP")] = str("mul") + self.env[str("PADDLE_INF_NAN_SKIP_ROLE")] = str("loss") + self.env[str("PADDLE_INF_NAN_SKIP_VAR")] = str( + "elementwise_add:fc_0.tmp_1") + + +if __name__ == '__main__': + unittest.main() -- GitLab