提交 8a0f611b 编写于 作者: W WangXi 提交者: gongweibao

Rewrite check nan inf tools (#21076)

上级 019147eb
......@@ -133,7 +133,7 @@ cc_library(op_kernel_type SRCS op_kernel_type.cc DEPS device_context place)
cc_library(unused_var_check SRCS unused_var_check.cc DEPS glog no_need_buffer_vars_inference)
cc_library(operator SRCS operator.cc DEPS op_info device_context tensor scope glog trainer_desc_proto data_feed_proto
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check)
shape_inference data_transform lod_tensor profiler transfer_scope_cache op_kernel_type op_call_stack unused_var_check nan_inf_utils)
cc_test(operator_test SRCS operator_test.cc DEPS operator op_registry device_context)
cc_test(operator_exception_test SRCS operator_exception_test.cc DEPS operator op_registry device_context)
......
......@@ -22,6 +22,7 @@ endif()
if(WITH_GPU)
nv_library(nan_inf_utils SRCS nan_inf_utils_detail.cc nan_inf_utils_detail.cu DEPS framework_proto scope place)
nv_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
dynload_cuda variable_visitor)
nv_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
......@@ -43,6 +44,7 @@ if(WITH_GPU)
nv_library(fused_broadcast_op_handle SRCS fused_broadcast_op_handle.cc DEPS broadcast_op_handle)
else()
cc_library(nan_inf_utils SRCS nan_inf_utils_detail.cc DEPS framework_proto scope place)
cc_library(all_reduce_op_handle SRCS all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
variable_visitor)
cc_library(fused_all_reduce_op_handle SRCS fused_all_reduce_op_handle.cc DEPS op_handle_base scope lod_tensor ddim memory
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <string>
#include <vector>
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace details {
// assert false when meets NAN or inf
void CheckVarHasNanOrInf(const std::string& op_type,
const framework::Scope& scope,
const std::string& var_name,
const platform::Place& place);
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& scope,
const platform::Place& place);
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include <algorithm>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/selected_rows.h"
namespace paddle {
namespace framework {
namespace details {
static std::once_flag white_list_init_flag;
static int op_role_nan_inf_white_list = 0;
static constexpr int FORWARD = 0x10000;
// lazy init
static const std::unordered_map<std::string, int>& role_str2int() {
/* In op_proto_maker.h
* framework::OpRole::kForward = 0x0000,
* framework::OpRole::kBackward = 0x0001,
* framework::OpRole::kOptimize = 0x0002,
* framework::OpRole::kRPC = 0x0004,
* framework::OpRole::kDist = 0x0008,
* framework::OpRole::kLRSched = 0x0010,
* framework::OpRole::kLoss = 0x0100,
* framework::OpRole::kNotSpecified = 0x1000,
*/
static const std::unordered_map<std::string, int> _role_str2int = {
{"forward", FORWARD}, /* kForward=0, can't filter */
{"backward", static_cast<int>(framework::OpRole::kBackward)},
{"optimize", static_cast<int>(framework::OpRole::kOptimize)},
{"rpc", static_cast<int>(framework::OpRole::kRPC)},
{"dist", static_cast<int>(framework::OpRole::kDist)},
{"lrsched", static_cast<int>(framework::OpRole::kLRSched)},
{"loss", static_cast<int>(framework::OpRole::kLoss)},
{"default", static_cast<int>(framework::OpRole::kNotSpecified)},
};
return _role_str2int;
}
static std::unordered_set<std::string>& op_type_nan_inf_white_list() {
static std::unordered_set<std::string> _op_type_nan_inf_white_list = {
"coalesce_tensor", /* This Op will alloc tensor, and may not init space */
};
return _op_type_nan_inf_white_list;
}
static std::unordered_map<std::string, std::vector<std::string>>&
op_var_nan_inf_white_list() {
static std::unordered_map<std::string, std::vector<std::string>>
_op_var_nan_inf_white_list = {
/* encoded & gather var consist of idx&val, can't judge directly */
{"dgc", {"__dgc_encoded__", "__dgc_gather__"}},
};
return _op_var_nan_inf_white_list;
}
static void InitWhiteListFormEnv() {
// op_type_skip and op_var_skip may be NULL.
// So need init static value in there, prevent thread competition.
// NOTE. role_str2int needn't do this for it only used in this func.
op_type_nan_inf_white_list();
op_var_nan_inf_white_list();
// export PADDLE_INF_NAN_SKIP_OP="op0,op1,op2"
// export PADDLE_INF_NAN_SKIP_ROLE="role1,role2,role3"
// export PADDLE_INF_NAN_SKIP_VAR="op0:var0,op0:var1,op1:var0"
const char* op_type_skip = std::getenv("PADDLE_INF_NAN_SKIP_OP");
const char* op_role_skip = std::getenv("PADDLE_INF_NAN_SKIP_ROLE");
const char* op_var_skip = std::getenv("PADDLE_INF_NAN_SKIP_VAR");
if (op_type_skip != NULL) {
std::stringstream ss(op_type_skip);
std::string op_type;
while (std::getline(ss, op_type, ',')) {
op_type_nan_inf_white_list().emplace(op_type);
}
}
if (op_role_skip != NULL) {
std::stringstream ss(op_role_skip);
std::string op_role;
while (std::getline(ss, op_role, ',')) {
PADDLE_ENFORCE_EQ(role_str2int().find(op_role) != role_str2int().end(),
true,
platform::errors::InvalidArgument(
"Skip role must be one of "
"{forward,backward,optimize,rpc,dist,lrsched,loss,"
"default}, instead of %s",
op_role));
op_role_nan_inf_white_list |= role_str2int().at(op_role);
}
}
if (op_var_skip != NULL) {
std::stringstream ss(op_var_skip);
std::string op_var;
while (std::getline(ss, op_var, ',')) {
auto pos = op_var.find(":");
PADDLE_ENFORCE_EQ(
pos != std::string::npos, true,
platform::errors::InvalidArgument(
"Skip var format must be op:var, instead of %s", op_var));
std::string op = op_var.substr(0, pos);
std::string var = op_var.substr(pos + 1);
op_var_nan_inf_white_list()[op].emplace_back(var);
}
}
}
template <typename T>
static void PrintNanInf(const T* value, const size_t numel, int print_num,
const std::string& op_type,
const std::string& var_name) {
size_t nan_count, inf_count, num_count;
nan_count = inf_count = num_count = 0;
// CPU print num value
for (size_t i = 0; i < numel; ++i) {
size_t count = 0;
if (std::isnan(value[i])) {
count = nan_count++;
} else if (std::isinf(value[i])) {
count = inf_count++;
} else {
count = num_count++;
}
if (count < static_cast<size_t>(print_num)) {
printf("numel:%lu index:%lu value:%f\n", static_cast<uint64_t>(numel),
static_cast<uint64_t>(i), static_cast<float>(value[i]));
}
}
bool has_nan_inf = true;
printf("In cpu, there has %lu,%lu,%lu nan,inf,num\n",
static_cast<uint64_t>(nan_count), static_cast<uint64_t>(inf_count),
static_cast<uint64_t>(num_count));
PADDLE_ENFORCE_EQ(has_nan_inf, false,
platform::errors::PreconditionNotMet(
"===ERROR: in [op=%s] [tensor=%s] find nan or inf===",
op_type, var_name));
}
// openmp 4.0, reduction with fp16
#if defined _OPENMP && _OPENMP >= 201307
// more detail see: 180 page of
// https://www.openmp.org/wp-content/uploads/OpenMP4.0.0.pdf
#pragma omp declare reduction(+ : paddle::platform::float16 : omp_out += omp_in)
#endif
template <typename T>
static void CheckNanInf(const T* value, const size_t numel, int print_num,
const std::string& op_type,
const std::string& var_name) {
T sum = static_cast<T>(0.0);
#if defined _OPENMP && _OPENMP >= 201307
#pragma omp parallel for simd reduction(+ : sum)
#elif defined _OPENMP
#pragma omp parallel for reduction(+ : sum)
#endif
for (size_t i = 0; i < numel; ++i) {
sum += (value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#if defined _OPENMP && _OPENMP >= 201307
// openmp4.0 not need to specialization fp16
#elif defined _OPENMP
template <>
void CheckNanInf<paddle::platform::float16>(
const paddle::platform::float16* value, const size_t numel, int print_num,
const std::string& op_type, const std::string& var_name) {
float sum = 0.0f;
#pragma omp parallel for reduction(+ : sum)
for (size_t i = 0; i < numel; ++i) {
sum += static_cast<float>(value[i] - value[i]);
}
if (std::isnan(sum) || std::isinf(sum)) {
PrintNanInf(value, numel, print_num, op_type, var_name);
}
}
#endif
template <>
template <typename T>
void TensorCheckerVisitor<platform::CPUDeviceContext>::apply(
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
// use env strategy control in future, -1=print_all.
int print_num = 3;
CheckNanInf(tensor_.data<T>(), tensor_.numel(), print_num, op_type_,
var_name_);
}
template <>
void tensor_check<platform::CPUDeviceContext>(const std::string& op_type,
const std::string& var_name,
const framework::Tensor& tensor,
const platform::Place& place) {
TensorCheckerVisitor<platform::CPUDeviceContext> vistor(op_type, var_name,
tensor, place);
VisitDataType(tensor.type(), vistor);
}
void CheckVarHasNanOrInf(const std::string& op_type,
const framework::Scope& scope,
const std::string& var_name,
const platform::Place& place) {
auto* var = scope.FindVar(var_name);
PADDLE_ENFORCE_NOT_NULL(
var, platform::errors::NotFound("In op=%s, can't find var:%s", op_type,
var_name));
const Tensor* tensor{nullptr};
if (var->IsType<framework::LoDTensor>()) {
tensor = &var->Get<framework::LoDTensor>();
} else if (var->IsType<framework::SelectedRows>()) {
tensor = &var->Get<framework::SelectedRows>().value();
} else {
VLOG(10) << var_name << " var_name need not to check";
return;
}
if (tensor->memory_size() == 0) {
VLOG(10) << var_name << " var_name need not to check, size == 0";
return;
}
VLOG(10) << "begin check " << op_type << " var_name:" << var_name
<< ", place:" << tensor->place() << ", numel:" << tensor->numel();
if (platform::is_gpu_place(tensor->place())) {
#ifdef PADDLE_WITH_CUDA
tensor_check<platform::CUDADeviceContext>(op_type, var_name, *tensor,
place);
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.",
var_name));
#endif
return;
}
tensor_check<platform::CPUDeviceContext>(op_type, var_name, *tensor, place);
}
bool IsSkipOp(const framework::OperatorBase& op) {
if (op_type_nan_inf_white_list().count(op.Type()) != 0) return true;
int op_role = op.template Attr<int>(
framework::OpProtoAndCheckerMaker::OpRoleAttrName());
// kForward=0, can't filter
if (op_role == static_cast<int>(framework::OpRole::kForward)) {
op_role = FORWARD;
}
if (op_role_nan_inf_white_list & op_role) return true;
return false;
}
void CheckOpHasNanOrInf(const framework::OperatorBase& op,
const framework::Scope& exec_scope,
const platform::Place& place) {
std::call_once(white_list_init_flag, InitWhiteListFormEnv);
if (IsSkipOp(op)) return;
if (op_var_nan_inf_white_list().count(op.Type()) == 0) {
// NOTE. vname may destruct in the end of this func.
for (auto& vname : op.OutputVars(true)) {
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
CheckVarHasNanOrInf(op.Type(), exec_scope, vname, place);
}
} else {
for (auto& vname : op.OutputVars(true)) {
bool need_check = true;
for (auto& white_vname : op_var_nan_inf_white_list().at(op.Type())) {
if (vname.find(white_vname) != std::string::npos) {
need_check = false;
break;
}
}
if (!need_check) continue;
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
CheckVarHasNanOrInf(op.Type(), exec_scope, vname, place);
}
}
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include <algorithm>
#include <unordered_map>
#include <utility>
#include <vector>
namespace paddle {
namespace framework {
namespace details {
static std::once_flag init_multi_gpu_op_var_map_flag;
// lazy init
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>&
multi_op_var2gpu_str() {
static std::vector<std::unordered_map<std::string, memory::AllocationPtr>>
_multi_op_var2gpu_str;
return _multi_op_var2gpu_str;
}
static std::vector<std::mutex>& multi_op_var2gpu_str_mutex() {
static std::vector<std::mutex> _multi_op_var2gpu_str_mutex;
return _multi_op_var2gpu_str_mutex;
}
static void InitMultiGPUOpVarMap() {
int dev_count = platform::GetCUDADeviceCount();
PADDLE_ENFORCE_GT(dev_count, 0,
platform::errors::NotFound(
"cuda device must > 0, now dev_count=%d", dev_count));
// https://stackoverflow.com/questions/16465633/how-can-i-use-something-like-stdvectorstdmutex
std::vector<std::unordered_map<std::string, memory::AllocationPtr>> tmp_multi(
dev_count);
std::vector<std::mutex> tmp_multi_mutex(dev_count);
multi_op_var2gpu_str().swap(tmp_multi);
multi_op_var2gpu_str_mutex().swap(tmp_multi_mutex);
}
template <typename T>
__device__ __forceinline__ void PrintNanInfKernel(const T* value,
const size_t numel,
int print_num,
char* debug_info) {
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
__shared__ unsigned int nan_count, inf_count, num_count;
if (threadIdx.x == 0) nan_count = inf_count = num_count = 0;
__syncthreads;
for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
unsigned int count = 0;
if (isnan(value[i])) {
count = atomicAdd(&nan_count, 1);
} else if (isinf(value[i])) {
count = atomicAdd(&inf_count, 1);
} else {
count = atomicAdd(&num_count, 1);
}
// for cuda, print in every block
if (count < print_num) {
printf("numel:%lu idx:%lu value:%f\n", static_cast<uint64_t>(numel),
static_cast<uint64_t>(i), static_cast<float>(value[i]));
}
}
__syncthreads;
if (true && threadIdx.x == 0) {
printf("In block %d, there has %u,%u,%u nan,inf,num\n", blockIdx.x,
nan_count, inf_count, num_count);
PADDLE_ENFORCE(false, "===ERROR: in %s find nan or inf===", debug_info);
}
}
// Resnet 2gpus speed test, no check 270 images/s, this check 229 images/s
template <typename T>
__global__ void CheckNanInfKernel(const T* value, const size_t numel,
int print_num, char* debug_info) {
/// step 1, judge wheater has nan or inf
__shared__ volatile int has_nan_inf;
if (threadIdx.x == 0) has_nan_inf = false;
__syncthreads();
const size_t tid = threadIdx.x + blockIdx.x * blockDim.x;
T sum = static_cast<T>(0.0);
// Todo(wangxi). simd speed up
for (size_t i = tid; i < numel; i += blockDim.x * gridDim.x) {
sum += (value[i] - value[i]);
}
if (isnan(sum) || isinf(sum)) has_nan_inf = true;
__syncthreads();
/// Note. different blocks may behave differently
if (!has_nan_inf) return;
PrintNanInfKernel(value, numel, print_num, debug_info);
}
template <>
template <typename T>
void TensorCheckerVisitor<platform::CUDADeviceContext>::apply(
typename std::enable_if<std::is_floating_point<T>::value>::type*) const {
int print_num = 3;
auto* dev_ctx = reinterpret_cast<platform::CUDADeviceContext*>(
platform::DeviceContextPool::Instance().Get(tensor_.place()));
int dev_id = boost::get<platform::CUDAPlace>(tensor_.place()).device;
PADDLE_ENFORCE_EQ(
(dev_id >= 0 && dev_id < multi_op_var2gpu_str_mutex().size()), true,
platform::errors::OutOfRange("GPU dev_id must >=0 and < dev_count=%d",
multi_op_var2gpu_str_mutex().size()));
std::string op_var = "[op=" + op_type_ + "] [tensor=" + var_name_ + "]";
char* gpu_str_ptr = NULL;
{
auto& op_var2gpu_str_mutex = multi_op_var2gpu_str_mutex().at(dev_id);
auto& op_var2gpu_str = multi_op_var2gpu_str().at(dev_id);
std::lock_guard<std::mutex> guard(op_var2gpu_str_mutex);
if (op_var2gpu_str.find(op_var) == op_var2gpu_str.end()) { // insert
auto gpu_str_tensor =
paddle::memory::Alloc(*dev_ctx, op_var.length() + 1);
gpu_str_ptr = reinterpret_cast<char*>(gpu_str_tensor->ptr());
op_var2gpu_str.emplace(op_var, std::move(gpu_str_tensor));
auto iter = op_var2gpu_str.find(op_var);
PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true,
platform::errors::PreconditionNotMet(
"op_var=%s should successed insert into "
"op_var2gpu_str, but now failed",
op_var));
PADDLE_ENFORCE_CUDA_SUCCESS(
cudaMemcpyAsync(gpu_str_ptr, iter->first.c_str(), op_var.length() + 1,
cudaMemcpyHostToDevice, dev_ctx->stream()),
platform::errors::External(
"Async cudaMemcpy op_var info to gpu failed."));
} else { // get
auto iter = op_var2gpu_str.find(op_var);
PADDLE_ENFORCE_EQ(iter != op_var2gpu_str.end(), true,
platform::errors::PreconditionNotMet(
"op_var=%s should be in the op_var2gpu_str, but "
"now can't find it",
op_var));
gpu_str_ptr = reinterpret_cast<char*>(iter->second->ptr());
}
}
const size_t threads = 1024;
size_t blocks = std::min(128ul, (tensor_.numel() + threads - 1) / threads);
CheckNanInfKernel<<<blocks, threads, 0, dev_ctx->stream()>>>(
tensor_.data<T>(), tensor_.numel(), print_num, gpu_str_ptr);
}
template <>
void tensor_check<platform::CUDADeviceContext>(const std::string& op_type,
const std::string& var_name,
const framework::Tensor& tensor,
const platform::Place& place) {
std::call_once(init_multi_gpu_op_var_map_flag, InitMultiGPUOpVarMap);
TensorCheckerVisitor<platform::CUDADeviceContext> vistor(op_type, var_name,
tensor, place);
VisitDataType(tensor.type(), vistor);
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace details {
template <typename DeviceContext>
struct TensorCheckerVisitor {
TensorCheckerVisitor(const std::string& op_type, const std::string& var_name,
const framework::Tensor& tensor,
const platform::Place& place)
: op_type_(op_type),
var_name_(var_name),
tensor_(tensor),
place_(place) {}
template <typename T>
void apply(
typename std::enable_if<std::is_integral<T>::value>::type* = 0) const {
VLOG(10) << var_name_ << " need not to check, it's type is not float point";
}
template <typename T>
void apply(typename std::enable_if<std::is_floating_point<T>::value>::type* =
0) const;
std::string op_type_;
std::string var_name_;
const framework::Tensor& tensor_;
const platform::Place& place_;
};
template <typename DeviceContext>
void tensor_check(const std::string& op_type, const std::string& var_name,
const framework::Tensor& tensor,
const platform::Place& place);
} // namespace details
} // namespace framework
} // namespace paddle
......@@ -21,6 +21,7 @@ limitations under the License. */
#include <unordered_set>
#include <vector>
#include "paddle/fluid/framework/data_transform.h"
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/op_call_stack.h"
......@@ -1012,16 +1013,7 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
}
if (FLAGS_check_nan_inf) {
for (auto& vname : OutputVars(true)) {
auto* var = exec_scope.FindVar(vname);
if (var == nullptr) continue;
if (var->IsType<framework::LoDTensor>()) {
CheckTensorNANOrInf(type_, vname, var->Get<framework::LoDTensor>());
} else if (var->IsType<framework::SelectedRows>()) {
CheckTensorNANOrInf(type_, vname,
var->Get<framework::SelectedRows>().value());
}
}
framework::details::CheckOpHasNanOrInf(*this, exec_scope, place);
}
// To solve issue #15032, have a discussion with @Luotao for cpu inference,
......
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import unicode_literals
from __future__ import print_function
import os
import sys
import time
import numpy as np
os.environ[str("FLAGS_check_nan_inf")] = str("1")
os.environ[str("GLOG_vmodule")] = str("nan_inf_utils_detail=10")
import paddle.fluid.core as core
import paddle
import paddle.fluid as fluid
import paddle.compat as cpt
np.random.seed(0)
def generator():
batch_size = 5
for i in range(5):
curr_train_x = np.random.randint(
batch_size, size=(batch_size, 3)).astype("float32")
if i >= 2:
curr_train_x[0, :] = np.nan
curr_train_x[-1, :] = np.inf
res = []
for i in range(batch_size):
y = i % 3
res.append([y])
y_label = np.array(res).astype('int64')
yield [curr_train_x, y_label]
def net():
x = fluid.layers.data(name="x", shape=[3], dtype='float32')
y = fluid.layers.data(name="y", shape=[1], dtype='int64')
# test int64 value
zero = fluid.layers.fill_constant(shape=[1], dtype='int64', value=0)
# test float16 value
fp16_zero = fluid.layers.cast(zero, dtype='float16')
y = y + zero
hidden = x
for i in range(2):
hidden = fluid.layers.fc(input=hidden, size=400, act="sigmoid")
hidden = fluid.layers.fc(input=hidden, size=3, act=None)
cost, y_predict = fluid.layers.softmax_with_cross_entropy(
hidden, y, return_softmax=True)
acc_top1 = fluid.layers.accuracy(input=y_predict, label=y, k=1)
avg_cost = fluid.layers.mean(cost)
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.05)
sgd_optimizer.minimize(avg_cost)
return y_predict, avg_cost, acc_top1
def check(use_cuda):
main = fluid.Program()
startup = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(main, startup):
y_predict, avg_cost, acc_top1 = net()
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
exe.run(startup)
step = 0.0
for train_data, y_label in generator():
outs = exe.run(
main,
feed={'x': train_data,
'y': y_label},
fetch_list=[y_predict.name, avg_cost.name, acc_top1.name])
step += 1
print('iter={:.0f},cost={},acc1={}'.format(step, outs[1][0],
outs[2][0]))
if __name__ == '__main__':
if core.is_compiled_with_cuda():
try:
check(use_cuda=True)
assert False
except Exception as e:
print(e)
assert type(e) == core.EnforceNotMet
try:
check(use_cuda=False)
assert False
except Exception as e:
print(e)
assert type(e) == core.EnforceNotMet
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import unicode_literals
from __future__ import print_function
import unittest
import os
import sys
import subprocess
class TestNanInf(unittest.TestCase):
def setUp(self):
self._python_interp = sys.executable
if os.getenv('WITH_COVERAGE', 'OFF') == 'ON':
self._python_interp += " -m coverage run --branch -p"
self._python_interp += " check_nan_inf_base.py"
self.env = os.environ.copy()
def test_nan_inf(self):
cmd = self._python_interp
proc = subprocess.Popen(
cmd.split(" "),
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
env=self.env)
out, err = proc.communicate()
returncode = proc.returncode
print(out)
print(err)
assert returncode == 0
# in python3, type(out+err) is 'bytes', need use encode
assert (out + err).find('find nan or inf'.encode()) != -1
class TestNanInfEnv(TestNanInf):
def setUp(self):
super(TestNanInfEnv, self).setUp()
# windows python have some bug with env, so need use str to pass ci
# otherwise, "TypeError: environment can only contain strings"
self.env[str("PADDLE_INF_NAN_SKIP_OP")] = str("mul")
self.env[str("PADDLE_INF_NAN_SKIP_ROLE")] = str("loss")
self.env[str("PADDLE_INF_NAN_SKIP_VAR")] = str(
"elementwise_add:fc_0.tmp_1")
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册