提交 dafd449c 编写于 作者: F fengjiayi

Unify `step_block` and `block` to `sub_block`

上级 7ab48aec
...@@ -430,14 +430,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -430,14 +430,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
std::vector<std::unique_ptr<OpDescBind>> op_grads; std::vector<std::unique_ptr<OpDescBind>> op_grads;
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") { if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
int step_block_idx = (*it)->GetBlockAttr("step_block"); int step_block_idx = (*it)->GetBlockAttr("sub_block");
BlockDescBind* backward_block = CreateStepBlock( BlockDescBind* backward_block = CreateStepBlock(
program_desc, no_grad_vars, grad_to_var, step_block_idx); program_desc, no_grad_vars, grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") { } else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block = BlockDescBind* backward_block =
CreateStepBlock(program_desc, no_grad_vars, grad_to_var, CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
(*it)->GetBlockAttr("block")); (*it)->GetBlockAttr("sub_block"));
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else { } else {
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var);
......
...@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->front() = &scope.NewScope(); scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front(); auto &cur_scope = *scopes->front();
auto *block = Attr<framework::BlockDescBind *>("block"); auto *block = Attr<framework::BlockDescBind *>("sub_block");
framework::Executor exec(dev_ctx); framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
} }
...@@ -88,7 +88,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -88,7 +88,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"unify the conditional block, rnn and while op, the type of " "unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>"); "scope is std::vector<Scope*>");
AddAttr<framework::BlockDescBind *>( AddAttr<framework::BlockDescBind *>(
"block", "The step block of conditional block operator"); "sub_block", "The step block of conditional block operator");
AddComment(R"DOC(Conditional block operator AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the Run the sub-block if X is not empty. Params is the other inputs and Out is the
...@@ -117,7 +117,7 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -117,7 +117,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>(); auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *scopes[0];
auto *block = Attr<framework::BlockDescBind *>("block"); auto *block = Attr<framework::BlockDescBind *>("sub_block");
framework::Executor exec(dev_ctx); framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
...@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { ...@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetInput("Scope", Output("Scope")); grad_op->SetInput("Scope", Output("Scope"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params")); grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
grad_op->SetBlockAttr("block", *this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDescBind>(grad_op);
} }
}; };
......
...@@ -25,7 +25,7 @@ constexpr char kOutputs[] = "outputs"; ...@@ -25,7 +25,7 @@ constexpr char kOutputs[] = "outputs";
constexpr char kStepScopes[] = "step_scopes"; constexpr char kStepScopes[] = "step_scopes";
constexpr char kExStates[] = "ex_states"; constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states"; constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "step_block"; constexpr char kStepBlock[] = "sub_block";
constexpr char kReverse[] = "reverse"; constexpr char kReverse[] = "reverse";
constexpr char kIsTrain[] = "is_train"; constexpr char kIsTrain[] = "is_train";
#define GRAD_SUFFIX "@GRAD" #define GRAD_SUFFIX "@GRAD"
......
...@@ -25,7 +25,7 @@ namespace operators { ...@@ -25,7 +25,7 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "step_block"; constexpr char kStepBlock[] = "sub_block";
constexpr char kCondition[] = "Condition"; constexpr char kCondition[] = "Condition";
constexpr char kStepScopes[] = "StepScopes"; constexpr char kStepScopes[] = "StepScopes";
constexpr char kParameters[] = "X"; constexpr char kParameters[] = "X";
......
...@@ -1130,7 +1130,7 @@ class StaticRNN(object): ...@@ -1130,7 +1130,7 @@ class StaticRNN(object):
attrs={ attrs={
'ex_states': pre_memories, 'ex_states': pre_memories,
'states': memories, 'states': memories,
'step_block': rnn_block 'sub_block': rnn_block
}) })
...@@ -1207,7 +1207,7 @@ class While(object): ...@@ -1207,7 +1207,7 @@ class While(object):
}, },
outputs={'Out': out_vars, outputs={'Out': out_vars,
'StepScopes': [step_scope]}, 'StepScopes': [step_scope]},
attrs={'step_block': while_block}) attrs={'sub_block': while_block})
def lstm(x, def lstm(x,
...@@ -1671,7 +1671,7 @@ class ConditionalBlock(object): ...@@ -1671,7 +1671,7 @@ class ConditionalBlock(object):
}, },
outputs={'Out': out_list, outputs={'Out': out_list,
'Scope': [step_scope]}, 'Scope': [step_scope]},
attrs={'block': inside_block}) attrs={'sub_block': inside_block})
class IfElseBlockGuard(object): class IfElseBlockGuard(object):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册