未验证 提交 a51c492c 编写于 作者: C Chen Weihang 提交者: GitHub

[Eager] Add nan and inf check utils (#42763)

* add nan_inf_utils for eager

* support check nan and inf

* add unittest for coverage
上级 9b15efce
set(eager_deps phi_api phi_dygraph_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta grad_node_info grad_tensor_holder accumulation_node custom_operator_node)
set(eager_deps phi_api phi_dygraph_api hook_utils tensor_utils utils global_utils backward phi_tensor tracer layer autograd_meta eager_nan_inf_utils grad_node_info grad_tensor_holder accumulation_node custom_operator_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps final_dygraph_function final_dygraph_node dygraph_function dygraph_node)
......@@ -18,6 +18,7 @@ if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info switch_autotune)
endif()
cc_library(eager_nan_inf_utils SRCS nan_inf_utils.cc DEPS phi_tensor nan_inf_utils enforce)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS phi_api phi_tensor)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS phi_api phi_tensor)
......
......@@ -17,11 +17,12 @@
#include <atomic>
#include <memory>
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/phi/api/ext/op_meta_info.h"
#include "paddle/utils/small_vector.h"
namespace egr {
constexpr size_t kSlotSmallVectorSize = 15U;
class UniqueNameGenerator {
public:
explicit UniqueNameGenerator(std::string prefix = "") : prefix_(prefix) {}
......
......@@ -147,6 +147,8 @@ paddle::small_vector<std::vector<paddle::experimental::Tensor>, egr::kSlotSmallV
// Call grad_api function
VLOG(3) << \"Final State Running: {}\";
{}
// Check NaN and Inf id needed
{}
// Get GradIn autograd_meta
{}
......@@ -172,6 +174,8 @@ FORWARD_FUNCTION_TEMPLATE = \
{}
// Forward API Call
VLOG(3) << \"Final State Running: \" << \"{}\";
{}
// Check NaN and Inf if needed
{}
// Get Outputs
{}
......@@ -232,9 +236,11 @@ NODE_CC_FILE_TEMPLATE = \
#include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/api/generated/eager_generated/backwards/nodes.h"
#include "paddle/fluid/eager/to_static/run_program_op_node.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/phi/api/include/sparse_api.h"
DECLARE_bool(check_nan_inf);
{}
"""
......@@ -259,7 +265,9 @@ FORWARD_CC_FILE_TEMPLATE = \
#include "paddle/fluid/eager/amp_utils.h"
#include "paddle/fluid/eager/eager_amp_auto_cast.h"
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
DECLARE_bool(check_nan_inf);
{}
{}
"""
......@@ -339,6 +347,10 @@ CREATE_RECOVER_OPTIONAL_TENSOR_TEMPLATE = \
if( {}.impl() ) {}_optional = paddle::make_optional<const paddle::experimental::Tensor&>({});
"""
CHECK_NAN_AND_INF_TEMPLATE = \
""" if (FLAGS_check_nan_inf) {{ egr::CheckTensorHasNanOrInf("{}", {}); }}
"""
#######################
## Generator Helpers ##
......@@ -909,6 +921,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
num_outputs = len(forward_outputs_position_map.keys()) - len(
intermediate_outputs)
# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(function_name,
"api_result")
# Get Outputs
get_outputs_str = ""
for name, (rtype, pos) in forward_outputs_position_map.items():
......@@ -1032,10 +1048,10 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
self.forward_definition_str += FORWARD_FUNCTION_TEMPLATE.format(
returns_type_str, forward_function_name, inputs_args_definition_str,
dygraph_event_str, amp_logic_str, inputs_autograd_meta_str,
forward_function_name, forward_call_str, get_outputs_str,
outputs_autograd_meta_str, compute_require_grad_args_str,
check_inplace_str, bump_inplace_version_str, node_creation_str,
returns_str)
forward_function_name, forward_call_str, check_nan_inf_str,
get_outputs_str, outputs_autograd_meta_str,
compute_require_grad_args_str, check_inplace_str,
bump_inplace_version_str, node_creation_str, returns_str)
self.forward_declaration_str += f"{returns_type_str} {forward_function_name}({inputs_args_declaration_str});\n"
def GenerateInplacedForwardDygraphFunctions(self):
......@@ -1338,6 +1354,10 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
grad_function_call_str = grad_function_call_str + f"{indent}{grad_api_namespace}{backward_api_name}({grad_api_args_str});"
# Check Nan and Inf
check_nan_inf_str = CHECK_NAN_AND_INF_TEMPLATE.format(backward_api_name,
"returns")
# Prepare for Node Creation if Necessary
inputs_autograd_meta_str = ""
outputs_autograd_meta_str = ""
......@@ -1425,7 +1445,7 @@ class DygraphNodeGenerator(DygraphFunctionGeneratorBase):
self.node_definition_str = GRAD_FUNCTION_TEMPLATE.format(
grad_node_name, fill_zero_str, get_grad_in_args_str, grad_node_name,
grad_function_call_str, inputs_autograd_meta_str,
grad_function_call_str, check_nan_inf_str, inputs_autograd_meta_str,
outputs_autograd_meta_str, compute_require_grad_str,
grad_node_creation_str, returns_str)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/phi/core/dense_tensor.h"
#include "paddle/phi/core/selected_rows.h"
namespace egr {
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) {
if (tensor.initialized()) {
auto& tensor_name = tensor.name();
const phi::DenseTensor* dense_tensor{nullptr};
if (tensor.is_dense_tensor()) {
dense_tensor = static_cast<const phi::DenseTensor*>(tensor.impl().get());
} else if (tensor.is_selected_rows()) {
dense_tensor = &(
static_cast<const phi::SelectedRows*>(tensor.impl().get())->value());
} else {
VLOG(10) << "Only DenseTensor or SelectedRows need to check, "
<< tensor_name << " is no need.";
return;
}
auto& place = dense_tensor->place();
if (paddle::platform::is_gpu_place(place)) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
paddle::framework::details::tensor_check<
paddle::platform::CUDADeviceContext>(api_name, tensor_name,
*dense_tensor, place);
#else
PADDLE_THROW(paddle::platform::errors::PreconditionNotMet(
"Tensor[%s] use gpu place. PaddlePaddle must compile with GPU.",
tensor_name));
#endif
return;
}
paddle::framework::details::tensor_check<
paddle::platform::CPUDeviceContext>(api_name, tensor_name,
*dense_tensor, place);
}
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfThreeTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFourTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<3>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFiveTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<3>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<4>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfSixTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<1>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<2>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<3>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<4>(tensors));
CheckTensorHasNanOrInf(api_name, std::get<5>(tensors));
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors) {
for (auto& tensor : tensors) {
CheckTensorHasNanOrInf(api_name, tensor);
}
}
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>& tensors) {
for (auto& tensor_vector : tensors) {
CheckTensorHasNanOrInf(api_name, tensor_vector);
}
}
} // namespace egr
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <tuple>
#include <vector>
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/small_vector.h"
namespace egr {
using paddle::experimental::Tensor;
using TupleOfTwoTensors = std::tuple<Tensor, Tensor>;
using TupleOfThreeTensors = std::tuple<Tensor, Tensor, Tensor>;
using TupleOfFourTensors = std::tuple<Tensor, Tensor, Tensor, Tensor>;
using TupleOfFiveTensors = std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor>;
using TupleOfSixTensors =
std::tuple<Tensor, Tensor, Tensor, Tensor, Tensor, Tensor>;
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfThreeTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFourTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfFiveTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfSixTensors& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors);
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>& tensors);
} // namespace egr
cc_test(test_egr_task_tensor_utils SRCS tensor_utils_test.cc DEPS ${eager_deps})
cc_test(test_egr_task_eager_utils SRCS eager_utils_test.cc DEPS ${eager_deps})
cc_test(test_egr_task_forward_autograd SRCS forward_autograd_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_nan_inf_utils SRCS nan_inf_utils_test.cc DEPS eager_nan_inf_utils)
if(NOT ((NOT WITH_PYTHON) AND ON_INFER))
cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps} eager_scale scale_node)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <limits>
#include <tuple>
#include "gtest/gtest.h"
#include "paddle/fluid/eager/nan_inf_utils.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/phi/api/include/api.h"
#include "paddle/phi/api/include/strings_api.h"
#include "paddle/phi/core/kernel_registry.h"
PD_DECLARE_KERNEL(full, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(strings_empty, CPU, ALL_LAYOUT);
namespace egr {
#define CHECK_NAN_INF(tensors) \
{ \
bool caught_exception = false; \
try { \
CheckTensorHasNanOrInf("nan_inf_test", tensors); \
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \
std::string::npos); \
} \
EXPECT_TRUE(caught_exception); \
}
#define CHECK_NO_NAN_INF(tensors) \
{ \
bool caught_exception = false; \
try { \
CheckTensorHasNanOrInf("nan_inf_test", tensors); \
} catch (paddle::platform::EnforceNotMet & error) { \
caught_exception = true; \
std::string ex_msg = error.what(); \
EXPECT_TRUE(ex_msg.find("There are `nan` or `inf` in tensor") != \
std::string::npos); \
} \
EXPECT_FALSE(caught_exception); \
}
TEST(NanInfUtils, Functions) {
// test all methods
auto tensor = paddle::experimental::full(
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
CHECK_NAN_INF(tensor);
auto tensor1 = paddle::experimental::full(
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
auto two_tensors = std::make_tuple(tensor, tensor1);
CHECK_NAN_INF(two_tensors);
auto tensor2 = paddle::experimental::full(
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
auto three_tensors = std::make_tuple(tensor, tensor1, tensor2);
CHECK_NAN_INF(three_tensors);
auto tensor3 = paddle::experimental::full(
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
auto four_tensors = std::make_tuple(tensor, tensor1, tensor2, tensor3);
CHECK_NAN_INF(four_tensors);
auto tensor4 = paddle::experimental::full(
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
auto five_tensors =
std::make_tuple(tensor, tensor1, tensor2, tensor3, tensor4);
CHECK_NAN_INF(five_tensors);
auto tensor5 = paddle::experimental::full(
{3, 4}, std::numeric_limits<double>::quiet_NaN(), phi::DataType::FLOAT64);
auto six_tensors =
std::make_tuple(tensor, tensor1, tensor2, tensor3, tensor4, tensor5);
CHECK_NAN_INF(six_tensors);
std::vector<paddle::experimental::Tensor> tensor_vec;
tensor_vec.emplace_back(tensor);
tensor_vec.emplace_back(tensor1);
CHECK_NAN_INF(tensor_vec);
paddle::small_vector<std::vector<paddle::experimental::Tensor>,
egr::kSlotSmallVectorSize>
small_vec;
small_vec.emplace_back(tensor_vec);
CHECK_NAN_INF(small_vec);
// test selected_rows
paddle::experimental::Tensor tensor_sr;
auto sr = std::make_shared<phi::SelectedRows>();
*sr->mutable_value() =
*(static_cast<const phi::DenseTensor*>(tensor.impl().get()));
tensor_sr.set_impl(sr);
CHECK_NAN_INF(tensor_sr);
// test other tensor
auto tensor_str = paddle::experimental::strings::empty({3, 4});
CHECK_NO_NAN_INF(tensor_str);
}
} // namespace egr
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
namespace egr {
constexpr size_t kSlotSmallVectorSize = 15U;
} // namespace egr
......@@ -15,6 +15,7 @@
#include "paddle/fluid/framework/details/nan_inf_utils.h"
#include "paddle/fluid/framework/details/nan_inf_utils_detail.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/scope.h"
#ifdef PADDLE_WITH_ASCEND_CL
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
......@@ -19,7 +19,9 @@
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/scope.h"
namespace paddle {
namespace framework {
......
......@@ -16,7 +16,7 @@
#include <string>
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
namespace phi {
......
......@@ -25,8 +25,7 @@ os.environ[str("GLOG_vmodule")] = str("nan_inf_utils_detail=10")
import paddle
import paddle.nn as nn
from paddle.fluid.framework import _enable_legacy_dygraph
_enable_legacy_dygraph()
from paddle.fluid.framework import _test_eager_guard
np.random.seed(0)
......@@ -94,7 +93,7 @@ def check(use_cuda):
sgd.clear_grad()
if __name__ == '__main__':
def run_check():
if paddle.is_compiled_with_cuda():
try:
check(use_cuda=True)
......@@ -112,3 +111,9 @@ if __name__ == '__main__':
print(e)
print(type(e))
assert type(e) == RuntimeError
if __name__ == '__main__':
with _test_eager_guard():
run_check()
run_check()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册