From e747623e8639ee43a8dd2b33d04f6110a1182de3 Mon Sep 17 00:00:00 2001 From: Yu Yang Date: Wed, 18 Oct 2017 13:14:41 -0700 Subject: [PATCH] Change ProgramDesc not a global variable (#4879) * Change ProgramDesc not a global variable * Polish code style * Correct implement BlockDesc destructor * Unify program as parameter name --- paddle/framework/attribute.cc | 18 +++------- paddle/framework/attribute.h | 5 +-- paddle/framework/backward_test.cc | 31 ++++------------- paddle/framework/executor.cc | 3 +- paddle/framework/op_registry.cc | 5 +-- paddle/framework/op_registry.h | 3 +- paddle/framework/op_registry_test.cc | 12 +++---- paddle/framework/operator_test.cc | 6 ++-- paddle/framework/program_desc.cc | 33 +++++-------------- paddle/framework/program_desc.h | 7 ++-- paddle/framework/var_type_inference_test.cc | 4 +-- paddle/operators/dynamic_recurrent_op_test.cc | 2 +- paddle/pybind/protobuf.cc | 16 +-------- paddle/pybind/pybind.cc | 9 ++--- python/paddle/v2/framework/framework.py | 6 ++-- .../v2/framework/tests/test_infer_shape.py | 4 +-- .../paddle/v2/framework/tests/test_layers.py | 6 ++-- .../v2/framework/tests/test_protobuf_descs.py | 18 +++++----- 18 files changed, 62 insertions(+), 126 deletions(-) diff --git a/paddle/framework/attribute.cc b/paddle/framework/attribute.cc index d6a2975aaa4..29fe352ca45 100644 --- a/paddle/framework/attribute.cc +++ b/paddle/framework/attribute.cc @@ -19,19 +19,7 @@ limitations under the License. */ namespace paddle { namespace framework { -static ProgramDesc* g_program_desc = nullptr; - -ProgramDesc& GetProgramDesc() { - if (g_program_desc == nullptr) { - g_program_desc = new ProgramDesc(); - auto root_block = g_program_desc->mutable_blocks()->Add(); - root_block->set_idx(0); - root_block->set_parent_idx(-1); - } - return *g_program_desc; -} - -Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { +Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* program) { switch (attr_desc.type()) { case framework::AttrType::BOOLEAN: { return attr_desc.b(); @@ -74,7 +62,9 @@ Attribute GetAttrValue(const OpDesc::Attr& attr_desc) { return val; } case framework::AttrType::BLOCK: { - return GetProgramDesc().mutable_blocks(attr_desc.block_idx()); + PADDLE_ENFORCE(program != nullptr, + "Need to specify ProgramDesc when get a block attr"); + return program->mutable_blocks(attr_desc.block_idx()); } } PADDLE_ENFORCE(false, "Unknown OpDesc::AttrDesc::type !"); diff --git a/paddle/framework/attribute.h b/paddle/framework/attribute.h index 8a7a949346e..9744662b8f7 100644 --- a/paddle/framework/attribute.h +++ b/paddle/framework/attribute.h @@ -26,16 +26,13 @@ limitations under the License. */ namespace paddle { namespace framework { - -ProgramDesc& GetProgramDesc(); - template inline AttrType AttrTypeID() { Attribute tmp = T(); return static_cast(tmp.which() - 1); } -Attribute GetAttrValue(const OpDesc::Attr& attr_desc); +Attribute GetAttrValue(const OpDesc::Attr& attr_desc, ProgramDesc* desc); class AttrReader { public: diff --git a/paddle/framework/backward_test.cc b/paddle/framework/backward_test.cc index 0c35a157bcf..10301f7e394 100644 --- a/paddle/framework/backward_test.cc +++ b/paddle/framework/backward_test.cc @@ -495,19 +495,8 @@ TEST(Backward, linear_net_intermediate_variable_has_no_grad) { EXPECT_EQ(bwd_net->ops_[2]->Outputs(all).size(), 0UL); } -// =================================== // - -f::ProgramDesc *GetNewProgramDesc() { - auto *program_desc = new f::ProgramDesc(); - auto *root_block = program_desc->add_blocks(); - root_block->set_idx(0); - root_block->set_parent_idx(-1); - return program_desc; -} - TEST(Backward, simple_single_op) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op = block->AppendOp(); @@ -543,8 +532,7 @@ TEST(Backward, simple_single_op) { } TEST(Backward, default_attribute) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op = block->AppendOp(); op->SetType("mul"); @@ -570,8 +558,7 @@ TEST(Backward, default_attribute) { } TEST(Backward, simple_mult_op) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("rowwise_add"); @@ -654,8 +641,7 @@ TEST(Backward, simple_mult_op) { } TEST(Backward, intermedia_var_no_grad) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("rowwise_add"); @@ -725,8 +711,7 @@ TEST(Backward, intermedia_var_no_grad) { } TEST(Backward, var_no_grad) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("mult_in_out"); @@ -802,8 +787,7 @@ TEST(Backward, var_no_grad) { } TEST(Backward, shared_var) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); f::OpDescBind *op1 = block->AppendOp(); op1->SetType("rowwise_add"); @@ -893,8 +877,7 @@ TEST(Backward, shared_var) { } TEST(Backward, half_backward) { - f::ProgramDesc *program_desc = GetNewProgramDesc(); - f::ProgramDescBind &program = f::ProgramDescBind::Instance(program_desc); + f::ProgramDescBind program; f::BlockDescBind *block = program.Block(0); auto *op1 = block->AppendOp(); op1->SetType("minus"); diff --git a/paddle/framework/executor.cc b/paddle/framework/executor.cc index b3b85b58657..00caa6e1d53 100644 --- a/paddle/framework/executor.cc +++ b/paddle/framework/executor.cc @@ -75,7 +75,8 @@ void Executor::Run(const ProgramDesc& pdesc, Scope* scope, int block_id) { } for (auto& op_desc : block.ops()) { - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp( + op_desc, const_cast(&pdesc)); op->Run(local_scope, *device); } diff --git a/paddle/framework/op_registry.cc b/paddle/framework/op_registry.cc index 504afbd5dba..c2f2438edf6 100644 --- a/paddle/framework/op_registry.cc +++ b/paddle/framework/op_registry.cc @@ -43,12 +43,13 @@ static VariableNameMap ConvertOpDescVarsToVarNameMap( return ret_val; } -std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc) { +std::unique_ptr OpRegistry::CreateOp(const OpDesc& op_desc, + ProgramDesc* program) { VariableNameMap inputs = ConvertOpDescVarsToVarNameMap(op_desc.inputs()); VariableNameMap outputs = ConvertOpDescVarsToVarNameMap(op_desc.outputs()); AttributeMap attrs; for (auto& attr : op_desc.attrs()) { - attrs[attr.name()] = GetAttrValue(attr); + attrs[attr.name()] = GetAttrValue(attr, program); } return CreateOp(op_desc.type(), inputs, outputs, attrs); diff --git a/paddle/framework/op_registry.h b/paddle/framework/op_registry.h index dfca46b789f..d25b4abccb1 100644 --- a/paddle/framework/op_registry.h +++ b/paddle/framework/op_registry.h @@ -74,7 +74,8 @@ class OpRegistry { const VariableNameMap& outputs, AttributeMap attrs); - static std::unique_ptr CreateOp(const OpDesc& op_desc); + static std::unique_ptr CreateOp(const OpDesc& op_desc, + ProgramDesc* program); static std::unique_ptr CreateOp(const OpDescBind& op_desc); }; diff --git a/paddle/framework/op_registry_test.cc b/paddle/framework/op_registry_test.cc index b860fe6cac7..6289125d7c7 100644 --- a/paddle/framework/op_registry_test.cc +++ b/paddle/framework/op_registry_test.cc @@ -74,7 +74,7 @@ TEST(OpRegistry, CreateOp) { attr->set_type(paddle::framework::AttrType::FLOAT); attr->set_f(scale); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -95,7 +95,7 @@ TEST(OpRegistry, IllegalAttr) { bool caught = false; try { - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "larger_than check fail"; @@ -115,7 +115,7 @@ TEST(OpRegistry, DefaultValue) { ASSERT_TRUE(op_desc.IsInitialized()); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::framework::Scope scope; paddle::platform::CPUDeviceContext dev_ctx; op->Run(scope, dev_ctx); @@ -131,7 +131,7 @@ TEST(OpRegistry, CustomChecker) { // attr 'test_attr' is not set bool caught = false; try { - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "Attribute 'test_attr' is required!"; @@ -149,7 +149,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_i(3); caught = false; try { - paddle::framework::OpRegistry::CreateOp(op_desc); + paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); } catch (paddle::platform::EnforceNotMet err) { caught = true; std::string msg = "'test_attr' must be even!"; @@ -166,7 +166,7 @@ TEST(OpRegistry, CustomChecker) { attr->set_name("test_attr"); attr->set_type(paddle::framework::AttrType::INT); attr->set_i(4); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); paddle::platform::CPUDeviceContext dev_ctx; paddle::framework::Scope scope; op->Run(scope, dev_ctx); diff --git a/paddle/framework/operator_test.cc b/paddle/framework/operator_test.cc index d7890ac8d0a..c358f1a2b6e 100644 --- a/paddle/framework/operator_test.cc +++ b/paddle/framework/operator_test.cc @@ -83,7 +83,7 @@ TEST(OperatorBase, all) { paddle::platform::CPUDeviceContext device_context; paddle::framework::Scope scope; - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); scope.Var("OUT1"); ASSERT_EQ(paddle::framework::op_run_num, 0); op->Run(scope, device_context); @@ -208,7 +208,7 @@ TEST(OpKernel, all) { paddle::platform::CPUDeviceContext cpu_device_context; paddle::framework::Scope scope; - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 0); op->Run(scope, cpu_device_context); ASSERT_EQ(paddle::framework::cpu_kernel_run_num, 1); @@ -244,7 +244,7 @@ TEST(OpKernel, multi_inputs) { scope.Var("y0")->GetMutable(); scope.Var("y1")->GetMutable(); - auto op = paddle::framework::OpRegistry::CreateOp(op_desc); + auto op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); op->Run(scope, cpu_device_context); } diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc index fcb72928842..df846f115a5 100644 --- a/paddle/framework/program_desc.cc +++ b/paddle/framework/program_desc.cc @@ -18,27 +18,10 @@ limitations under the License. */ namespace paddle { namespace framework { -using ProgDescMap = - std::unordered_map>; -static ProgDescMap *g_bind_map = nullptr; - -ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) { - if (g_bind_map == nullptr) { - g_bind_map = new ProgDescMap(); - } - auto &map = *g_bind_map; - auto &ptr = map[prog]; - - if (ptr == nullptr) { - ptr.reset(new ProgramDescBind(prog)); - } - return *ptr; -} - BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { - auto *b = prog_->add_blocks(); + auto *b = prog_.add_blocks(); b->set_parent_idx(parent.ID()); - b->set_idx(prog_->blocks_size() - 1); + b->set_idx(prog_.blocks_size() - 1); blocks_.emplace_back(new BlockDescBind(this, b)); return blocks_.back().get(); } @@ -47,14 +30,14 @@ ProgramDesc *ProgramDescBind::Proto() { for (auto &block : blocks_) { block->Flush(); } - return prog_; + return &prog_; } -ProgramDescBind::ProgramDescBind(ProgramDesc *prog) { - prog_ = prog; - for (auto &block : *prog->mutable_blocks()) { - blocks_.emplace_back(new BlockDescBind(this, &block)); - } +ProgramDescBind::ProgramDescBind() { + auto *block = prog_.mutable_blocks()->Add(); + block->set_idx(0); + block->set_parent_idx(-1); + blocks_.emplace_back(new BlockDescBind(this, block)); } } // namespace framework } // namespace paddle diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h index f29b1c54e71..514b62654df 100644 --- a/paddle/framework/program_desc.h +++ b/paddle/framework/program_desc.h @@ -26,7 +26,7 @@ class BlockDescBind; class ProgramDescBind { public: - static ProgramDescBind &Instance(ProgramDesc *prog); + ProgramDescBind(); BlockDescBind *AppendBlock(const BlockDescBind &parent); @@ -37,10 +37,7 @@ class ProgramDescBind { ProgramDesc *Proto(); private: - explicit ProgramDescBind(ProgramDesc *prog); - - // Not owned - ProgramDesc *prog_; + ProgramDesc prog_; std::vector> blocks_; diff --git a/paddle/framework/var_type_inference_test.cc b/paddle/framework/var_type_inference_test.cc index 87399208e92..918de1fd055 100644 --- a/paddle/framework/var_type_inference_test.cc +++ b/paddle/framework/var_type_inference_test.cc @@ -62,7 +62,7 @@ namespace paddle { namespace framework { TEST(InferVarType, sum_op) { - auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + ProgramDescBind prog; auto *op = prog.Block(0)->AppendOp(); op->SetType("sum"); op->SetInput("X", {"test_a", "test_b", "test_c"}); @@ -83,7 +83,7 @@ TEST(InferVarType, sum_op) { } TEST(InferVarType, sum_op_without_infer_var_type) { - auto &prog = ProgramDescBind::Instance(&GetProgramDesc()); + ProgramDescBind prog; auto *op = prog.Block(0)->AppendOp(); op->SetType("sum_without_infer_var_type"); op->SetInput("X", {"test2_a", "test2_b", "test2_c"}); diff --git a/paddle/operators/dynamic_recurrent_op_test.cc b/paddle/operators/dynamic_recurrent_op_test.cc index 83a5ba36d9a..36f405568d7 100644 --- a/paddle/operators/dynamic_recurrent_op_test.cc +++ b/paddle/operators/dynamic_recurrent_op_test.cc @@ -51,7 +51,7 @@ class DynamicRecurrentOpTestHelper : public ::testing::Test { CreateGlobalVariables(); auto op_desc = CreateOpDesc(); - op = paddle::framework::OpRegistry::CreateOp(op_desc); + op = paddle::framework::OpRegistry::CreateOp(op_desc, nullptr); dop = dynamic_cast(op.get()); InitCacheManually(); InitStepNet(); diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 82aae72ba93..fbdd673295b 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -100,21 +100,7 @@ using namespace paddle::framework; // NOLINT // Bind Methods void BindProgramDesc(py::module &m) { py::class_(m, "ProgramDesc", "") - .def_static("instance", - []() -> ProgramDescBind * { - return &ProgramDescBind::Instance(&GetProgramDesc()); - }, - py::return_value_policy::reference) - .def_static("__create_program_desc__", - []() -> ProgramDescBind * { - // Only used for unit-test - auto *prog_desc = new ProgramDesc; - auto *block = prog_desc->mutable_blocks()->Add(); - block->set_idx(0); - block->set_parent_idx(-1); - return &ProgramDescBind::Instance(prog_desc); - }, - py::return_value_policy::reference) + .def(py::init<>()) .def("append_block", &ProgramDescBind::AppendBlock, py::return_value_policy::reference) .def("append_backward", diff --git a/paddle/pybind/pybind.cc b/paddle/pybind/pybind.cc index fcae92ad99f..9eb1bf4a16e 100644 --- a/paddle/pybind/pybind.cc +++ b/paddle/pybind/pybind.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include "paddle/framework/backward.h" #include "paddle/framework/executor.h" #include "paddle/framework/feed_fetch_method.h" +#include "paddle/framework/framework.pb.h" #include "paddle/framework/lod_tensor.h" #include "paddle/framework/selected_rows.h" #include "paddle/framework/tensor_array.h" @@ -259,7 +260,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - return OpRegistry::CreateOp(desc); + return OpRegistry::CreateOp(desc, nullptr); }) .def("backward", [](const OperatorBase &forwardOp, @@ -363,7 +364,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - auto rnn_op = OpRegistry::CreateOp(desc); + auto rnn_op = OpRegistry::CreateOp(desc, nullptr); return static_cast(rnn_op.release()); }) .def("set_stepnet", [](operators::RecurrentOp &self, @@ -381,7 +382,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - auto rnn_op = OpRegistry::CreateOp(desc); + auto rnn_op = OpRegistry::CreateOp(desc, nullptr); return static_cast( rnn_op.release()); }) @@ -408,7 +409,7 @@ All parameter, weight, gradient are variables in Paddle. PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); - auto cond_op = OpRegistry::CreateOp(desc); + auto cond_op = OpRegistry::CreateOp(desc, nullptr); return static_cast(cond_op.release()); }) .def("set_truenet", diff --git a/python/paddle/v2/framework/framework.py b/python/paddle/v2/framework/framework.py index e16bc72447f..93e2218eab9 100644 --- a/python/paddle/v2/framework/framework.py +++ b/python/paddle/v2/framework/framework.py @@ -384,10 +384,8 @@ class Program(object): cls._instance = cls() return cls._instance - def __init__(self, desc=None): - if desc is None: - desc = core.ProgramDesc.instance() - self.desc = desc + def __init__(self): + self.desc = core.ProgramDesc() self.blocks = [Block(self, 0)] self.current_block_idx = 0 diff --git a/python/paddle/v2/framework/tests/test_infer_shape.py b/python/paddle/v2/framework/tests/test_infer_shape.py index 19bb45acef9..5cfb9e6687f 100644 --- a/python/paddle/v2/framework/tests/test_infer_shape.py +++ b/python/paddle/v2/framework/tests/test_infer_shape.py @@ -5,7 +5,7 @@ import paddle.v2.framework.core as core class TestInferShape(unittest.TestCase): def test_sum_op(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) @@ -33,7 +33,7 @@ class TestInferShape(unittest.TestCase): self.assertEqual(out.shape(), shape) def test_mul_op(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) diff --git a/python/paddle/v2/framework/tests/test_layers.py b/python/paddle/v2/framework/tests/test_layers.py index ce20371cfba..2ffadf7371f 100644 --- a/python/paddle/v2/framework/tests/test_layers.py +++ b/python/paddle/v2/framework/tests/test_layers.py @@ -6,8 +6,7 @@ import unittest class TestBook(unittest.TestCase): def test_fit_a_line(self): - pd = core.ProgramDesc.__create_program_desc__() - program = Program(desc=pd) + program = Program() x = data_layer( name='x', shape=[13], data_type='float32', program=program) y_predict = fc_layer(input=x, size=1, act=None, program=program) @@ -21,8 +20,7 @@ class TestBook(unittest.TestCase): print str(program) def test_recognize_digits_mlp(self): - pd = core.ProgramDesc.__create_program_desc__() - program = Program(desc=pd) + program = Program() # Change g_program, so the rest layers use `g_program` images = data_layer( diff --git a/python/paddle/v2/framework/tests/test_protobuf_descs.py b/python/paddle/v2/framework/tests/test_protobuf_descs.py index c775b1a398d..6ed8edf91c7 100644 --- a/python/paddle/v2/framework/tests/test_protobuf_descs.py +++ b/python/paddle/v2/framework/tests/test_protobuf_descs.py @@ -4,7 +4,7 @@ import paddle.v2.framework.core as core class TestOpDesc(unittest.TestCase): def test_op_desc(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) @@ -64,16 +64,16 @@ class TestOpDesc(unittest.TestCase): class TestProgramDesc(unittest.TestCase): def test_instance(self): - program_desc = core.ProgramDesc.__create_program_desc__() + program_desc = core.ProgramDesc() self.assertIsNotNone(program_desc) del program_desc - program_desc = core.ProgramDesc.instance() + program_desc = core.ProgramDesc() self.assertIsNotNone(program_desc) self.assertIsNotNone(program_desc.block(0)) del program_desc def test_append_block(self): - prog_desc = core.ProgramDesc.__create_program_desc__() + prog_desc = core.ProgramDesc() self.assertIsNotNone(prog_desc) block_root = prog_desc.block(0) self.assertIsNotNone(block_root) @@ -91,7 +91,7 @@ class TestProgramDesc(unittest.TestCase): class TestVarDesc(unittest.TestCase): def test_shape(self): - program_desc = core.ProgramDesc.__create_program_desc__() + program_desc = core.ProgramDesc() block = program_desc.block(0) var = block.var('my_var') var.set_type(core.VarDesc.VarType.SELECTED_ROWS) @@ -102,7 +102,7 @@ class TestVarDesc(unittest.TestCase): self.assertEqual(core.VarDesc.VarType.SELECTED_ROWS, var.type()) def test_data_type(self): - program_desc = core.ProgramDesc.__create_program_desc__() + program_desc = core.ProgramDesc() block = program_desc.block(0) var = block.var('my_var') var.set_type(core.VarDesc.VarType.LOD_TENSOR) @@ -113,7 +113,7 @@ class TestVarDesc(unittest.TestCase): class TestBlockDesc(unittest.TestCase): def test_add_var(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) @@ -121,12 +121,12 @@ class TestBlockDesc(unittest.TestCase): var2 = block.var("var2") var3 = block.var("var3") all_vars = block.all_vars() - self.assertEqual(set(all_vars), set([var1, var2, var3])) + self.assertEqual(set(all_vars), {var1, var2, var3}) var2_re = block.find_var("var2") self.assertEqual(var2_re, var2) def test_add_op(self): - prog = core.ProgramDesc.__create_program_desc__() + prog = core.ProgramDesc() self.assertIsNotNone(prog) block = prog.block(0) self.assertIsNotNone(block) -- GitLab