未验证 提交 d50ae7ec 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Enabled AutoCodeGen for Eager Dygraph (#37639)

上级 87b97776
generated/**
autocodegen/generated_example/
\ No newline at end of file
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor legacy autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
set(generated_deps dygraph_function dygraph_node)
if(NOT DEFINED ON_INFER)
message("Performing Eager Dygraph Auto Code Generation")
add_subdirectory(auto_code_generator)
endif()
add_subdirectory(api) add_subdirectory(api)
add_subdirectory(accumulation) add_subdirectory(accumulation)
add_subdirectory(tests) add_subdirectory(legacy)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api) cc_library(grad_node_info SRCS grad_node_info.cc DEPS pten pten_api)
cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation) cc_library(grad_tensor_holder SRCS grad_tensor_holder.cc DEPS grad_node_info gradient_accumulation)
cc_library(autograd_meta SRCS autograd_meta.cc DEPS pten pten_api)
cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta) cc_library(utils SRCS utils.cc DEPS pten pten_api global_utils layer proto_desc operator op_registry variable_helper memcpy scale_op autograd_meta)
cc_library(legacy SRCS ${DYGRAPH_LEGACY} DEPS global_utils proto_desc operator pten pten_api op_registry variable_helper memcpy)
cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info) cc_library(backward SRCS backward.cc DEPS grad_tensor_holder utils autograd_meta grad_node_info)
add_subdirectory(tests)
add_subdirectory(eager_generated) add_subdirectory(eager_generated)
if(NOT DEFINED ON_INFER)
add_subdirectory(fluid_generated)
endif()
...@@ -6,13 +6,24 @@ target_link_libraries(eager_generator ${EAGER_GENERETOR_DEPS}) ...@@ -6,13 +6,24 @@ target_link_libraries(eager_generator ${EAGER_GENERETOR_DEPS})
get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES) get_property (os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(eager_generator ${os_dependency_modules}) target_link_libraries(eager_generator ${os_dependency_modules})
if(WITH_ROCM)
target_link_libraries(eager_generator ${ROCM_HIPRTC_LIB})
endif()
# Prepare file structure # Prepare file structure
message("Generate dygraph file structure at path: ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/generated") message("Generate dygraph file structure at path: ${PADDLE_SOURCE_DIR}/paddle/fluid/eager/generated")
execute_process( execute_process(
COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generate_file_structures.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/" COMMAND "${PYTHON_EXECUTABLE}" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/auto_code_generator/generate_file_structures.py" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/"
) )
add_custom_target(eager_codegen if(WIN32)
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated" add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator.exe" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
DEPENDS eager_generator DEPENDS eager_generator
VERBATIM) VERBATIM)
else()
add_custom_target(eager_codegen
COMMAND "${CMAKE_CURRENT_BINARY_DIR}/eager_generator" "${PADDLE_SOURCE_DIR}/paddle/fluid/eager/api/generated/fluid_generated"
DEPENDS eager_generator
VERBATIM)
endif()
...@@ -577,11 +577,6 @@ static std::string GenerateGradNodeCreationContent( ...@@ -577,11 +577,6 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable, // If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out = // then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")" // egr::EagerUtils::autograd_meta("op_proto->outputs()[0].name()")"
// TODO(zhanlve): in case of multiple slotname but none of which are
// duplicable,
// avoid constructing vector<AutogradMeta*>, generate seperate
// AutogradMeta* objects respectively.
std::string get_autograd_meta_str = " // Prepare Autograd Meta \n"; std::string get_autograd_meta_str = " // Prepare Autograd Meta \n";
for (const proto::OpProto::Var& input : op_proto.inputs()) { for (const proto::OpProto::Var& input : op_proto.inputs()) {
const std::string& input_name = input.name(); const std::string& input_name = input.name();
...@@ -607,11 +602,6 @@ static std::string GenerateGradNodeCreationContent( ...@@ -607,11 +602,6 @@ static std::string GenerateGradNodeCreationContent(
// If single output slotname and not duplicable, // If single output slotname and not duplicable,
// then generate: "egr::AutogradMeta* p_autograd_out = // then generate: "egr::AutogradMeta* p_autograd_out =
// egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")" // egr::EagerUtils::autograd_meta("op_proto.outputs()[0].name()")"
// TODO(zhanlve): in case of multiple slotname but none of which are
// duplicable,
// avoid constructing vector<AutogradMeta*>, generate seperate
// AutogradMeta* objects respectively.
for (const proto::OpProto::Var& output : op_proto.outputs()) { for (const proto::OpProto::Var& output : op_proto.outputs()) {
const std::string& output_name = output.name(); const std::string& output_name = output.name();
const std::string& output_autograd_name = "p_autograd_" + output_name; const std::string& output_autograd_name = "p_autograd_" + output_name;
...@@ -725,9 +715,9 @@ static std::string GenerateGradNodeCreationContent( ...@@ -725,9 +715,9 @@ static std::string GenerateGradNodeCreationContent(
// [Generation] GradNode Creation // [Generation] GradNode Creation
const char* GRAD_NODE_CREATION_TEMPLATE = const char* GRAD_NODE_CREATION_TEMPLATE =
" %s" " %s"
" bool require_any_grad = egr::ComputeRequireGrad(%s);\n" " bool require_any_grad = egr::EagerUtils::ComputeRequireGrad(%s);\n"
" if(require_any_grad) {\n" " if(require_any_grad) {\n"
" egr::PassStopGradient(%s);\n" " egr::EagerUtils::PassStopGradient(%s);\n"
"%s\n }"; "%s\n }";
std::string grad_node_creation_body_str = paddle::string::Sprintf( std::string grad_node_creation_body_str = paddle::string::Sprintf(
GRAD_NODE_CREATION_TEMPLATE, prepare_autograd_meta_str, GRAD_NODE_CREATION_TEMPLATE, prepare_autograd_meta_str,
...@@ -793,7 +783,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -793,7 +783,7 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
Controller.Instance().GetExpectedPlace(), {}); Controller.Instance().GetExpectedPlace(), {});
// According to fwd_outputs_names // According to fwd_outputs_names
std::vector<egr::EagerTensor> Out0 = GetOutputs(outs["Out0"]); std::vector<egr::EagerTensor> Out0 = GGetOutputetOutputs(outs["Out0"]);
egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]); egr::EagerTensor Out1 = GetOutputs(outs["Out1"][0]);
std::vector<egr::EagerTensor> Out2 = GetOutputs(outs["Out2"]); std::vector<egr::EagerTensor> Out2 = GetOutputs(outs["Out2"]);
...@@ -830,7 +820,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -830,7 +820,8 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
input_args_str_list[input_position] = input_args_str_list[input_position] =
paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name); paddle::string::Sprintf(FWD_INS_ARG_TEMPLATE, input_name);
} }
const char* FWD_INS_CONTENT_TEMPLATE = "{ \"%s\", egr::SyncToVars(%s) },"; const char* FWD_INS_CONTENT_TEMPLATE =
"{ \"%s\", egr::EagerUtils::SyncToVars(%s) },";
ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE, ins_contents_str += paddle::string::Sprintf(FWD_INS_CONTENT_TEMPLATE,
input_name, input_name); input_name, input_name);
} }
...@@ -925,14 +916,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents( ...@@ -925,14 +916,14 @@ static std::pair<std::string, std::string> GenerateForwardFunctionContents(
if (output.duplicable()) { if (output.duplicable()) {
const char* FWD_OUT_TENSORS_TEMPLATE = const char* FWD_OUT_TENSORS_TEMPLATE =
" std::vector<egr::EagerTensor> %s = " " std::vector<egr::EagerTensor> %s = "
"egr::GetOutputs(outs[\"%s\"]);\n"; "egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE, out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSORS_TEMPLATE,
output_name, output_name); output_name, output_name);
return_types[return_position] = "std::vector<egr::EagerTensor>"; return_types[return_position] = "std::vector<egr::EagerTensor>";
} else { } else {
const char* FWD_OUT_TENSOR_TEMPLATE = const char* FWD_OUT_TENSOR_TEMPLATE =
" egr::EagerTensor %s = " " egr::EagerTensor %s = "
"egr::GetOutput(outs[\"%s\"][0]);\n"; "egr::EagerUtils::GetOutput(outs[\"%s\"][0]);\n";
out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE, out_tensor_str = paddle::string::Sprintf(FWD_OUT_TENSOR_TEMPLATE,
output_name, output_name); output_name, output_name);
return_types[return_position] = "egr::EagerTensor"; return_types[return_position] = "egr::EagerTensor";
...@@ -1093,7 +1084,8 @@ static std::string GenerateGradNodeCCContents( ...@@ -1093,7 +1084,8 @@ static std::string GenerateGradNodeCCContents(
grad_ins_fwd_slotname_map.at(grad_input_name) + "_"; grad_ins_fwd_slotname_map.at(grad_input_name) + "_";
const char* GRAD_INS_FWD_CONTENT_TEMPLATE = const char* GRAD_INS_FWD_CONTENT_TEMPLATE =
"{ \"%s\", " "{ \"%s\", "
"egr::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&this->%s, " "egr::EagerUtils::SyncToVars(egr::EagerUtils::RecoverTensorWrapper(&"
"this->%s, "
"nullptr)) },"; "nullptr)) },";
ins_contents_str += ins_contents_str +=
paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE, paddle::string::Sprintf(GRAD_INS_FWD_CONTENT_TEMPLATE,
...@@ -1104,7 +1096,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -1104,7 +1096,7 @@ static std::string GenerateGradNodeCCContents(
size_t fwd_output_position = fwd_outputs_name_pos_map.at( size_t fwd_output_position = fwd_outputs_name_pos_map.at(
grad_ins_grad_slotname_map.at(grad_input_name)); grad_ins_grad_slotname_map.at(grad_input_name));
const char* GRAD_INS_GRAD_CONTENT_TEMPLATE = const char* GRAD_INS_GRAD_CONTENT_TEMPLATE =
"{ \"%s\", egr::SyncToVars(grads[%d]) },"; "{ \"%s\", egr::EagerUtils::SyncToVars(grads[%d]) },";
ins_contents_str += paddle::string::Sprintf( ins_contents_str += paddle::string::Sprintf(
GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position); GRAD_INS_GRAD_CONTENT_TEMPLATE, grad_input_name, fwd_output_position);
...@@ -1206,7 +1198,7 @@ static std::string GenerateGradNodeCCContents( ...@@ -1206,7 +1198,7 @@ static std::string GenerateGradNodeCCContents(
fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name)); fwd_inputs_name_pos_map.at(grad_outs_slotname_map.at(grad_out_name));
const char* BWD_OUTPUT_TEMPLATE = const char* BWD_OUTPUT_TEMPLATE =
" outputs[%d] = GetOutputs(outs[\"%s\"]);\n"; " outputs[%d] = egr::EagerUtils::GetOutputs(outs[\"%s\"]);\n";
outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE, outputs_str += paddle::string::Sprintf(BWD_OUTPUT_TEMPLATE,
fwd_input_position, grad_out_name); fwd_input_position, grad_out_name);
} }
...@@ -1526,6 +1518,9 @@ static void DygraphCodeGeneration(const std::string& output_dir) { ...@@ -1526,6 +1518,9 @@ static void DygraphCodeGeneration(const std::string& output_dir) {
GenerateForwardHFile(output_dir, dygraph_forward_api_str); GenerateForwardHFile(output_dir, dygraph_forward_api_str);
} }
} // namespace framework
} // namespace paddle
int main(int argc, char* argv[]) { int main(int argc, char* argv[]) {
if (argc != 2) { if (argc != 2) {
std::cerr << "argc must be 2" << std::endl; std::cerr << "argc must be 2" << std::endl;
...@@ -1537,6 +1532,3 @@ int main(int argc, char* argv[]) { ...@@ -1537,6 +1532,3 @@ int main(int argc, char* argv[]) {
return 0; return 0;
} }
} // namespace framework
} // namespace paddle
...@@ -20,7 +20,7 @@ ...@@ -20,7 +20,7 @@
#include "paddle/fluid/framework/pten_utils.h" #include "paddle/fluid/framework/pten_utils.h"
#include "paddle/utils/small_vector.h" #include "paddle/utils/small_vector.h"
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/platform/xpu/xpu_op_list.h" #include "paddle/fluid/platform/device/xpu/xpu_op_list.h"
#endif #endif
DECLARE_bool(check_nan_inf); DECLARE_bool(check_nan_inf);
DECLARE_bool(run_pten_kernel); DECLARE_bool(run_pten_kernel);
......
set(eager_deps pten pten_api hook_utils tensor_utils utils global_utils backward pten_tensor autograd_meta grad_node_info grad_tensor_holder gradient_accumulation accumulation_node)
set(fluid_deps tracer layer proto_desc operator op_registry variable_helper memcpy)
add_subdirectory(data_structure_tests) add_subdirectory(data_structure_tests)
add_subdirectory(task_tests) add_subdirectory(task_tests)
...@@ -5,3 +5,7 @@ cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_ ...@@ -5,3 +5,7 @@ cc_test(test_egr_task_backward SRCS backward_test.cc DEPS ${eager_deps} ${fluid_
cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_hook SRCS hook_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_cross_batch SRCS cross_batch_accumulation_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node) cc_test(test_egr_task_fwd_bwd_joint SRCS fwd_bwd_joint_test.cc DEPS ${eager_deps} ${fluid_deps} eager_scale scale_node)
if(NOT DEFINED ON_INFER)
cc_test(test_egr_task_autocodegen SRCS generated_test.cc DEPS ${eager_deps} ${fluid_deps} ${generated_deps})
endif()
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// Eager Dygraph
#include <chrono>
#include "gtest/gtest.h"
#include "paddle/fluid/eager/api/all.h"
#include "paddle/fluid/eager/api/utils/tensor_utils.h"
#include "paddle/fluid/eager/autograd_meta.h"
#include "paddle/fluid/eager/backward.h"
#include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/tests/test_utils.h"
#include "paddle/fluid/imperative/tracer.h"
#include "paddle/fluid/eager/api/generated/fluid_generated/dygraph_forward_api.h"
#include "paddle/pten/core/kernel_registry.h"
// TODO(jiabin): remove nolint here!!!
using namespace egr; // NOLINT
namespace eager_test {
TEST(Generated, Sigmoid) {
// Prepare Device Contexts
InitEnv(paddle::platform::CPUPlace());
VLOG(6) << "Init Env";
// 1. Prepare Input
paddle::framework::DDim ddim = paddle::framework::make_ddim({2, 4, 4, 4});
VLOG(6) << "Make Dim";
egr::EagerTensor tensor = CreateTensorWithValue(
ddim, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 0.0, true);
VLOG(6) << "Make EagerTensor";
RetainGradForTensor(tensor);
VLOG(6) << "Retain Grad for Tensor";
auto output_tensor = sigmoid_dygraph_function(tensor, {});
VLOG(6) << "Run Backward";
CompareVariableWithValue<float>(output_tensor, 0.5);
std::vector<egr::EagerTensor> target_tensors = {output_tensor};
VLOG(6) << "Runing Backward";
RunBackward(target_tensors, {});
VLOG(6) << "Finish Backward";
CompareGradVariableWithValue<float>(tensor, 0.25);
}
TEST(Generated, Matmul_v2) {
// Prepare Device Contexts
InitEnv(paddle::platform::CPUPlace());
auto tracer = std::make_shared<paddle::imperative::Tracer>();
paddle::imperative::SetCurrentTracer(tracer);
// 1. Prepare Input
paddle::framework::DDim ddimX = paddle::framework::make_ddim({4, 16});
egr::EagerTensor X = CreateTensorWithValue(
ddimX, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 3.0, true);
RetainGradForTensor(X);
paddle::framework::DDim ddimY = paddle::framework::make_ddim({16, 20});
egr::EagerTensor Y = CreateTensorWithValue(
ddimY, paddle::platform::CPUPlace(), pten::DataType::FLOAT32,
pten::DataLayout::NCHW, 2.0, true);
RetainGradForTensor(Y);
auto output_tensor = matmul_v2_dygraph_function(
X, Y, {{"trans_x", false}, {"trans_y", false}});
CompareVariableWithValue<float>(output_tensor, 96);
std::vector<egr::EagerTensor> target_tensors = {output_tensor};
RunBackward(target_tensors, {});
CompareGradVariableWithValue<float>(X, 2.0 * 20);
CompareGradVariableWithValue<float>(Y, 3.0 * 4);
}
} // namespace eager_test
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include "paddle/fluid/eager/utils.h" #include "paddle/fluid/eager/utils.h"
#include "paddle/fluid/eager/api/utils/global_utils.h" #include "paddle/fluid/eager/api/utils/global_utils.h"
#include "paddle/fluid/eager/tensor_wrapper.h"
#include "paddle/pten/api/all.h" #include "paddle/pten/api/all.h"
#include "paddle/pten/common/layout.h" #include "paddle/pten/common/layout.h"
...@@ -188,4 +189,19 @@ egr::EagerTensor EagerUtils::GetOutput( ...@@ -188,4 +189,19 @@ egr::EagerTensor EagerUtils::GetOutput(
return EagerTensor((*(out.get()))); return EagerTensor((*(out.get())));
} }
EagerTensor EagerUtils::RecoverTensorWrapper(
TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node) {
return tw->recover(grad_node);
}
std::vector<EagerTensor> EagerUtils::RecoverTensorWrapper(
std::vector<TensorWrapper>* tw,
const std::shared_ptr<GradNodeBase>& grad_node) {
std::vector<EagerTensor> ret;
for (auto& t : *tw) {
ret.emplace_back(t.recover(grad_node));
}
return ret;
}
} // namespace egr } // namespace egr
...@@ -22,6 +22,8 @@ ...@@ -22,6 +22,8 @@
namespace egr { namespace egr {
class TensorWrapper;
/** /**
* EagerUtils is utils used to do some static conversion or autograd * EagerUtils is utils used to do some static conversion or autograd
* members access, this class is desinged to be a full static functional * members access, this class is desinged to be a full static functional
...@@ -131,6 +133,13 @@ class EagerUtils { ...@@ -131,6 +133,13 @@ class EagerUtils {
iter.apply(std::forward<Args>(args)...); iter.apply(std::forward<Args>(args)...);
} }
// TensorWrapper Utils
static egr::EagerTensor RecoverTensorWrapper(
egr::TensorWrapper* tw, const std::shared_ptr<GradNodeBase>& grad_node);
static std::vector<egr::EagerTensor> RecoverTensorWrapper(
std::vector<egr::TensorWrapper>* tw,
const std::shared_ptr<GradNodeBase>& grad_node);
// Intermidate needed remove this once we don't need legacy // Intermidate needed remove this once we don't need legacy
static std::vector<std::shared_ptr<egr::EagerTensor>> SyncToVars( static std::vector<std::shared_ptr<egr::EagerTensor>> SyncToVars(
const egr::EagerTensor& tensor); const egr::EagerTensor& tensor);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/eager/legacy/type_def.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/imperative/type_defs.h" #include "paddle/fluid/imperative/type_defs.h"
...@@ -53,6 +54,19 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type, ...@@ -53,6 +54,19 @@ void CheckOpHasNanOrInfInDygraph(const std::string& op_type,
} }
} }
template <typename TensorType>
static void CheckOpHasNanOrInfInEager(const std::string& op_type,
const egr::NameMap<TensorType>& op_outs,
platform::Place place) {
for (const auto& pair : op_outs) {
for (const auto& tensor : pair.second) {
auto* var = tensor->MutableVar();
if (var == nullptr) continue;
CheckVarHasNanOrInf(op_type, tensor->name(), var, place);
}
}
}
#ifdef PADDLE_WITH_ASCEND_CL #ifdef PADDLE_WITH_ASCEND_CL
void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op, void NPUAllocAndClearFloatStatus(const framework::OperatorBase& op,
const framework::ScopeBase& scope, const framework::ScopeBase& scope,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册