未验证 提交 2e1bca99 编写于 作者: G guofei 提交者: GitHub

Refine the gradient calculation errors caused by renaming in while_grad (#27814)

test=develop
上级 8fa4c098
...@@ -12,6 +12,8 @@ ...@@ -12,6 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <set>
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
...@@ -70,6 +72,23 @@ class WhileOp : public framework::OperatorBase { ...@@ -70,6 +72,23 @@ class WhileOp : public framework::OperatorBase {
auto *block = Attr<framework::BlockDesc *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
bool is_test = Attr<bool>("is_test");
std::set<std::string> no_copy_var_names;
if (!is_test) {
const std::vector<framework::OpDesc *> &all_ops = block->AllOps();
for (const framework::OpDesc *op : all_ops) {
const framework::VariableNameMap &input_var_names = op->Inputs();
const framework::VariableNameMap &output_var_names = op->Outputs();
for (auto &ipt : input_var_names) {
for (const std::string &var_name : ipt.second) {
if (StrInVaraiableNameMap(var_name, output_var_names)) {
no_copy_var_names.insert(var_name);
}
}
}
}
}
auto step_scopes = auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
...@@ -89,7 +108,6 @@ class WhileOp : public framework::OperatorBase { ...@@ -89,7 +108,6 @@ class WhileOp : public framework::OperatorBase {
"The Output(StepScope) of WhileOp should be empty.")); "The Output(StepScope) of WhileOp should be empty."));
bool cond_data = GetCondData(cond); bool cond_data = GetCondData(cond);
bool is_test = Attr<bool>("is_test");
auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars); auto &skip_vars = Attr<std::vector<std::string>>(kSkipEagerDeletionVars);
VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars); VLOG(2) << GetSkipEagerDeletionVarsDebugString(skip_vars);
...@@ -98,8 +116,32 @@ class WhileOp : public framework::OperatorBase { ...@@ -98,8 +116,32 @@ class WhileOp : public framework::OperatorBase {
while (cond_data) { while (cond_data) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
step_scopes->push_back(&current_scope); step_scopes->push_back(&current_scope);
std::vector<std::string> rename_vars;
for (const std::string &input_var_name : Inputs(kX)) {
if (no_copy_var_names.find(input_var_name) ==
no_copy_var_names.end()) {
std::string input_var_rename = input_var_name + kSuffix;
framework::Variable *input_var = scope.FindVar(input_var_name);
if (input_var->IsType<framework::LoDTensor>()) {
rename_vars.push_back(input_var_rename);
auto input_var_tensor = input_var->Get<LoDTensor>();
auto *rename_input_var_tensor =
current_scope.Var(input_var_rename)->GetMutable<LoDTensor>();
framework::TensorCopy(input_var_tensor, dev_place,
rename_input_var_tensor);
rename_input_var_tensor->set_lod(input_var_tensor.lod());
}
}
}
executor.RunPreparedContext(ctx.get(), &current_scope, false, true, executor.RunPreparedContext(ctx.get(), &current_scope, false, true,
true); true);
for (auto &var_rename : rename_vars) {
std::string input_var_name =
var_rename.substr(0, var_rename.size() - strlen(kSuffix));
current_scope.Rename(var_rename, input_var_name);
}
cond_data = cond_data =
GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>()); GetCondData(scope.FindVar(Input(kCondition))->Get<LoDTensor>());
} }
...@@ -312,6 +354,10 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -312,6 +354,10 @@ class WhileGradOp : public framework::OperatorBase {
// continue; // continue;
// } // }
auto var_iter =
std::find(outside_og_names.begin(), outside_og_names.end(),
pg_ig_names[param_id]);
// zero gradient variable in step 0 // zero gradient variable in step 0
if (cur_scope_iter == step_scopes->rbegin()) { if (cur_scope_iter == step_scopes->rbegin()) {
auto *var = (*cur_scope_iter)->FindVar(inside_grad_name); auto *var = (*cur_scope_iter)->FindVar(inside_grad_name);
...@@ -326,7 +372,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -326,7 +372,8 @@ class WhileGradOp : public framework::OperatorBase {
"or LoDTensor, but the received var[%s] is %s.", "or LoDTensor, but the received var[%s] is %s.",
inside_grad_name, framework::ToTypeName(var->Type()))); inside_grad_name, framework::ToTypeName(var->Type())));
if (var->IsType<LoDTensor>()) { if ((var_iter == outside_og_names.end()) &&
var->IsType<LoDTensor>()) {
auto &inside_tensor = var->Get<framework::LoDTensor>(); auto &inside_tensor = var->Get<framework::LoDTensor>();
framework::AttributeMap attrs; framework::AttributeMap attrs;
attrs["dtype"] = inside_tensor.type(); attrs["dtype"] = inside_tensor.type();
...@@ -343,6 +390,10 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -343,6 +390,10 @@ class WhileGradOp : public framework::OperatorBase {
->set_lod(inside_tensor.lod()); ->set_lod(inside_tensor.lod());
} }
} }
auto var_outside = scope.FindVar(pg_ig_names[param_id]);
if ((var_iter == outside_og_names.end()) ||
((var_iter != outside_og_names.end()) &&
var_outside->IsType<framework::LoDTensorArray>())) {
auto new_inside_name = cur_scope.Rename(inside_grad_name); auto new_inside_name = cur_scope.Rename(inside_grad_name);
auto sum_op = framework::OpRegistry::CreateOp( auto sum_op = framework::OpRegistry::CreateOp(
"sum", {{"X", {pg_ig_names[param_id], new_inside_name}}}, "sum", {{"X", {pg_ig_names[param_id], new_inside_name}}},
...@@ -351,6 +402,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -351,6 +402,7 @@ class WhileGradOp : public framework::OperatorBase {
sum_op->Run(cur_scope, dev_place); sum_op->Run(cur_scope, dev_place);
cur_scope.Rename(new_inside_name, inside_grad_name); cur_scope.Rename(new_inside_name, inside_grad_name);
} }
}
dev_ctx.Wait(); dev_ctx.Wait();
const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope); const_cast<framework::Scope &>(scope).DeleteScope(&cur_scope);
} }
......
...@@ -232,5 +232,16 @@ bool GetCondData(const framework::LoDTensor &cond) { ...@@ -232,5 +232,16 @@ bool GetCondData(const framework::LoDTensor &cond) {
return cpu_cond->data<bool>()[0]; return cpu_cond->data<bool>()[0];
} }
bool StrInVaraiableNameMap(const std::string &name,
const framework::VariableNameMap &var_names) {
for (auto &ipt : var_names) {
if (std::find(ipt.second.begin(), ipt.second.end(), name) !=
ipt.second.end()) {
return true;
}
}
return false;
}
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -38,6 +38,7 @@ static constexpr char kX[] = "X"; ...@@ -38,6 +38,7 @@ static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD"; static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out"; static constexpr char kOutputs[] = "Out";
static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars"; static constexpr char kSkipEagerDeletionVars[] = "skip_eager_deletion_vars";
static constexpr char kSuffix[] = "@TMP_COPY";
void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
const framework::ProgramDesc &program, int block_id, const framework::ProgramDesc &program, int block_id,
...@@ -50,5 +51,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp( ...@@ -50,5 +51,8 @@ void PrepareSafeEagerDeletionOnWhileOpAndWhileGradOp(
bool GetCondData(const framework::LoDTensor &cond); bool GetCondData(const framework::LoDTensor &cond);
bool StrInVaraiableNameMap(const std::string &,
const framework::VariableNameMap &);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -16,6 +16,7 @@ from __future__ import print_function ...@@ -16,6 +16,7 @@ from __future__ import print_function
import numpy as np import numpy as np
import unittest import unittest
import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid.layers as layers import paddle.fluid.layers as layers
...@@ -24,6 +25,8 @@ from paddle.fluid.executor import Executor ...@@ -24,6 +25,8 @@ from paddle.fluid.executor import Executor
from paddle.fluid.framework import Program, program_guard from paddle.fluid.framework import Program, program_guard
from paddle.fluid.backward import append_backward from paddle.fluid.backward import append_backward
paddle.enable_static()
class TestApiWhileLoop(unittest.TestCase): class TestApiWhileLoop(unittest.TestCase):
def test_var_tuple(self): def test_var_tuple(self):
...@@ -199,16 +202,10 @@ class TestApiWhileLoop_Backward(unittest.TestCase): ...@@ -199,16 +202,10 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
def cond(i, x): def cond(i, x):
return layers.less_than(i, eleven) return layers.less_than(i, eleven)
def body(j, x): def body(i, x):
# TODO: In while block, if the var created in parent block
# participates in the calculation of gradient, the result of gradient
# is incorrect because each step scope always returns the same value
# generated by last step.
# Here we call `assign` op in while block to avoid this bug, and working on fixing it in next PR.
i = layers.assign(j)
x = layers.elementwise_mul(x=i, y=i) x = layers.elementwise_mul(x=i, y=i)
j = layers.increment(j) i = layers.increment(i)
return [j, x] return [i, x]
main_program = Program() main_program = Program()
startup_program = Program() startup_program = Program()
...@@ -244,10 +241,10 @@ class TestApiWhileLoop_Backward(unittest.TestCase): ...@@ -244,10 +241,10 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
def test_while_loop_backward2(self): def test_while_loop_backward2(self):
def cond(i, x): def cond(i, x):
return i < 5 return i < 3
def body(i, x): def body(i, x):
x = x + i x = x * i
i = i + 1 i = i + 1
return [i, x] return [i, x]
...@@ -269,17 +266,21 @@ class TestApiWhileLoop_Backward(unittest.TestCase): ...@@ -269,17 +266,21 @@ class TestApiWhileLoop_Backward(unittest.TestCase):
feed_i = np.ones(1).astype('float32') feed_i = np.ones(1).astype('float32')
feed_x = np.ones(1).astype('float32') feed_x = np.ones(1).astype('float32')
data = np.asarray([11]).astype('float32') data = np.asarray([2]).astype('float32')
i_grad = np.asarray([1]).astype('float32') i_grad = np.asarray([3]).astype('float32')
x_grad = np.asarray([2]).astype('float32')
res = exe.run(main_program, res = exe.run(main_program,
feed={'i': feed_i, feed={'i': feed_i,
'x': feed_x}, 'x': feed_x},
fetch_list=[mean.name, i.grad_name]) fetch_list=[mean.name, i.grad_name, x.grad_name])
self.assertTrue(np.allclose(np.asarray(res[0]), data)) self.assertTrue(np.allclose(np.asarray(res[0]), data))
self.assertTrue( self.assertTrue(
np.allclose(np.asarray(res[1]), i_grad), np.allclose(np.asarray(res[1]), i_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad)) msg=" \nres = \n{} \n\n ans = \n{}".format(res[1], i_grad))
self.assertTrue(
np.allclose(np.asarray(res[2]), x_grad),
msg=" \nres = \n{} \n\n ans = \n{}".format(res[2], x_grad))
class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase): class TestApiWhileLoop_NestedWithBackwardAndLoDTensorArray(unittest.TestCase):
......
...@@ -24,6 +24,8 @@ from paddle.fluid.backward import append_backward ...@@ -24,6 +24,8 @@ from paddle.fluid.backward import append_backward
import numpy import numpy
from paddle.fluid import compiler, Program, program_guard from paddle.fluid import compiler, Program, program_guard
paddle.enable_static()
class TestWhileOp(unittest.TestCase): class TestWhileOp(unittest.TestCase):
def simple_net(self): def simple_net(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册