未验证 提交 dd8ffe1e 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7131 from reyoung/feature/tiny_enhance_of_while_op

Tiny enhance of while_op
......@@ -25,12 +25,12 @@ namespace operators {
using StepScopeVar = std::vector<framework::Scope *>;
using LoDTensor = framework::LoDTensor;
constexpr char kStepBlock[] = "sub_block";
constexpr char kCondition[] = "Condition";
constexpr char kStepScopes[] = "StepScopes";
constexpr char kParameters[] = "X";
constexpr char kParamGrads[] = "X@GRAD";
constexpr char kOutputs[] = "Out";
static constexpr char kStepBlock[] = "sub_block";
static constexpr char kCondition[] = "Condition";
static constexpr char kStepScopes[] = "StepScopes";
static constexpr char kX[] = "X";
static constexpr char kXGRAD[] = "X@GRAD";
static constexpr char kOutputs[] = "Out";
class WhileOp : public framework::OperatorBase {
public:
......@@ -67,7 +67,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
public:
WhileOpMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddInput(kParameters,
AddInput(kX,
"A set of variables, which are required by operators inside the "
"block of While Op.")
.AsDuplicable();
......@@ -158,8 +158,8 @@ class WhileGradOp : public framework::OperatorBase {
executor.Run(*program, *cur_scope_iter, block->ID(), false);
auto &pg_names = Outputs(kParamGrads);
auto &p_names = Inputs(kParameters);
auto &pg_names = Outputs(kXGRAD);
auto &p_names = Inputs(kX);
PADDLE_ENFORCE_EQ(pg_names.size(), p_names.size());
for (size_t param_id = 0; param_id < pg_names.size(); ++param_id) {
if (pg_names[param_id] == framework::kEmptyVarName) {
......@@ -213,11 +213,11 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad = new framework::OpDesc();
grad->SetType("while_grad");
grad->SetInput(kParameters, Input(kParameters));
grad->SetInput(kX, Input(kX));
// Not all of IGs will be generated by inner gradient operators of while op.
// Ignore IGs that is not generated by the inside block.
auto igs = InputGrad(kParameters, /*do not drop empty gradient*/ false);
auto igs = InputGrad(kX, /*do not drop empty gradient*/ false);
std::unordered_set<std::string> all_outs;
for (size_t i = 0; i < grad_block_[0]->OpSize(); ++i) {
for (auto &oname : grad_block_[0]->Op(i)->OutputArgumentNames()) {
......@@ -231,7 +231,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
}
}
grad->SetOutput(framework::GradVarName(kParameters), igs);
grad->SetOutput(framework::GradVarName(kX), igs);
grad->SetInput(kOutputs, Output(kOutputs));
......@@ -240,7 +240,7 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
std::unordered_set<std::string> block_ins;
auto *fwd_block = this->grad_block_[0]->ParentBlock();
{
for (auto &p : Input(kParameters)) {
for (auto &p : Input(kX)) {
block_ins.insert(p);
}
for (auto &o : Output(kOutputs)) {
......@@ -288,8 +288,8 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kParameters);
auto pg_names = op_desc.Output(framework::GradVarName(kParameters));
auto p_names = op_desc.Input(kX);
auto pg_names = op_desc.Output(framework::GradVarName(kX));
for (size_t i = 0; i < p_names.size(); ++i) {
auto &p_var = detail::Ref(block->FindVarRecursive(p_names[i]));
......@@ -307,21 +307,21 @@ class WhileGradOpVarTypeInference : public framework::VarTypeInference {
class WhileGradOpShapeInference : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *ctx) const override {
ctx->HasInputs(kParameters);
ctx->HasOutputs(framework::GradVarName(kParameters));
ctx->HasInputs(kX);
ctx->HasOutputs(framework::GradVarName(kX));
ctx->HasInputs(kOutputs);
ctx->HasInputs(framework::GradVarName(kOutputs));
auto p_names = ctx->Inputs(kParameters);
auto pg_names = ctx->Outputs(kParamGrads);
auto var_types = ctx->GetInputsVarType(kParameters);
auto p_names = ctx->Inputs(kX);
auto pg_names = ctx->Outputs(kXGRAD);
auto var_types = ctx->GetInputsVarType(kX);
std::vector<std::string> names_to_set;
std::vector<framework::DDim> dims_to_set;
for (size_t i = 0; i < p_names.size(); ++i) {
if (pg_names[i] == framework::kEmptyVarName) {
continue;
}
auto dims = ctx->GetInputsElementDim(kParameters, i);
auto dims = ctx->GetInputsElementDim(kX, i);
if (var_types[i] == framework::proto::VarDesc::LOD_TENSOR) {
names_to_set.push_back(pg_names[i]);
dims_to_set.push_back(dims);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册