diff --git a/paddle/operators/shrink_rnn_memory_op.cc b/paddle/operators/shrink_rnn_memory_op.cc index ade94b40bed91c64d3074036c067de34323bdaa7..bf870115a4d7b6f4d578df7707826973d4363ba6 100644 --- a/paddle/operators/shrink_rnn_memory_op.cc +++ b/paddle/operators/shrink_rnn_memory_op.cc @@ -138,6 +138,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp { math::set_constant(dev_ctx, &rest_tensor, 0.0f); } } + dx_tensor.set_lod(x_tensor.lod()); } }; diff --git a/paddle/operators/while_op.cc b/paddle/operators/while_op.cc index 7a3400919efe6f3bed40e45a245b556beab6fce4..2fdd25dbbe68659f8a0a9da13a87148ed259127a 100644 --- a/paddle/operators/while_op.cc +++ b/paddle/operators/while_op.cc @@ -121,8 +121,8 @@ class WhileGradOp : public framework::OperatorBase { for (size_t i = 0; i < outside_og_names.size(); ++i) { auto outside_og_name = outside_og_names[i]; auto inside_og_name = inside_og_names[i]; - VLOG(10) << "Linking outside " << outside_og_name << " --> inside " - << inside_og_name; + VLOG(8) << "Linking outside " << outside_og_name << " --> inside " + << inside_og_name; auto &og_outside = detail::Ref(scope.FindVar(outside_og_name), "Cannot find Outside Gradient %s", outside_og_name); @@ -141,11 +141,11 @@ class WhileGradOp : public framework::OperatorBase { auto &outside_array = og_outside.Get(); auto &inside_array = detail::Ref(og_inside.GetMutable()); - VLOG(10) << outside_og_name << " size = " << outside_array.size(); + VLOG(8) << outside_og_name << " size = " << outside_array.size(); inside_array.resize(outside_array.size()); for (size_t j = 0; j < inside_array.size(); ++j) { - VLOG(10) << j << " " << outside_array[j].numel(); + VLOG(8) << j << " " << outside_array[j].numel(); if (outside_array[j].numel() != 0) { inside_array[j].set_lod(outside_array[j].lod()); inside_array[j].ShareDataWith(outside_array[j]); @@ -187,10 +187,14 @@ class WhileGradOp : public framework::OperatorBase { attrs["shape"] = framework::vectorize2int(inside_tensor.dims()); attrs["value"] = 0.0f; + auto var_name = pg_names[param_id]; auto zero_op = framework::OpRegistry::CreateOp( "fill_constant", framework::VariableNameMap{}, - {{"Out", {pg_names[param_id]}}}, attrs); + {{"Out", {var_name}}}, attrs); zero_op->Run(scope, dev_place); + scope.FindVar(var_name) + ->GetMutable() + ->set_lod(inside_tensor.lod()); } } @@ -231,7 +235,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { auto igs = InputGrad(kX, /*do not drop empty gradient*/ false); for (auto &each_ig : igs) { if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) { - VLOG(10) << "Ignore " << each_ig; + VLOG(8) << "Ignore " << each_ig; each_ig = framework::kEmptyVarName; } }