未验证 提交 6366cffe 编写于 作者: W wanghuancoder 提交者: GitHub

fix check nan bug (#52729)

上级 439551bd
...@@ -276,6 +276,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """ ...@@ -276,6 +276,8 @@ FORWARD_ONLY_FUNCTION_TEMPLATE = """
// Before log info // Before log info
{} {}
// Forward API Call // Forward API Call
{}
// Check NaN and Inf if needed
{} {}
// Get Outputs // Get Outputs
{} {}
...@@ -1675,6 +1677,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase): ...@@ -1675,6 +1677,7 @@ class DygraphForwardFunctionGenerator(DygraphFunctionGeneratorBase):
forward_api_name, forward_api_name,
before_log_str, before_log_str,
forward_call_str, forward_call_str,
check_nan_inf_str,
get_outputs_str, get_outputs_str,
forward_api_name, forward_api_name,
check_inplace_str, check_inplace_str,
......
...@@ -122,6 +122,11 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) { ...@@ -122,6 +122,11 @@ void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor) {
} }
} }
void CheckTensorHasNanOrInf(const std::string& api_name,
const paddle::optional<Tensor>& tensor) {
CheckTensorHasNanOrInf(api_name, tensor.get());
}
void CheckTensorHasNanOrInf(const std::string& api_name, void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors) { const TupleOfTwoTensors& tensors) {
CheckTensorHasNanOrInf(api_name, std::get<0>(tensors)); CheckTensorHasNanOrInf(api_name, std::get<0>(tensors));
...@@ -169,6 +174,14 @@ void CheckTensorHasNanOrInf(const std::string& api_name, ...@@ -169,6 +174,14 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
} }
} }
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::optional<std::vector<Tensor>>& tensors) {
if (tensors) {
CheckTensorHasNanOrInf(api_name, tensors.get());
}
}
void CheckTensorHasNanOrInf( void CheckTensorHasNanOrInf(
const std::string& api_name, const std::string& api_name,
const paddle::small_vector<std::vector<paddle::Tensor>, const paddle::small_vector<std::vector<paddle::Tensor>,
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/fluid/eager/type_defs.h" #include "paddle/fluid/eager/type_defs.h"
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/optional.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
namespace egr { namespace egr {
...@@ -36,6 +37,9 @@ using TupleOfTensorAndVector = ...@@ -36,6 +37,9 @@ using TupleOfTensorAndVector =
void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor); void CheckTensorHasNanOrInf(const std::string& api_name, const Tensor& tensor);
void CheckTensorHasNanOrInf(const std::string& api_name,
const paddle::optional<Tensor>& tensor);
void CheckTensorHasNanOrInf(const std::string& api_name, void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTwoTensors& tensors); const TupleOfTwoTensors& tensors);
...@@ -54,6 +58,10 @@ void CheckTensorHasNanOrInf(const std::string& api_name, ...@@ -54,6 +58,10 @@ void CheckTensorHasNanOrInf(const std::string& api_name,
void CheckTensorHasNanOrInf(const std::string& api_name, void CheckTensorHasNanOrInf(const std::string& api_name,
const std::vector<Tensor>& tensors); const std::vector<Tensor>& tensors);
void CheckTensorHasNanOrInf(
const std::string& api_name,
const paddle::optional<std::vector<Tensor>>& tensors);
void CheckTensorHasNanOrInf(const std::string& api_name, void CheckTensorHasNanOrInf(const std::string& api_name,
const TupleOfTensorAndVector& tensors); const TupleOfTensorAndVector& tensors);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册