提交 aea5ccca 编写于 作者: Y Yang Yang

revise typo

上级 973aec2c
......@@ -22,7 +22,7 @@ constexpr char kInputs[] = "inputs";
constexpr char kInitialStates[] = "initial_states";
constexpr char kParameters[] = "parameters";
constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "step_scopes";
constexpr char kStepScopes[] = "step_scopes";
constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "sub_block";
......@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) {
......@@ -295,7 +295,7 @@ class RecurrentOp : public RecurrentBase {
private:
StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Output(kParallelScopes));
auto *var = scope.FindVar(Output(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len);
......@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program();
for (size_t step_id = 0; step_id < seq_len; ++step_id) {
......@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase {
private:
StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const {
auto *var = scope.FindVar(Input(kParallelScopes));
auto *var = scope.FindVar(Input(kStepScopes));
PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
......@@ -510,7 +510,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(kOutputs,
"The output sequence of RNN. The sequence length must be same.")
.AsDuplicable();
AddOutput(kParallelScopes,
AddOutput(kStepScopes,
"StepScopes contain all local variables in each time step.");
AddAttr<std::vector<std::string>>(kExStates,
string::Sprintf(
......@@ -523,7 +523,7 @@ The ex-state means the state value in the ex-timestep or the previous time step
string::Sprintf(
"The state variable names. [%s, %s, %s] must be the same order",
kExStates, kStates, kInitStateGrads));
AddAttr<framework::BlockDescBind *>(kParallelBlock,
AddAttr<framework::BlockDescBind *>(kStepBlock,
"The step block inside RNN");
AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not.
By default reverse=False
......@@ -576,7 +576,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
}
for (auto &output_param : this->OutputNames()) {
if (output_param == kParallelScopes) {
if (output_param == kStepScopes) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param));
......@@ -587,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
}
}
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]);
grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad);
}
......
......@@ -27,7 +27,7 @@ using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "sub_block";
constexpr char kCondition[] = "Condition";
constexpr char kParallelScopes[] = "StepScopes";
constexpr char kStepScopes[] = "StepScopes";
constexpr char kParameters[] = "X";
constexpr char kParamGrads[] = "X@GRAD";
constexpr char kOutputs[] = "Out";
......@@ -46,11 +46,11 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program();
auto step_scopes =
scope.FindVar(Output(kParallelScopes))->GetMutable<StepScopeVar>();
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>();
while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope();
......@@ -78,11 +78,11 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"A set of variables, which will be assigned with values "
"generated by the operators inside the block of While Op.")
.AsDuplicable();
AddOutput(kParallelScopes,
AddOutput(kStepScopes,
"(StepScopeVar) A vector of local scope, which size equals the "
"step number of While Op. The i'th scope storages temporary "
"variables generated in the i'th step.");
AddAttr<framework::BlockDescBind *>(kParallelBlock,
AddAttr<framework::BlockDescBind *>(kStepBlock,
"The step block inside WhileOp");
AddComment(R"DOC(
)DOC");
......@@ -99,11 +99,11 @@ class WhileGradOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock);
auto *program = block->Program();
auto *step_scopes =
scope.FindVar(Input(kParallelScopes))->GetMutable<StepScopeVar>();
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>();
auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
auto inside_og_names =
......@@ -272,9 +272,9 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin());
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
grad->SetInput(kParallelScopes, Output(kParallelScopes));
grad->SetInput(kStepScopes, Output(kStepScopes));
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]);
grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
// record the original output gradient names, since the gradient name of
// while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册