未验证 提交 6366cffe 编写于 作者: W wanghuancoder 提交者: GitHub

fix check nan bug (#52729)

上级 439551bd
......@@ -276,6 +276,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """
// Before log info
{}
// Forward API Call
{}
// Check NaN and Inf if needed
{}
// Get Outputs
{}
......@@ -1675,6 +1677,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
forward_api_name,
before_log_str,
forward_call_str,
check_nan_inf_str,
get_outputs_str,
forward_api_name,
check_inplace_str,
......
......@@ -122,6 +122,11 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) {
}
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const paddle::optional<Tensor>& tensor) {
CheckTensorHasNanOrInf(api_name, tensor.get());
}
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
......@@ -169,6 +174,14 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
}
}
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::optional<std::vector<Tensor>>& tensors) {
if (tensors) {
CheckTensorHasNanOrInf(api_name, tensors.get());
}
}
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::small_vector<std::vector<paddle::Tensor>,
......
......@@ -20,6 +20,7 @@
#include "paddle/fluid/eager/type_defs.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h"
namespace egr {
......@@ -36,6 +37,9 @@ using TupleOfTensorAndVector =
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
void CheckTensorHasNanOrInf(const std::string& api_name,
const paddle::optional<Tensor>& tensor);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors);
......@@ -54,6 +58,10 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors);
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::optional<std::vector<Tensor>>& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册