未验证 提交 3228fc34 编写于 作者: 0 0x45f 提交者: GitHub

Fix loop index for FillZeroForEmptyGradInputs (#40909)

* Fix loop index for FillZeroForEmptyGradInputs

* Call fill zero in run_program_grad
上级 c7b69fd2
...@@ -48,7 +48,7 @@ static std::unordered_map<std::string, paddle::framework::AttributeMap> ...@@ -48,7 +48,7 @@ static std::unordered_map<std::string, paddle::framework::AttributeMap>
operators_with_attrs = {}; operators_with_attrs = {};
static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = { static std::unordered_set<std::string> ops_to_fill_zero_for_empty_grads = {
"split"}; "split", "rnn"};
/* --- Black Ops list that's NO NEED to apply code generation --- */ /* --- Black Ops list that's NO NEED to apply code generation --- */
static std::unordered_set<std::string> black_ops_list = {"run_program"}; static std::unordered_set<std::string> black_ops_list = {"run_program"};
......
...@@ -21,7 +21,7 @@ import os ...@@ -21,7 +21,7 @@ import os
######################## ########################
### Global Variables ### ### Global Variables ###
######################## ########################
ops_to_fill_zero_for_empty_grads = set(list("split")) ops_to_fill_zero_for_empty_grads = set(["split", "rnn"])
# For API dispatch used at python-level # For API dispatch used at python-level
# { op_name : [arg_name, ...] } # { op_name : [arg_name, ...] }
......
...@@ -367,6 +367,7 @@ class GradNodeRunProgram : public egr::GradNodeBase { ...@@ -367,6 +367,7 @@ class GradNodeRunProgram : public egr::GradNodeBase {
paddle::platform::errors::InvalidArgument( paddle::platform::errors::InvalidArgument(
"The out_grads.size() of RunProgramGradOp should be equal to 1.")); "The out_grads.size() of RunProgramGradOp should be equal to 1."));
egr::EagerUtils::FillZeroForEmptyGradInputs(&grads, this->InputMeta());
VLOG(3) << "out_grads[0].size() : " << grads[0].size(); VLOG(3) << "out_grads[0].size() : " << grads[0].size();
std::vector<paddle::experimental::Tensor> x_grad; std::vector<paddle::experimental::Tensor> x_grad;
std::vector<paddle::experimental::Tensor> params_grad; std::vector<paddle::experimental::Tensor> params_grad;
......
...@@ -398,7 +398,7 @@ void EagerUtils::FillZeroForEmptyGradInputs( ...@@ -398,7 +398,7 @@ void EagerUtils::FillZeroForEmptyGradInputs(
std::vector<std::vector<paddle::experimental::Tensor>>* in_grads, std::vector<std::vector<paddle::experimental::Tensor>>* in_grads,
const std::vector<std::vector<GradSlotMeta>>& grad_in_metas) { const std::vector<std::vector<GradSlotMeta>>& grad_in_metas) {
for (size_t i = 0; i < in_grads->size(); i++) { for (size_t i = 0; i < in_grads->size(); i++) {
for (size_t j = 0; j < (*in_grads)[0].size(); j++) { for (size_t j = 0; j < (*in_grads)[i].size(); j++) {
paddle::experimental::Tensor& grad = (*in_grads)[i][j]; paddle::experimental::Tensor& grad = (*in_grads)[i][j];
if (!grad.is_initialized()) { if (!grad.is_initialized()) {
const GradSlotMeta& grad_in_meta = grad_in_metas[i][j]; const GradSlotMeta& grad_in_meta = grad_in_metas[i][j];
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册