提交 e747623e 编写于 作者: Y Yu Yang 提交者: GitHub

Change ProgramDesc not a global variable (#4879)

* Change ProgramDesc not a global variable

* Polish code style

* Correct implement BlockDesc destructor

* Unify program as parameter name
上级 efd009a0
......@@ -19,19 +19,7 @@ limitations under the License. */
namespace paddle {
namespace framework {
static ProgramDesc* g_program_desc = nullptr;
ProgramDesc& GetProgramDesc() {
if (g_program_desc == nullptr) {
g_program_desc = new ProgramDesc();
auto root_block = g_program_desc->mutable_blocks()->Add();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
}
return *g_program_desc;
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) {
switch (attr_desc.type()) {
case framework::AttrType::BOOLEAN: {
return attr_desc.b();
......@@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) {
return val;
}
case framework::AttrType::BLOCK: {
return GetProgramDesc().mutable_blocks(attr_desc.block_idx());
PADDLE_ENFORCE(program != nullptr,
"Need to specify ProgramDesc when get a block attr");
return program->mutable_blocks(attr_desc.block_idx());
}
}
PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !");
......
......@@ -26,16 +26,13 @@ limitations under the License. */
namespace paddle {
namespace framework {
ProgramDesc& GetProgramDesc();
template <typename T>
inline AttrType AttrTypeID() {
Attribute tmp = T();
return static_cast<AttrType>(tmp.which() - 1);
}
Attribute GetAttrValue(const OpDesc::Attr& attr_desc);
Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc);
class AttrReader {
public:
......
......@@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) {
EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL);
}
// =================================== //
f::ProgramDesc *GetNewProgramDesc() {
auto *program_desc = new f::ProgramDesc();
auto *root_block = program_desc->add_blocks();
root_block->set_idx(0);
root_block->set_parent_idx(-1);
return program_desc;
}
TEST(Backward, simple_single_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp();
......@@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) {
}
TEST(Backward, default_attribute) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op = block->AppendOp();
op->SetType("mul");
......@@ -570,8 +558,7 @@ TEST(Backward, default_attribute) {
}
TEST(Backward, simple_mult_op) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add");
......@@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) {
}
TEST(Backward, intermedia_var_no_grad) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add");
......@@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) {
}
TEST(Backward, var_no_grad) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("mult_in_out");
......@@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) {
}
TEST(Backward, shared_var) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
f::OpDescBind *op1 = block->AppendOp();
op1->SetType("rowwise_add");
......@@ -893,8 +877,7 @@ TEST(Backward, shared_var) {
}
TEST(Backward, half_backward) {
f::ProgramDesc *program_desc = GetNewProgramDesc();
f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc);
f::ProgramDescBind program;
f::BlockDescBind *block = program.Block(0);
auto *op1 = block->AppendOp();
op1->SetType("minus");
......
......@@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) {
}
for (auto& op_desc : block.ops()) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(
op_desc, const_cast<ProgramDesc*>(&pdesc));
op->Run(local_scope, *device);
}
......
......@@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap(
return ret_val;
}
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc) {
std::unique_ptr<OperatorBase> OpRegistry::CreateOp(const OpDesc& op_desc,
ProgramDesc* program) {
VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs());
VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs());
AttributeMap attrs;
for (auto& attr : op_desc.attrs()) {
attrs[attr.name()] = GetAttrValue(attr);
attrs[attr.name()] = GetAttrValue(attr, program);
}
return CreateOp(op_desc.type(), inputs, outputs, attrs);
......
......@@ -74,7 +74,8 @@ class OpRegistry {
const VariableNameMap& outputs,
AttributeMap attrs);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc);
static std::unique_ptr<OperatorBase> CreateOp(const OpDesc& op_desc,
ProgramDesc* program);
static std::unique_ptr<OperatorBase> CreateOp(const OpDescBind& op_desc);
};
......
......@@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) {
attr->set_type(paddle::framework::AttrType::FLOAT);
attr->set_f(scale);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
......@@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) {
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "larger_than check fail";
......@@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) {
ASSERT_TRUE(op_desc.IsInitialized());
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::framework::Scope scope;
paddle::platform::CPUDeviceContext dev_ctx;
op->Run(scope, dev_ctx);
......@@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) {
// attr 'test_attr' is not set
bool caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "Attribute 'test_attr' is required!";
......@@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_i(3);
caught = false;
try {
paddle::framework::OpRegistry::CreateOp(op_desc);
paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
} catch (paddle::platform::EnforceNotMet err) {
caught = true;
std::string msg = "'test_attr' must be even!";
......@@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) {
attr->set_name("test_attr");
attr->set_type(paddle::framework::AttrType::INT);
attr->set_i(4);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
paddle::platform::CPUDeviceContext dev_ctx;
paddle::framework::Scope scope;
op->Run(scope, dev_ctx);
......
......@@ -83,7 +83,7 @@ TEST(OperatorBase, all) {
paddle::platform::CPUDeviceContext device_context;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
scope.Var("OUT1");
ASSERT_EQ(paddle::framework::op_run_num, 0);
op->Run(scope, device_context);
......@@ -208,7 +208,7 @@ TEST(OpKernel, all) {
paddle::platform::CPUDeviceContext cpu_device_context;
paddle::framework::Scope scope;
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0);
op->Run(scope, cpu_device_context);
ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1);
......@@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) {
scope.Var("y0")->GetMutable<Tensor>();
scope.Var("y1")->GetMutable<Tensor>();
auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
op->Run(scope, cpu_device_context);
}
......
......@@ -18,27 +18,10 @@ limitations under the License. */
namespace paddle {
namespace framework {
using ProgDescMap =
std::unordered_map<ProgramDesc *, std::unique_ptr<ProgramDescBind>>;
static ProgDescMap *g_bind_map = nullptr;
ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) {
if (g_bind_map == nullptr) {
g_bind_map = new ProgDescMap();
}
auto &map = *g_bind_map;
auto &ptr = map[prog];
if (ptr == nullptr) {
ptr.reset(new ProgramDescBind(prog));
}
return *ptr;
}
BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) {
auto *b = prog_->add_blocks();
auto *b = prog_.add_blocks();
b->set_parent_idx(parent.ID());
b->set_idx(prog_->blocks_size() - 1);
b->set_idx(prog_.blocks_size() - 1);
blocks_.emplace_back(new BlockDescBind(this, b));
return blocks_.back().get();
}
......@@ -47,14 +30,14 @@ ProgramDesc *ProgramDescBind::Proto() {
for (auto &block : blocks_) {
block->Flush();
}
return prog_;
return &prog_;
}
ProgramDescBind::ProgramDescBind(ProgramDesc *prog) {
prog_ = prog;
for (auto &block : *prog->mutable_blocks()) {
blocks_.emplace_back(new BlockDescBind(this, &block));
}
ProgramDescBind::ProgramDescBind() {
auto *block = prog_.mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
blocks_.emplace_back(new BlockDescBind(this, block));
}
} // namespace framework
} // namespace paddle
......@@ -26,7 +26,7 @@ class BlockDescBind;
class ProgramDescBind {
public:
static ProgramDescBind &Instance(ProgramDesc *prog);
ProgramDescBind();
BlockDescBind *AppendBlock(const BlockDescBind &parent);
......@@ -37,10 +37,7 @@ class ProgramDescBind {
ProgramDesc *Proto();
private:
explicit ProgramDescBind(ProgramDesc *prog);
// Not owned
ProgramDesc *prog_;
ProgramDesc prog_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_;
......
......@@ -62,7 +62,7 @@ namespace paddle {
namespace framework {
TEST(InferVarType, sum_op) {
auto &prog = ProgramDescBind::Instance(&GetProgramDesc());
ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp();
op->SetType("sum");
op->SetInput("X", {"test_a", "test_b", "test_c"});
......@@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) {
}
TEST(InferVarType, sum_op_without_infer_var_type) {
auto &prog = ProgramDescBind::Instance(&GetProgramDesc());
ProgramDescBind prog;
auto *op = prog.Block(0)->AppendOp();
op->SetType("sum_without_infer_var_type");
op->SetInput("X", {"test2_a", "test2_b", "test2_c"});
......
......@@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test {
CreateGlobalVariables();
auto op_desc = CreateOpDesc();
op = paddle::framework::OpRegistry::CreateOp(op_desc);
op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr);
dop = dynamic_cast<DynamicRecurrentOp*>(op.get());
InitCacheManually();
InitStepNet();
......
......@@ -100,21 +100,7 @@ using namespace paddle::framework; // NOLINT
// Bind Methods
void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def_static("instance",
[]() -> ProgramDescBind * {
return &ProgramDescBind::Instance(&GetProgramDesc());
},
py::return_value_policy::reference)
.def_static("__create_program_desc__",
[]() -> ProgramDescBind * {
// Only used for unit-test
auto *prog_desc = new ProgramDesc;
auto *block = prog_desc->mutable_blocks()->Add();
block->set_idx(0);
block->set_parent_idx(-1);
return &ProgramDescBind::Instance(prog_desc);
},
py::return_value_policy::reference)
.def(py::init<>())
.def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference)
.def("append_backward",
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/framework/backward.h"
#include "paddle/framework/executor.h"
#include "paddle/framework/feed_fetch_method.h"
#include "paddle/framework/framework.pb.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/framework/selected_rows.h"
#include "paddle/framework/tensor_array.h"
......@@ -259,7 +260,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
return OpRegistry::CreateOp(desc);
return OpRegistry::CreateOp(desc, nullptr);
})
.def("backward",
[](const OperatorBase &forwardOp,
......@@ -363,7 +364,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::RecurrentOp *>(rnn_op.release());
})
.def("set_stepnet", [](operators::RecurrentOp &self,
......@@ -381,7 +382,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto rnn_op = OpRegistry::CreateOp(desc);
auto rnn_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::DynamicRecurrentOp *>(
rnn_op.release());
})
......@@ -408,7 +409,7 @@ All parameter, weight, gradient are variables in Paddle.
PADDLE_ENFORCE(desc.IsInitialized(),
"User OpDesc is not initialized, reason %s",
desc.InitializationErrorString());
auto cond_op = OpRegistry::CreateOp(desc);
auto cond_op = OpRegistry::CreateOp(desc, nullptr);
return static_cast<operators::CondOp *>(cond_op.release());
})
.def("set_truenet",
......
......@@ -384,10 +384,8 @@ class Program(object):
cls._instance = cls()
return cls._instance
def __init__(self, desc=None):
if desc is None:
desc = core.ProgramDesc.instance()
self.desc = desc
def __init__(self):
self.desc = core.ProgramDesc()
self.blocks = [Block(self, 0)]
self.current_block_idx = 0
......
......@@ -5,7 +5,7 @@ import paddle.v2.framework.core as core
class TestInferShape(unittest.TestCase):
def test_sum_op(self):
prog = core.ProgramDesc.__create_program_desc__()
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
......@@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase):
self.assertEqual(out.shape(), shape)
def test_mul_op(self):
prog = core.ProgramDesc.__create_program_desc__()
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
......
......@@ -6,8 +6,7 @@ import unittest
class TestBook(unittest.TestCase):
def test_fit_a_line(self):
pd = core.ProgramDesc.__create_program_desc__()
program = Program(desc=pd)
program = Program()
x = data_layer(
name='x', shape=[13], data_type='float32', program=program)
y_predict = fc_layer(input=x, size=1, act=None, program=program)
......@@ -21,8 +20,7 @@ class TestBook(unittest.TestCase):
print str(program)
def test_recognize_digits_mlp(self):
pd = core.ProgramDesc.__create_program_desc__()
program = Program(desc=pd)
program = Program()
# Change g_program, so the rest layers use `g_program`
images = data_layer(
......
......@@ -4,7 +4,7 @@ import paddle.v2.framework.core as core
class TestOpDesc(unittest.TestCase):
def test_op_desc(self):
prog = core.ProgramDesc.__create_program_desc__()
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
......@@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase):
class TestProgramDesc(unittest.TestCase):
def test_instance(self):
program_desc = core.ProgramDesc.__create_program_desc__()
program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc)
del program_desc
program_desc = core.ProgramDesc.instance()
program_desc = core.ProgramDesc()
self.assertIsNotNone(program_desc)
self.assertIsNotNone(program_desc.block(0))
del program_desc
def test_append_block(self):
prog_desc = core.ProgramDesc.__create_program_desc__()
prog_desc = core.ProgramDesc()
self.assertIsNotNone(prog_desc)
block_root = prog_desc.block(0)
self.assertIsNotNone(block_root)
......@@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase):
class TestVarDesc(unittest.TestCase):
def test_shape(self):
program_desc = core.ProgramDesc.__create_program_desc__()
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_var')
var.set_type(core.VarDesc.VarType.SELECTED_ROWS)
......@@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase):
self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type())
def test_data_type(self):
program_desc = core.ProgramDesc.__create_program_desc__()
program_desc = core.ProgramDesc()
block = program_desc.block(0)
var = block.var('my_var')
var.set_type(core.VarDesc.VarType.LOD_TENSOR)
......@@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase):
class TestBlockDesc(unittest.TestCase):
def test_add_var(self):
prog = core.ProgramDesc.__create_program_desc__()
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
......@@ -121,12 +121,12 @@ class TestBlockDesc(unittest.TestCase):
var2 = block.var("var2")
var3 = block.var("var3")
all_vars = block.all_vars()
self.assertEqual(set(all_vars), set([var1, var2, var3]))
self.assertEqual(set(all_vars), {var1, var2, var3})
var2_re = block.find_var("var2")
self.assertEqual(var2_re, var2)
def test_add_op(self):
prog = core.ProgramDesc.__create_program_desc__()
prog = core.ProgramDesc()
self.assertIsNotNone(prog)
block = prog.block(0)
self.assertIsNotNone(block)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册