未验证 提交 2e1bca99 编写于 作者: G guofei 提交者: GitHub

Refine the gradient calculation errors caused by renaming in while_grad (#27814)

test=develop
上级 8fa4c098
......@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <set>
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
......@@ -70,6 +72,23 @@ class WhileOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program();
bool is_test = Attr<bool>("is_test");
std::set<std::string> no_copy_var_names;
if (!is_test) {
const std::vector<framework::OpDesc *> &all_ops = block->AllOps();
for (const framework::OpDesc *op : all_ops) {
const framework::VariableNameMap &input_var_names = op->Inputs();
const framework::VariableNameMap &output_var_names = op->Outputs();
for (auto &ipt : input_var_names) {
for (const std::string &var_name : ipt.second) {
if (StrInVaraiableNameMap(var_name, output_var_names)) {
no_copy_var_names.insert(var_name);
}
}
}
}
}
auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
......@@ -89,7 +108,6 @@ class WhileOp : public framework::OperatorBase {
"The Output(StepScope) of WhileOp should be empty."));
bool cond_data = GetCondData(cond);
bool is_test = Attr<bool>("is_test");
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
......@@ -98,8 +116,32 @@ class WhileOp : public framework::OperatorBase {
while (cond_data) {
auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope);
std::vector<std::string> rename_vars;
for (const std::string &input_var_name : Inputs(kX)) {
if (no_copy_var_names.find(input_var_name) ==
no_copy_var_names.end()) {
std::string input_var_rename = input_var_name + kSuffix;
framework::Variable *input_var = scope.FindVar(input_var_name);
if (input_var->IsType<framework::LoDTensor>()) {
rename_vars.push_back(input_var_rename);
auto input_var_tensor = input_var->Get<LoDTensor>();
auto *rename_input_var_tensor =
current_scope.Var(input_var_rename)->GetMutable<LoDTensor>();
framework::TensorCopy(input_var_tensor, dev_place,
rename_input_var_tensor);
rename_input_var_tensor->set_lod(input_var_tensor.lod());
}
}
}
executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
true);
for (auto &var_rename : rename_vars) {
std::string input_var_name =
var_rename.substr(0, var_rename.size() - strlen(kSuffix));
current_scope.Rename(var_rename, input_var_name);
}
cond_data =
GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>());
}
......@@ -312,6 +354,10 @@ class WhileGradOp : public framework::OperatorBase {
// continue;
// }
auto var_iter =
std::find(outside_og_names.begin(), outside_og_names.end(),
pg_ig_names[param_id]);
// zero gradient variable in step 0
if (cur_scope_iter == step_scopes->rbegin()) {
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
......@@ -326,7 +372,8 @@ class WhileGradOp : public framework::OperatorBase {
"or LoDTensor, but the received var[%s] is %s.",
inside_grad_name, framework::ToTypeName(var->Type())));
if (var->IsType<LoDTensor>()) {
if ((var_iter == outside_og_names.end()) &&
var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs;
attrs["dtype"] = inside_tensor.type();
......@@ -343,13 +390,18 @@ class WhileGradOp : public framework::OperatorBase {
->set_lod(inside_tensor.lod());
}
}
auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
{{"Out", {pg_ig_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
auto var_outside = scope.FindVar(pg_ig_names[param_id]);
if ((var_iter == outside_og_names.end()) ||
((var_iter != outside_og_names.end()) &&
var_outside->IsType<framework::LoDTensorArray>())) {
auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
{{"Out", {pg_ig_names[param_id]}}},
framework::AttributeMap{{"use_mkldnn", {false}}});
sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name);
}
}
dev_ctx.Wait();
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
......
......@@ -232,5 +232,16 @@ bool GetCondData(const framework::LoDTensor &cond) {
return cpu_cond->data<bool>()[0];
}
bool StrInVaraiableNameMap(const std::string &name,
const framework::VariableNameMap &var_names) {
for (auto &ipt : var_names) {
if (std::find(ipt.second.begin(), ipt.second.end(), name) !=
ipt.second.end()) {
return true;
}
}
return false;
}
} // namespace operators
} // namespace paddle
......@@ -38,6 +38,7 @@ static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
static constexpr char kSuffix[] = "@TMP_COPY";
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program, int block_id,
......@@ -50,5 +51,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bool GetCondData(const framework::LoDTensor &cond);
bool StrInVaraiableNameMap(const std::string &,
const framework::VariableNameMap &);
} // namespace operators
} // namespace paddle
......@@ -16,6 +16,7 @@ from __future__ import print_function
import numpy as np
import unittest
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import paddle.fluid.layers as layers
......@@ -24,6 +25,8 @@ from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard
from paddle.fluid.backward import append_backward
paddle.enable_static()
class TestApiWhileLoop(unittest.TestCase):
def test_var_tuple(self):
......@@ -199,16 +202,10 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
def cond(i, x):
return layers.less_than(i, eleven)
def body(j, x):
# TODO: In while block, if the var created in parent block
# participates in the calculation of gradient, the result of gradient
# is incorrect because each step scope always returns the same value
# generated by last step.
# Here we call `assign` op in while block to avoid this bug, and working on fixing it in next PR.
i = layers.assign(j)
def body(i, x):
x = layers.elementwise_mul(x=i, y=i)
j = layers.increment(j)
return [j, x]
i = layers.increment(i)
return [i, x]
main_program = Program()
startup_program = Program()
......@@ -244,10 +241,10 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
def test_while_loop_backward2(self):
def cond(i, x):
return i < 5
return i < 3
def body(i, x):
x = x + i
x = x * i
i = i + 1
return [i, x]
......@@ -269,17 +266,21 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32')
data = np.asarray([11]).astype('float32')
i_grad = np.asarray([1]).astype('float32')
data = np.asarray([2]).astype('float32')
i_grad = np.asarray([3]).astype('float32')
x_grad = np.asarray([2]).astype('float32')
res = exe.run(main_program,
feed={'i': feed_i,
'x': feed_x},
fetch_list=[mean.name, i.grad_name])
fetch_list=[mean.name, i.grad_name, x.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
self.assertTrue(
np.allclose(np.asarray(res[2]), x_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[2], x_grad))
class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
......
......@@ -24,6 +24,8 @@ from paddle.fluid.backward import append_backward
import numpy
from paddle.fluid import compiler, Program, program_guard
paddle.enable_static()
class TestWhileOp(unittest.TestCase):
def simple_net(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册