未验证 提交 2f50ae99 编写于 作者: Z Zhanlue Yang 提交者: GitHub

Support initializing specific grad tensors to zero for selected operators (#39963)

* Supported Complex2Real Conversion for Eager Dygraph

* Supported Complex2Real Conversion for Eager Dygraph

* Enabled complex type promotion test for matmul_v2

* Fix CI issues

* Support initializing specific grad tensors to zero for selected operators

* Merged adj_edges_ with GradSlotMeta

* Fixed monir issue

* Adjusted num runs

* Recovered Eager performance tests configurations

* Recovered Eager performance tests configurations

* Adjusted performance tests configurations

* Fixed Minor Issues with performance tests

* Moved out Edge from GradSlotMeta

* Fixed issues from merge

* Fixed typo

* Addressed review comments

* Fixed merge issues

* Fixed minor issues

* Fixed minor issue

* Fixed major issues and enabled auto_prune test cases

* Fixed issues from merge
上级 8991e9ae
...@@ -39,8 +39,9 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor, ...@@ -39,8 +39,9 @@ static void CopyOrAddTensor(paddle::experimental::Tensor* tensor,
} }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeAccumulation::
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, operator()(
bool create_graph) { std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation"; VLOG(3) << "Running Eager Backward Node: GradNodeAccumulation";
PADDLE_ENFORCE(grads.size() == 1, PADDLE_ENFORCE(grads.size() == 1,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
......
...@@ -35,7 +35,7 @@ class GradNodeAccumulation : public GradNodeBase { ...@@ -35,7 +35,7 @@ class GradNodeAccumulation : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override; bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
...@@ -145,8 +145,9 @@ void GradNodeScale::SetTensorWrappers_X( ...@@ -145,8 +145,9 @@ void GradNodeScale::SetTensorWrappers_X(
void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; } void GradNodeScale::SetAttributes_scale(float scale) { scale_ = scale; }
std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale:: std::vector<std::vector<paddle::experimental::Tensor>> GradNodeScale::
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, operator()(
bool create_graph) { std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph) {
// 1. Check Output Size // 1. Check Output Size
PADDLE_ENFORCE( PADDLE_ENFORCE(
((grads.size() == 1) && (grads[0].size() == 1)), ((grads.size() == 1) && (grads[0].size() == 1)),
......
...@@ -39,7 +39,7 @@ class GradNodeScale : public GradNodeBase { ...@@ -39,7 +39,7 @@ class GradNodeScale : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) override; bool create_graph = false) override;
void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; } void ClearTensorWrappers() override { VLOG(6) << "Do nothing here now"; }
......
...@@ -47,6 +47,9 @@ std::unordered_map<std::string, std::vector<std::string>> ...@@ -47,6 +47,9 @@ std::unordered_map<std::string, std::vector<std::string>>
static std::unordered_map<std::string, paddle::framework::AttributeMap> static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {}; operators_with_attrs = {};
static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = {
"split"};
/* --- Black Ops list that's NO NEED to apply code generation --- */ /* --- Black Ops list that's NO NEED to apply code generation --- */
static std::unordered_set<std::string> black_ops_list = {"run_program"}; static std::unordered_set<std::string> black_ops_list = {"run_program"};
...@@ -2243,11 +2246,21 @@ static std::string GenerateGradNodeCCContents( ...@@ -2243,11 +2246,21 @@ static std::string GenerateGradNodeCCContents(
// [Generation] Get Full Grad Function // [Generation] Get Full Grad Function
const char* GRAD_FUNCTION_TEMPLATE = const char* GRAD_FUNCTION_TEMPLATE =
"std::vector<std::vector<paddle::experimental::Tensor>> " "std::vector<std::vector<paddle::experimental::Tensor>> "
"GradNode%s::operator()(const " "GradNode%s::operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, " "std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"bool create_graph) {\n%s\n}"; "create_graph) {\n"
std::string grad_function_str = paddle::string::Sprintf( "%s"
GRAD_FUNCTION_TEMPLATE, fwd_op_type, generated_grad_function_body); "%s"
"\n}";
std::string fill_zero_str = "";
if (ops_to_fill_zero_for_empty_grads.count(fwd_op_type)) {
fill_zero_str =
"egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, "
"this->InputMeta());\n";
}
std::string grad_function_str =
paddle::string::Sprintf(GRAD_FUNCTION_TEMPLATE, fwd_op_type,
fill_zero_str, generated_grad_function_body);
VLOG(6) << "Generated returns"; VLOG(6) << "Generated returns";
...@@ -2279,9 +2292,9 @@ static std::string GenerateGradNodeHeaderContents( ...@@ -2279,9 +2292,9 @@ static std::string GenerateGradNodeHeaderContents(
" ~GradNode%s() override { VLOG(6) << \" Destruct GradNode%s \"; }\n" " ~GradNode%s() override { VLOG(6) << \" Destruct GradNode%s \"; }\n"
"\n" "\n"
" virtual std::vector<std::vector<paddle::experimental::Tensor>> " " virtual std::vector<std::vector<paddle::experimental::Tensor>> "
"operator()(const " "operator()("
"std::vector<std::vector<paddle::experimental::Tensor>>& grads, const " "std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool "
"bool create_graph = false) " "create_graph = false) "
"override;\n" "override;\n"
"\n" "\n"
" void ClearTensorWrappers() override { \n" " void ClearTensorWrappers() override { \n"
......
...@@ -17,6 +17,8 @@ import re ...@@ -17,6 +17,8 @@ import re
import argparse import argparse
import os import os
ops_to_fill_zero_for_empty_grads = set(list("split"))
# For API dispatch used at python-level # For API dispatch used at python-level
# { op_name : [arg_name, ...] } # { op_name : [arg_name, ...] }
core_ops_returns_info = {} core_ops_returns_info = {}
...@@ -598,7 +600,8 @@ class {} : public egr::GradNodeBase {{ ...@@ -598,7 +600,8 @@ class {} : public egr::GradNodeBase {{
~{}() override = default; ~{}() override = default;
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override; std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph = false) override;
std::string name() override {{ return \" {} \"; }} std::string name() override {{ return \" {} \"; }}
void ClearTensorWrappers() override {{ void ClearTensorWrappers() override {{
...@@ -656,10 +659,11 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -656,10 +659,11 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
for _, (ttype, fwd_position, for _, (ttype, fwd_position,
grad_api_position) in backward_grad_input_map.items(): grad_api_position) in backward_grad_input_map.items():
if IsPlainTensorType(ttype): if IsPlainTensorType(ttype):
grad_api_args[grad_api_position] = f"grads[{fwd_position}][0]" grad_api_args[
grad_api_position] = f"hooked_grads[{fwd_position}][0]"
else: else:
assert IsVectorTensorType(ttype) assert IsVectorTensorType(ttype)
grad_api_args[grad_api_position] = f"grads[{fwd_position}]" grad_api_args[grad_api_position] = f"hooked_grads[{fwd_position}]"
for name, _, _, grad_api_position in backward_attrs_list: for name, _, _, grad_api_position in backward_attrs_list:
saved_attribute_name = GetSavedName(name) saved_attribute_name = GetSavedName(name)
...@@ -687,23 +691,30 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map, ...@@ -687,23 +691,30 @@ def GenerateNodeDefinition(fwd_api_name, bwd_api_name, backward_fwd_input_map,
grad_node_name = GetGradNodeName(fwd_api_name) grad_node_name = GetGradNodeName(fwd_api_name)
fill_zero_str = ""
if fwd_api_name in ops_to_fill_zero_for_empty_grads:
fill_zero_str = "egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());\n"
if len(namespace) > 0: if len(namespace) > 0:
grad_api_namespace = f"paddle::experimental::{namespace}" grad_api_namespace = f"paddle::experimental::{namespace}"
else: else:
grad_api_namespace = f"paddle::experimental" grad_api_namespace = f"paddle::experimental"
FUNCTION_TEMPLATE = """ FUNCTION_TEMPLATE = """
std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{ std::vector<std::vector<paddle::experimental::Tensor>> {}::operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads, bool create_graph) {{
{}
auto hooked_grads = ApplyGradientHooks(grads);
// Call grad_api function // Call grad_api function
VLOG(3) << \"Finally State Running: \" << \"{}\"; VLOG(3) << \"Final State Running: \" << \"{}\";
auto grad_api_returns = {}::{}({}); auto grad_api_returns = {}::{}({});
{} {}
}} }}
""" """
node_definition_str = FUNCTION_TEMPLATE.format( node_definition_str = FUNCTION_TEMPLATE.format(
grad_node_name, grad_node_name, grad_api_namespace, bwd_api_name, grad_node_name, fill_zero_str, grad_node_name, grad_api_namespace,
grad_api_args_str, returns_str) bwd_api_name, grad_api_args_str, returns_str)
return node_definition_str return node_definition_str
......
...@@ -20,8 +20,8 @@ ...@@ -20,8 +20,8 @@
namespace egr { namespace egr {
std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode:: std::vector<std::vector<paddle::experimental::Tensor>> RunCustomOpNode::
operator()(const std::vector<std::vector<paddle::experimental::Tensor>>& grads, operator()(std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) { bool create_graph) { // NOLINT
paddle::CustomOpKernelContext ctx; paddle::CustomOpKernelContext ctx;
auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs( auto grad_inputs_name = paddle::framework::OpMetaInfoHelper::GetInputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
......
...@@ -37,8 +37,9 @@ class RunCustomOpNode : public GradNodeBase { ...@@ -37,8 +37,9 @@ class RunCustomOpNode : public GradNodeBase {
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph) override; bool create_graph = false) // NOLINT
override;
std::string name() { std::string name() {
return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_); return paddle::string::Sprintf("RunCustomOpNode: %s_grad", op_type_);
......
...@@ -102,6 +102,7 @@ const std::vector<std::vector<GradSlotMeta>>& GradNodeBase::OutputMeta() const { ...@@ -102,6 +102,7 @@ const std::vector<std::vector<GradSlotMeta>>& GradNodeBase::OutputMeta() const {
void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out, void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
size_t slot_rank) { size_t slot_rank) {
VLOG(6) << "Set GradSlotMeta for Grad Inputs";
auto* fwd_out_meta = egr::EagerUtils::nullable_autograd_meta(fwd_out); auto* fwd_out_meta = egr::EagerUtils::nullable_autograd_meta(fwd_out);
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1), slot_rank, (bwd_in_meta_.size() - 1),
...@@ -117,6 +118,12 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out, ...@@ -117,6 +118,12 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
auto& meta = metas[0]; auto& meta = metas[0];
meta.SetStopGradient(fwd_out_meta->StopGradient()); meta.SetStopGradient(fwd_out_meta->StopGradient());
if (!fwd_out.is_initialized()) {
VLOG(6)
<< "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor";
return;
}
// Record TensorMeta // Record TensorMeta
if (phi::DenseTensor::classof(fwd_out.impl().get())) { if (phi::DenseTensor::classof(fwd_out.impl().get())) {
// Only Copy Meta // Only Copy Meta
...@@ -128,7 +135,9 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out, ...@@ -128,7 +135,9 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
paddle::platform::errors::Fatal( paddle::platform::errors::Fatal(
"Attempting to copy DenseTensorMeta with phi::DataType::UNDEFINED," "Attempting to copy DenseTensorMeta with phi::DataType::UNDEFINED,"
"which is illegal.")); "which is illegal."));
meta.SetTensorMeta(dense_tensor->meta()); meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_out.inner_place());
if (paddle::framework::IsComplexType( if (paddle::framework::IsComplexType(
paddle::framework::TransToProtoVarType(dense_tensor->type()))) { paddle::framework::TransToProtoVarType(dense_tensor->type()))) {
...@@ -143,6 +152,7 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out, ...@@ -143,6 +152,7 @@ void GradNodeBase::SetGradInMeta(const paddle::experimental::Tensor& fwd_out,
void GradNodeBase::SetGradInMeta( void GradNodeBase::SetGradInMeta(
const std::vector<paddle::experimental::Tensor>& fwd_out, const std::vector<paddle::experimental::Tensor>& fwd_out,
size_t slot_rank) { size_t slot_rank) {
VLOG(6) << "Set GradSlotMeta for Grad Inputs";
size_t slot_size = fwd_out.size(); size_t slot_size = fwd_out.size();
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
slot_rank, (bwd_in_meta_.size() - 1), slot_rank, (bwd_in_meta_.size() - 1),
...@@ -172,6 +182,12 @@ void GradNodeBase::SetGradInMeta( ...@@ -172,6 +182,12 @@ void GradNodeBase::SetGradInMeta(
meta.SetStopGradient(fwd_out_meta->StopGradient()); meta.SetStopGradient(fwd_out_meta->StopGradient());
} }
if (!fwd_out_tensor.is_initialized()) {
VLOG(6)
<< "Skip Configuring GradSlotMeta for uninitialized GradInput Tensor";
return;
}
// Record TensorMeta // Record TensorMeta
if (phi::DenseTensor::classof(fwd_out_tensor.impl().get())) { if (phi::DenseTensor::classof(fwd_out_tensor.impl().get())) {
// Only Copy Meta // Only Copy Meta
...@@ -184,6 +200,8 @@ void GradNodeBase::SetGradInMeta( ...@@ -184,6 +200,8 @@ void GradNodeBase::SetGradInMeta(
"with phi::DataType::UNDEFINED," "with phi::DataType::UNDEFINED,"
"which is illegal.")); "which is illegal."));
meta.SetTensorMeta(dense_tensor->meta()); meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_out_tensor.inner_place());
if (paddle::framework::IsComplexType( if (paddle::framework::IsComplexType(
paddle::framework::TransToProtoVarType(dense_tensor->type()))) { paddle::framework::TransToProtoVarType(dense_tensor->type()))) {
need_complex_to_real_ = true; need_complex_to_real_ = true;
...@@ -228,6 +246,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in, ...@@ -228,6 +246,7 @@ void GradNodeBase::SetGradOutMeta(const paddle::experimental::Tensor& fwd_in,
"with phi::DataType::UNDEFINED," "with phi::DataType::UNDEFINED,"
"which is illegal.")); "which is illegal."));
meta.SetTensorMeta(dense_tensor->meta()); meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in.inner_place());
} }
} else { } else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with " VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta with "
...@@ -272,6 +291,7 @@ void GradNodeBase::SetGradOutMeta( ...@@ -272,6 +291,7 @@ void GradNodeBase::SetGradOutMeta(
"phi::DataType::UNDEFINED," "phi::DataType::UNDEFINED,"
"which is illegal.")); "which is illegal."));
meta.SetTensorMeta(dense_tensor->meta()); meta.SetTensorMeta(dense_tensor->meta());
meta.SetPlace(fwd_in_tensor.inner_place());
} }
} else { } else {
VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta " VLOG(6) << "Unable to initialize the DenseTensorMeta of GradSlotMeta "
......
...@@ -76,8 +76,12 @@ class GradSlotMeta { ...@@ -76,8 +76,12 @@ class GradSlotMeta {
return *meta_.get(); return *meta_.get();
} }
void SetPlace(const phi::Place& place) { place_ = place; }
const phi::Place& GetPlace() const { return place_; }
private: private:
bool stop_gradient_{false}; bool stop_gradient_{false};
phi::Place place_;
std::shared_ptr<phi::DenseTensorMeta> meta_ = nullptr; std::shared_ptr<phi::DenseTensorMeta> meta_ = nullptr;
}; };
...@@ -102,7 +106,7 @@ class GradNodeBase { ...@@ -102,7 +106,7 @@ class GradNodeBase {
* is better choice to fit this format. * is better choice to fit this format.
* **/ * **/
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, std::vector<std::vector<paddle::experimental::Tensor>>& grads, // NOLINT
bool create_graph = false) = 0; bool create_graph = false) = 0;
virtual void ClearTensorWrappers() = 0; virtual void ClearTensorWrappers() = 0;
......
...@@ -53,7 +53,7 @@ class GradTensorHolder { ...@@ -53,7 +53,7 @@ class GradTensorHolder {
return buffer_[pos]; return buffer_[pos];
} }
const std::vector<std::vector<paddle::experimental::Tensor>>& Buffers() { std::vector<std::vector<paddle::experimental::Tensor>>& Buffers() {
return buffer_; return buffer_;
} }
......
...@@ -80,13 +80,15 @@ TEST(AccumulationNode, Tensor) { ...@@ -80,13 +80,15 @@ TEST(AccumulationNode, Tensor) {
grad_meta->SetStopGradient(false); grad_meta->SetStopGradient(false);
// operator() // operator()
paddle::experimental::Tensor ret_et0 = node->operator()({{et0}})[0][0]; std::vector<std::vector<paddle::experimental::Tensor>> et0_vec = {{et0}};
paddle::experimental::Tensor ret_et0 = node->operator()(et0_vec)[0][0];
auto* ret_et0_ptr = auto* ret_et0_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et0.impl()) std::dynamic_pointer_cast<phi::DenseTensor>(ret_et0.impl())
->data<paddle::platform::float16>(); ->data<paddle::platform::float16>();
CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f)); CHECK_EQ(ret_et0_ptr[0], paddle::platform::float16(10.0f));
paddle::experimental::Tensor ret_et1 = node->operator()({{et1}})[0][0]; std::vector<std::vector<paddle::experimental::Tensor>> et1_vec = {{et1}};
paddle::experimental::Tensor ret_et1 = node->operator()(et1_vec)[0][0];
auto* ret_et1_ptr = auto* ret_et1_ptr =
std::dynamic_pointer_cast<phi::DenseTensor>(ret_et1.impl()) std::dynamic_pointer_cast<phi::DenseTensor>(ret_et1.impl())
...@@ -121,7 +123,7 @@ TEST(AccumulationNode, Tensor) { ...@@ -121,7 +123,7 @@ TEST(AccumulationNode, Tensor) {
std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1)); std::make_shared<egr::CppTensorVoidHook>(reduce_hook_1));
// operator() // operator()
paddle::experimental::Tensor _ret = node->operator()({{et0}})[0][0]; paddle::experimental::Tensor _ret = node->operator()(et0_vec)[0][0];
// Check operator() result, should be 36.0 // Check operator() result, should be 36.0
auto* _ret_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(_ret.impl()) auto* _ret_ptr = std::dynamic_pointer_cast<phi::DenseTensor>(_ret.impl())
......
...@@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase { ...@@ -32,7 +32,7 @@ class GradTestNode : public egr::GradNodeBase {
GradTestNode() : GradNodeBase() { val_ = 1.0; } GradTestNode() : GradNodeBase() { val_ = 1.0; }
std::string name() override { return "GradTestNode"; } std::string name() override { return "GradTestNode"; }
std::vector<std::vector<paddle::experimental::Tensor>> operator()( std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>>& grads, std::vector<std::vector<paddle::experimental::Tensor>>& grads,
bool create_graph = false) override { bool create_graph = false) override {
val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl()) val_ = std::dynamic_pointer_cast<phi::DenseTensor>(grads[0][0].impl())
->data<float>()[0]; ->data<float>()[0];
......
...@@ -247,4 +247,20 @@ TEST(EagerUtils, GetGradAccumulationNode) { ...@@ -247,4 +247,20 @@ TEST(EagerUtils, GetGradAccumulationNode) {
ASSERT_ANY_THROW(egr::EagerUtils::GetGradAccumulationNode(t0)); ASSERT_ANY_THROW(egr::EagerUtils::GetGradAccumulationNode(t0));
} }
TEST(EagerUtils, FillZeroForEmptyGradInputs) {
std::vector<std::vector<paddle::experimental::Tensor>> grads = {
std::vector<paddle::experimental::Tensor>(1)};
std::vector<std::vector<GradSlotMeta>> slot_metas = {
std::vector<GradSlotMeta>(1)};
phi::DenseTensorMeta tensor_meta;
tensor_meta.dtype = paddle::experimental::DataType::FLOAT32;
tensor_meta.dims = {2, 4};
slot_metas[0][0].SetTensorMeta(tensor_meta);
slot_metas[0][0].SetPlace(phi::CPUPlace());
EagerUtils::FillZeroForEmptyGradInputs(&grads, slot_metas);
eager_test::CompareTensorWithValue<float>(grads[0][0], 0.0);
}
} // namespace egr } // namespace egr
...@@ -370,7 +370,7 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -370,7 +370,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
~GradNodeRunProgram() override = default; ~GradNodeRunProgram() override = default;
// Functor: perform backward computations // Functor: perform backward computations
virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()( virtual std::vector<std::vector<paddle::experimental::Tensor>> operator()(
const std::vector<std::vector<paddle::experimental::Tensor>> &grads, std::vector<std::vector<paddle::experimental::Tensor>> &grads, // NOLINT
bool create_graph) override { bool create_graph) override {
VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram"; VLOG(3) << "Running Eager Backward Node: GradNodeRunProgram";
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "paddle/phi/api/all.h" #include "paddle/phi/api/all.h"
#include "paddle/phi/common/layout.h" #include "paddle/phi/common/layout.h"
#include "paddle/phi/core/compat/convert_utils.h"
#include "paddle/phi/core/tensor_meta.h" #include "paddle/phi/core/tensor_meta.h"
#include "paddle/fluid/framework/data_layout.h" #include "paddle/fluid/framework/data_layout.h"
...@@ -392,4 +393,28 @@ std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode( ...@@ -392,4 +393,28 @@ std::shared_ptr<egr::GradNodeBase> EagerUtils::GetGradAccumulationNode(
} }
} }
void EagerUtils::FillZeroForEmptyGradInputs(
std::vector<std::vector<paddle::experimental::Tensor>>* in_grads,
const std::vector<std::vector<GradSlotMeta>>& grad_in_metas) {
for (size_t i = 0; i < in_grads->size(); i++) {
for (size_t j = 0; j < (*in_grads)[0].size(); j++) {
paddle::experimental::Tensor& grad = (*in_grads)[i][j];
if (!grad.is_initialized()) {
const GradSlotMeta& grad_in_meta = grad_in_metas[i][j];
PADDLE_ENFORCE(
grad_in_meta.HasTensorMeta(),
paddle::platform::errors::Fatal(
"Unable to fill empty grad inputs due to empty GradSlotMeta"));
const auto& tensor_meta = grad_in_meta.GetTensorMeta();
phi::Place place = grad_in_meta.GetPlace();
auto tensor_with_zero = paddle::experimental::full(
phi::vectorize(tensor_meta.dims), 0.0, tensor_meta.dtype, place);
grad.set_impl(tensor_with_zero.impl());
}
}
}
}
} // namespace egr } // namespace egr
...@@ -217,6 +217,13 @@ class EagerUtils { ...@@ -217,6 +217,13 @@ class EagerUtils {
const std::vector<paddle::experimental::Tensor>& tensors); const std::vector<paddle::experimental::Tensor>& tensors);
static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode( static std::shared_ptr<egr::GradNodeBase> GetGradAccumulationNode(
const paddle::experimental::Tensor& tensor); const paddle::experimental::Tensor& tensor);
/**
* Fill Zero
* **/
static void FillZeroForEmptyGradInputs(
std::vector<std::vector<paddle::experimental::Tensor>>* out_grads,
const std::vector<std::vector<GradSlotMeta>>& grad_out_metas);
}; };
} // namespace egr } // namespace egr
...@@ -182,7 +182,7 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -182,7 +182,7 @@ class TestImperativeAutoPrune(unittest.TestCase):
self.func_auto_prune2() self.func_auto_prune2()
# TODO(jiabin): Support this when we support better split tensor # TODO(jiabin): Support this when we support better split tensor
def test_auto_prune3(self): def func_auto_prune3(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
case3 = AutoPruneLayer3(input_size=784) case3 = AutoPruneLayer3(input_size=784)
value1 = np.arange(784).reshape(1, 784).astype("float32") value1 = np.arange(784).reshape(1, 784).astype("float32")
...@@ -194,7 +194,12 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -194,7 +194,12 @@ class TestImperativeAutoPrune(unittest.TestCase):
self.assertTrue(case3.linear.weight._grad_ivar() is not None) self.assertTrue(case3.linear.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune4(self): def test_auto_prune3(self):
with _test_eager_guard():
self.func_auto_prune3()
self.func_auto_prune3()
def func_auto_prune4(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
case4 = AutoPruneLayer3(input_size=784) case4 = AutoPruneLayer3(input_size=784)
value1 = np.arange(784).reshape(1, 784).astype("float32") value1 = np.arange(784).reshape(1, 784).astype("float32")
...@@ -206,7 +211,12 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -206,7 +211,12 @@ class TestImperativeAutoPrune(unittest.TestCase):
self.assertTrue(case4.linear.weight._grad_ivar() is not None) self.assertTrue(case4.linear.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 1).all()) self.assertTrue((part2.gradient() == 1).all())
def test_auto_prune5(self): def test_auto_prune4(self):
with _test_eager_guard():
self.func_auto_prune4()
self.func_auto_prune4()
def func_auto_prune5(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
case4 = AutoPruneLayer3(input_size=784) case4 = AutoPruneLayer3(input_size=784)
value1 = np.arange(784).reshape(1, 784).astype("float32") value1 = np.arange(784).reshape(1, 784).astype("float32")
...@@ -218,6 +228,11 @@ class TestImperativeAutoPrune(unittest.TestCase): ...@@ -218,6 +228,11 @@ class TestImperativeAutoPrune(unittest.TestCase):
self.assertTrue(case4.linear.weight._grad_ivar() is not None) self.assertTrue(case4.linear.weight._grad_ivar() is not None)
self.assertTrue((part2.gradient() == 0).all()) self.assertTrue((part2.gradient() == 0).all())
def test_auto_prune5(self):
with _test_eager_guard():
self.func_auto_prune5()
self.func_auto_prune5()
def func_auto_prune6(self): def func_auto_prune6(self):
with fluid.dygraph.guard(): with fluid.dygraph.guard():
value0 = np.arange(26).reshape(2, 13).astype("float32") value0 = np.arange(26).reshape(2, 13).astype("float32")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册