提交 0c5202cb 编写于 作者: Y Yang Yu

Tiny enhance of while_op

上级 90a33ddd
...@@ -25,12 +25,12 @@ namespace operators { ...@@ -25,12 +25,12 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>; using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor; using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "sub_block"; static constexpr char kStepBlock[] = "sub_block";
constexpr char kCondition[] = "Condition"; static constexpr char kCondition[] = "Condition";
constexpr char kStepScopes[] = "StepScopes"; static constexpr char kStepScopes[] = "StepScopes";
constexpr char kParameters[] = "X"; static constexpr char kX[] = "X";
constexpr char kParamGrads[] = "X@GRAD"; static constexpr char kXGRAD[] = "X@GRAD";
constexpr char kOutputs[] = "Out"; static constexpr char kOutputs[] = "Out";
class WhileOp : public framework::OperatorBase { class WhileOp : public framework::OperatorBase {
public: public:
...@@ -67,7 +67,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -67,7 +67,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker) WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParameters, AddInput(kX,
"A set of variables, which are required by operators inside the " "A set of variables, which are required by operators inside the "
"block of While Op.") "block of While Op.")
.AsDuplicable(); .AsDuplicable();
...@@ -158,8 +158,8 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -158,8 +158,8 @@ class WhileGradOp : public framework::OperatorBase {
executor.Run(*program, *cur_scope_iter, block->ID(), false); executor.Run(*program, *cur_scope_iter, block->ID(), false);
auto &pg_names = Outputs(kParamGrads); auto &pg_names = Outputs(kXGRAD);
auto &p_names = Inputs(kParameters); auto &p_names = Inputs(kX);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size()); PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) { for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
if (pg_names[param_id] == framework::kEmptyVarName) { if (pg_names[param_id] == framework::kEmptyVarName) {
...@@ -213,11 +213,11 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -213,11 +213,11 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad = new framework::OpDesc(); auto *grad = new framework::OpDesc();
grad->SetType("while_grad"); grad->SetType("while_grad");
grad->SetInput(kParameters, Input(kParameters)); grad->SetInput(kX, Input(kX));
// Not all of IGs will be generated by inner gradient operators of while op. // Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block. // Ignore IGs that is not generated by the inside block.
auto igs = InputGrad(kParameters, /*do not drop empty gradient*/ false); auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
std::unordered_set<std::string> all_outs; std::unordered_set<std::string> all_outs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) { for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) { for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
...@@ -231,7 +231,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -231,7 +231,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
} }
} }
grad->SetOutput(framework::GradVarName(kParameters), igs); grad->SetOutput(framework::GradVarName(kX), igs);
grad->SetInput(kOutputs, Output(kOutputs)); grad->SetInput(kOutputs, Output(kOutputs));
...@@ -240,7 +240,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -240,7 +240,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::unordered_set<std::string> block_ins; std::unordered_set<std::string> block_ins;
auto *fwd_block = this->grad_block_[0]->ParentBlock(); auto *fwd_block = this->grad_block_[0]->ParentBlock();
{ {
for (auto &p : Input(kParameters)) { for (auto &p : Input(kX)) {
block_ins.insert(p); block_ins.insert(p);
} }
for (auto &o : Output(kOutputs)) { for (auto &o : Output(kOutputs)) {
...@@ -288,8 +288,8 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference { ...@@ -288,8 +288,8 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kParameters); auto p_names = op_desc.Input(kX);
auto pg_names = op_desc.Output(framework::GradVarName(kParameters)); auto pg_names = op_desc.Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) { for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i])); auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
...@@ -307,21 +307,21 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference { ...@@ -307,21 +307,21 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
class WhileGradOpShapeInference : public framework::InferShapeBase { class WhileGradOpShapeInference : public framework::InferShapeBase {
public: public:
void operator()(framework::InferShapeContext *ctx) const override { void operator()(framework::InferShapeContext *ctx) const override {
ctx->HasInputs(kParameters); ctx->HasInputs(kX);
ctx->HasOutputs(framework::GradVarName(kParameters)); ctx->HasOutputs(framework::GradVarName(kX));
ctx->HasInputs(kOutputs); ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs)); ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kParameters); auto p_names = ctx->Inputs(kX);
auto pg_names = ctx->Outputs(kParamGrads); auto pg_names = ctx->Outputs(kXGRAD);
auto var_types = ctx->GetInputsVarType(kParameters); auto var_types = ctx->GetInputsVarType(kX);
std::vector<std::string> names_to_set; std::vector<std::string> names_to_set;
std::vector<framework::DDim> dims_to_set; std::vector<framework::DDim> dims_to_set;
for (size_t i = 0; i < p_names.size(); ++i) { for (size_t i = 0; i < p_names.size(); ++i) {
if (pg_names[i] == framework::kEmptyVarName) { if (pg_names[i] == framework::kEmptyVarName) {
continue; continue;
} }
auto dims = ctx->GetInputsElementDim(kParameters, i); auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) { if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]); names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims); dims_to_set.push_back(dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册