未验证 提交 09189732 编写于 作者: Y Yu Yang 提交者: GitHub

Rename XXDescBind --> XXDesc (#6797)

* Rename XXDescBind --> XXDesc

* Fix Compile
上级 0295b000
...@@ -42,7 +42,7 @@ static std::unordered_set<std::string>& CtrlFlowOps() { ...@@ -42,7 +42,7 @@ static std::unordered_set<std::string>& CtrlFlowOps() {
static inline std::unique_ptr<OperatorBase> CreateGradOp( static inline std::unique_ptr<OperatorBase> CreateGradOp(
const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set, const OperatorBase& op, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
OpDescBind op_desc; OpDesc op_desc;
op_desc.SetInputMap(op.Inputs()); op_desc.SetInputMap(op.Inputs());
op_desc.SetOutputMap(op.Outputs()); op_desc.SetOutputMap(op.Outputs());
op_desc.SetType(op.Type()); op_desc.SetType(op.Type());
...@@ -53,7 +53,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp( ...@@ -53,7 +53,7 @@ static inline std::unique_ptr<OperatorBase> CreateGradOp(
grad_ops.reserve(grad_descs.size()); grad_ops.reserve(grad_descs.size());
std::transform(grad_descs.begin(), grad_descs.end(), std::transform(grad_descs.begin(), grad_descs.end(),
std::back_inserter(grad_ops), std::back_inserter(grad_ops),
[](const std::unique_ptr<OpDescBind>& grad_desc) { [](const std::unique_ptr<OpDesc>& grad_desc) {
return OpRegistry::CreateOp(*grad_desc); return OpRegistry::CreateOp(*grad_desc);
}); });
PADDLE_ENFORCE(!grad_ops.empty()); PADDLE_ENFORCE(!grad_ops.empty());
...@@ -296,7 +296,7 @@ static std::string FwdName(const std::string& grad_name) { ...@@ -296,7 +296,7 @@ static std::string FwdName(const std::string& grad_name) {
static void CreateGradVarInBlock( static void CreateGradVarInBlock(
size_t grad_op_start_index, size_t grad_op_start_index,
const std::unordered_map<std::string, std::string>& param_name_map, const std::unordered_map<std::string, std::string>& param_name_map,
BlockDescBind* block_desc, BlockDesc* block_desc,
std::unordered_map<std::string, GradVarInfo>* grad_var_record) { std::unordered_map<std::string, GradVarInfo>* grad_var_record) {
auto ops = block_desc->AllOps(); auto ops = block_desc->AllOps();
for (size_t op_index = grad_op_start_index; op_index < ops.size(); for (size_t op_index = grad_op_start_index; op_index < ops.size();
...@@ -350,12 +350,11 @@ static void CreateGradVarInBlock( ...@@ -350,12 +350,11 @@ static void CreateGradVarInBlock(
} }
} }
std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( std::vector<std::unique_ptr<OpDesc>> MakeOpGrad(
const OpDescBind* op_desc, std::unordered_set<std::string>* no_grad_vars, const OpDesc* op_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block = const std::vector<BlockDesc*>& grad_block = std::vector<BlockDesc*>()) {
std::vector<BlockDescBind*>()) { std::vector<std::unique_ptr<OpDesc>> grad_op_descs;
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs;
// All input gradients of forwarding operator do not need to calculate. // All input gradients of forwarding operator do not need to calculate.
const std::vector<std::string>& inputs = op_desc->InputArgumentNames(); const std::vector<std::string>& inputs = op_desc->InputArgumentNames();
if (AllGradInSet(inputs, *no_grad_vars)) { if (AllGradInSet(inputs, *no_grad_vars)) {
...@@ -386,7 +385,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -386,7 +385,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
.Get(op_desc->Type()) .Get(op_desc->Type())
.GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block); .GradOpMaker()(*op_desc, *no_grad_vars, grad_to_var, grad_block);
std::list<std::unique_ptr<OpDescBind>> pending_fill_zeros_ops; std::list<std::unique_ptr<OpDesc>> pending_fill_zeros_ops;
for (auto& desc : grad_op_descs) { for (auto& desc : grad_op_descs) {
for (const std::string& in_name : desc->InputArgumentNames()) { for (const std::string& in_name : desc->InputArgumentNames()) {
if (no_grad_vars->count(in_name)) { if (no_grad_vars->count(in_name)) {
...@@ -394,8 +393,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -394,8 +393,8 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1); 0, in_name.size() - sizeof(kGradVarSuffix) / sizeof(char) + 1);
std::string new_name = prefix + kZeroVarSuffix; std::string new_name = prefix + kZeroVarSuffix;
desc->Rename(in_name, new_name); desc->Rename(in_name, new_name);
std::unique_ptr<OpDescBind> fill_zeros_op( std::unique_ptr<OpDesc> fill_zeros_op(
new OpDescBind("fill_zeros_like", {{"X", {prefix}}}, new OpDesc("fill_zeros_like", {{"X", {prefix}}},
{{"Y", {new_name}}}, AttributeMap{})); {{"Y", {new_name}}}, AttributeMap{}));
pending_fill_zeros_ops.push_back(std::move(fill_zeros_op)); pending_fill_zeros_ops.push_back(std::move(fill_zeros_op));
} }
...@@ -408,34 +407,33 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad( ...@@ -408,34 +407,33 @@ std::vector<std::unique_ptr<OpDescBind>> MakeOpGrad(
return grad_op_descs; return grad_op_descs;
} }
static BlockDescBind* CreateStepBlock( static BlockDesc* CreateStepBlock(
ProgramDescBind& program_desc, ProgramDesc& program_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx); int step_block_idx);
std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( std::vector<std::unique_ptr<OpDesc>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDesc& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
VLOG(5) << "MakeBlockBackward"; VLOG(5) << "MakeBlockBackward";
BlockDescBind* cur_block = program_desc.MutableBlock(block_idx); BlockDesc* cur_block = program_desc.MutableBlock(block_idx);
std::vector<OpDescBind*> op_descs = cur_block->AllOps(); std::vector<OpDesc*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
std::vector<std::unique_ptr<OpDescBind>> backward_descs; std::vector<std::unique_ptr<OpDesc>> backward_descs;
for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) { for (auto it = op_descs.rbegin(); it != op_descs.rend(); ++it) {
VLOG(5) << "Making backward " << (*it)->Type() << " op"; VLOG(5) << "Making backward " << (*it)->Type() << " op";
std::vector<std::unique_ptr<OpDescBind>> op_grads; std::vector<std::unique_ptr<OpDesc>> op_grads;
if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") { if ((*it)->Type() == "recurrent" || (*it)->Type() == "while") {
int step_block_idx = (*it)->GetBlockAttr("sub_block"); int step_block_idx = (*it)->GetBlockAttr("sub_block");
BlockDescBind* backward_block = CreateStepBlock( BlockDesc* backward_block = CreateStepBlock(program_desc, no_grad_vars,
program_desc, no_grad_vars, grad_to_var, step_block_idx); grad_to_var, step_block_idx);
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
} else if ((*it)->Type() == "conditional_block") { } else if ((*it)->Type() == "conditional_block") {
BlockDescBind* backward_block = BlockDesc* backward_block =
CreateStepBlock(program_desc, no_grad_vars, grad_to_var, CreateStepBlock(program_desc, no_grad_vars, grad_to_var,
(*it)->GetBlockAttr("sub_block")); (*it)->GetBlockAttr("sub_block"));
op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block}); op_grads = MakeOpGrad(*it, no_grad_vars, grad_to_var, {backward_block});
...@@ -463,14 +461,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -463,14 +461,14 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
} }
++grad_desc_idx; ++grad_desc_idx;
} }
std::transform( std::transform(op_grads.begin(), op_grads.end(),
op_grads.begin(), op_grads.end(), std::back_inserter(backward_descs), std::back_inserter(backward_descs),
[](std::unique_ptr<OpDescBind>& ptr) { return std::move(ptr); }); [](std::unique_ptr<OpDesc>& ptr) { return std::move(ptr); });
} }
VLOG(5) << "Appending Sums"; VLOG(5) << "Appending Sums";
// Check whether some variables are written more than once // Check whether some variables are written more than once
std::list<std::pair<size_t, std::unique_ptr<OpDescBind>>> pending_sum_ops; std::list<std::pair<size_t, std::unique_ptr<OpDesc>>> pending_sum_ops;
for (const auto& dup : dup_out_ops) { for (const auto& dup : dup_out_ops) {
const std::string& out_name = dup.first; const std::string& out_name = dup.first;
const std::vector<size_t> dup_op = dup.second; const std::vector<size_t> dup_op = dup.second;
...@@ -486,16 +484,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -486,16 +484,15 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
sum_op_inputs.emplace_back(new_name); sum_op_inputs.emplace_back(new_name);
next_g_name = sum_op_inputs.back(); next_g_name = sum_op_inputs.back();
} }
std::unique_ptr<OpDescBind> sum_op( std::unique_ptr<OpDesc> sum_op(new OpDesc("sum", {{"X", sum_op_inputs}},
new OpDescBind("sum", {{"X", sum_op_inputs}}, {{"Out", {out_name}}}, {{"Out", {out_name}}},
AttributeMap{})); AttributeMap{}));
pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)}); pending_sum_ops.push_back({dup_op.back(), std::move(sum_op)});
} }
} }
pending_sum_ops.sort( pending_sum_ops.sort([](const std::pair<size_t, std::unique_ptr<OpDesc>>& a,
[](const std::pair<size_t, std::unique_ptr<OpDescBind>>& a, const std::pair<size_t, std::unique_ptr<OpDesc>>& b) {
const std::pair<size_t, std::unique_ptr<OpDescBind>>& b) {
return a.first > b.first; return a.first > b.first;
}); });
for (auto& p : pending_sum_ops) { for (auto& p : pending_sum_ops) {
...@@ -508,14 +505,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -508,14 +505,13 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
return backward_descs; return backward_descs;
} }
static BlockDescBind* CreateStepBlock( static BlockDesc* CreateStepBlock(
ProgramDescBind& program_desc, ProgramDesc& program_desc, std::unordered_set<std::string>* no_grad_vars,
std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
int step_block_idx) { int step_block_idx) {
auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx, auto backward_block_op_descs = MakeBlockBackward(program_desc, step_block_idx,
no_grad_vars, grad_to_var); no_grad_vars, grad_to_var);
BlockDescBind* backward_block = BlockDesc* backward_block =
program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx)); program_desc.AppendBlock(*program_desc.MutableBlock(step_block_idx));
for (auto& ptr : backward_block_op_descs) { for (auto& ptr : backward_block_op_descs) {
backward_block->AppendAllocatedOp(move(ptr)); backward_block->AppendAllocatedOp(move(ptr));
...@@ -524,7 +520,7 @@ static BlockDescBind* CreateStepBlock( ...@@ -524,7 +520,7 @@ static BlockDescBind* CreateStepBlock(
} }
ParamGradInfoMap AppendBackward( ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target, ProgramDesc& program_desc, const VarDesc& target,
const std::unordered_set<std::string>& no_grad_vars) { const std::unordered_set<std::string>& no_grad_vars) {
std::unordered_set<std::string> no_grad_var_names; std::unordered_set<std::string> no_grad_var_names;
no_grad_var_names.reserve(no_grad_vars.size() + 1); no_grad_var_names.reserve(no_grad_vars.size() + 1);
...@@ -541,8 +537,8 @@ ParamGradInfoMap AppendBackward( ...@@ -541,8 +537,8 @@ ParamGradInfoMap AppendBackward(
PADDLE_ENFORCE(is_scalar, "target should be scalar"); PADDLE_ENFORCE(is_scalar, "target should be scalar");
VLOG(3) << "backward from loss=" << target.Name() VLOG(3) << "backward from loss=" << target.Name()
<< " data_type=" << target.GetDataType(); << " data_type=" << target.GetDataType();
std::unique_ptr<OpDescBind> fill_one_op( std::unique_ptr<OpDesc> fill_one_op(
new OpDescBind("fill_constant", {}, {{"Out", {fill_one_op_out}}}, new OpDesc("fill_constant", {}, {{"Out", {fill_one_op_out}}},
{{"shape", std::vector<int>{1}}, {{"shape", std::vector<int>{1}},
{"value", static_cast<float>(1.0)}, {"value", static_cast<float>(1.0)},
{"dtype", target.GetDataType()}})); {"dtype", target.GetDataType()}}));
......
...@@ -49,7 +49,7 @@ using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/, ...@@ -49,7 +49,7 @@ using ParamGradInfoMap = std::unordered_map<std::string /*fwd_var_name*/,
GradVarInfo /*grad_var_info*/>; GradVarInfo /*grad_var_info*/>;
ParamGradInfoMap AppendBackward( ParamGradInfoMap AppendBackward(
ProgramDescBind& program_desc, const VarDescBind& target, ProgramDesc& program_desc, const VarDesc& target,
const std::unordered_set<std::string>& no_grad_vars); const std::unordered_set<std::string>& no_grad_vars);
} // namespace framework } // namespace framework
......
...@@ -58,13 +58,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker { ...@@ -58,13 +58,13 @@ class RowWiseAddGradMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<OpDescBind> Apply() const override { std::unique_ptr<OpDesc> Apply() const override {
auto grad_op = new OpDescBind(); auto grad_op = new OpDesc();
grad_op->SetInput(GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(GradVarName("X"), InputGrad("X")); grad_op->SetOutput(GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(GradVarName("b"), InputGrad("b")); grad_op->SetOutput(GradVarName("b"), InputGrad("b"));
grad_op->SetType("rowwise_add_grad"); grad_op->SetType("rowwise_add_grad");
return std::unique_ptr<OpDescBind>(grad_op); return std::unique_ptr<OpDesc>(grad_op);
} }
}; };
...@@ -190,11 +190,11 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase { ...@@ -190,11 +190,11 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const override { std::vector<std::unique_ptr<OpDesc>> operator()() const override {
std::vector<std::unique_ptr<OpDescBind>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
auto x_g = InputGrad("X"); auto x_g = InputGrad("X");
if (!x_g.empty()) { if (!x_g.empty()) {
auto *op_desc = new OpDescBind(); auto *op_desc = new OpDesc();
op_desc->SetType("scale"); op_desc->SetType("scale");
op_desc->SetInput("X", OutputGrad("Out")); op_desc->SetInput("X", OutputGrad("Out"));
op_desc->SetOutput("Out", x_g); op_desc->SetOutput("Out", x_g);
...@@ -204,7 +204,7 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase { ...@@ -204,7 +204,7 @@ class MinusGradOpDescMaker : public GradOpDescMakerBase {
auto y_g = InputGrad("Y"); auto y_g = InputGrad("Y");
if (!y_g.empty()) { if (!y_g.empty()) {
auto *op_desc = new OpDescBind(); auto *op_desc = new OpDesc();
op_desc->SetType("scale"); op_desc->SetType("scale");
op_desc->SetInput("X", OutputGrad("Out")); op_desc->SetInput("X", OutputGrad("Out"));
op_desc->SetOutput("Out", y_g); op_desc->SetOutput("Out", y_g);
...@@ -505,25 +505,25 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -505,25 +505,25 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
} }
TEST(Backward, simple_single_op) { TEST(Backward, simple_single_op) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp(); f::OpDesc *op = block->AppendOp();
op->SetType("rowwise_add"); op->SetType("rowwise_add");
op->SetInput("X", {"x"}); op->SetInput("X", {"x"});
op->SetInput("b", {"b"}); op->SetInput("b", {"b"});
op->SetOutput("Out", {"out"}); op->SetOutput("Out", {"out"});
auto target = f::VarDescBind("out"); auto target = f::VarDesc("out");
target.SetShape({1}); target.SetShape({1});
auto var_to_grad = auto var_to_grad =
AppendBackward(program, target, std::unordered_set<std::string>{}); AppendBackward(program, target, std::unordered_set<std::string>{});
ASSERT_EQ(block->AllOps().size(), 3UL); ASSERT_EQ(block->AllOps().size(), 3UL);
f::OpDescBind *fill_op = block->AllOps()[1]; f::OpDesc *fill_op = block->AllOps()[1];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
f::OpDescBind *grad_op = block->AllOps()[2]; f::OpDesc *grad_op = block->AllOps()[2];
EXPECT_EQ(grad_op->Type(), "rowwise_add_grad"); EXPECT_EQ(grad_op->Type(), "rowwise_add_grad");
ASSERT_EQ(grad_op->InputNames().size(), 1UL); ASSERT_EQ(grad_op->InputNames().size(), 1UL);
ASSERT_EQ(grad_op->OutputNames().size(), 2UL); ASSERT_EQ(grad_op->OutputNames().size(), 2UL);
...@@ -543,16 +543,16 @@ TEST(Backward, simple_single_op) { ...@@ -543,16 +543,16 @@ TEST(Backward, simple_single_op) {
} }
TEST(Backward, default_attribute) { TEST(Backward, default_attribute) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp(); f::OpDesc *op = block->AppendOp();
op->SetType("mul"); op->SetType("mul");
op->SetInput("X", {"x"}); op->SetInput("X", {"x"});
op->SetInput("Y", {"y"}); op->SetInput("Y", {"y"});
op->SetOutput("Out", {"out"}); op->SetOutput("Out", {"out"});
op->CheckAttrs(); op->CheckAttrs();
auto target = f::VarDescBind("out"); auto target = f::VarDesc("out");
target.SetShape({1}); target.SetShape({1});
AppendBackward(program, target, std::unordered_set<std::string>{}); AppendBackward(program, target, std::unordered_set<std::string>{});
...@@ -560,47 +560,47 @@ TEST(Backward, default_attribute) { ...@@ -560,47 +560,47 @@ TEST(Backward, default_attribute) {
EXPECT_EQ(boost::get<int>(op->GetAttr("x_num_col_dims")), 1); EXPECT_EQ(boost::get<int>(op->GetAttr("x_num_col_dims")), 1);
EXPECT_EQ(boost::get<int>(op->GetAttr("y_num_col_dims")), 1); EXPECT_EQ(boost::get<int>(op->GetAttr("y_num_col_dims")), 1);
f::OpDescBind *fill_op = block->AllOps()[1]; f::OpDesc *fill_op = block->AllOps()[1];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
f::OpDescBind *grad_op = block->AllOps()[2]; f::OpDesc *grad_op = block->AllOps()[2];
ASSERT_EQ(grad_op->Type(), "mul_grad"); ASSERT_EQ(grad_op->Type(), "mul_grad");
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("x_num_col_dims")), 1); EXPECT_EQ(boost::get<int>(grad_op->GetAttr("x_num_col_dims")), 1);
EXPECT_EQ(boost::get<int>(grad_op->GetAttr("y_num_col_dims")), 1); EXPECT_EQ(boost::get<int>(grad_op->GetAttr("y_num_col_dims")), 1);
} }
TEST(Backward, simple_mult_op) { TEST(Backward, simple_mult_op) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDesc *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
op1->SetInput("b", {"b1"}); op1->SetInput("b", {"b1"});
op1->SetOutput("Out", {"out1"}); op1->SetOutput("Out", {"out1"});
f::OpDescBind *op2 = block->AppendOp(); f::OpDesc *op2 = block->AppendOp();
op2->SetType("mul"); op2->SetType("mul");
op2->SetInput("X", {"out1"}); op2->SetInput("X", {"out1"});
op2->SetInput("Y", {"y2"}); op2->SetInput("Y", {"y2"});
op2->SetOutput("Out", {"out2"}); op2->SetOutput("Out", {"out2"});
f::OpDescBind *op3 = block->AppendOp(); f::OpDesc *op3 = block->AppendOp();
op3->SetType("rowwise_add"); op3->SetType("rowwise_add");
op3->SetInput("X", {"out2"}); op3->SetInput("X", {"out2"});
op3->SetInput("b", {"b3"}); op3->SetInput("b", {"b3"});
op3->SetOutput("Out", {"out3"}); op3->SetOutput("Out", {"out3"});
auto target = f::VarDescBind("out3"); auto target = f::VarDesc("out3");
target.SetShape({1}); target.SetShape({1});
size_t forward_len = block->AllOps().size(); size_t forward_len = block->AllOps().size();
auto var_to_grad = auto var_to_grad =
AppendBackward(program, target, std::unordered_set<std::string>{}); AppendBackward(program, target, std::unordered_set<std::string>{});
ASSERT_EQ(block->AllOps().size(), 6UL + 1); ASSERT_EQ(block->AllOps().size(), 6UL + 1);
f::OpDescBind *fill_op = block->AllOps()[forward_len]; f::OpDesc *fill_op = block->AllOps()[forward_len];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
f::OpDescBind *grad_op1 = block->AllOps()[6]; f::OpDesc *grad_op1 = block->AllOps()[6];
EXPECT_EQ(grad_op1->Type(), "rowwise_add_grad"); EXPECT_EQ(grad_op1->Type(), "rowwise_add_grad");
ASSERT_EQ(grad_op1->InputNames().size(), 1UL); ASSERT_EQ(grad_op1->InputNames().size(), 1UL);
ASSERT_EQ(grad_op1->OutputNames().size(), 2UL); ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
...@@ -611,7 +611,7 @@ TEST(Backward, simple_mult_op) { ...@@ -611,7 +611,7 @@ TEST(Backward, simple_mult_op) {
EXPECT_EQ(grad_op1->Output(f::GradVarName("b")), EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
std::vector<std::string>({f::GradVarName("b1")})); std::vector<std::string>({f::GradVarName("b1")}));
f::OpDescBind *grad_op2 = block->AllOps()[5]; f::OpDesc *grad_op2 = block->AllOps()[5];
EXPECT_EQ(grad_op2->Type(), "mul_grad"); EXPECT_EQ(grad_op2->Type(), "mul_grad");
ASSERT_EQ(grad_op2->InputNames().size(), 4UL); ASSERT_EQ(grad_op2->InputNames().size(), 4UL);
ASSERT_EQ(grad_op2->OutputNames().size(), 2UL); ASSERT_EQ(grad_op2->OutputNames().size(), 2UL);
...@@ -625,7 +625,7 @@ TEST(Backward, simple_mult_op) { ...@@ -625,7 +625,7 @@ TEST(Backward, simple_mult_op) {
EXPECT_EQ(grad_op2->Output(f::GradVarName("Y")), EXPECT_EQ(grad_op2->Output(f::GradVarName("Y")),
std::vector<std::string>({f::GradVarName("y2")})); std::vector<std::string>({f::GradVarName("y2")}));
f::OpDescBind *grad_op3 = block->AllOps()[4]; f::OpDesc *grad_op3 = block->AllOps()[4];
EXPECT_EQ(grad_op3->Type(), "rowwise_add_grad"); EXPECT_EQ(grad_op3->Type(), "rowwise_add_grad");
ASSERT_EQ(grad_op3->InputNames().size(), 1UL); ASSERT_EQ(grad_op3->InputNames().size(), 1UL);
ASSERT_EQ(grad_op3->OutputNames().size(), 2UL); ASSERT_EQ(grad_op3->OutputNames().size(), 2UL);
...@@ -655,42 +655,42 @@ TEST(Backward, simple_mult_op) { ...@@ -655,42 +655,42 @@ TEST(Backward, simple_mult_op) {
} }
TEST(Backward, intermedia_var_no_grad) { TEST(Backward, intermedia_var_no_grad) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDesc *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
op1->SetInput("b", {"b1"}); op1->SetInput("b", {"b1"});
op1->SetOutput("Out", {"out1"}); op1->SetOutput("Out", {"out1"});
f::OpDescBind *op2 = block->AppendOp(); f::OpDesc *op2 = block->AppendOp();
op2->SetType("mul"); op2->SetType("mul");
op2->SetInput("X", {"x2"}); op2->SetInput("X", {"x2"});
op2->SetInput("Y", {"y2"}); op2->SetInput("Y", {"y2"});
op2->SetOutput("Out", {"out2"}); op2->SetOutput("Out", {"out2"});
f::OpDescBind *op3 = block->AppendOp(); f::OpDesc *op3 = block->AppendOp();
op3->SetType("rowwise_add"); op3->SetType("rowwise_add");
op3->SetInput("X", {"out2"}); op3->SetInput("X", {"out2"});
op3->SetInput("b", {"b3"}); op3->SetInput("b", {"b3"});
op3->SetOutput("Out", {"out3"}); op3->SetOutput("Out", {"out3"});
f::OpDescBind *op4 = block->AppendOp(); f::OpDesc *op4 = block->AppendOp();
op4->SetType("mul"); op4->SetType("mul");
op4->SetInput("X", {"out1"}); op4->SetInput("X", {"out1"});
op4->SetInput("Y", {"out3"}); op4->SetInput("Y", {"out3"});
op4->SetOutput("Out", {"out4"}); op4->SetOutput("Out", {"out4"});
auto target = f::VarDescBind("out4"); auto target = f::VarDesc("out4");
target.SetShape({1}); target.SetShape({1});
size_t forward_len = block->AllOps().size(); size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"out3"}); auto var_to_grad = AppendBackward(program, target, {"out3"});
ASSERT_EQ(block->AllOps().size(), 7UL); ASSERT_EQ(block->AllOps().size(), 7UL);
f::OpDescBind *fill_op = block->AllOps()[forward_len]; f::OpDesc *fill_op = block->AllOps()[forward_len];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
f::OpDescBind *grad_op1 = block->AllOps()[6]; f::OpDesc *grad_op1 = block->AllOps()[6];
EXPECT_EQ(grad_op1->Type(), "rowwise_add_grad"); EXPECT_EQ(grad_op1->Type(), "rowwise_add_grad");
ASSERT_EQ(grad_op1->InputNames().size(), 1UL); ASSERT_EQ(grad_op1->InputNames().size(), 1UL);
ASSERT_EQ(grad_op1->OutputNames().size(), 2UL); ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
...@@ -701,7 +701,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -701,7 +701,7 @@ TEST(Backward, intermedia_var_no_grad) {
EXPECT_EQ(grad_op1->Output(f::GradVarName("b")), EXPECT_EQ(grad_op1->Output(f::GradVarName("b")),
std::vector<std::string>({f::GradVarName("b1")})); std::vector<std::string>({f::GradVarName("b1")}));
f::OpDescBind *grad_op4 = block->AllOps()[5]; f::OpDesc *grad_op4 = block->AllOps()[5];
EXPECT_EQ(grad_op4->Type(), "mul_grad"); EXPECT_EQ(grad_op4->Type(), "mul_grad");
ASSERT_EQ(grad_op4->InputNames().size(), 4UL); ASSERT_EQ(grad_op4->InputNames().size(), 4UL);
ASSERT_EQ(grad_op4->OutputNames().size(), 2UL); ASSERT_EQ(grad_op4->OutputNames().size(), 2UL);
...@@ -726,32 +726,32 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -726,32 +726,32 @@ TEST(Backward, intermedia_var_no_grad) {
} }
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDesc *op1 = block->AppendOp();
op1->SetType("mult_in_out"); op1->SetType("mult_in_out");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
op1->SetInput("H", {"h1"}); op1->SetInput("H", {"h1"});
op1->SetOutput("Y", {"y1"}); op1->SetOutput("Y", {"y1"});
op1->SetOutput("Z", {"z1"}); op1->SetOutput("Z", {"z1"});
f::OpDescBind *op2 = block->AppendOp(); f::OpDesc *op2 = block->AppendOp();
op2->SetType("mult_in_out"); op2->SetType("mult_in_out");
op2->SetInput("X", {"y1"}); op2->SetInput("X", {"y1"});
op2->SetInput("H", {"z1"}); op2->SetInput("H", {"z1"});
op2->SetOutput("Y", {"y2"}); op2->SetOutput("Y", {"y2"});
op2->SetOutput("Z", {"z2"}); op2->SetOutput("Z", {"z2"});
auto target = f::VarDescBind("z2"); auto target = f::VarDesc("z2");
target.SetShape({1}); target.SetShape({1});
size_t forward_len = block->AllOps().size(); size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"z1"}); auto var_to_grad = AppendBackward(program, target, {"z1"});
ASSERT_EQ(block->AllOps().size(), 6UL); ASSERT_EQ(block->AllOps().size(), 6UL);
f::OpDescBind *fill_op = block->AllOps()[forward_len]; f::OpDesc *fill_op = block->AllOps()[forward_len];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
f::OpDescBind *grad_op2 = block->AllOps()[3]; f::OpDesc *grad_op2 = block->AllOps()[3];
ASSERT_EQ(grad_op2->Type(), "mult_in_out_grad"); ASSERT_EQ(grad_op2->Type(), "mult_in_out_grad");
ASSERT_EQ(grad_op2->InputNames().size(), 6UL); ASSERT_EQ(grad_op2->InputNames().size(), 6UL);
ASSERT_EQ(grad_op2->OutputNames().size(), 2UL); ASSERT_EQ(grad_op2->OutputNames().size(), 2UL);
...@@ -767,7 +767,7 @@ TEST(Backward, var_no_grad) { ...@@ -767,7 +767,7 @@ TEST(Backward, var_no_grad) {
std::vector<std::string>({f::GradVarName("y1")})); std::vector<std::string>({f::GradVarName("y1")}));
EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector<std::string>()); EXPECT_EQ(grad_op2->Output(f::GradVarName("H")), std::vector<std::string>());
f::OpDescBind *fill_zero_op = block->AllOps()[4]; f::OpDesc *fill_zero_op = block->AllOps()[4];
ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like"); ASSERT_EQ(fill_zero_op->Type(), "fill_zeros_like");
ASSERT_EQ(fill_zero_op->InputNames().size(), 1UL); ASSERT_EQ(fill_zero_op->InputNames().size(), 1UL);
ASSERT_EQ(fill_zero_op->OutputNames().size(), 1UL); ASSERT_EQ(fill_zero_op->OutputNames().size(), 1UL);
...@@ -775,7 +775,7 @@ TEST(Backward, var_no_grad) { ...@@ -775,7 +775,7 @@ TEST(Backward, var_no_grad) {
EXPECT_EQ(fill_zero_op->Output("Y"), EXPECT_EQ(fill_zero_op->Output("Y"),
std::vector<std::string>({std::string("z1") + f::kZeroVarSuffix})); std::vector<std::string>({std::string("z1") + f::kZeroVarSuffix}));
f::OpDescBind *grad_op1 = block->AllOps()[5]; f::OpDesc *grad_op1 = block->AllOps()[5];
ASSERT_EQ(grad_op1->Type(), "mult_in_out_grad"); ASSERT_EQ(grad_op1->Type(), "mult_in_out_grad");
ASSERT_EQ(grad_op1->InputNames().size(), 6UL); ASSERT_EQ(grad_op1->InputNames().size(), 6UL);
ASSERT_EQ(grad_op1->OutputNames().size(), 2UL); ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
...@@ -803,37 +803,37 @@ TEST(Backward, var_no_grad) { ...@@ -803,37 +803,37 @@ TEST(Backward, var_no_grad) {
} }
TEST(Backward, shared_var) { TEST(Backward, shared_var) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDesc *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
op1->SetInput("b", {"b1"}); op1->SetInput("b", {"b1"});
op1->SetOutput("Out", {"out1"}); op1->SetOutput("Out", {"out1"});
f::OpDescBind *op2 = block->AppendOp(); f::OpDesc *op2 = block->AppendOp();
op2->SetType("mul"); op2->SetType("mul");
op2->SetInput("X", {"out1"}); op2->SetInput("X", {"out1"});
op2->SetInput("Y", {"y2"}); op2->SetInput("Y", {"y2"});
op2->SetOutput("Out", {"out2"}); op2->SetOutput("Out", {"out2"});
f::OpDescBind *op3 = block->AppendOp(); f::OpDesc *op3 = block->AppendOp();
op3->SetType("rowwise_add"); op3->SetType("rowwise_add");
op3->SetInput("X", {"out1"}); op3->SetInput("X", {"out1"});
op3->SetInput("b", {"b3"}); op3->SetInput("b", {"b3"});
op3->SetOutput("Out", {"out3"}); op3->SetOutput("Out", {"out3"});
auto target = f::VarDescBind("out3"); auto target = f::VarDesc("out3");
target.SetShape({1}); target.SetShape({1});
size_t forward_len = block->AllOps().size(); size_t forward_len = block->AllOps().size();
auto var_to_grad = auto var_to_grad =
AppendBackward(program, target, std::unordered_set<std::string>{}); AppendBackward(program, target, std::unordered_set<std::string>{});
ASSERT_EQ(block->AllOps().size(), 8UL); ASSERT_EQ(block->AllOps().size(), 8UL);
f::OpDescBind *fill_op = block->AllOps()[forward_len]; f::OpDesc *fill_op = block->AllOps()[forward_len];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
f::OpDescBind *grad_op3 = block->AllOps()[4]; f::OpDesc *grad_op3 = block->AllOps()[4];
ASSERT_EQ(grad_op3->Type(), "rowwise_add_grad"); ASSERT_EQ(grad_op3->Type(), "rowwise_add_grad");
ASSERT_EQ(grad_op3->InputNames().size(), 1UL); ASSERT_EQ(grad_op3->InputNames().size(), 1UL);
ASSERT_EQ(grad_op3->OutputNames().size(), 2UL); ASSERT_EQ(grad_op3->OutputNames().size(), 2UL);
...@@ -844,7 +844,7 @@ TEST(Backward, shared_var) { ...@@ -844,7 +844,7 @@ TEST(Backward, shared_var) {
EXPECT_EQ(grad_op3->Output(f::GradVarName("b")), EXPECT_EQ(grad_op3->Output(f::GradVarName("b")),
std::vector<std::string>({f::GradVarName("b3")})); std::vector<std::string>({f::GradVarName("b3")}));
f::OpDescBind *grad_op4 = block->AllOps()[5]; f::OpDesc *grad_op4 = block->AllOps()[5];
ASSERT_EQ(grad_op4->Type(), "mul_grad"); ASSERT_EQ(grad_op4->Type(), "mul_grad");
ASSERT_EQ(grad_op4->InputNames().size(), 4UL); ASSERT_EQ(grad_op4->InputNames().size(), 4UL);
ASSERT_EQ(grad_op4->OutputNames().size(), 2UL); ASSERT_EQ(grad_op4->OutputNames().size(), 2UL);
...@@ -858,7 +858,7 @@ TEST(Backward, shared_var) { ...@@ -858,7 +858,7 @@ TEST(Backward, shared_var) {
EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")), EXPECT_EQ(grad_op4->Output(f::GradVarName("Y")),
std::vector<std::string>({f::GradVarName("y2")})); std::vector<std::string>({f::GradVarName("y2")}));
f::OpDescBind *sum_op = block->AllOps()[6]; f::OpDesc *sum_op = block->AllOps()[6];
ASSERT_EQ(sum_op->Type(), "sum"); ASSERT_EQ(sum_op->Type(), "sum");
ASSERT_EQ(sum_op->InputNames().size(), 1UL); ASSERT_EQ(sum_op->InputNames().size(), 1UL);
ASSERT_EQ(sum_op->OutputNames().size(), 1UL); ASSERT_EQ(sum_op->OutputNames().size(), 1UL);
...@@ -868,7 +868,7 @@ TEST(Backward, shared_var) { ...@@ -868,7 +868,7 @@ TEST(Backward, shared_var) {
EXPECT_EQ(sum_op->Output("Out"), EXPECT_EQ(sum_op->Output("Out"),
std::vector<std::string>({f::GradVarName("out1")})); std::vector<std::string>({f::GradVarName("out1")}));
f::OpDescBind *grad_op1 = block->AllOps()[7]; f::OpDesc *grad_op1 = block->AllOps()[7];
ASSERT_EQ(grad_op1->Type(), "rowwise_add_grad"); ASSERT_EQ(grad_op1->Type(), "rowwise_add_grad");
ASSERT_EQ(grad_op1->InputNames().size(), 1UL); ASSERT_EQ(grad_op1->InputNames().size(), 1UL);
ASSERT_EQ(grad_op1->OutputNames().size(), 2UL); ASSERT_EQ(grad_op1->OutputNames().size(), 2UL);
...@@ -895,19 +895,19 @@ TEST(Backward, shared_var) { ...@@ -895,19 +895,19 @@ TEST(Backward, shared_var) {
} }
TEST(Backward, half_backward) { TEST(Backward, half_backward) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
auto *op1 = block->AppendOp(); auto *op1 = block->AppendOp();
op1->SetType("minus"); op1->SetType("minus");
op1->SetInput("X", {"a"}); op1->SetInput("X", {"a"});
op1->SetInput("Y", {"b"}); op1->SetInput("Y", {"b"});
op1->SetOutput("Out", {"out"}); op1->SetOutput("Out", {"out"});
auto target = f::VarDescBind("out"); auto target = f::VarDesc("out");
target.SetShape({1}); target.SetShape({1});
size_t forward_len = block->AllOps().size(); size_t forward_len = block->AllOps().size();
auto var_to_grad = AppendBackward(program, target, {"b"}); auto var_to_grad = AppendBackward(program, target, {"b"});
f::OpDescBind *fill_op = block->AllOps()[forward_len]; f::OpDesc *fill_op = block->AllOps()[forward_len];
EXPECT_EQ(fill_op->Type(), "fill_constant"); EXPECT_EQ(fill_op->Type(), "fill_constant");
auto ops = block->AllOps(); auto ops = block->AllOps();
ASSERT_EQ(3UL, ops.size()); ASSERT_EQ(3UL, ops.size());
......
...@@ -19,18 +19,18 @@ limitations under the License. */ ...@@ -19,18 +19,18 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
VarDescBind *BlockDescBind::Var(const std::string &name) { VarDesc *BlockDesc::Var(const std::string &name) {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it != vars_.end()) { if (it != vars_.end()) {
return it->second.get(); return it->second.get();
} }
need_update_ = true; need_update_ = true;
auto *var = new VarDescBind(name); auto *var = new VarDesc(name);
vars_[name].reset(var); vars_[name].reset(var);
return var; return var;
} }
VarDescBind *BlockDescBind::FindVar(const std::string &name) const { VarDesc *BlockDesc::FindVar(const std::string &name) const {
auto it = vars_.find(name); auto it = vars_.find(name);
if (it == vars_.end()) { if (it == vars_.end()) {
return nullptr; return nullptr;
...@@ -38,11 +38,11 @@ VarDescBind *BlockDescBind::FindVar(const std::string &name) const { ...@@ -38,11 +38,11 @@ VarDescBind *BlockDescBind::FindVar(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
bool BlockDescBind::HasVar(const std::string &name) const { bool BlockDesc::HasVar(const std::string &name) const {
return vars_.find(name) != vars_.end(); return vars_.find(name) != vars_.end();
} }
VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const { VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
if (name == kEmptyVarName) return nullptr; if (name == kEmptyVarName) return nullptr;
auto it = vars_.find(name); auto it = vars_.find(name);
...@@ -53,53 +53,52 @@ VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const { ...@@ -53,53 +53,52 @@ VarDescBind *BlockDescBind::FindVarRecursive(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
VarDescBind *BlockDescBind::FindRecursiveOrCreateVar( VarDesc *BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
const std::string &name_bytes) { VarDesc *res = FindVarRecursive(name_bytes);
VarDescBind *res = FindVarRecursive(name_bytes);
if (res == nullptr) { if (res == nullptr) {
res = Var(name_bytes); res = Var(name_bytes);
} }
return res; return res;
} }
bool BlockDescBind::HasVarRecursive(const std::string &name) const { bool BlockDesc::HasVarRecursive(const std::string &name) const {
return FindVarRecursive(name) != nullptr; return FindVarRecursive(name) != nullptr;
} }
std::vector<VarDescBind *> BlockDescBind::AllVars() const { std::vector<VarDesc *> BlockDesc::AllVars() const {
std::vector<VarDescBind *> res; std::vector<VarDesc *> res;
for (const auto &p : vars_) { for (const auto &p : vars_) {
res.push_back(p.second.get()); res.push_back(p.second.get());
} }
return res; return res;
} }
OpDescBind *BlockDescBind::AppendOp() { OpDesc *BlockDesc::AppendOp() {
need_update_ = true; need_update_ = true;
ops_.emplace_back(new OpDescBind()); ops_.emplace_back(new OpDesc());
return ops_.back().get(); return ops_.back().get();
} }
void BlockDescBind::AppendAllocatedOp(std::unique_ptr<OpDescBind> &&op_desc) { void BlockDesc::AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc) {
need_update_ = true; need_update_ = true;
ops_.emplace_back(std::move(op_desc)); ops_.emplace_back(std::move(op_desc));
} }
OpDescBind *BlockDescBind::PrependOp() { OpDesc *BlockDesc::PrependOp() {
need_update_ = true; need_update_ = true;
ops_.emplace_front(new OpDescBind()); ops_.emplace_front(new OpDesc());
return ops_.front().get(); return ops_.front().get();
} }
std::vector<OpDescBind *> BlockDescBind::AllOps() const { std::vector<OpDesc *> BlockDesc::AllOps() const {
std::vector<OpDescBind *> res; std::vector<OpDesc *> res;
for (const auto &op : ops_) { for (const auto &op : ops_) {
res.push_back(op.get()); res.push_back(op.get());
} }
return res; return res;
} }
void BlockDescBind::Flush() { void BlockDesc::Flush() {
for (auto &op_desc : ops_) { for (auto &op_desc : ops_) {
op_desc->Flush(); op_desc->Flush();
} }
...@@ -121,43 +120,43 @@ void BlockDescBind::Flush() { ...@@ -121,43 +120,43 @@ void BlockDescBind::Flush() {
} }
} }
BlockDescBind *BlockDescBind::ParentBlock() const { BlockDesc *BlockDesc::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) { if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr; return nullptr;
} }
return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx())); return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
proto::BlockDesc *BlockDescBind::Proto() { proto::BlockDesc *BlockDesc::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc) BlockDesc::BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) { : prog_(prog), desc_(desc), need_update_(false) {
for (const proto::VarDesc &var_desc : desc_->vars()) { for (const proto::VarDesc &var_desc : desc_->vars()) {
vars_[var_desc.name()].reset(new VarDescBind(var_desc)); vars_[var_desc.name()].reset(new VarDesc(var_desc));
} }
for (const proto::OpDesc &op_desc : desc_->ops()) { for (const proto::OpDesc &op_desc : desc_->ops()) {
ops_.emplace_back(new OpDescBind(op_desc, prog)); ops_.emplace_back(new OpDesc(op_desc, prog));
} }
} }
BlockDescBind::BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc, BlockDesc::BlockDesc(const BlockDesc &other, proto::BlockDesc *desc,
ProgramDescBind *prog) ProgramDesc *prog)
: prog_(prog), desc_(desc) { : prog_(prog), desc_(desc) {
need_update_ = true; need_update_ = true;
for (auto &op : other.ops_) { for (auto &op : other.ops_) {
ops_.emplace_back(new OpDescBind(*op)); ops_.emplace_back(new OpDesc(*op));
} }
for (auto &it : other.vars_) { for (auto &it : other.vars_) {
auto *var = new VarDescBind(*it.second); auto *var = new VarDesc(*it.second);
vars_[it.first].reset(var); vars_[it.first].reset(var);
} }
} }
void BlockDescBind::ClearPBOps() { void BlockDesc::ClearPBOps() {
auto ops = this->desc_->mutable_ops(); auto ops = this->desc_->mutable_ops();
while (!ops->empty()) { while (!ops->empty()) {
// we do not own the OpDesc, so release the ownership. // we do not own the OpDesc, so release the ownership.
...@@ -165,7 +164,7 @@ void BlockDescBind::ClearPBOps() { ...@@ -165,7 +164,7 @@ void BlockDescBind::ClearPBOps() {
} }
} }
void BlockDescBind::ClearPBVars() { void BlockDesc::ClearPBVars() {
auto vars = this->desc_->mutable_vars(); auto vars = this->desc_->mutable_vars();
while (!vars->empty()) { while (!vars->empty()) {
// we do not own the VarDesc, so release the ownership. // we do not own the VarDesc, so release the ownership.
......
...@@ -28,20 +28,19 @@ limitations under the License. */ ...@@ -28,20 +28,19 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class ProgramDescBind; class ProgramDesc;
// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize // Each Protobuf Message, we provide a XXXBind class. In that class, we optimize
// read/write speed. Only when we want the protobuf message, the local changes // read/write speed. Only when we want the protobuf message, the local changes
// will be synchronized (by `Sync` method). // will be synchronized (by `Sync` method).
class BlockDescBind { class BlockDesc {
public: public:
BlockDescBind(ProgramDescBind *prog, proto::BlockDesc *desc); BlockDesc(ProgramDesc *prog, proto::BlockDesc *desc);
BlockDescBind(const BlockDescBind &other, proto::BlockDesc *desc, BlockDesc(const BlockDesc &other, proto::BlockDesc *desc, ProgramDesc *prog);
ProgramDescBind *prog);
~BlockDescBind() { ~BlockDesc() {
this->ClearPBVars(); this->ClearPBVars();
this->ClearPBOps(); this->ClearPBOps();
} }
...@@ -50,15 +49,15 @@ class BlockDescBind { ...@@ -50,15 +49,15 @@ class BlockDescBind {
int32_t Parent() const { return desc_->parent_idx(); } int32_t Parent() const { return desc_->parent_idx(); }
VarDescBind *Var(const std::string &name_bytes); VarDesc *Var(const std::string &name_bytes);
VarDescBind *FindVar(const std::string &name_bytes) const; VarDesc *FindVar(const std::string &name_bytes) const;
bool HasVar(const std::string &var_name) const; bool HasVar(const std::string &var_name) const;
VarDescBind *FindVarRecursive(const std::string &name_bytes) const; VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDescBind *FindRecursiveOrCreateVar(const std::string &name_bytes); VarDesc *FindRecursiveOrCreateVar(const std::string &name_bytes);
bool HasVarRecursive(const std::string &var_name) const; bool HasVarRecursive(const std::string &var_name) const;
...@@ -70,41 +69,41 @@ class BlockDescBind { ...@@ -70,41 +69,41 @@ class BlockDescBind {
return var_names; return var_names;
} }
std::vector<VarDescBind *> AllVars() const; std::vector<VarDesc *> AllVars() const;
BlockDescBind *ParentBlock() const; BlockDesc *ParentBlock() const;
OpDescBind *AppendOp(); OpDesc *AppendOp();
void AppendAllocatedOp(std::unique_ptr<OpDescBind> &&op_desc); void AppendAllocatedOp(std::unique_ptr<OpDesc> &&op_desc);
OpDescBind *PrependOp(); OpDesc *PrependOp();
std::vector<OpDescBind *> AllOps() const; std::vector<OpDesc *> AllOps() const;
size_t OpSize() const { return ops_.size(); } size_t OpSize() const { return ops_.size(); }
OpDescBind *Op(int idx) { return ops_.at(idx).get(); } OpDesc *Op(int idx) { return ops_.at(idx).get(); }
void Flush(); void Flush();
proto::BlockDesc *Proto(); proto::BlockDesc *Proto();
ProgramDescBind *Program() { return this->prog_; } ProgramDesc *Program() { return this->prog_; }
private: private:
void ClearPBOps(); void ClearPBOps();
void ClearPBVars(); void ClearPBVars();
private: private:
ProgramDescBind *prog_; // not_own ProgramDesc *prog_; // not_own
proto::BlockDesc *desc_; // not_own proto::BlockDesc *desc_; // not_own
bool need_update_; bool need_update_;
std::deque<std::unique_ptr<OpDescBind>> ops_; std::deque<std::unique_ptr<OpDesc>> ops_;
std::unordered_map<std::string, std::unique_ptr<VarDescBind>> vars_; std::unordered_map<std::string, std::unique_ptr<VarDesc>> vars_;
DISABLE_COPY_AND_ASSIGN(BlockDescBind); DISABLE_COPY_AND_ASSIGN(BlockDesc);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -106,10 +106,10 @@ template <typename T> ...@@ -106,10 +106,10 @@ template <typename T>
struct OpInfoFiller<T, kGradOpDescMaker> { struct OpInfoFiller<T, kGradOpDescMaker> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->grad_op_maker_ = []( info->grad_op_maker_ = [](
const OpDescBind& fwd_op, const OpDesc& fwd_op,
const std::unordered_set<std::string>& no_grad_set, const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block) { const std::vector<BlockDesc*>& grad_block) {
T maker(fwd_op, no_grad_set, grad_to_var, grad_block); T maker(fwd_op, no_grad_set, grad_to_var, grad_block);
return maker(); return maker();
}; };
...@@ -119,7 +119,7 @@ struct OpInfoFiller<T, kGradOpDescMaker> { ...@@ -119,7 +119,7 @@ struct OpInfoFiller<T, kGradOpDescMaker> {
template <typename T> template <typename T>
struct OpInfoFiller<T, kVarTypeInference> { struct OpInfoFiller<T, kVarTypeInference> {
void operator()(const char* op_type, OpInfo* info) const { void operator()(const char* op_type, OpInfo* info) const {
info->infer_var_type_ = [](const OpDescBind& fwd_op, BlockDescBind* block) { info->infer_var_type_ = [](const OpDesc& fwd_op, BlockDesc* block) {
T inference; T inference;
inference(fwd_op, block); inference(fwd_op, block);
}; };
......
...@@ -64,7 +64,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -64,7 +64,7 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
} }
} }
void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id, void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id,
bool create_local_scope) { bool create_local_scope) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)
......
...@@ -114,7 +114,7 @@ class Executor { ...@@ -114,7 +114,7 @@ class Executor {
* ProgramDesc * ProgramDesc
* Scope * Scope
*/ */
void Run(const ProgramDescBind&, Scope*, int, bool create_local_scope = true); void Run(const ProgramDesc&, Scope*, int, bool create_local_scope = true);
private: private:
std::vector<const platform::DeviceContext*> device_contexts_; std::vector<const platform::DeviceContext*> device_contexts_;
......
...@@ -25,18 +25,16 @@ namespace framework { ...@@ -25,18 +25,16 @@ namespace framework {
class GradOpDescMakerBase { class GradOpDescMakerBase {
public: public:
explicit GradOpDescMakerBase( explicit GradOpDescMakerBase(
const OpDescBind& fwd_op, const OpDesc& fwd_op, const std::unordered_set<std::string>& no_grad_set,
const std::unordered_set<std::string>& no_grad_set,
std::unordered_map<std::string, std::string>* grad_to_var, std::unordered_map<std::string, std::string>* grad_to_var,
const std::vector<BlockDescBind*>& grad_block = const std::vector<BlockDesc*>& grad_block = std::vector<BlockDesc*>())
std::vector<BlockDescBind*>())
: fwd_op_(fwd_op), : fwd_op_(fwd_op),
no_grad_set_(no_grad_set), no_grad_set_(no_grad_set),
grad_to_var_(grad_to_var), grad_to_var_(grad_to_var),
grad_block_(grad_block) {} grad_block_(grad_block) {}
virtual ~GradOpDescMakerBase() = default; virtual ~GradOpDescMakerBase() = default;
virtual std::vector<std::unique_ptr<OpDescBind>> operator()() const = 0; virtual std::vector<std::unique_ptr<OpDesc>> operator()() const = 0;
protected: protected:
std::vector<std::string> InputGrad(const std::string& name, std::vector<std::string> InputGrad(const std::string& name,
...@@ -105,26 +103,26 @@ class GradOpDescMakerBase { ...@@ -105,26 +103,26 @@ class GradOpDescMakerBase {
std::string ForwardOpType() const { return this->fwd_op_.Type(); } std::string ForwardOpType() const { return this->fwd_op_.Type(); }
private: private:
const OpDescBind& fwd_op_; const OpDesc& fwd_op_;
const std::unordered_set<std::string>& no_grad_set_; const std::unordered_set<std::string>& no_grad_set_;
std::unordered_map<std::string, std::string>* grad_to_var_; std::unordered_map<std::string, std::string>* grad_to_var_;
protected: protected:
std::vector<BlockDescBind*> grad_block_; std::vector<BlockDesc*> grad_block_;
}; };
class SingleGradOpDescMaker : public GradOpDescMakerBase { class SingleGradOpDescMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const { std::vector<std::unique_ptr<OpDesc>> operator()() const {
std::vector<std::unique_ptr<OpDescBind>> retv; std::vector<std::unique_ptr<OpDesc>> retv;
retv.emplace_back(this->Apply()); retv.emplace_back(this->Apply());
return retv; return retv;
} }
protected: protected:
virtual std::unique_ptr<OpDescBind> Apply() const = 0; virtual std::unique_ptr<OpDesc> Apply() const = 0;
}; };
template <bool DropEmptyIG = true> template <bool DropEmptyIG = true>
...@@ -133,8 +131,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -133,8 +131,8 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
using SingleGradOpDescMaker::SingleGradOpDescMaker; using SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
virtual std::unique_ptr<OpDescBind> Apply() const { virtual std::unique_ptr<OpDesc> Apply() const {
auto* grad = new OpDescBind(); auto* grad = new OpDesc();
grad->SetType(this->GradOpType()); grad->SetType(this->GradOpType());
for (auto& input_param : this->InputNames()) { for (auto& input_param : this->InputNames()) {
...@@ -150,7 +148,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -150,7 +148,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
return std::unique_ptr<OpDescBind>(grad); return std::unique_ptr<OpDesc>(grad);
} }
virtual std::string GradOpType() const { virtual std::string GradOpType() const {
...@@ -161,7 +159,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker { ...@@ -161,7 +159,7 @@ class DefaultGradOpDescMaker : public SingleGradOpDescMaker {
class EmptyGradOpMaker : public GradOpDescMakerBase { class EmptyGradOpMaker : public GradOpDescMakerBase {
public: public:
using GradOpDescMakerBase::GradOpDescMakerBase; using GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<OpDescBind>> operator()() const override { std::vector<std::unique_ptr<OpDesc>> operator()() const override {
return {}; return {};
} }
}; };
......
...@@ -25,12 +25,11 @@ limitations under the License. */ ...@@ -25,12 +25,11 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OpDescBind; class OpDesc;
class BlockDescBind; class BlockDesc;
class CompileTimeInferShapeContext : public InferShapeContext { class CompileTimeInferShapeContext : public InferShapeContext {
public: public:
CompileTimeInferShapeContext(const OpDescBind &op, CompileTimeInferShapeContext(const OpDesc &op, const BlockDesc &block);
const BlockDescBind &block);
bool HasInput(const std::string &name) const override; bool HasInput(const std::string &name) const override;
...@@ -76,13 +75,12 @@ class CompileTimeInferShapeContext : public InferShapeContext { ...@@ -76,13 +75,12 @@ class CompileTimeInferShapeContext : public InferShapeContext {
void SetDim(const std::string &name, const DDim &dim) override; void SetDim(const std::string &name, const DDim &dim) override;
const OpDescBind &op_; const OpDesc &op_;
const BlockDescBind &block_; const BlockDesc &block_;
}; };
OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDesc::OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const VariableNameMap &outputs, const AttributeMap &attrs) {
const AttributeMap &attrs) {
desc_.set_type(type); desc_.set_type(type);
inputs_ = inputs; inputs_ = inputs;
outputs_ = outputs; outputs_ = outputs;
...@@ -90,7 +88,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs, ...@@ -90,7 +88,7 @@ OpDescBind::OpDescBind(const std::string &type, const VariableNameMap &inputs,
need_update_ = true; need_update_ = true;
} }
OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog) OpDesc::OpDesc(const proto::OpDesc &desc, ProgramDesc *prog)
: desc_(desc), need_update_(false) { : desc_(desc), need_update_(false) {
// restore inputs_ // restore inputs_
int input_size = desc_.inputs_size(); int input_size = desc_.inputs_size();
...@@ -126,20 +124,19 @@ OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog) ...@@ -126,20 +124,19 @@ OpDescBind::OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog)
} }
} }
proto::OpDesc *OpDescBind::Proto() { proto::OpDesc *OpDesc::Proto() {
Flush(); Flush();
return &desc_; return &desc_;
} }
const std::vector<std::string> &OpDescBind::Input( const std::vector<std::string> &OpDesc::Input(const std::string &name) const {
const std::string &name) const {
auto it = inputs_.find(name); auto it = inputs_.find(name);
PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name, PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name,
Type()); Type());
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::InputArgumentNames() const { std::vector<std::string> OpDesc::InputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->inputs_) { for (auto &ipt : this->inputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
...@@ -147,21 +144,20 @@ std::vector<std::string> OpDescBind::InputArgumentNames() const { ...@@ -147,21 +144,20 @@ std::vector<std::string> OpDescBind::InputArgumentNames() const {
return retv; return retv;
} }
void OpDescBind::SetInput(const std::string &param_name, void OpDesc::SetInput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
inputs_[param_name] = args; inputs_[param_name] = args;
} }
const std::vector<std::string> &OpDescBind::Output( const std::vector<std::string> &OpDesc::Output(const std::string &name) const {
const std::string &name) const {
auto it = outputs_.find(name); auto it = outputs_.find(name);
PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s", PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s",
name, Type()); name, Type());
return it->second; return it->second;
} }
std::vector<std::string> OpDescBind::OutputArgumentNames() const { std::vector<std::string> OpDesc::OutputArgumentNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
for (auto &ipt : this->outputs_) { for (auto &ipt : this->outputs_) {
retv.insert(retv.end(), ipt.second.begin(), ipt.second.end()); retv.insert(retv.end(), ipt.second.begin(), ipt.second.end());
...@@ -169,19 +165,19 @@ std::vector<std::string> OpDescBind::OutputArgumentNames() const { ...@@ -169,19 +165,19 @@ std::vector<std::string> OpDescBind::OutputArgumentNames() const {
return retv; return retv;
} }
void OpDescBind::SetOutput(const std::string &param_name, void OpDesc::SetOutput(const std::string &param_name,
const std::vector<std::string> &args) { const std::vector<std::string> &args) {
need_update_ = true; need_update_ = true;
this->outputs_[param_name] = args; this->outputs_[param_name] = args;
} }
proto::AttrType OpDescBind::GetAttrType(const std::string &name) const { proto::AttrType OpDesc::GetAttrType(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return static_cast<proto::AttrType>(it->second.which() - 1); return static_cast<proto::AttrType>(it->second.which() - 1);
} }
std::vector<std::string> OpDescBind::AttrNames() const { std::vector<std::string> OpDesc::AttrNames() const {
std::vector<std::string> retv; std::vector<std::string> retv;
retv.reserve(attrs_.size()); retv.reserve(attrs_.size());
for (auto &attr : attrs_) { for (auto &attr : attrs_) {
...@@ -190,41 +186,39 @@ std::vector<std::string> OpDescBind::AttrNames() const { ...@@ -190,41 +186,39 @@ std::vector<std::string> OpDescBind::AttrNames() const {
return retv; return retv;
} }
void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
this->attrs_[name] = v; this->attrs_[name] = v;
need_update_ = true; need_update_ = true;
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { void OpDesc::SetBlockAttr(const std::string &name, BlockDesc &block) {
this->attrs_[name] = &block; this->attrs_[name] = &block;
need_update_ = true; need_update_ = true;
} }
void OpDescBind::SetAttrMap( void OpDesc::SetAttrMap(
const std::unordered_map<std::string, Attribute> &attr_map) { const std::unordered_map<std::string, Attribute> &attr_map) {
attrs_ = attr_map; attrs_ = attr_map;
need_update_ = true; need_update_ = true;
} }
Attribute OpDescBind::GetAttr(const std::string &name) const { Attribute OpDesc::GetAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return it->second; return it->second;
} }
int OpDescBind::GetBlockAttr(const std::string &name) const { int OpDesc::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDescBind *>(it->second)->ID(); return boost::get<BlockDesc *>(it->second)->ID();
} }
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap() const std::unordered_map<std::string, Attribute> &OpDesc::GetAttrMap() const {
const {
return attrs_; return attrs_;
} }
void OpDescBind::Rename(const std::string &old_name, void OpDesc::Rename(const std::string &old_name, const std::string &new_name) {
const std::string &new_name) {
for (auto &input : inputs_) { for (auto &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
} }
...@@ -235,7 +229,7 @@ void OpDescBind::Rename(const std::string &old_name, ...@@ -235,7 +229,7 @@ void OpDescBind::Rename(const std::string &old_name,
need_update_ = true; need_update_ = true;
} }
void OpDescBind::RenameOutput(const std::string &old_name, void OpDesc::RenameOutput(const std::string &old_name,
const std::string &new_name) { const std::string &new_name) {
for (auto &output : outputs_) { for (auto &output : outputs_) {
std::replace(output.second.begin(), output.second.end(), old_name, std::replace(output.second.begin(), output.second.end(), old_name,
...@@ -244,7 +238,7 @@ void OpDescBind::RenameOutput(const std::string &old_name, ...@@ -244,7 +238,7 @@ void OpDescBind::RenameOutput(const std::string &old_name,
need_update_ = true; need_update_ = true;
} }
void OpDescBind::RenameInput(const std::string &old_name, void OpDesc::RenameInput(const std::string &old_name,
const std::string &new_name) { const std::string &new_name) {
for (auto &input : inputs_) { for (auto &input : inputs_) {
std::replace(input.second.begin(), input.second.end(), old_name, new_name); std::replace(input.second.begin(), input.second.end(), old_name, new_name);
...@@ -278,7 +272,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> { ...@@ -278,7 +272,7 @@ struct SetAttrDescVisitor : public boost::static_visitor<void> {
void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); }
}; };
void OpDescBind::Flush() { void OpDesc::Flush() {
if (need_update_) { if (need_update_) {
this->desc_.mutable_inputs()->Clear(); this->desc_.mutable_inputs()->Clear();
for (auto &ipt : inputs_) { for (auto &ipt : inputs_) {
...@@ -330,7 +324,7 @@ static void InitInferShapeFuncs() { ...@@ -330,7 +324,7 @@ static void InitInferShapeFuncs() {
}); });
} }
void OpDescBind::CheckAttrs() { void OpDesc::CheckAttrs() {
PADDLE_ENFORCE(!Type().empty(), PADDLE_ENFORCE(!Type().empty(),
"CheckAttr() can not be called before type is setted."); "CheckAttr() can not be called before type is setted.");
auto *checker = OpInfoMap::Instance().Get(Type()).Checker(); auto *checker = OpInfoMap::Instance().Get(Type()).Checker();
...@@ -342,7 +336,7 @@ void OpDescBind::CheckAttrs() { ...@@ -342,7 +336,7 @@ void OpDescBind::CheckAttrs() {
checker->Check(attrs_); checker->Check(attrs_);
} }
void OpDescBind::InferShape(const BlockDescBind &block) const { void OpDesc::InferShape(const BlockDesc &block) const {
VLOG(3) << "CompileTime infer shape on " << Type(); VLOG(3) << "CompileTime infer shape on " << Type();
InitInferShapeFuncs(); InitInferShapeFuncs();
auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_; auto &infer_shape = OpInfoMap::Instance().Get(this->Type()).infer_shape_;
...@@ -365,7 +359,7 @@ void OpDescBind::InferShape(const BlockDescBind &block) const { ...@@ -365,7 +359,7 @@ void OpDescBind::InferShape(const BlockDescBind &block) const {
infer_shape(&ctx); infer_shape(&ctx);
} }
void OpDescBind::InferVarType(BlockDescBind *block) const { void OpDesc::InferVarType(BlockDesc *block) const {
auto &info = OpInfoMap::Instance().Get(this->Type()); auto &info = OpInfoMap::Instance().Get(this->Type());
if (info.infer_var_type_) { if (info.infer_var_type_) {
info.infer_var_type_(*this, block); info.infer_var_type_(*this, block);
...@@ -384,7 +378,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const { ...@@ -384,7 +378,7 @@ void OpDescBind::InferVarType(BlockDescBind *block) const {
} }
CompileTimeInferShapeContext::CompileTimeInferShapeContext( CompileTimeInferShapeContext::CompileTimeInferShapeContext(
const OpDescBind &op, const BlockDescBind &block) const OpDesc &op, const BlockDesc &block)
: op_(op), block_(block) {} : op_(op), block_(block) {}
bool CompileTimeInferShapeContext::HasInput(const std::string &name) const { bool CompileTimeInferShapeContext::HasInput(const std::string &name) const {
......
...@@ -23,17 +23,17 @@ limitations under the License. */ ...@@ -23,17 +23,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class BlockDescBind; class BlockDesc;
class ProgramDescBind; class ProgramDesc;
class OpDescBind { class OpDesc {
public: public:
OpDescBind() {} OpDesc() {}
OpDescBind(const std::string &type, const VariableNameMap &inputs, OpDesc(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs); const VariableNameMap &outputs, const AttributeMap &attrs);
OpDescBind(const proto::OpDesc &desc, ProgramDescBind *prog); OpDesc(const proto::OpDesc &desc, ProgramDesc *prog);
proto::OpDesc *Proto(); proto::OpDesc *Proto();
...@@ -65,7 +65,7 @@ class OpDescBind { ...@@ -65,7 +65,7 @@ class OpDescBind {
void SetAttr(const std::string &name, const Attribute &v); void SetAttr(const std::string &name, const Attribute &v);
void SetBlockAttr(const std::string &name, BlockDescBind &block); void SetBlockAttr(const std::string &name, BlockDesc &block);
Attribute GetAttr(const std::string &name) const; Attribute GetAttr(const std::string &name) const;
...@@ -107,9 +107,9 @@ class OpDescBind { ...@@ -107,9 +107,9 @@ class OpDescBind {
void CheckAttrs(); void CheckAttrs();
void InferShape(const BlockDescBind &block) const; void InferShape(const BlockDesc &block) const;
void InferVarType(BlockDescBind *block) const; void InferVarType(BlockDesc *block) const;
void MarkAsTarget() { desc_.set_is_target(true); } void MarkAsTarget() { desc_.set_is_target(true); }
......
...@@ -47,7 +47,7 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -47,7 +47,7 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
std::unique_ptr<OperatorBase> OpRegistry::CreateOp( std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
const proto::OpDesc& op_desc) { const proto::OpDesc& op_desc) {
VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be" VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) " "used in unit tests. Use CreateOp(const OpDesc& op_desc) "
"instead."; "instead.";
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
...@@ -59,7 +59,7 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp( ...@@ -59,7 +59,7 @@ std::unique_ptr<OperatorBase> OpRegistry::CreateOp(
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDescBind& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(), return CreateOp(op_desc.Type(), op_desc.Inputs(), op_desc.Outputs(),
op_desc.GetAttrMap()); op_desc.GetAttrMap());
} }
......
...@@ -79,7 +79,7 @@ class OpRegistry { ...@@ -79,7 +79,7 @@ class OpRegistry {
static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const proto::OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
}; };
template <typename PlaceType, bool at_end, size_t I, typename... KernelType> template <typename PlaceType, bool at_end, size_t I, typename... KernelType>
......
...@@ -18,49 +18,49 @@ limitations under the License. */ ...@@ -18,49 +18,49 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { BlockDesc *ProgramDesc::AppendBlock(const BlockDesc &parent) {
auto *b = desc_.add_blocks(); auto *b = desc_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
b->set_idx(desc_.blocks_size() - 1); b->set_idx(desc_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDesc(this, b));
return blocks_.back().get(); return blocks_.back().get();
} }
proto::ProgramDesc *ProgramDescBind::Proto() { proto::ProgramDesc *ProgramDesc::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
return &desc_; return &desc_;
} }
ProgramDescBind::ProgramDescBind() { ProgramDesc::ProgramDesc() {
auto *block = desc_.mutable_blocks()->Add(); auto *block = desc_.mutable_blocks()->Add();
block->set_idx(kRootBlockIndex); block->set_idx(kRootBlockIndex);
block->set_parent_idx(kNoneBlockIndex); block->set_parent_idx(kNoneBlockIndex);
blocks_.emplace_back(new BlockDescBind(this, block)); blocks_.emplace_back(new BlockDesc(this, block));
} }
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) { ProgramDesc::ProgramDesc(const ProgramDesc &o) {
desc_ = o.desc_; desc_ = o.desc_;
for (int i = 0; i < desc_.blocks_size(); ++i) { for (int i = 0; i < desc_.blocks_size(); ++i) {
auto *block = desc_.mutable_blocks(i); auto *block = desc_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this)); blocks_.emplace_back(new BlockDesc(*o.blocks_[i], block, this));
} }
} }
ProgramDescBind::ProgramDescBind(const proto::ProgramDesc &desc) { ProgramDesc::ProgramDesc(const proto::ProgramDesc &desc) {
desc_ = desc; desc_ = desc;
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
} }
ProgramDescBind::ProgramDescBind(const std::string &binary_str) { ProgramDesc::ProgramDesc(const std::string &binary_str) {
PADDLE_ENFORCE(desc_.ParseFromString(binary_str), PADDLE_ENFORCE(desc_.ParseFromString(binary_str),
"Fail to parse program_desc from binary string."); "Fail to parse program_desc from binary string.");
for (auto &block_desc : *desc_.mutable_blocks()) { for (auto &block_desc : *desc_.mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block_desc)); blocks_.emplace_back(new BlockDesc(this, &block_desc));
} }
} }
......
...@@ -23,23 +23,23 @@ limitations under the License. */ ...@@ -23,23 +23,23 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class BlockDescBind; class BlockDesc;
class ProgramDescBind { class ProgramDesc {
public: public:
ProgramDescBind(); ProgramDesc();
explicit ProgramDescBind(const proto::ProgramDesc &desc); explicit ProgramDesc(const proto::ProgramDesc &desc);
ProgramDescBind(const ProgramDescBind &o); ProgramDesc(const ProgramDesc &o);
explicit ProgramDescBind(const std::string &binary_str); explicit ProgramDesc(const std::string &binary_str);
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDesc *AppendBlock(const BlockDesc &parent);
BlockDescBind *MutableBlock(size_t idx) { return blocks_[idx].get(); } BlockDesc *MutableBlock(size_t idx) { return blocks_[idx].get(); }
const BlockDescBind &Block(size_t idx) const { return *blocks_[idx]; } const BlockDesc &Block(size_t idx) const { return *blocks_[idx]; }
size_t Size() const { return blocks_.size(); } size_t Size() const { return blocks_.size(); }
...@@ -48,7 +48,7 @@ class ProgramDescBind { ...@@ -48,7 +48,7 @@ class ProgramDescBind {
private: private:
proto::ProgramDesc desc_; proto::ProgramDesc desc_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDesc>> blocks_;
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
TEST(ProgramDesc, copy_ctor) { TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program; ProgramDesc program;
auto* global_block = program.MutableBlock(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
...@@ -42,12 +42,12 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -42,12 +42,12 @@ TEST(ProgramDesc, copy_ctor) {
out->SetType(proto::VarDesc_VarType_LOD_TENSOR); out->SetType(proto::VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()}); op->SetOutput("Y", {out->Name()});
ProgramDescBind program_copy(program); ProgramDesc program_copy(program);
auto* global_block_copy = program_copy.MutableBlock(0); auto* global_block_copy = program_copy.MutableBlock(0);
ASSERT_NE(global_block, global_block_copy); ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDesc* var_before) {
ASSERT_TRUE(global_block_copy->HasVar(name)); ASSERT_TRUE(global_block_copy->HasVar(name));
auto* copy = global_block_copy->Var(name); auto* copy = global_block_copy->Var(name);
ASSERT_NE(copy, var_before); ASSERT_NE(copy, var_before);
...@@ -81,7 +81,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -81,7 +81,7 @@ TEST(ProgramDesc, copy_ctor) {
} }
TEST(ProgramDescBind, serialize_and_deserialize) { TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin; ProgramDesc program_origin;
auto* global_block = program_origin.MutableBlock(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(proto::VarDesc_VarType_LOD_TENSOR); x->SetType(proto::VarDesc_VarType_LOD_TENSOR);
...@@ -107,11 +107,11 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -107,11 +107,11 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
std::string binary_str; std::string binary_str;
program_origin.Proto()->SerializeToString(&binary_str); program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str); ProgramDesc program_restored(binary_str);
auto* global_block_restored = program_restored.MutableBlock(0); auto* global_block_restored = program_restored.MutableBlock(0);
ASSERT_NE(global_block, global_block_restored); ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDesc* var_before) {
ASSERT_TRUE(global_block_restored->HasVar(name)); ASSERT_TRUE(global_block_restored->HasVar(name));
auto* restored = global_block_restored->Var(name); auto* restored = global_block_restored->Var(name);
ASSERT_NE(restored, var_before); ASSERT_NE(restored, var_before);
......
...@@ -29,7 +29,7 @@ namespace ops = paddle::operators; ...@@ -29,7 +29,7 @@ namespace ops = paddle::operators;
void AddOp(const std::string &type, const f::VariableNameMap &inputs, void AddOp(const std::string &type, const f::VariableNameMap &inputs,
const f::VariableNameMap &outputs, f::AttributeMap attrs, const f::VariableNameMap &outputs, f::AttributeMap attrs,
paddle::framework::BlockDescBind *block) { paddle::framework::BlockDesc *block) {
// insert output // insert output
for (auto kv : outputs) { for (auto kv : outputs) {
for (auto v : kv.second) { for (auto v : kv.second) {
...@@ -51,8 +51,8 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -51,8 +51,8 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
} }
TEST(Prune, one_operator) { TEST(Prune, one_operator) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
block); block);
...@@ -69,8 +69,8 @@ TEST(Prune, one_operator) { ...@@ -69,8 +69,8 @@ TEST(Prune, one_operator) {
} }
TEST(Prune, forward) { TEST(Prune, forward) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, f::AttributeMap{},
block); block);
...@@ -92,8 +92,8 @@ TEST(Prune, forward) { ...@@ -92,8 +92,8 @@ TEST(Prune, forward) {
} }
TEST(Prune, multi_input_op) { TEST(Prune, multi_input_op) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, f::AttributeMap{}, AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, f::AttributeMap{},
block); block);
...@@ -113,8 +113,8 @@ TEST(Prune, multi_input_op) { ...@@ -113,8 +113,8 @@ TEST(Prune, multi_input_op) {
} }
TEST(Prune, multi_output_op) { TEST(Prune, multi_output_op) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block); f::AttributeMap{}, block);
...@@ -132,8 +132,8 @@ TEST(Prune, multi_output_op) { ...@@ -132,8 +132,8 @@ TEST(Prune, multi_output_op) {
} }
TEST(Prune, multi_target) { TEST(Prune, multi_target) {
f::ProgramDescBind program; f::ProgramDesc program;
f::BlockDescBind *block = program.MutableBlock(0); f::BlockDesc *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}},
f::AttributeMap{}, block); f::AttributeMap{}, block);
......
...@@ -25,11 +25,9 @@ ...@@ -25,11 +25,9 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class OperatorBase; class OperatorBase;
class OpDescBind; class OpDesc;
class BlockDescBind;
class BlockDesc;
class InferShapeContext; class InferShapeContext;
class BlockDescBind; class BlockDesc;
using VariableNameMap = std::map<std::string, std::vector<std::string>>; using VariableNameMap = std::map<std::string, std::vector<std::string>>;
...@@ -37,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -37,7 +35,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDescBind*>; std::vector<bool>, BlockDesc*>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
...@@ -45,13 +43,13 @@ using OpCreator = std::function<OperatorBase*( ...@@ -45,13 +43,13 @@ using OpCreator = std::function<OperatorBase*(
const std::string& /*type*/, const VariableNameMap& /*inputs*/, const std::string& /*type*/, const VariableNameMap& /*inputs*/,
const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>; const VariableNameMap& /*outputs*/, const AttributeMap& /*attrs*/)>;
using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDescBind>>( using GradOpMakerFN = std::function<std::vector<std::unique_ptr<OpDesc>>(
const OpDescBind&, const std::unordered_set<std::string>& /*no_grad_set*/, const OpDesc&, const std::unordered_set<std::string>& /*no_grad_set*/,
std::unordered_map<std::string, std::string>* /*grad_to_var*/, std::unordered_map<std::string, std::string>* /*grad_to_var*/,
const std::vector<BlockDescBind*>& grad_block)>; const std::vector<BlockDesc*>& grad_block)>;
using InferVarTypeFN = std::function<void(const OpDescBind& /*op_desc*/, using InferVarTypeFN =
BlockDescBind* /*block*/)>; std::function<void(const OpDesc& /*op_desc*/, BlockDesc* /*block*/)>;
using InferShapeFN = std::function<void(InferShapeContext*)>; using InferShapeFN = std::function<void(InferShapeContext*)>;
......
...@@ -18,29 +18,27 @@ limitations under the License. */ ...@@ -18,29 +18,27 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
proto::VarDesc::VarType VarDescBind::GetType() const { return desc_.type(); } proto::VarDesc::VarType VarDesc::GetType() const { return desc_.type(); }
void VarDescBind::SetType(proto::VarDesc::VarType type) { void VarDesc::SetType(proto::VarDesc::VarType type) { desc_.set_type(type); }
desc_.set_type(type);
}
void VarDescBind::SetShape(const std::vector<int64_t> &dims) { void VarDesc::SetShape(const std::vector<int64_t> &dims) {
VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims()); VectorToRepeated(dims, mutable_tensor_desc()->mutable_dims());
} }
void VarDescBind::SetDataType(proto::DataType data_type) { void VarDesc::SetDataType(proto::DataType data_type) {
mutable_tensor_desc()->set_data_type(data_type); mutable_tensor_desc()->set_data_type(data_type);
} }
std::vector<int64_t> VarDescBind::Shape() const { std::vector<int64_t> VarDesc::Shape() const {
return RepeatedToVector(tensor_desc().dims()); return RepeatedToVector(tensor_desc().dims());
} }
proto::DataType VarDescBind::GetDataType() const { proto::DataType VarDesc::GetDataType() const {
return tensor_desc().data_type(); return tensor_desc().data_type();
} }
void VarDescBind::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
desc_.mutable_lod_tensor()->set_lod_level(lod_level); desc_.mutable_lod_tensor()->set_lod_level(lod_level);
...@@ -54,7 +52,7 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) { ...@@ -54,7 +52,7 @@ void VarDescBind::SetLoDLevel(int32_t lod_level) {
} }
} }
int32_t VarDescBind::GetLodLevel() const { int32_t VarDesc::GetLodLevel() const {
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::LOD_TENSOR: case proto::VarDesc::LOD_TENSOR:
return desc_.lod_tensor().lod_level(); return desc_.lod_tensor().lod_level();
...@@ -66,7 +64,7 @@ int32_t VarDescBind::GetLodLevel() const { ...@@ -66,7 +64,7 @@ int32_t VarDescBind::GetLodLevel() const {
} }
} }
const proto::TensorDesc &VarDescBind::tensor_desc() const { const proto::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type"); PADDLE_ENFORCE(desc_.has_type(), "invoke TensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {
case proto::VarDesc::SELECTED_ROWS: case proto::VarDesc::SELECTED_ROWS:
...@@ -80,7 +78,7 @@ const proto::TensorDesc &VarDescBind::tensor_desc() const { ...@@ -80,7 +78,7 @@ const proto::TensorDesc &VarDescBind::tensor_desc() const {
} }
} }
proto::TensorDesc *VarDescBind::mutable_tensor_desc() { proto::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), PADDLE_ENFORCE(desc_.has_type(),
"invoke MutableTensorDesc must after set type"); "invoke MutableTensorDesc must after set type");
switch (desc_.type()) { switch (desc_.type()) {
......
...@@ -53,14 +53,14 @@ inline void VectorToRepeated(const std::vector<bool> &vec, ...@@ -53,14 +53,14 @@ inline void VectorToRepeated(const std::vector<bool> &vec,
} }
} }
class VarDescBind { class VarDesc {
public: public:
explicit VarDescBind(const std::string &name) { explicit VarDesc(const std::string &name) {
desc_.set_name(name); desc_.set_name(name);
desc_.set_type(proto::VarDesc::LOD_TENSOR); desc_.set_type(proto::VarDesc::LOD_TENSOR);
} }
explicit VarDescBind(const proto::VarDesc &desc) : desc_(desc) {} explicit VarDesc(const proto::VarDesc &desc) : desc_(desc) {}
proto::VarDesc *Proto() { return &desc_; } proto::VarDesc *Proto() { return &desc_; }
......
...@@ -21,8 +21,7 @@ namespace framework { ...@@ -21,8 +21,7 @@ namespace framework {
class VarTypeInference { class VarTypeInference {
public: public:
virtual ~VarTypeInference() {} virtual ~VarTypeInference() {}
virtual void operator()(const OpDescBind& op_desc, virtual void operator()(const OpDesc& op_desc, BlockDesc* block) const = 0;
BlockDescBind* block) const = 0;
}; };
} // namespace framework } // namespace framework
......
...@@ -33,8 +33,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker { ...@@ -33,8 +33,7 @@ class SumOpMaker : public OpProtoAndCheckerMaker {
class SumOpVarTypeInference : public VarTypeInference { class SumOpVarTypeInference : public VarTypeInference {
public: public:
void operator()(const OpDescBind &op_desc, void operator()(const OpDesc &op_desc, BlockDesc *block) const override {
BlockDescBind *block) const override {
auto &inputs = op_desc.Input("X"); auto &inputs = op_desc.Input("X");
auto default_var_type = proto::VarDesc::SELECTED_ROWS; auto default_var_type = proto::VarDesc::SELECTED_ROWS;
...@@ -62,7 +61,7 @@ namespace paddle { ...@@ -62,7 +61,7 @@ namespace paddle {
namespace framework { namespace framework {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
ProgramDescBind prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
...@@ -85,7 +84,7 @@ TEST(InferVarType, sum_op) { ...@@ -85,7 +84,7 @@ TEST(InferVarType, sum_op) {
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDescBind prog; ProgramDesc prog;
auto *op = prog.MutableBlock(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
......
...@@ -149,14 +149,14 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker { ...@@ -149,14 +149,14 @@ class ArrayToLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("lod_tensor_to_array"); grad_op->SetType("lod_tensor_to_array");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetInput("RankTable", Input("RankTable")); grad_op->SetInput("RankTable", Input("RankTable"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -121,12 +121,12 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker { ...@@ -121,12 +121,12 @@ class AssignGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op = new framework::OpDescBind(); auto *op = new framework::OpDesc();
op->SetType("assign"); op->SetType("assign");
op->SetInput("X", OutputGrad("Out")); op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X")); op->SetOutput("Out", InputGrad("X"));
return std::unique_ptr<framework::OpDescBind>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
}; };
......
...@@ -119,8 +119,8 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase { ...@@ -119,8 +119,8 @@ class BeamSearchDecodeInferShape : public framework::InferShapeBase {
class BeamSearchDecodeInferVarType : public framework::VarTypeInference { class BeamSearchDecodeInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDescBind* block) const override { framework::BlockDesc* block) const override {
for (auto& o : op_desc.Output("SentenceIds")) { for (auto& o : op_desc.Output("SentenceIds")) {
block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR); block->Var(o)->SetType(framework::proto::VarDesc::LOD_TENSOR);
} }
......
...@@ -52,14 +52,14 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -52,14 +52,14 @@ class CastOpGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad = new framework::OpDescBind(); auto grad = new framework::OpDesc();
grad->SetType("cast"); grad->SetType("cast");
grad->SetInput("X", OutputGrad("Out")); grad->SetInput("X", OutputGrad("Out"));
grad->SetOutput("Out", InputGrad("X")); grad->SetOutput("Out", InputGrad("X"));
grad->SetAttr("out_dtype", GetAttr("in_dtype")); grad->SetAttr("out_dtype", GetAttr("in_dtype"));
grad->SetAttr("in_dtype", GetAttr("out_dtype")); grad->SetAttr("in_dtype", GetAttr("out_dtype"));
return std::unique_ptr<framework::OpDescBind>(grad); return std::unique_ptr<framework::OpDesc>(grad);
} }
}; };
......
...@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp { ...@@ -65,7 +65,7 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->front() = &scope.NewScope(); scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front(); auto &cur_scope = *scopes->front();
auto *block = Attr<framework::BlockDescBind *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
framework::Executor exec(dev_ctx); framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
} }
...@@ -86,7 +86,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -86,7 +86,7 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"(std::vector<Scope*>) The step scope of conditional block. To " "(std::vector<Scope*>) The step scope of conditional block. To "
"unify the conditional block, rnn and while op, the type of " "unify the conditional block, rnn and while op, the type of "
"scope is std::vector<Scope*>"); "scope is std::vector<Scope*>");
AddAttr<framework::BlockDescBind *>( AddAttr<framework::BlockDesc *>(
"sub_block", "The step block of conditional block operator"); "sub_block", "The step block of conditional block operator");
AddComment(R"DOC(Conditional block operator AddComment(R"DOC(Conditional block operator
...@@ -116,7 +116,7 @@ class ConditionalBlockGradOp : public ConditionalOp { ...@@ -116,7 +116,7 @@ class ConditionalBlockGradOp : public ConditionalOp {
auto &scopes = scope_var->Get<std::vector<framework::Scope *>>(); auto &scopes = scope_var->Get<std::vector<framework::Scope *>>();
framework::Scope &cur_scope = *scopes[0]; framework::Scope &cur_scope = *scopes[0];
auto *block = Attr<framework::BlockDescBind *>("sub_block"); auto *block = Attr<framework::BlockDesc *>("sub_block");
framework::Executor exec(dev_ctx); framework::Executor exec(dev_ctx);
exec.Run(*block->Program(), &cur_scope, block->ID(), false); exec.Run(*block->Program(), &cur_scope, block->ID(), false);
...@@ -170,8 +170,8 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { ...@@ -170,8 +170,8 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto grad_op = new framework::OpDescBind(); auto grad_op = new framework::OpDesc();
grad_op->SetType("conditional_block_grad"); grad_op->SetType("conditional_block_grad");
grad_op->SetInput("X", Input("X")); grad_op->SetInput("X", Input("X"));
grad_op->SetInput("Params", Input("Params")); grad_op->SetInput("Params", Input("Params"));
...@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker { ...@@ -181,7 +181,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params")); grad_op->SetOutput(framework::GradVarName("Params"), InputGrad("Params"));
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]); grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -93,13 +93,13 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -93,13 +93,13 @@ class IncrementGradOpMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("increment"); grad_op->SetType("increment");
grad_op->SetInput("X", Output("Out")); grad_op->SetInput("X", Output("Out"));
grad_op->SetOutput("Out", Input("X")); grad_op->SetOutput("Out", Input("X"));
grad_op->SetAttr("step", -boost::get<float>(GetAttr("step"))); grad_op->SetAttr("step", -boost::get<float>(GetAttr("step")));
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -63,8 +63,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase { ...@@ -63,8 +63,8 @@ class LoDRankTableInferShape : public framework::InferShapeBase {
class LoDRankTableInferVarType : public framework::VarTypeInference { class LoDRankTableInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o)->SetType( block->FindRecursiveOrCreateVar(o)->SetType(
framework::proto::VarDesc::LOD_RANK_TABLE); framework::proto::VarDesc::LOD_RANK_TABLE);
......
...@@ -127,8 +127,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase { ...@@ -127,8 +127,8 @@ class LoDTensorToArrayInferShape : public framework::InferShapeBase {
class LoDTensorToArrayInferVarType : public framework::VarTypeInference { class LoDTensorToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDesc *block) const override {
for (auto &out_var : op_desc.Output("Out")) { for (auto &out_var : op_desc.Output("Out")) {
block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); block->Var(out_var)->SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
} }
...@@ -140,14 +140,14 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpDescMaker { ...@@ -140,14 +140,14 @@ class LoDTensorToArrayGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("array_to_lod_tensor"); grad_op->SetType("array_to_lod_tensor");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetInput("RankTable", Input("RankTable")); grad_op->SetInput("RankTable", Input("RankTable"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -108,8 +108,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel { ...@@ -108,8 +108,8 @@ class LookupTableOpGrad : public framework::OperatorWithKernel {
class LookupTableOpGradVarTypeInference : public framework::VarTypeInference { class LookupTableOpGradVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDescBind* block) const override { framework::BlockDesc* block) const override {
auto out_var_name = op_desc.Output(framework::GradVarName("W")).front(); auto out_var_name = op_desc.Output(framework::GradVarName("W")).front();
auto attr = op_desc.GetAttr("is_sparse"); auto attr = op_desc.GetAttr("is_sparse");
bool is_sparse = boost::get<bool>(attr); bool is_sparse = boost::get<bool>(attr);
......
...@@ -60,13 +60,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker { ...@@ -60,13 +60,13 @@ class MeanGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* grad_op = new framework::OpDescBind(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("mean_grad"); grad_op->SetType("mean_grad");
grad_op->SetInput("X", Input("X")); grad_op->SetInput("X", Input("X"));
grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); grad_op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X")); grad_op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -161,15 +161,15 @@ class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker { ...@@ -161,15 +161,15 @@ class MergeLoDTensorGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("split_lod_tensor"); grad_op->SetType("split_lod_tensor");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetInput("Mask", Input("Mask")); grad_op->SetInput("Mask", Input("Mask"));
grad_op->SetOutput("OutTrue", InputGrad("InTrue")); grad_op->SetOutput("OutTrue", InputGrad("InTrue"));
grad_op->SetOutput("OutFalse", InputGrad("InFalse")); grad_op->SetOutput("OutFalse", InputGrad("InFalse"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -70,12 +70,11 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { ...@@ -70,12 +70,11 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
public: public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDescBind>> operator()() std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
const override { std::vector<std::unique_ptr<framework::OpDesc>> ops;
std::vector<std::unique_ptr<framework::OpDescBind>> ops;
auto x_g = InputGrad("X"); auto x_g = InputGrad("X");
if (!x_g.empty()) { if (!x_g.empty()) {
auto *x_g_op = new framework::OpDescBind(); auto *x_g_op = new framework::OpDesc();
x_g_op->SetType("scale"); x_g_op->SetType("scale");
x_g_op->SetInput("X", OutputGrad("Out")); x_g_op->SetInput("X", OutputGrad("Out"));
x_g_op->SetOutput("Out", x_g); x_g_op->SetOutput("Out", x_g);
...@@ -85,7 +84,7 @@ class MinusGradMaker : public framework::GradOpDescMakerBase { ...@@ -85,7 +84,7 @@ class MinusGradMaker : public framework::GradOpDescMakerBase {
auto y_g = InputGrad("Y"); auto y_g = InputGrad("Y");
if (!y_g.empty()) { if (!y_g.empty()) {
auto *y_g_op = new framework::OpDescBind(); auto *y_g_op = new framework::OpDesc();
y_g_op->SetType("scale"); y_g_op->SetType("scale");
y_g_op->SetInput("X", OutputGrad("Out")); y_g_op->SetInput("X", OutputGrad("Out"));
y_g_op->SetOutput("Out", y_g); y_g_op->SetOutput("Out", y_g);
......
...@@ -65,7 +65,7 @@ class NCCLTester : public ::testing::Test { ...@@ -65,7 +65,7 @@ class NCCLTester : public ::testing::Test {
} }
void NCCLInitOp() { void NCCLInitOp() {
std::unique_ptr<f::OpDescBind> op1(new f::OpDescBind); std::unique_ptr<f::OpDesc> op1(new f::OpDesc);
op1->SetType("ncclInit"); op1->SetType("ncclInit");
op1->SetOutput("Communicator", {"comm"}); op1->SetOutput("Communicator", {"comm"});
...@@ -81,10 +81,9 @@ class NCCLTester : public ::testing::Test { ...@@ -81,10 +81,9 @@ class NCCLTester : public ::testing::Test {
} }
template <class T> template <class T>
void PerThreadProgram(int gpu_id, const f::OpDescBind &op_desc, void PerThreadProgram(int gpu_id, const f::OpDesc &op_desc, f::Scope *scope) {
f::Scope *scope) {
std::unique_lock<std::mutex> lk(mu); std::unique_lock<std::mutex> lk(mu);
const f::OpDescBind *op1 = &op_desc; const f::OpDesc *op1 = &op_desc;
p::GPUPlace place(gpu_id); p::GPUPlace place(gpu_id);
auto &ctx = dev_ctxs.at(gpu_id); auto &ctx = dev_ctxs.at(gpu_id);
...@@ -125,7 +124,7 @@ class NCCLTester : public ::testing::Test { ...@@ -125,7 +124,7 @@ class NCCLTester : public ::testing::Test {
// ncclInitOp with desc // ncclInitOp with desc
TEST(NCCL, ncclInitOp) { TEST(NCCL, ncclInitOp) {
std::unique_ptr<f::OpDescBind> op_desc(new f::OpDescBind); std::unique_ptr<f::OpDesc> op_desc(new f::OpDesc);
op_desc->SetType("ncclInit"); op_desc->SetType("ncclInit");
op_desc->SetOutput("Communicator", {"x1"}); op_desc->SetOutput("Communicator", {"x1"});
...@@ -145,7 +144,7 @@ TEST(NCCL, ncclInitOp) { ...@@ -145,7 +144,7 @@ TEST(NCCL, ncclInitOp) {
// ncclAllReduceOp with desc // ncclAllReduceOp with desc
TEST_F(NCCLTester, ncclAllReduceOp) { TEST_F(NCCLTester, ncclAllReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDesc> op2(new f::OpDesc);
op2->SetType("ncclAllReduce"); op2->SetType("ncclAllReduce");
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
op2->SetInput("Communicator", {"comm"}); op2->SetInput("Communicator", {"comm"});
...@@ -192,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) { ...@@ -192,7 +191,7 @@ TEST_F(NCCLTester, ncclAllReduceOp) {
// ncclReduceOp with desc // ncclReduceOp with desc
TEST_F(NCCLTester, ncclReduceOp) { TEST_F(NCCLTester, ncclReduceOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDesc> op2(new f::OpDesc);
const int kRoot = 0; const int kRoot = 0;
op2->SetType("ncclReduce"); op2->SetType("ncclReduce");
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
...@@ -240,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) { ...@@ -240,7 +239,7 @@ TEST_F(NCCLTester, ncclReduceOp) {
// ncclBcastOp with desc // ncclBcastOp with desc
TEST_F(NCCLTester, ncclBcastOp) { TEST_F(NCCLTester, ncclBcastOp) {
std::unique_ptr<f::OpDescBind> op2(new f::OpDescBind); std::unique_ptr<f::OpDesc> op2(new f::OpDesc);
const int kRoot = 5; const int kRoot = 5;
op2->SetType("ncclBcast"); op2->SetType("ncclBcast");
op2->SetInput("X", {"st"}); op2->SetInput("X", {"st"});
......
...@@ -116,14 +116,14 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker { ...@@ -116,14 +116,14 @@ class PadOpGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* bind = new framework::OpDescBind(); auto* bind = new framework::OpDesc();
bind->SetInput("X", Input("X")); bind->SetInput("X", Input("X"));
bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); bind->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
bind->SetOutput(framework::GradVarName("X"), InputGrad("X")); bind->SetOutput(framework::GradVarName("X"), InputGrad("X"));
bind->SetAttrMap(Attrs()); bind->SetAttrMap(Attrs());
bind->SetType("pad_grad"); bind->SetType("pad_grad");
return std::unique_ptr<framework::OpDescBind>(bind); return std::unique_ptr<framework::OpDesc>(bind);
} }
}; };
......
...@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase { ...@@ -234,7 +234,7 @@ class RecurrentOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t i = 0; i < seq_len; ++i) { for (size_t i = 0; i < seq_len; ++i) {
...@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase { ...@@ -317,7 +317,7 @@ class RecurrentGradOp : public RecurrentBase {
auto reverse = Attr<bool>(kReverse); auto reverse = Attr<bool>(kReverse);
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
for (size_t step_id = 0; step_id < seq_len; ++step_id) { for (size_t step_id = 0; step_id < seq_len; ++step_id) {
...@@ -522,8 +522,7 @@ The ex-state means the state value in the ex-timestep or the previous time step ...@@ -522,8 +522,7 @@ The ex-state means the state value in the ex-timestep or the previous time step
string::Sprintf( string::Sprintf(
"The state variable names. [%s, %s, %s] must be the same order", "The state variable names. [%s, %s, %s] must be the same order",
kExStates, kStates, kInitStateGrads)); kExStates, kStates, kInitStateGrads));
AddAttr<framework::BlockDescBind *>(kStepBlock, AddAttr<framework::BlockDesc *>(kStepBlock, "The step block inside RNN");
"The step block inside RNN");
AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not. AddAttr<bool>(kReverse, R"DOC(Calculate RNN reversely or not.
By default reverse=False By default reverse=False
...@@ -565,8 +564,8 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -565,8 +564,8 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
virtual std::unique_ptr<framework::OpDescBind> Apply() const { virtual std::unique_ptr<framework::OpDesc> Apply() const {
auto *grad = new framework::OpDescBind(); auto *grad = new framework::OpDesc();
grad->SetType("recurrent_grad"); grad->SetType("recurrent_grad");
for (auto &input_param : this->InputNames()) { for (auto &input_param : this->InputNames()) {
grad->SetInput(input_param, this->Input(input_param)); grad->SetInput(input_param, this->Input(input_param));
...@@ -588,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -588,7 +587,7 @@ class RecurrentGradOpDescMaker : public framework::SingleGradOpDescMaker {
grad->SetAttrMap(this->Attrs()); grad->SetAttrMap(this->Attrs());
grad->SetBlockAttr(kStepBlock, *grad_block_[0]); grad->SetBlockAttr(kStepBlock, *grad_block_[0]);
return std::unique_ptr<framework::OpDescBind>(grad); return std::unique_ptr<framework::OpDesc>(grad);
} }
}; };
......
...@@ -58,13 +58,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -58,13 +58,13 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("scale"); grad_op->SetType("scale");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttr("scale", GetAttr("scale")); grad_op->SetAttr("scale", GetAttr("scale"));
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -136,14 +136,14 @@ class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker { ...@@ -136,14 +136,14 @@ class ShrinkRNNGradOpMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *op = new framework::OpDescBind(); auto *op = new framework::OpDesc();
op->SetType("shrink_rnn_memory_grad"); op->SetType("shrink_rnn_memory_grad");
op->SetInput("X", Input("X")); op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out")); op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X")); op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
}; };
......
...@@ -50,13 +50,13 @@ class SignGradMaker : public framework::SingleGradOpDescMaker { ...@@ -50,13 +50,13 @@ class SignGradMaker : public framework::SingleGradOpDescMaker {
public: public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("scale"); grad_op->SetType("scale");
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttr("scale", 0.0f); grad_op->SetAttr("scale", 0.0f);
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -173,8 +173,8 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -173,8 +173,8 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto* grad_op = new framework::OpDescBind(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("softmax_with_cross_entropy_grad"); grad_op->SetType("softmax_with_cross_entropy_grad");
grad_op->SetInput("Label", Input("Label")); grad_op->SetInput("Label", Input("Label"));
grad_op->SetInput("Softmax", Output("Softmax")); grad_op->SetInput("Softmax", Output("Softmax"));
...@@ -183,7 +183,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker { ...@@ -183,7 +183,7 @@ class SoftmaxGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss")); grad_op->SetInput(framework::GradVarName("Loss"), OutputGrad("Loss"));
grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits")); grad_op->SetOutput(framework::GradVarName("Logits"), InputGrad("Logits"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -163,8 +163,8 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker { ...@@ -163,8 +163,8 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("merge_lod_tensor"); grad_op->SetType("merge_lod_tensor");
grad_op->SetInput("InTrue", OutputGrad("OutTrue")); grad_op->SetInput("InTrue", OutputGrad("OutTrue"));
grad_op->SetInput("InFalse", OutputGrad("OutFalse")); grad_op->SetInput("InFalse", OutputGrad("OutFalse"));
...@@ -172,7 +172,7 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker { ...@@ -172,7 +172,7 @@ class SplitLoDTensorArrayGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetInput("X", Input("X")); grad_op->SetInput("X", Input("X"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -108,13 +108,13 @@ class SplitGradMaker : public framework::SingleGradOpDescMaker { ...@@ -108,13 +108,13 @@ class SplitGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto op = new framework::OpDescBind(); auto op = new framework::OpDesc();
op->SetType("concat"); op->SetType("concat");
op->SetInput("X", OutputGrad("Out")); op->SetInput("X", OutputGrad("Out"));
op->SetOutput("Out", InputGrad("X")); op->SetOutput("Out", InputGrad("X"));
op->SetAttrMap(Attrs()); op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(op); return std::unique_ptr<framework::OpDesc>(op);
} }
}; };
......
...@@ -115,8 +115,8 @@ the LoD information with the first input. ...@@ -115,8 +115,8 @@ the LoD information with the first input.
class SumOpVarTypeInference : public framework::VarTypeInference { class SumOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind& op_desc, void operator()(const framework::OpDesc& op_desc,
framework::BlockDescBind* block) const override { framework::BlockDesc* block) const override {
auto& inputs = op_desc.Input("X"); auto& inputs = op_desc.Input("X");
auto var_type = framework::proto::VarDesc::SELECTED_ROWS; auto var_type = framework::proto::VarDesc::SELECTED_ROWS;
...@@ -169,20 +169,19 @@ class SumGradMaker : public framework::GradOpDescMakerBase { ...@@ -169,20 +169,19 @@ class SumGradMaker : public framework::GradOpDescMakerBase {
public: public:
using framework::GradOpDescMakerBase::GradOpDescMakerBase; using framework::GradOpDescMakerBase::GradOpDescMakerBase;
std::vector<std::unique_ptr<framework::OpDescBind>> operator()() std::vector<std::unique_ptr<framework::OpDesc>> operator()() const override {
const override {
auto x_grads = InputGrad("X"); auto x_grads = InputGrad("X");
std::vector<std::unique_ptr<framework::OpDescBind>> grad_ops; std::vector<std::unique_ptr<framework::OpDesc>> grad_ops;
grad_ops.reserve(x_grads.size()); grad_ops.reserve(x_grads.size());
auto og = OutputGrad("Out"); auto og = OutputGrad("Out");
std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops), std::transform(x_grads.begin(), x_grads.end(), std::back_inserter(grad_ops),
[&og](const std::string& x_grad) { [&og](const std::string& x_grad) {
auto* grad_op = new framework::OpDescBind(); auto* grad_op = new framework::OpDesc();
grad_op->SetType("scale"); grad_op->SetType("scale");
grad_op->SetInput("X", og); grad_op->SetInput("X", og);
grad_op->SetOutput("Out", {x_grad}); grad_op->SetOutput("Out", {x_grad});
grad_op->SetAttr("scale", 1.0f); grad_op->SetAttr("scale", 1.0f);
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
}); });
return grad_ops; return grad_ops;
} }
......
...@@ -96,8 +96,8 @@ class WriteToArrayInferShape : public framework::InferShapeBase { ...@@ -96,8 +96,8 @@ class WriteToArrayInferShape : public framework::InferShapeBase {
class WriteToArrayInferVarType : public framework::VarTypeInference { class WriteToArrayInferVarType : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDesc *block) const override {
auto x_name = op_desc.Input("X")[0]; auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
...@@ -175,14 +175,14 @@ class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker { ...@@ -175,14 +175,14 @@ class WriteToArrayGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("read_from_array"); grad_op->SetType("read_from_array");
grad_op->SetInput("I", Input("I")); grad_op->SetInput("I", Input("I"));
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
...@@ -191,14 +191,14 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker { ...@@ -191,14 +191,14 @@ class ReadFromArrayGradMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad_op = new framework::OpDescBind(); auto *grad_op = new framework::OpDesc();
grad_op->SetType("write_to_array"); grad_op->SetType("write_to_array");
grad_op->SetInput("I", Input("I")); grad_op->SetInput("I", Input("I"));
grad_op->SetInput("X", OutputGrad("Out")); grad_op->SetInput("X", OutputGrad("Out"));
grad_op->SetOutput("Out", InputGrad("X")); grad_op->SetOutput("Out", InputGrad("X"));
grad_op->SetAttrMap(Attrs()); grad_op->SetAttrMap(Attrs());
return std::unique_ptr<framework::OpDescBind>(grad_op); return std::unique_ptr<framework::OpDesc>(grad_op);
} }
}; };
......
...@@ -46,7 +46,7 @@ class WhileOp : public framework::OperatorBase { ...@@ -46,7 +46,7 @@ class WhileOp : public framework::OperatorBase {
PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1})); PADDLE_ENFORCE_EQ(cond.dims(), paddle::framework::make_ddim({1}));
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto step_scopes = auto step_scopes =
...@@ -82,7 +82,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -82,7 +82,7 @@ class WhileOpMaker : public framework::OpProtoAndCheckerMaker {
"(StepScopeVar) A vector of local scope, which size equals the " "(StepScopeVar) A vector of local scope, which size equals the "
"step number of While Op. The i'th scope storages temporary " "step number of While Op. The i'th scope storages temporary "
"variables generated in the i'th step."); "variables generated in the i'th step.");
AddAttr<framework::BlockDescBind *>(kStepBlock, AddAttr<framework::BlockDesc *>(kStepBlock,
"The step block inside WhileOp"); "The step block inside WhileOp");
AddComment(R"DOC( AddComment(R"DOC(
)DOC"); )DOC");
...@@ -99,7 +99,7 @@ class WhileGradOp : public framework::OperatorBase { ...@@ -99,7 +99,7 @@ class WhileGradOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
framework::Executor executor(dev_ctx); framework::Executor executor(dev_ctx);
auto *block = Attr<framework::BlockDescBind *>(kStepBlock); auto *block = Attr<framework::BlockDesc *>(kStepBlock);
auto *program = block->Program(); auto *program = block->Program();
auto *step_scopes = auto *step_scopes =
...@@ -209,8 +209,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -209,8 +209,8 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker; using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected: protected:
std::unique_ptr<framework::OpDescBind> Apply() const override { std::unique_ptr<framework::OpDesc> Apply() const override {
auto *grad = new framework::OpDescBind(); auto *grad = new framework::OpDesc();
grad->SetType("while_grad"); grad->SetType("while_grad");
grad->SetInput(kParameters, Input(kParameters)); grad->SetInput(kParameters, Input(kParameters));
...@@ -279,14 +279,14 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker { ...@@ -279,14 +279,14 @@ class WhileGradOpDescMaker : public framework::SingleGradOpDescMaker {
// while operator could be renamed. // while operator could be renamed.
grad->SetAttr("original_output_grad", extra_inputs_list); grad->SetAttr("original_output_grad", extra_inputs_list);
return std::unique_ptr<framework::OpDescBind>(grad); return std::unique_ptr<framework::OpDesc>(grad);
} }
}; };
class WhileGradOpVarTypeInference : public framework::VarTypeInference { class WhileGradOpVarTypeInference : public framework::VarTypeInference {
public: public:
void operator()(const framework::OpDescBind &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDescBind *block) const override { framework::BlockDesc *block) const override {
auto p_names = op_desc.Input(kParameters); auto p_names = op_desc.Input(kParameters);
auto pg_names = op_desc.Output(framework::GradVarName(kParameters)); auto pg_names = op_desc.Output(framework::GradVarName(kParameters));
......
...@@ -108,21 +108,21 @@ static py::bytes SerializeMessage(T &self) { ...@@ -108,21 +108,21 @@ static py::bytes SerializeMessage(T &self) {
// Bind Methods // Bind Methods
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "") py::class_<ProgramDesc>(m, "ProgramDesc", "")
.def(py::init<>()) .def(py::init<>())
.def("__init__", .def("__init__",
[](ProgramDescBind &self, const ProgramDescBind &other) { [](ProgramDesc &self, const ProgramDesc &other) {
new (&self) ProgramDescBind(other); new (&self) ProgramDesc(other);
}) })
.def("__init__", .def("__init__",
[](ProgramDescBind &self, const py::bytes &binary_str) { [](ProgramDesc &self, const py::bytes &binary_str) {
std::string str(binary_str); std::string str(binary_str);
new (&self) ProgramDescBind(str); new (&self) ProgramDesc(str);
}) })
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDesc::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_backward", .def("append_backward",
[](ProgramDescBind &program_desc, const VarDescBind &target, [](ProgramDesc &program_desc, const VarDesc &target,
const std::unordered_set<std::string> &no_grad_vars) { const std::unordered_set<std::string> &no_grad_vars) {
ParamGradInfoMap param_grad_map = ParamGradInfoMap param_grad_map =
AppendBackward(program_desc, target, no_grad_vars); AppendBackward(program_desc, target, no_grad_vars);
...@@ -138,12 +138,12 @@ void BindProgramDesc(py::module &m) { ...@@ -138,12 +138,12 @@ void BindProgramDesc(py::module &m) {
} }
return retv; return retv;
}) })
.def("block", &ProgramDescBind::MutableBlock, .def("block", &ProgramDesc::MutableBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size) .def("num_blocks", &ProgramDesc::Size)
.def("serialize_to_string", SerializeMessage<ProgramDescBind>) .def("serialize_to_string", SerializeMessage<ProgramDesc>)
.def("parse_from_string", .def("parse_from_string",
[](ProgramDescBind &program_desc, const std::string &data) { [](ProgramDesc &program_desc, const std::string &data) {
proto::ProgramDesc *desc = program_desc.Proto(); proto::ProgramDesc *desc = program_desc.Proto();
PADDLE_ENFORCE(desc->ParseFromString(data), PADDLE_ENFORCE(desc->ParseFromString(data),
"Fail to parse ProgramDesc from string. This could " "Fail to parse ProgramDesc from string. This could "
...@@ -152,35 +152,34 @@ void BindProgramDesc(py::module &m) { ...@@ -152,35 +152,34 @@ void BindProgramDesc(py::module &m) {
} }
void BindBlockDesc(py::module &m) { void BindBlockDesc(py::module &m) {
py::class_<BlockDescBind>(m, "BlockDesc", "") py::class_<BlockDesc>(m, "BlockDesc", "")
.def_property_readonly("id", &BlockDescBind::ID) .def_property_readonly("id", &BlockDesc::ID)
.def_property_readonly("parent", &BlockDescBind::Parent) .def_property_readonly("parent", &BlockDesc::Parent)
.def("append_op", &BlockDescBind::AppendOp, .def("append_op", &BlockDesc::AppendOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("prepend_op", &BlockDescBind::PrependOp, .def("prepend_op", &BlockDesc::PrependOp,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("var", .def("var",
[](BlockDescBind &self, py::bytes byte_name) { [](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.Var(name); return self.Var(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("has_var", .def("has_var",
[](BlockDescBind &self, py::bytes byte_name) { [](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.HasVar(name); return self.HasVar(name);
}) })
.def("find_var", .def("find_var",
[](BlockDescBind &self, py::bytes byte_name) { [](BlockDesc &self, py::bytes byte_name) {
std::string name = byte_name; std::string name = byte_name;
return self.FindVar(name); return self.FindVar(name);
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("all_vars", &BlockDescBind::AllVars, .def("all_vars", &BlockDesc::AllVars, py::return_value_policy::reference)
py::return_value_policy::reference) .def("op_size", &BlockDesc::OpSize)
.def("op_size", &BlockDescBind::OpSize) .def("op", &BlockDesc::Op, py::return_value_policy::reference)
.def("op", &BlockDescBind::Op, py::return_value_policy::reference) .def("serialize_to_string", SerializeMessage<BlockDesc>);
.def("serialize_to_string", SerializeMessage<BlockDescBind>);
} }
void BindVarDsec(py::module &m) { void BindVarDsec(py::module &m) {
...@@ -193,25 +192,25 @@ void BindVarDsec(py::module &m) { ...@@ -193,25 +192,25 @@ void BindVarDsec(py::module &m) {
.value("FP32", proto::DataType::FP32) .value("FP32", proto::DataType::FP32)
.value("FP64", proto::DataType::FP64); .value("FP64", proto::DataType::FP64);
py::class_<VarDescBind> var_desc(m, "VarDesc", ""); py::class_<VarDesc> var_desc(m, "VarDesc", "");
var_desc var_desc
.def("name", .def("name",
[](const VarDescBind &self) { [](const VarDesc &self) {
py::bytes name = self.Name(); py::bytes name = self.Name();
return name; return name;
}, },
py::return_value_policy::reference) py::return_value_policy::reference)
.def("set_shape", &VarDescBind::SetShape) .def("set_shape", &VarDesc::SetShape)
.def("set_dtype", &VarDescBind::SetDataType) .def("set_dtype", &VarDesc::SetDataType)
.def("shape", &VarDescBind::Shape, py::return_value_policy::reference) .def("shape", &VarDesc::Shape, py::return_value_policy::reference)
.def("dtype", &VarDescBind::GetDataType) .def("dtype", &VarDesc::GetDataType)
.def("lod_level", &VarDescBind::GetLodLevel) .def("lod_level", &VarDesc::GetLodLevel)
.def("set_lod_level", &VarDescBind::SetLoDLevel) .def("set_lod_level", &VarDesc::SetLoDLevel)
.def("type", &VarDescBind::GetType) .def("type", &VarDesc::GetType)
.def("set_type", &VarDescBind::SetType) .def("set_type", &VarDesc::SetType)
.def("serialize_to_string", SerializeMessage<VarDescBind>) .def("serialize_to_string", SerializeMessage<VarDesc>)
.def("persistable", &VarDescBind::Persistable) .def("persistable", &VarDesc::Persistable)
.def("set_persistable", &VarDescBind::SetPersistable); .def("set_persistable", &VarDesc::SetPersistable);
py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "") py::enum_<proto::VarDesc::VarType>(var_desc, "VarType", "")
.value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR) .value("LOD_TENSOR", proto::VarDesc::LOD_TENSOR)
...@@ -235,26 +234,26 @@ void BindOpDesc(py::module &m) { ...@@ -235,26 +234,26 @@ void BindOpDesc(py::module &m) {
.value("BOOLS", proto::AttrType::BOOLEANS) .value("BOOLS", proto::AttrType::BOOLEANS)
.value("BLOCK", proto::AttrType::BLOCK); .value("BLOCK", proto::AttrType::BLOCK);
py::class_<OpDescBind> op_desc(m, "OpDesc", ""); py::class_<OpDesc> op_desc(m, "OpDesc", "");
op_desc.def("type", &OpDescBind::Type) op_desc.def("type", &OpDesc::Type)
.def("set_type", &OpDescBind::SetType) .def("set_type", &OpDesc::SetType)
.def("input", &OpDescBind::Input) .def("input", &OpDesc::Input)
.def("input_names", &OpDescBind::InputNames) .def("input_names", &OpDesc::InputNames)
.def("set_input", &OpDescBind::SetInput) .def("set_input", &OpDesc::SetInput)
.def("output", &OpDescBind::Output) .def("output", &OpDesc::Output)
.def("output_names", &OpDescBind::OutputNames) .def("output_names", &OpDesc::OutputNames)
.def("set_output", &OpDescBind::SetOutput) .def("set_output", &OpDesc::SetOutput)
.def("has_attr", &OpDescBind::HasAttr) .def("has_attr", &OpDesc::HasAttr)
.def("attr_type", &OpDescBind::GetAttrType) .def("attr_type", &OpDesc::GetAttrType)
.def("attr_names", &OpDescBind::AttrNames) .def("attr_names", &OpDesc::AttrNames)
.def("set_attr", &OpDescBind::SetAttr) .def("set_attr", &OpDesc::SetAttr)
.def("attr", &OpDescBind::GetAttr) .def("attr", &OpDesc::GetAttr)
.def("set_block_attr", &OpDescBind::SetBlockAttr) .def("set_block_attr", &OpDesc::SetBlockAttr)
.def("block_attr", &OpDescBind::GetBlockAttr) .def("block_attr", &OpDesc::GetBlockAttr)
.def("check_attrs", &OpDescBind::CheckAttrs) .def("check_attrs", &OpDesc::CheckAttrs)
.def("infer_shape", &OpDescBind::InferShape) .def("infer_shape", &OpDesc::InferShape)
.def("infer_var_type", &OpDescBind::InferVarType) .def("infer_var_type", &OpDesc::InferVarType)
.def("serialize_to_string", SerializeMessage<OpDescBind>); .def("serialize_to_string", SerializeMessage<OpDesc>);
} }
} // namespace pybind } // namespace pybind
......
...@@ -266,36 +266,36 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -266,36 +266,36 @@ All parameter, weight, gradient are variables in Paddle.
return ret_values; return ret_values;
}); });
m.def("get_grad_op_descs", m.def("get_grad_op_descs",
[](const OpDescBind &op_desc, [](const OpDesc &op_desc,
const std::unordered_set<std::string> &no_grad_set, const std::unordered_set<std::string> &no_grad_set,
std::unordered_map<std::string, std::string> &grad_to_var, std::unordered_map<std::string, std::string> &grad_to_var,
const std::vector<BlockDescBind *> &grad_sub_block) { const std::vector<BlockDesc *> &grad_sub_block) {
std::vector<std::unique_ptr<OpDescBind>> grad_op_descs = std::vector<std::unique_ptr<OpDesc>> grad_op_descs =
framework::OpInfoMap::Instance() framework::OpInfoMap::Instance()
.Get(op_desc.Type()) .Get(op_desc.Type())
.GradOpMaker()(op_desc, no_grad_set, &grad_to_var, .GradOpMaker()(op_desc, no_grad_set, &grad_to_var,
grad_sub_block); grad_sub_block);
std::vector<OpDescBind *> grad_op_desc_ptrs(grad_op_descs.size()); std::vector<OpDesc *> grad_op_desc_ptrs(grad_op_descs.size());
std::transform( std::transform(
grad_op_descs.begin(), grad_op_descs.end(), grad_op_descs.begin(), grad_op_descs.end(),
grad_op_desc_ptrs.begin(), grad_op_desc_ptrs.begin(),
[](std::unique_ptr<OpDescBind> &p) { return p.release(); }); [](std::unique_ptr<OpDesc> &p) { return p.release(); });
return grad_op_desc_ptrs; return grad_op_desc_ptrs;
}); });
m.def("prune", [](const ProgramDescBind &origin, m.def("prune", [](const ProgramDesc &origin,
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin); ProgramDesc prog_with_targets(origin);
for (const auto &t : targets) { for (const auto &t : targets) {
prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget(); prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
} }
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc); Prune(*prog_with_targets.Proto(), &pruned_desc);
return new ProgramDescBind(pruned_desc); return new ProgramDesc(pruned_desc);
}); });
m.def("inference_optimize", [](ProgramDescBind &origin) { m.def("inference_optimize", [](ProgramDesc &origin) {
proto::ProgramDesc pruned_desc; proto::ProgramDesc pruned_desc;
InferenceOptimize(*(origin.Proto()), &pruned_desc); InferenceOptimize(*(origin.Proto()), &pruned_desc);
return new ProgramDescBind(pruned_desc); return new ProgramDesc(pruned_desc);
}); });
m.def_submodule( m.def_submodule(
"var_names", "var_names",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册