提交 9d2c77e6 编写于 作者: Y Yang Yang

parallel_do skeleton pass compile

上级 61ec0b95
...@@ -185,6 +185,7 @@ set(DEPS_OPS ...@@ -185,6 +185,7 @@ set(DEPS_OPS
cond_op cond_op
cross_entropy_op cross_entropy_op
recurrent_op recurrent_op
parallel_do_op
softmax_with_cross_entropy_op softmax_with_cross_entropy_op
softmax_op softmax_op
sequence_softmax_op sequence_softmax_op
...@@ -256,6 +257,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute) ...@@ -256,6 +257,7 @@ op_library(lstm_op DEPS sequence2batch lstm_compute)
op_library(conv_transpose_op DEPS vol2col) op_library(conv_transpose_op DEPS vol2col)
op_library(gru_op DEPS sequence2batch gru_compute) op_library(gru_op DEPS sequence2batch gru_compute)
op_library(recurrent_op SRCS recurrent_op.cc DEPS executor) op_library(recurrent_op SRCS recurrent_op.cc DEPS executor)
op_library(parallel_do_op SRCS parallel_do_op.cc DEPS executor)
# FIXME(typhoonzero): save/load depends lodtensor serialization functions # FIXME(typhoonzero): save/load depends lodtensor serialization functions
op_library(save_op DEPS lod_tensor) op_library(save_op DEPS lod_tensor)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <vector>
#include "paddle/framework/executor.h"
#include "paddle/framework/op_registry.h"
#include "paddle/framework/operator.h"
namespace paddle {
namespace operators {
constexpr char kInputs[] = "inputs";
constexpr char kParameters[] = "parameters";
constexpr char kPlaces[] = "places";
constexpr char kParallelBlock[] = "parallel_block";
constexpr char kOutputs[] = "outputs";
constexpr char kParallelScopes[] = "parallel_scopes";
// #define GRAD_SUFFIX "@GRAD"
// constexpr char kInputGrads[] = "inputs" GRAD_SUFFIX;
// constexpr char kOutputGrads[] = "outputs" GRAD_SUFFIX;
// constexpr char kParamGrads[] = "parameters" GRAD_SUFFIX;
using ParallelScopeVar = std::vector<framework::Scope *>;
using OperatorBase = framework::OperatorBase;
class ParallelDoOp : public OperatorBase {
public:
ParallelDoOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
// create scope
// copy parameters
}
};
class ParallelDoGradOp : public OperatorBase {
public:
ParallelDoGradOp(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {}
};
class ParallelDoOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
ParallelDoOpProtoMaker(framework::OpProto *proto,
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kInputs, "").AsDuplicable();
AddInput(kParameters, "").AsDuplicable();
AddInput(kPlaces, "");
AddOutput(kOutputs, "").AsDuplicable();
AddOutput(kParallelScopes, "");
AddAttr<framework::BlockDescBind *>(kParallelBlock, "");
AddComment(R"DOC(
ParallelDo Operator.
)DOC");
}
};
class ParallelDoGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
virtual std::unique_ptr<framework::OpDescBind> Apply() const {
PADDLE_THROW("Not Implemented");
auto *grad = new framework::OpDescBind();
grad->SetType("recurrent_grad");
for (auto &input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param));
grad->SetOutput(framework::GradVarName(input_param),
this->InputGrad(input_param));
}
for (auto &output_param : this->OutputNames()) {
if (output_param == kParallelScopes) {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param));
} else {
grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param),
this->OutputGrad(output_param));
}
}
grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kParallelBlock, *grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad);
}
};
class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
PADDLE_THROW("Not Implemented");
// std::vector<std::string> input{kInputs};
// std::vector<std::string> output{kOutputs};
// for (auto &s : input) {
// PADDLE_ENFORCE(ctx->HasInputs(s));
// PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(s)),
// "Cannot find the gradient variable %s",
// framework::GradVarName(s));
// }
// for (auto &s : output) {
// PADDLE_ENFORCE(ctx->HasInputs(s));
// }
// for (auto &s : input) {
// ctx->SetOutputsDim(framework::GradVarName(s), ctx->GetInputsDim(s));
// }
// if (ctx->HasInputs(kParameters)) {
// PADDLE_ENFORCE(ctx->HasOutputs(framework::GradVarName(kParameters)));
// ctx->SetOutputsDim(framework::GradVarName(kParameters),
// ctx->GetInputsDim(kParameters));
// }
}
};
} // namespace operators
} // namespace paddle
REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp,
paddle::operators::ParallelDoOpProtoMaker,
paddle::operators::ParallelDoGradOpDescMaker);
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp,
paddle::operators::ParallelDoGradOpShapeInference);
...@@ -22,10 +22,10 @@ constexpr char kInputs[] = "inputs"; ...@@ -22,10 +22,10 @@ constexpr char kInputs[] = "inputs";
constexpr char kInitialStates[] = "initial_states"; constexpr char kInitialStates[] = "initial_states";
constexpr char kParameters[] = "parameters"; constexpr char kParameters[] = "parameters";
constexpr char kOutputs[] = "outputs"; constexpr char kOutputs[] = "outputs";
constexpr char kStepScopes[] = "step_scopes"; constexpr char kParallelScopes[] = "step_scopes";
constexpr char kExStates[] = "ex_states"; constexpr char kExStates[] = "ex_states";
constexpr char kStates[] = "states"; constexpr char kStates[] = "states";
constexpr char kStepBlock[] = "step_block"; constexpr char kParallelBlock[] = "step_block";
constexpr char kReverse[] = "reverse"; constexpr char kReverse[] = "reverse";
constexpr char kIsTrain[] = "is_train"; constexpr char kIsTrain[] = "is_train";
#define GRAD_SUFFIX "@GRAD" #define GRAD_SUFFIX "@GRAD"
...@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase { ...@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
...@@ -295,7 +295,7 @@ class RecurrentOp : public RecurrentBase { ...@@ -295,7 +295,7 @@ class RecurrentOp : public RecurrentBase {
private: private:
StepScopes CreateStepScopes(const framework::Scope &scope, StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Output(kStepScopes)); auto *var = scope.FindVar(Output(kParallelScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(), return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len); Attr<bool>(kIsTrain), seq_len);
...@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
...@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -465,7 +465,7 @@ class RecurrentGradOp : public RecurrentBase {
private: private:
StepScopes CreateStepScopes(const framework::Scope &scope, StepScopes CreateStepScopes(const framework::Scope &scope,
size_t seq_len) const { size_t seq_len) const {
auto *var = scope.FindVar(Input(kStepScopes)); auto *var = scope.FindVar(Input(kParallelScopes));
PADDLE_ENFORCE(var != nullptr); PADDLE_ENFORCE(var != nullptr);
return StepScopes(scope, var->GetMutable<StepScopeVar>(), return StepScopes(scope, var->GetMutable<StepScopeVar>(),
Attr<bool>(kIsTrain), seq_len, true /*is_backward*/); Attr<bool>(kIsTrain), seq_len, true /*is_backward*/);
...@@ -510,7 +510,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -510,7 +510,7 @@ class RecurrentOpProtoMaker : public framework::OpProtoAndCheckerMaker {
AddOutput(kOutputs, AddOutput(kOutputs,
"The output sequence of RNN. The sequence length must be same.") "The output sequence of RNN. The sequence length must be same.")
.AsDuplicable(); .AsDuplicable();
AddOutput(kStepScopes, AddOutput(kParallelScopes,
"StepScopes contain all local variables in each time step."); "StepScopes contain all local variables in each time step.");
AddAttr<std::vector<std::string>>(kExStates, AddAttr<std::vector<std::string>>(kExStates,
string::Sprintf( string::Sprintf(
...@@ -523,7 +523,7 @@ The ex-state means the state value in the ex-timestep or the previous time step ...@@ -523,7 +523,7 @@ The ex-state means the state value in the ex-timestep or the previous time step
string::Sprintf( string::Sprintf(
"The state variable names. [%s, %s, %s] must be the same order", "The state variable names. [%s, %s, %s] must be the same order",
kExStates, kStates, kInitStateGrads)); kExStates, kStates, kInitStateGrads));
AddAttr<framework::BlockDescBind *>(kStepBlock, AddAttr<framework::BlockDescBind *>(kParallelBlock,
"The step block inside RNN"); "The step block inside RNN");
AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not. AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not.
By default reverse=False By default reverse=False
...@@ -576,7 +576,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -576,7 +576,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
} }
for (auto &output_param : this->OutputNames()) { for (auto &output_param : this->OutputNames()) {
if (output_param == kStepScopes) { if (output_param == kParallelScopes) {
grad->SetInput(output_param, this->Output(output_param)); grad->SetInput(output_param, this->Output(output_param));
grad->SetInput(framework::GradVarName(output_param), grad->SetInput(framework::GradVarName(output_param),
this->Output(output_param)); this->Output(output_param));
...@@ -587,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -587,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
} }
} }
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]); grad->SetBlockAttr(kParallelBlock, *grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad); return std::unique_ptr<framework::OpDescBind>(grad);
} }
......
...@@ -25,9 +25,9 @@ namespace operators { ...@@ -25,9 +25,9 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "step_block"; constexpr char kParallelBlock[] = "step_block";
constexpr char kCondition[] = "Condition"; constexpr char kCondition[] = "Condition";
constexpr char kStepScopes[] = "StepScopes"; constexpr char kParallelScopes[] = "StepScopes";
constexpr char kParameters[] = "X"; constexpr char kParameters[] = "X";
constexpr char kParamGrads[] = "X@GRAD"; constexpr char kParamGrads[] = "X@GRAD";
constexpr char kOutputs[] = "Out"; constexpr char kOutputs[] = "Out";
...@@ -46,11 +46,11 @@ class WhileOp : public framework::OperatorBase { ...@@ -46,11 +46,11 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *program = block->Program(); auto *program = block->Program();
auto step_scopes = auto step_scopes =
scope.FindVar(Output(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Output(kParallelScopes))->GetMutable<StepScopeVar>();
while (cond.data<bool>()[0]) { while (cond.data<bool>()[0]) {
auto &current_scope = scope.NewScope(); auto &current_scope = scope.NewScope();
...@@ -78,11 +78,11 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -78,11 +78,11 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"A set of variables, which will be assigned with values " "A set of variables, which will be assigned with values "
"generated by the operators inside the block of While Op.") "generated by the operators inside the block of While Op.")
.AsDuplicable(); .AsDuplicable();
AddOutput(kStepScopes, AddOutput(kParallelScopes,
"(StepScopeVar) A vector of local scope, which size equals the " "(StepScopeVar) A vector of local scope, which size equals the "
"step number of While Op. The i'th scope storages temporary " "step number of While Op. The i'th scope storages temporary "
"variables generated in the i'th step."); "variables generated in the i'th step.");
AddAttr<framework::BlockDescBind *>(kStepBlock, AddAttr<framework::BlockDescBind *>(kParallelBlock,
"The step block inside WhileOp"); "The step block inside WhileOp");
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
...@@ -99,11 +99,11 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -99,11 +99,11 @@ class WhileGradOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDescBind *>(kParallelBlock);
auto *program = block->Program(); auto *program = block->Program();
auto *step_scopes = auto *step_scopes =
scope.FindVar(Input(kStepScopes))->GetMutable<StepScopeVar>(); scope.FindVar(Input(kParallelScopes))->GetMutable<StepScopeVar>();
auto outside_og_names = Inputs(framework::GradVarName(kOutputs)); auto outside_og_names = Inputs(framework::GradVarName(kOutputs));
auto inside_og_names = auto inside_og_names =
...@@ -272,9 +272,9 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -272,9 +272,9 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::copy(extra_inputs.begin(), extra_inputs.end(), std::copy(extra_inputs.begin(), extra_inputs.end(),
extra_inputs_list.begin()); extra_inputs_list.begin());
grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list); grad->SetInput(framework::GradVarName(kOutputs), extra_inputs_list);
grad->SetInput(kStepScopes, Output(kStepScopes)); grad->SetInput(kParallelScopes, Output(kParallelScopes));
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]); grad->SetBlockAttr(kParallelBlock, *grad_block_[0]);
// record the original output gradient names, since the gradient name of // record the original output gradient names, since the gradient name of
// while operator could be renamed. // while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list); grad->SetAttr("original_output_grad", extra_inputs_list);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册