未验证 提交 1363ddb6 编写于 作者: Y Yu Yang 提交者: GitHub

Feature/executor use program bind (#5196)

* Init commit

* Make executor use ProgramDescBind

* Change Attribute from BlockDesc to BlockDescBind

* Since we will get the program desc in RNN, just BlockDesc is not
  enough.
上级 ee11f006
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) { Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: { case framework::AttrType::BOOLEAN: {
return attr_desc.b(); return attr_desc.b();
...@@ -61,13 +61,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) { ...@@ -61,13 +61,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
} }
return val; return val;
} }
case framework::AttrType::BLOCK: { default:
PADDLE_ENFORCE(program != nullptr, PADDLE_THROW("Unsupport attr type %d", attr_desc.type());
"Need to specify ProgramDesc when get a block attr");
return program->mutable_blocks(attr_desc.block_idx());
} }
}
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
return boost::blank(); return boost::blank();
} }
......
...@@ -32,7 +32,7 @@ inline AttrType AttrTypeID() { ...@@ -32,7 +32,7 @@ inline AttrType AttrTypeID() {
return static_cast<AttrType>(tmp.which() - 1); return static_cast<AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
class AttrReader { class AttrReader {
public: public:
......
...@@ -368,7 +368,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward( ...@@ -368,7 +368,7 @@ std::vector<std::unique_ptr<OpDescBind>> MakeBlockBackward(
ProgramDescBind& program_desc, int block_idx, ProgramDescBind& program_desc, int block_idx,
std::unordered_set<std::string>* no_grad_vars, std::unordered_set<std::string>* no_grad_vars,
std::unordered_map<std::string, std::string>* grad_to_var) { std::unordered_map<std::string, std::string>* grad_to_var) {
BlockDescBind* cur_block = program_desc.Block(block_idx); BlockDescBind* cur_block = program_desc.MutableBlock(block_idx);
std::vector<OpDescBind*> op_descs = cur_block->AllOps(); std::vector<OpDescBind*> op_descs = cur_block->AllOps();
std::unordered_map<std::string, std::vector<size_t>> dup_out_ops; std::unordered_map<std::string, std::vector<size_t>> dup_out_ops;
size_t grad_desc_idx = 0; size_t grad_desc_idx = 0;
...@@ -443,7 +443,7 @@ ParamGradInfoMap AppendBackward( ...@@ -443,7 +443,7 @@ ParamGradInfoMap AppendBackward(
} }
const int root_block_idx = 0; const int root_block_idx = 0;
auto root_block = program_desc.Block(root_block_idx); auto root_block = program_desc.MutableBlock(root_block_idx);
// insert fill one op for target // insert fill one op for target
// TODO(qiao) add some check to the target. // TODO(qiao) add some check to the target.
...@@ -492,7 +492,7 @@ ParamGradInfoMap AppendBackward( ...@@ -492,7 +492,7 @@ ParamGradInfoMap AppendBackward(
CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv); CreateGradVarInBlock(forward_op_num, grad_to_var, root_block, &retv);
for (size_t block_index = forward_block_num; for (size_t block_index = forward_block_num;
block_index < program_desc.Size(); ++block_index) { block_index < program_desc.Size(); ++block_index) {
CreateGradVarInBlock(0, grad_to_var, program_desc.Block(block_index), CreateGradVarInBlock(0, grad_to_var, program_desc.MutableBlock(block_index),
&retv); &retv);
} }
return retv; return retv;
......
...@@ -499,7 +499,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -499,7 +499,7 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
TEST(Backward, simple_single_op) { TEST(Backward, simple_single_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
op->SetType("rowwise_add"); op->SetType("rowwise_add");
...@@ -535,7 +535,7 @@ TEST(Backward, simple_single_op) { ...@@ -535,7 +535,7 @@ TEST(Backward, simple_single_op) {
TEST(Backward, default_attribute) { TEST(Backward, default_attribute) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
op->SetType("mul"); op->SetType("mul");
op->SetInput("X", {"x"}); op->SetInput("X", {"x"});
...@@ -561,7 +561,7 @@ TEST(Backward, default_attribute) { ...@@ -561,7 +561,7 @@ TEST(Backward, default_attribute) {
TEST(Backward, simple_mult_op) { TEST(Backward, simple_mult_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -644,7 +644,7 @@ TEST(Backward, simple_mult_op) { ...@@ -644,7 +644,7 @@ TEST(Backward, simple_mult_op) {
TEST(Backward, intermedia_var_no_grad) { TEST(Backward, intermedia_var_no_grad) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -714,7 +714,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -714,7 +714,7 @@ TEST(Backward, intermedia_var_no_grad) {
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("mult_in_out"); op1->SetType("mult_in_out");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -790,7 +790,7 @@ TEST(Backward, var_no_grad) { ...@@ -790,7 +790,7 @@ TEST(Backward, var_no_grad) {
TEST(Backward, shared_var) { TEST(Backward, shared_var) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
op1->SetInput("X", {"x1"}); op1->SetInput("X", {"x1"});
...@@ -880,7 +880,7 @@ TEST(Backward, shared_var) { ...@@ -880,7 +880,7 @@ TEST(Backward, shared_var) {
TEST(Backward, half_backward) { TEST(Backward, half_backward) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
auto *op1 = block->AppendOp(); auto *op1 = block->AppendOp();
op1->SetType("minus"); op1->SetType("minus");
op1->SetInput("X", {"a"}); op1->SetInput("X", {"a"});
......
...@@ -113,7 +113,7 @@ BlockDescBind *BlockDescBind::ParentBlock() const { ...@@ -113,7 +113,7 @@ BlockDescBind *BlockDescBind::ParentBlock() const {
if (this->desc_->parent_idx() == kNoneBlockIndex) { if (this->desc_->parent_idx() == kNoneBlockIndex) {
return nullptr; return nullptr;
} }
return prog_->Block(static_cast<size_t>(this->desc_->parent_idx())); return prog_->MutableBlock(static_cast<size_t>(this->desc_->parent_idx()));
} }
BlockDesc *BlockDescBind::Proto() { BlockDesc *BlockDescBind::Proto() {
......
...@@ -73,33 +73,32 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) { ...@@ -73,33 +73,32 @@ static void CreateTensor(Variable* var, VarDesc::VarType var_type) {
} }
} }
void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { void Executor::Run(const ProgramDescBind& pdesc, Scope* scope, int block_id) {
// TODO(tonyyang-svail): // TODO(tonyyang-svail):
// - only runs on the first device (i.e. no interdevice communication) // - only runs on the first device (i.e. no interdevice communication)
// - will change to use multiple blocks for RNN op and Cond Op // - will change to use multiple blocks for RNN op and Cond Op
PADDLE_ENFORCE_GT(pdesc.blocks_size(), block_id); PADDLE_ENFORCE_LT(block_id, pdesc.Size());
auto& block = pdesc.blocks(block_id); auto& block = pdesc.Block(block_id);
auto& device = device_contexts_[0]; auto& device = device_contexts_[0];
Scope& local_scope = scope->NewScope(); Scope& local_scope = scope->NewScope();
for (auto& var : block.vars()) { for (auto& var : block.AllVars()) {
if (var.persistable()) { if (var->Persistable()) {
auto* ptr = scope->Var(var.name()); auto* ptr = scope->Var(var->Name());
CreateTensor(ptr, var.type()); CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var.name() VLOG(3) << "Create Variable " << var->Name()
<< " global, which pointer is " << ptr; << " global, which pointer is " << ptr;
} else { } else {
auto* ptr = local_scope.Var(var.name()); auto* ptr = local_scope.Var(var->Name());
CreateTensor(ptr, var.type()); CreateTensor(ptr, var->GetType());
VLOG(3) << "Create Variable " << var.name() VLOG(3) << "Create Variable " << var->Name()
<< " locally, which pointer is " << ptr; << " locally, which pointer is " << ptr;
} }
} }
for (auto& op_desc : block.ops()) { for (auto& op_desc : block.AllOps()) {
auto op = paddle::framework::OpRegistry::CreateOp( auto op = paddle::framework::OpRegistry::CreateOp(*op_desc);
op_desc, const_cast<ProgramDesc*>(&pdesc));
op->Run(local_scope, *device); op->Run(local_scope, *device);
} }
......
...@@ -14,8 +14,8 @@ limitations under the License. */ ...@@ -14,8 +14,8 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/op_info.h" #include "paddle/framework/op_info.h"
#include "paddle/framework/program_desc.h"
#include "paddle/framework/scope.h" #include "paddle/framework/scope.h"
#include "paddle/framework/tensor.h" #include "paddle/framework/tensor.h"
...@@ -34,7 +34,7 @@ class Executor { ...@@ -34,7 +34,7 @@ class Executor {
* ProgramDesc * ProgramDesc
* Scope * Scope
*/ */
void Run(const ProgramDesc&, Scope*, int); void Run(const ProgramDescBind&, Scope*, int);
private: private:
std::vector<platform::DeviceContext*> device_contexts_; std::vector<platform::DeviceContext*> device_contexts_;
......
...@@ -114,7 +114,12 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog) ...@@ -114,7 +114,12 @@ OpDescBind::OpDescBind(const OpDesc &desc, ProgramDescBind *prog)
// restore attrs_ // restore attrs_
for (const OpDesc::Attr &attr : desc_.attrs()) { for (const OpDesc::Attr &attr : desc_.attrs()) {
std::string attr_name = attr.name(); std::string attr_name = attr.name();
attrs_[attr_name] = GetAttrValue(attr, prog->Proto()); if (attr.type() != AttrType::BLOCK) {
attrs_[attr_name] = GetAttrValue(attr);
} else {
auto bid = attr.block_idx();
attrs_[attr_name] = prog->MutableBlock(bid);
}
} }
} }
...@@ -188,8 +193,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { ...@@ -188,8 +193,7 @@ void OpDescBind::SetAttr(const std::string &name, const Attribute &v) {
} }
void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) {
BlockDesc *desc = block.Proto(); this->attrs_[name] = &block;
this->attrs_[name] = desc;
need_update_ = true; need_update_ = true;
} }
...@@ -208,7 +212,7 @@ Attribute OpDescBind::GetAttr(const std::string &name) const { ...@@ -208,7 +212,7 @@ Attribute OpDescBind::GetAttr(const std::string &name) const {
int OpDescBind::GetBlockAttr(const std::string &name) const { int OpDescBind::GetBlockAttr(const std::string &name) const {
auto it = attrs_.find(name); auto it = attrs_.find(name);
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
return boost::get<BlockDesc *>(it->second)->idx(); return boost::get<BlockDescBind *>(it->second)->ID();
} }
const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap() const std::unordered_map<std::string, Attribute> &OpDescBind::GetAttrMap()
......
...@@ -43,13 +43,15 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -43,13 +43,15 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val; return ret_val;
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc, std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
ProgramDesc* program) { VLOG(1) << "CreateOp directly from OpDesc is deprecated. It should only be"
"used in unit tests. Use CreateOp(const OpDescBind& op_desc) "
"instead.";
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr, program); attrs[attr.name()] = GetAttrValue(attr);
} }
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
......
...@@ -77,8 +77,7 @@ class OpRegistry { ...@@ -77,8 +77,7 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc, static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
ProgramDesc* program);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::OpRegistry::CreateOp(op_desc);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) { ...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
scope.Var("OUT1"); scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
...@@ -208,7 +208,7 @@ TEST(OpKernel, all) { ...@@ -208,7 +208,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y0")->GetMutable<LoDTensor>(); scope.Var("y0")->GetMutable<LoDTensor>();
scope.Var("y1")->GetMutable<LoDTensor>(); scope.Var("y1")->GetMutable<LoDTensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
......
...@@ -37,7 +37,9 @@ class ProgramDescBind { ...@@ -37,7 +37,9 @@ class ProgramDescBind {
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } BlockDescBind *MutableBlock(size_t idx) { return blocks_[idx].get(); }
const BlockDescBind &Block(size_t idx) const { return *blocks_[idx]; }
size_t Size() const { return blocks_.size(); } size_t Size() const { return blocks_.size(); }
......
...@@ -20,7 +20,7 @@ namespace paddle { ...@@ -20,7 +20,7 @@ namespace paddle {
namespace framework { namespace framework {
TEST(ProgramDesc, copy_ctor) { TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program; ProgramDescBind program;
auto* global_block = program.Block(0); auto* global_block = program.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR); x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
...@@ -44,7 +44,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -44,7 +44,7 @@ TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program_copy(program); ProgramDescBind program_copy(program);
auto* global_block_copy = program_copy.Block(0); auto* global_block_copy = program_copy.MutableBlock(0);
ASSERT_NE(global_block, global_block_copy); ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
...@@ -82,7 +82,7 @@ TEST(ProgramDesc, copy_ctor) { ...@@ -82,7 +82,7 @@ TEST(ProgramDesc, copy_ctor) {
TEST(ProgramDescBind, serialize_and_deserialize) { TEST(ProgramDescBind, serialize_and_deserialize) {
ProgramDescBind program_origin; ProgramDescBind program_origin;
auto* global_block = program_origin.Block(0); auto* global_block = program_origin.MutableBlock(0);
auto* x = global_block->Var("X"); auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR); x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0); x->SetLoDLevel(0);
...@@ -108,7 +108,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) { ...@@ -108,7 +108,7 @@ TEST(ProgramDescBind, serialize_and_deserialize) {
program_origin.Proto()->SerializeToString(&binary_str); program_origin.Proto()->SerializeToString(&binary_str);
ProgramDescBind program_restored(binary_str); ProgramDescBind program_restored(binary_str);
auto* global_block_restored = program_restored.Block(0); auto* global_block_restored = program_restored.MutableBlock(0);
ASSERT_NE(global_block, global_block_restored); ASSERT_NE(global_block, global_block_restored);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) { auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
......
...@@ -52,7 +52,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs, ...@@ -52,7 +52,7 @@ void AddOp(const std::string &type, const f::VariableNameMap &inputs,
TEST(Prune, one_operator) { TEST(Prune, one_operator) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
...@@ -69,7 +69,7 @@ TEST(Prune, one_operator) { ...@@ -69,7 +69,7 @@ TEST(Prune, one_operator) {
TEST(Prune, forward) { TEST(Prune, forward) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block); AddOp("one_one", {{"input", {"a"}}}, {{"output", {"b"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block); AddOp("one_one", {{"input", {"b"}}}, {{"output", {"c"}}}, {}, block);
...@@ -88,7 +88,7 @@ TEST(Prune, forward) { ...@@ -88,7 +88,7 @@ TEST(Prune, forward) {
TEST(Prune, multi_input_op) { TEST(Prune, multi_input_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block); AddOp("one_one", {{"input", {"a0"}}}, {{"output", {"b0"}}}, {}, block);
AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block); AddOp("one_one", {{"input", {"a1"}}}, {{"output", {"b1"}}}, {}, block);
...@@ -106,7 +106,7 @@ TEST(Prune, multi_input_op) { ...@@ -106,7 +106,7 @@ TEST(Prune, multi_input_op) {
TEST(Prune, multi_output_op) { TEST(Prune, multi_output_op) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
...@@ -122,7 +122,7 @@ TEST(Prune, multi_output_op) { ...@@ -122,7 +122,7 @@ TEST(Prune, multi_output_op) {
TEST(Prune, multi_target) { TEST(Prune, multi_target) {
f::ProgramDescBind program; f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.MutableBlock(0);
AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block); AddOp("one_two", {{"input", {"a"}}}, {{"output", {"b", "c"}}}, {}, block);
AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block); AddOp("one_one", {{"input", {"b"}}}, {{"output", {"b1"}}}, {}, block);
......
...@@ -36,7 +36,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>; ...@@ -36,7 +36,7 @@ using VariableNameMap = std::map<std::string, std::vector<std::string>>;
using Attribute = using Attribute =
boost::variant<boost::blank, int, float, std::string, std::vector<int>, boost::variant<boost::blank, int, float, std::string, std::vector<int>,
std::vector<float>, std::vector<std::string>, bool, std::vector<float>, std::vector<std::string>, bool,
std::vector<bool>, BlockDesc*>; std::vector<bool>, BlockDescBind*>;
using AttributeMap = std::unordered_map<std::string, Attribute>; using AttributeMap = std::unordered_map<std::string, Attribute>;
......
...@@ -63,41 +63,43 @@ namespace framework { ...@@ -63,41 +63,43 @@ namespace framework {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
ProgramDescBind prog; ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
op->SetOutput("Out", {"test_out"}); op->SetOutput("Out", {"test_out"});
prog.Block(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test_out"); prog.MutableBlock(0)->Var("test_out");
op->InferVarType(prog.Block(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::SELECTED_ROWS, prog.Block(0)->Var("test_out")->GetType()); ASSERT_EQ(VarDesc::SELECTED_ROWS,
prog.MutableBlock(0)->Var("test_out")->GetType());
prog.Block(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR); prog.MutableBlock(0)->Var("test_b")->SetType(VarDesc::LOD_TENSOR);
op->InferVarType(prog.Block(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc::LOD_TENSOR, prog.Block(0)->Var("test_out")->GetType()); ASSERT_EQ(VarDesc::LOD_TENSOR,
prog.MutableBlock(0)->Var("test_out")->GetType());
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
ProgramDescBind prog; ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.MutableBlock(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
op->SetOutput("Out", {"test2_out"}); op->SetOutput("Out", {"test2_out"});
prog.Block(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_a")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_b")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS); prog.MutableBlock(0)->Var("test2_c")->SetType(VarDesc::SELECTED_ROWS);
prog.Block(0)->Var("test2_out"); prog.MutableBlock(0)->Var("test2_out");
op->InferVarType(prog.Block(0)); op->InferVarType(prog.MutableBlock(0));
ASSERT_EQ(VarDesc_VarType_LOD_TENSOR, ASSERT_EQ(VarDesc_VarType_LOD_TENSOR,
prog.Block(0)->Var("test2_out")->GetType()); prog.MutableBlock(0)->Var("test2_out")->GetType());
} }
} // namespace framework } // namespace framework
......
...@@ -51,7 +51,7 @@ class RNNAlgorithmTestHelper : public ::testing::Test { ...@@ -51,7 +51,7 @@ class RNNAlgorithmTestHelper : public ::testing::Test {
CreateGlobalVariables(); CreateGlobalVariables();
auto op_desc = CreateOpDesc(); auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); op = paddle::framework::OpRegistry::CreateOp(op_desc);
dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn); dop = &(dynamic_cast<DynamicRecurrentOp*>(op.get())->rnn);
InitCacheManually(); InitCacheManually();
InitStepNet(); InitStepNet();
......
...@@ -129,7 +129,8 @@ void BindProgramDesc(py::module &m) { ...@@ -129,7 +129,8 @@ void BindProgramDesc(py::module &m) {
} }
return retv; return retv;
}) })
.def("block", &ProgramDescBind::Block, py::return_value_policy::reference) .def("block", &ProgramDescBind::MutableBlock,
py::return_value_policy::reference)
.def("num_blocks", &ProgramDescBind::Size) .def("num_blocks", &ProgramDescBind::Size)
.def("serialize_to_string", .def("serialize_to_string",
[](ProgramDescBind &program_desc) -> py::bytes { [](ProgramDescBind &program_desc) -> py::bytes {
......
...@@ -275,7 +275,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -275,7 +275,7 @@ All parameter, weight, gradient are variables in Paddle.
const std::vector<std::array<size_t, 2>> &targets) { const std::vector<std::array<size_t, 2>> &targets) {
ProgramDescBind prog_with_targets(origin); ProgramDescBind prog_with_targets(origin);
for (const auto &t : targets) { for (const auto &t : targets) {
prog_with_targets.Block(t[0])->Op(t[1])->MarkAsTarget(); prog_with_targets.MutableBlock(t[0])->Op(t[1])->MarkAsTarget();
} }
ProgramDesc pruned_desc; ProgramDesc pruned_desc;
Prune(*prog_with_targets.Proto(), &pruned_desc); Prune(*prog_with_targets.Proto(), &pruned_desc);
...@@ -335,7 +335,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -335,7 +335,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc, nullptr); return OpRegistry::CreateOp(desc);
}) })
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
...@@ -439,7 +439,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -439,7 +439,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc, nullptr); auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::RecurrentOp *>(rnn_op.release()); return static_cast<operators::RecurrentOp *>(rnn_op.release());
}) })
.def("set_stepnet", [](operators::RecurrentOp &self, .def("set_stepnet", [](operators::RecurrentOp &self,
...@@ -457,7 +457,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -457,7 +457,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc, nullptr); auto rnn_op = OpRegistry::CreateOp(desc);
return static_cast<operators::DynamicRecurrentOp *>( return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release()); rnn_op.release());
}) })
...@@ -484,7 +484,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -484,7 +484,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc, nullptr); auto cond_op = OpRegistry::CreateOp(desc);
return static_cast<operators::CondOp *>(cond_op.release()); return static_cast<operators::CondOp *>(cond_op.release());
}) })
.def("set_truenet", .def("set_truenet",
...@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -498,10 +498,7 @@ All parameter, weight, gradient are variables in Paddle.
py::class_<framework::Executor>(m, "Executor") py::class_<framework::Executor>(m, "Executor")
.def(py::init<std::vector<platform::Place> &>()) .def(py::init<std::vector<platform::Place> &>())
.def("run", [](Executor &self, ProgramDescBind *program_bind, .def("run", &Executor::Run);
Scope *scope, int block_id) {
self.Run(*program_bind->Proto(), scope, block_id);
});
m.def("unique_integer", UniqueIntegerGenerator); m.def("unique_integer", UniqueIntegerGenerator);
m.def("init_gflags", InitGflags); m.def("init_gflags", InitGflags);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册