提交 e747623e 编写于 作者: Y Yu Yang 提交者: GitHub

Change ProgramDesc not a global variable (#4879)

* Change ProgramDesc not a global variable

* Polish code style

* Correct implement BlockDesc destructor

* Unify program as parameter name
上级 efd009a0
...@@ -19,19 +19,7 @@ limitations under the License. */ ...@@ -19,19 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
static ProgramDesc* g_program_desc = nullptr; Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
ProgramDesc& GetProgramDesc() {
if (g_program_desc == nullptr) {
g_program_desc = new ProgramDesc();
auto root_block = g_program_desc->mutable_blocks()->Add();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
}
return *g_program_desc;
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
switch (attr_desc.type()) { switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: { case framework::AttrType::BOOLEAN: {
return attr_desc.b(); return attr_desc.b();
...@@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { ...@@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
return val; return val;
} }
case framework::AttrType::BLOCK: { case framework::AttrType::BLOCK: {
return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); PADDLE_ENFORCE(program != nullptr,
"Need to specify ProgramDesc when get a block attr");
return program->mutable_blocks(attr_desc.block_idx());
} }
} }
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
......
...@@ -26,16 +26,13 @@ limitations under the License. */ ...@@ -26,16 +26,13 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
ProgramDesc& GetProgramDesc();
template <typename T> template <typename T>
inline AttrType AttrTypeID() { inline AttrType AttrTypeID() {
Attribute tmp = T(); Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1); return static_cast<AttrType>(tmp.which() - 1);
} }
Attribute GetAttrValue(const OpDesc::Attr& attr_desc); Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc);
class AttrReader { class AttrReader {
public: public:
......
...@@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { ...@@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
} }
// =================================== //
f::ProgramDesc *GetNewProgramDesc() {
auto *program_desc = new f::ProgramDesc();
auto *root_block = program_desc->add_blocks();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
return program_desc;
}
TEST(Backward, simple_single_op) { TEST(Backward, simple_single_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
...@@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) { ...@@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) {
} }
TEST(Backward, default_attribute) { TEST(Backward, default_attribute) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp(); f::OpDescBind *op = block->AppendOp();
op->SetType("mul"); op->SetType("mul");
...@@ -570,8 +558,7 @@ TEST(Backward, default_attribute) { ...@@ -570,8 +558,7 @@ TEST(Backward, default_attribute) {
} }
TEST(Backward, simple_mult_op) { TEST(Backward, simple_mult_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
...@@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) { ...@@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) {
} }
TEST(Backward, intermedia_var_no_grad) { TEST(Backward, intermedia_var_no_grad) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
...@@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) { ...@@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) {
} }
TEST(Backward, var_no_grad) { TEST(Backward, var_no_grad) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("mult_in_out"); op1->SetType("mult_in_out");
...@@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) { ...@@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) {
} }
TEST(Backward, shared_var) { TEST(Backward, shared_var) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp(); f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add"); op1->SetType("rowwise_add");
...@@ -893,8 +877,7 @@ TEST(Backward, shared_var) { ...@@ -893,8 +877,7 @@ TEST(Backward, shared_var) {
} }
TEST(Backward, half_backward) { TEST(Backward, half_backward) {
f::ProgramDesc *program_desc = GetNewProgramDesc(); f::ProgramDescBind program;
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::BlockDescBind *block = program.Block(0); f::BlockDescBind *block = program.Block(0);
auto *op1 = block->AppendOp(); auto *op1 = block->AppendOp();
op1->SetType("minus"); op1->SetType("minus");
......
...@@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { ...@@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
} }
for (auto& op_desc : block.ops()) { for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(
op_desc, const_cast<ProgramDesc*>(&pdesc));
op->Run(local_scope, *device); op->Run(local_scope, *device);
} }
......
...@@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( ...@@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val; return ret_val;
} }
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) { std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc,
ProgramDesc* program) {
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs; AttributeMap attrs;
for (auto& attr : op_desc.attrs()) { for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr); attrs[attr.name()] = GetAttrValue(attr, program);
} }
return CreateOp(op_desc.type(), inputs, outputs, attrs); return CreateOp(op_desc.type(), inputs, outputs, attrs);
......
...@@ -74,7 +74,8 @@ class OpRegistry { ...@@ -74,7 +74,8 @@ class OpRegistry {
const VariableNameMap& outputs, const VariableNameMap& outputs,
AttributeMap attrs); AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc,
ProgramDesc* program);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc); static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
}; };
......
...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) { ...@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale); attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) { ...@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "larger_than check fail"; std::string msg = "larger_than check fail";
...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) { ...@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized()); ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::Scope scope; paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set // attr 'test_attr' is not set
bool caught = false; bool caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "Attribute 'test_attr' is required!"; std::string msg = "Attribute 'test_attr' is required!";
...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3); attr->set_i(3);
caught = false; caught = false;
try { try {
paddle::framework::OpRegistry::CreateOp(op_desc); paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) { } catch (paddle::platform::EnforceNotMet err) {
caught = true; caught = true;
std::string msg = "'test_attr' must be even!"; std::string msg = "'test_attr' must be even!";
...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) { ...@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr"); attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT); attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4); attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::platform::CPUDeviceContext dev_ctx; paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope; paddle::framework::Scope scope;
op->Run(scope, dev_ctx); op->Run(scope, dev_ctx);
......
...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) { ...@@ -83,7 +83,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context; paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
scope.Var("OUT1"); scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0); ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context); op->Run(scope, device_context);
...@@ -208,7 +208,7 @@ TEST(OpKernel, all) { ...@@ -208,7 +208,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context; paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope; paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) { ...@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y0")->GetMutable<Tensor>(); scope.Var("y0")->GetMutable<Tensor>();
scope.Var("y1")->GetMutable<Tensor>(); scope.Var("y1")->GetMutable<Tensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context); op->Run(scope, cpu_device_context);
} }
......
...@@ -18,27 +18,10 @@ limitations under the License. */ ...@@ -18,27 +18,10 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
using ProgDescMap =
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
static ProgDescMap *g_bind_map = nullptr;
ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) {
if (g_bind_map == nullptr) {
g_bind_map = new ProgDescMap();
}
auto &map = *g_bind_map;
auto &ptr = map[prog];
if (ptr == nullptr) {
ptr.reset(new ProgramDescBind(prog));
}
return *ptr;
}
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks(); auto *b = prog_.add_blocks();
b->set_parent_idx(parent.ID()); b->set_parent_idx(parent.ID());
b->set_idx(prog_->blocks_size() - 1); b->set_idx(prog_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b)); blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get(); return blocks_.back().get();
} }
...@@ -47,14 +30,14 @@ ProgramDesc *ProgramDescBind::Proto() { ...@@ -47,14 +30,14 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) { for (auto &block : blocks_) {
block->Flush(); block->Flush();
} }
return prog_; return &prog_;
} }
ProgramDescBind::ProgramDescBind(ProgramDesc *prog) { ProgramDescBind::ProgramDescBind() {
prog_ = prog; auto *block = prog_.mutable_blocks()->Add();
for (auto &block : *prog->mutable_blocks()) { block->set_idx(0);
blocks_.emplace_back(new BlockDescBind(this, &block)); block->set_parent_idx(-1);
} blocks_.emplace_back(new BlockDescBind(this, block));
} }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -26,7 +26,7 @@ class BlockDescBind; ...@@ -26,7 +26,7 @@ class BlockDescBind;
class ProgramDescBind { class ProgramDescBind {
public: public:
static ProgramDescBind &Instance(ProgramDesc *prog); ProgramDescBind();
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
...@@ -37,10 +37,7 @@ class ProgramDescBind { ...@@ -37,10 +37,7 @@ class ProgramDescBind {
ProgramDesc *Proto(); ProgramDesc *Proto();
private: private:
explicit ProgramDescBind(ProgramDesc *prog); ProgramDesc prog_;
// Not owned
ProgramDesc *prog_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
......
...@@ -62,7 +62,7 @@ namespace paddle { ...@@ -62,7 +62,7 @@ namespace paddle {
namespace framework { namespace framework {
TEST(InferVarType, sum_op) { TEST(InferVarType, sum_op) {
auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.Block(0)->AppendOp();
op->SetType("sum"); op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"}); op->SetInput("X", {"test_a", "test_b", "test_c"});
...@@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) { ...@@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) {
} }
TEST(InferVarType, sum_op_without_infer_var_type) { TEST(InferVarType, sum_op_without_infer_var_type) {
auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp(); auto *op = prog.Block(0)->AppendOp();
op->SetType("sum_without_infer_var_type"); op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
......
...@@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { ...@@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
CreateGlobalVariables(); CreateGlobalVariables();
auto op_desc = CreateOpDesc(); auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc); op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get()); dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
InitCacheManually(); InitCacheManually();
InitStepNet(); InitStepNet();
......
...@@ -100,21 +100,7 @@ using namespace paddle::framework; // NOLINT ...@@ -100,21 +100,7 @@ using namespace paddle::framework; // NOLINT
// Bind Methods // Bind Methods
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "") py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance", .def(py::init<>())
[]() -> ProgramDescBind * {
return &ProgramDescBind::Instance(&GetProgramDesc());
},
py::return_value_policy::reference)
.def_static("__create_program_desc__",
[]() -> ProgramDescBind * {
// Only used for unit-test
auto *prog_desc = new ProgramDesc;
auto *block = prog_desc->mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
return &ProgramDescBind::Instance(prog_desc);
},
py::return_value_policy::reference)
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_backward", .def("append_backward",
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/backward.h" #include "paddle/framework/backward.h"
#include "paddle/framework/executor.h" #include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h" #include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h" #include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h" #include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h" #include "paddle/framework/tensor_array.h"
...@@ -259,7 +260,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -259,7 +260,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc, nullptr);
}) })
.def("backward", .def("backward",
[](const OperatorBase &forwardOp, [](const OperatorBase &forwardOp,
...@@ -363,7 +364,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -363,7 +364,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc); auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::RecurrentOp *>(rnn_op.release()); return static_cast<operators::RecurrentOp *>(rnn_op.release());
}) })
.def("set_stepnet", [](operators::RecurrentOp &self, .def("set_stepnet", [](operators::RecurrentOp &self,
...@@ -381,7 +382,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -381,7 +382,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc); auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::DynamicRecurrentOp *>( return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release()); rnn_op.release());
}) })
...@@ -408,7 +409,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -408,7 +409,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(), PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s", "User OpDesc is not initialized, reason %s",
desc.InitializationErrorString()); desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc); auto cond_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::CondOp *>(cond_op.release()); return static_cast<operators::CondOp *>(cond_op.release());
}) })
.def("set_truenet", .def("set_truenet",
......
...@@ -384,10 +384,8 @@ class Program(object): ...@@ -384,10 +384,8 @@ class Program(object):
cls._instance = cls() cls._instance = cls()
return cls._instance return cls._instance
def __init__(self, desc=None): def __init__(self):
if desc is None: self.desc = core.ProgramDesc()
desc = core.ProgramDesc.instance()
self.desc = desc
self.blocks = [Block(self, 0)] self.blocks = [Block(self, 0)]
self.current_block_idx = 0 self.current_block_idx = 0
......
...@@ -5,7 +5,7 @@ import paddle.v2.framework.core as core ...@@ -5,7 +5,7 @@ import paddle.v2.framework.core as core
class TestInferShape(unittest.TestCase): class TestInferShape(unittest.TestCase):
def test_sum_op(self): def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
...@@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase): ...@@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase):
self.assertEqual(out.shape(), shape) self.assertEqual(out.shape(), shape)
def test_mul_op(self): def test_mul_op(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
......
...@@ -6,8 +6,7 @@ import unittest ...@@ -6,8 +6,7 @@ import unittest
class TestBook(unittest.TestCase): class TestBook(unittest.TestCase):
def test_fit_a_line(self): def test_fit_a_line(self):
pd = core.ProgramDesc.__create_program_desc__() program = Program()
program = Program(desc=pd)
x = data_layer( x = data_layer(
name='x', shape=[13], data_type='float32', program=program) name='x', shape=[13], data_type='float32', program=program)
y_predict = fc_layer(input=x, size=1, act=None, program=program) y_predict = fc_layer(input=x, size=1, act=None, program=program)
...@@ -21,8 +20,7 @@ class TestBook(unittest.TestCase): ...@@ -21,8 +20,7 @@ class TestBook(unittest.TestCase):
print str(program) print str(program)
def test_recognize_digits_mlp(self): def test_recognize_digits_mlp(self):
pd = core.ProgramDesc.__create_program_desc__() program = Program()
program = Program(desc=pd)
# Change g_program, so the rest layers use `g_program` # Change g_program, so the rest layers use `g_program`
images = data_layer( images = data_layer(
......
...@@ -4,7 +4,7 @@ import paddle.v2.framework.core as core ...@@ -4,7 +4,7 @@ import paddle.v2.framework.core as core
class TestOpDesc(unittest.TestCase): class TestOpDesc(unittest.TestCase):
def test_op_desc(self): def test_op_desc(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
...@@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase): ...@@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase):
class TestProgramDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase):
def test_instance(self): def test_instance(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
del program_desc del program_desc
program_desc = core.ProgramDesc.instance() program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc)
self.assertIsNotNone(program_desc.block(0)) self.assertIsNotNone(program_desc.block(0))
del program_desc del program_desc
def test_append_block(self): def test_append_block(self):
prog_desc = core.ProgramDesc.__create_program_desc__() prog_desc = core.ProgramDesc()
self.assertIsNotNone(prog_desc) self.assertIsNotNone(prog_desc)
block_root = prog_desc.block(0) block_root = prog_desc.block(0)
self.assertIsNotNone(block_root) self.assertIsNotNone(block_root)
...@@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase): ...@@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase):
class TestVarDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase):
def test_shape(self): def test_shape(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var('my_var')
var.set_type(core.VarDesc.VarType.SELECTED_ROWS) var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
...@@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_data_type(self): def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__() program_desc = core.ProgramDesc()
block = program_desc.block(0) block = program_desc.block(0)
var = block.var('my_var') var = block.var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR) var.set_type(core.VarDesc.VarType.LOD_TENSOR)
...@@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase): ...@@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase):
class TestBlockDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase):
def test_add_var(self): def test_add_var(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
...@@ -121,12 +121,12 @@ class TestBlockDesc(unittest.TestCase): ...@@ -121,12 +121,12 @@ class TestBlockDesc(unittest.TestCase):
var2 = block.var("var2") var2 = block.var("var2")
var3 = block.var("var3") var3 = block.var("var3")
all_vars = block.all_vars() all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3])) self.assertEqual(set(all_vars), {var1, var2, var3})
var2_re = block.find_var("var2") var2_re = block.find_var("var2")
self.assertEqual(var2_re, var2) self.assertEqual(var2_re, var2)
def test_add_op(self): def test_add_op(self):
prog = core.ProgramDesc.__create_program_desc__() prog = core.ProgramDesc()
self.assertIsNotNone(prog) self.assertIsNotNone(prog)
block = prog.block(0) block = prog.block(0)
self.assertIsNotNone(block) self.assertIsNotNone(block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册