提交 aea5ccca 编写于 作者: Y Yang Yang

revise typo

上级 973aec2c
...@@ -22,7 +22,7 @@ constexpr char kInputs[] = "inputs"; ...@@ -22,7 +22,7 @@ constexpr char kInputs[] = "inputs";
constexpr char kInitialStates[] = "initial_states"; constexpr char kInitialStates[] = "initial_states";
constexpr char kParameters[] = "parameters"; constexpr char kParameters[] = "parameters";
constexpr char kOutputs[] = "outputs"; constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "step_scopes"; constexpr char kStepScopes[] = "step_scopes";
constexpr char kExStates[] = "ex_states"; constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states"; constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "sub_block"; constexpr char kStepBlock[] = "sub_block";
...@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase { ...@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock); auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
...@@ -295,7 +295,7 @@ class RecurrentOp : public RecurrentBase { ...@@ -295,7 +295,7 @@ class RecurrentOp : public RecurrentBase {
private: private:
StepScopes CreateStepScopes(const framework::Scope &scope, StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Output(kParallelScopes)); auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(), return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len); Attr<bool>(kIsTrain), seq_len);
...@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock); auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
...@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase {
private: private:
StepScopes CreateStepScopes(const framework::Scope &scope, StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Input(kParallelScopes)); auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(), return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/); Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
...@@ -510,7 +510,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -510,7 +510,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(kOutputs, AddOutput(kOutputs,
"The output sequence of RNN. The sequence length must be same.") "The output sequence of RNN. The sequence length must be same.")
.AsDuplicable(); .AsDuplicable();
AddOutput(kParallelScopes, AddOutput(kStepScopes,
"StepScopes contain all local variables in each time step."); "StepScopes contain all local variables in each time step.");
AddAttr<std::vector<std::string>>(kExStates, AddAttr<std::vector<std::string>>(kExStates,
string::Sprintf( string::Sprintf(
...@@ -523,7 +523,7 @@ The ex-state means the state value in the ex-timestep or the previous time step ...@@ -523,7 +523,7 @@ The ex-state means the state value in the ex-timestep or the previous time step
string::Sprintf( string::Sprintf(
"The state variable names. [%s, %s, %s] must be the same order", "The state variable names. [%s, %s, %s] must be the same order",
kExStates, kStates, kInitStateGrads)); kExStates, kStates, kInitStateGrads));
AddAttr<framework::BlockDescBind *>(kParallelBlock, AddAttr<framework::BlockDescBind *>(kStepBlock,
"The step block inside RNN"); "The step block inside RNN");
AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not. AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not.
By default reverse=False By default reverse=False
...@@ -576,7 +576,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -576,7 +576,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
} }
for (auto &output_param : this->OutputNames()) { for (auto &output_param : this->OutputNames()) {
if (output_param == kParallelScopes) { if (output_param == kStepScopes) {
grad->SetInput(output_param, this->Output(output_param)); grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param), grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param)); this->Output(output_param));
...@@ -587,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -587,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
} }
} }
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]); grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad); return std::unique_ptr<framework::OpDescBind>(grad);
} }
......
...@@ -27,7 +27,7 @@ using LoDTensor = framework::LoDTensor; ...@@ -27,7 +27,7 @@ using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "sub_block"; constexpr char kStepBlock[] = "sub_block";
constexpr char kCondition[] = "Condition"; constexpr char kCondition[] = "Condition";
constexpr char kParallelScopes[] = "StepScopes"; constexpr char kStepScopes[] = "StepScopes";
constexpr char kParameters[] = "X"; constexpr char kParameters[] = "X";
constexpr char kParamGrads[] = "X@GRAD"; constexpr char kParamGrads[] = "X@GRAD";
constexpr char kOutputs[] = "Out"; constexpr char kOutputs[] = "Out";
...@@ -46,11 +46,11 @@ class WhileOp : public framework::OperatorBase { ...@@ -46,11 +46,11 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock); auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto step_scopes = auto step_scopes =
scope.FindVar(Output(kParallelScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
...@@ -78,11 +78,11 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,11 +78,11 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"A set of variables, which will be assigned with values " "A set of variables, which will be assigned with values "
"generated by the operators inside the block of While Op.") "generated by the operators inside the block of While Op.")
.AsDuplicable(); .AsDuplicable();
AddOutput(kParallelScopes, AddOutput(kStepScopes,
"(StepScopeVar) A vector of local scope, which size equals the " "(StepScopeVar) A vector of local scope, which size equals the "
"step number of While Op. The i'th scope storages temporary " "step number of While Op. The i'th scope storages temporary "
"variables generated in the i'th step."); "variables generated in the i'th step.");
AddAttr<framework::BlockDescBind *>(kParallelBlock, AddAttr<framework::BlockDescBind *>(kStepBlock,
"The step block inside WhileOp"); "The step block inside WhileOp");
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
...@@ -99,11 +99,11 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -99,11 +99,11 @@ class WhileGradOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock); auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto *step_scopes = auto *step_scopes =
scope.FindVar(Input(kParallelScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
auto outside_og_names = Inputs(framework::GradVarName(kOutputs)); auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
auto inside_og_names = auto inside_og_names =
...@@ -272,9 +272,9 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -272,9 +272,9 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::copy(extra_inputs.begin(), extra_inputs.end(), std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin()); extra_inputs_list.begin());
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
grad->SetInput(kParallelScopes, Output(kParallelScopes)); grad->SetInput(kStepScopes, Output(kStepScopes));
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]); grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
// record the original output gradient names, since the gradient name of // record the original output gradient names, since the gradient name of
// while operator could be renamed. // while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list); grad->SetAttr("original_output_grad", extra_inputs_list);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册