未验证 提交 9deb1756 编写于 作者: Q Qiao Longfei 提交者: GitHub

fix while_grad_op first step loss lod problem (#7490)

* fix while_grad_op first step loss lod problem

* optimize code
上级 59bc4c46
......@@ -138,6 +138,7 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
math::set_constant(dev_ctx, &rest_tensor, 0.0f);
}
}
dx_tensor.set_lod(x_tensor.lod());
}
};
......
......@@ -121,8 +121,8 @@ class WhileGradOp : public framework::OperatorBase {
for (size_t i = 0; i < outside_og_names.size(); ++i) {
auto outside_og_name = outside_og_names[i];
auto inside_og_name = inside_og_names[i];
VLOG(10) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
VLOG(8) << "Linking outside " << outside_og_name << " --> inside "
<< inside_og_name;
auto &og_outside =
detail::Ref(scope.FindVar(outside_og_name),
"Cannot find Outside Gradient %s", outside_og_name);
......@@ -141,11 +141,11 @@ class WhileGradOp : public framework::OperatorBase {
auto &outside_array = og_outside.Get<framework::LoDTensorArray>();
auto &inside_array =
detail::Ref(og_inside.GetMutable<framework::LoDTensorArray>());
VLOG(10) << outside_og_name << " size = " << outside_array.size();
VLOG(8) << outside_og_name << " size = " << outside_array.size();
inside_array.resize(outside_array.size());
for (size_t j = 0; j < inside_array.size(); ++j) {
VLOG(10) << j << " " << outside_array[j].numel();
VLOG(8) << j << " " << outside_array[j].numel();
if (outside_array[j].numel() != 0) {
inside_array[j].set_lod(outside_array[j].lod());
inside_array[j].ShareDataWith(outside_array[j]);
......@@ -187,10 +187,14 @@ class WhileGradOp : public framework::OperatorBase {
attrs["shape"] = framework::vectorize2int(inside_tensor.dims());
attrs["value"] = 0.0f;
auto var_name = pg_names[param_id];
auto zero_op = framework::OpRegistry::CreateOp(
"fill_constant", framework::VariableNameMap{},
{{"Out", {pg_names[param_id]}}}, attrs);
{{"Out", {var_name}}}, attrs);
zero_op->Run(scope, dev_place);
scope.FindVar(var_name)
->GetMutable<framework::LoDTensor>()
->set_lod(inside_tensor.lod());
}
}
......@@ -231,7 +235,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
for (auto &each_ig : igs) {
if (inner_op_outputs.find(each_ig) == inner_op_outputs.end()) {
VLOG(10) << "Ignore " << each_ig;
VLOG(8) << "Ignore " << each_ig;
each_ig = framework::kEmptyVarName;
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册