提交 47f773dd 编写于 作者: Y Yu Yang 提交者: GitHub

Copy Constructor for ProgramDesc (#4895)

* Implement FC layer with helper

* Update LayerHelper

* Add debug string for Python ProtoBuf

and Rename `Sync` to `Flush`

* Add check of ProtoBuf initialization

* Layer wrapper for FC

* Fix unittest

* Fix CI

* Add code generator

* AttributeChecker Better error log and speicalize bool

Since lots of types can be cast to bool

* Complete mlp, fit_a_line

* Implementation of simple conv_2d layer

* Fix bugs

* Change ProgramDesc not a global variable

* Polish code style

* Stash

* Correct implement BlockDesc destructor

* Correct implement BlockDesc destructor

* Unify program as parameter name

* Fix bugs

* Add unittest

* Fix unit test error

* Remove unused functions

* Add clone for Python Program

* Compare OpDescBind directly
上级 6a03a4d9
...@@ -20,6 +20,7 @@ proto_library(framework_proto SRCS framework.proto) ...@@ -20,6 +20,7 @@ proto_library(framework_proto SRCS framework.proto)
cc_library(attribute SRCS attribute.cc DEPS framework_proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto)
cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info) cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute ddim op_info)
cc_test(program_desc_test SRCS program_desc_test.cc DEPS proto_desc)
cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute)
cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker)
cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto)
......
...@@ -107,6 +107,19 @@ BlockDesc *BlockDescBind::Proto() { ...@@ -107,6 +107,19 @@ BlockDesc *BlockDescBind::Proto() {
Flush(); Flush();
return desc_; return desc_;
} }
BlockDescBind::BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog)
: prog_(prog), desc_(desc) {
need_update_ = true;
for (auto &op : other.ops_) {
ops_.emplace_back(new OpDescBind(*op));
}
for (auto &it : other.vars_) {
auto *var = new VarDescBind(*it.second);
vars_[it.first].reset(var);
}
}
void BlockDescBind::ClearPBOps() { void BlockDescBind::ClearPBOps() {
auto ops = this->desc_->mutable_ops(); auto ops = this->desc_->mutable_ops();
......
...@@ -16,8 +16,10 @@ limitations under the License. */ ...@@ -16,8 +16,10 @@ limitations under the License. */
#include <deque> #include <deque>
#include <memory> #include <memory>
#include <set>
#include <unordered_map> #include <unordered_map>
#include <vector> #include <vector>
#include "paddle/framework/op_desc.h" #include "paddle/framework/op_desc.h"
#include "paddle/framework/var_desc.h" #include "paddle/framework/var_desc.h"
#include "paddle/platform/macros.h" #include "paddle/platform/macros.h"
...@@ -36,6 +38,9 @@ class BlockDescBind { ...@@ -36,6 +38,9 @@ class BlockDescBind {
BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) BlockDescBind(ProgramDescBind *prog, BlockDesc *desc)
: prog_(prog), desc_(desc), need_update_(false) {} : prog_(prog), desc_(desc), need_update_(false) {}
BlockDescBind(const BlockDescBind &other, BlockDesc *desc,
ProgramDescBind *prog);
~BlockDescBind() { ~BlockDescBind() {
this->ClearPBVars(); this->ClearPBVars();
this->ClearPBOps(); this->ClearPBOps();
...@@ -51,6 +56,14 @@ class BlockDescBind { ...@@ -51,6 +56,14 @@ class BlockDescBind {
bool HasVar(const std::string &var_name) const; bool HasVar(const std::string &var_name) const;
std::set<std::string> LocalVarNames() const {
std::set<std::string> var_names;
for (auto &var : vars_) {
var_names.insert(var.first);
}
return var_names;
}
std::vector<VarDescBind *> AllVars() const; std::vector<VarDescBind *> AllVars() const;
BlockDescBind *ParentBlock() const; BlockDescBind *ParentBlock() const;
......
...@@ -39,5 +39,14 @@ ProgramDescBind::ProgramDescBind() { ...@@ -39,5 +39,14 @@ ProgramDescBind::ProgramDescBind() {
block->set_parent_idx(-1); block->set_parent_idx(-1);
blocks_.emplace_back(new BlockDescBind(this, block)); blocks_.emplace_back(new BlockDescBind(this, block));
} }
ProgramDescBind::ProgramDescBind(const ProgramDescBind &o) {
prog_ = o.prog_;
for (int i = 0; i < prog_.blocks_size(); ++i) {
auto *block = prog_.mutable_blocks(i);
blocks_.emplace_back(new BlockDescBind(*o.blocks_[i], block, this));
}
}
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -28,6 +28,8 @@ class ProgramDescBind { ...@@ -28,6 +28,8 @@ class ProgramDescBind {
public: public:
ProgramDescBind(); ProgramDescBind();
ProgramDescBind(const ProgramDescBind &o);
BlockDescBind *AppendBlock(const BlockDescBind &parent); BlockDescBind *AppendBlock(const BlockDescBind &parent);
BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); }
...@@ -40,8 +42,6 @@ class ProgramDescBind { ...@@ -40,8 +42,6 @@ class ProgramDescBind {
ProgramDesc prog_; ProgramDesc prog_;
std::vector<std::unique_ptr<BlockDescBind>> blocks_; std::vector<std::unique_ptr<BlockDescBind>> blocks_;
DISABLE_COPY_AND_ASSIGN(ProgramDescBind);
}; };
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/program_desc.h"
#include "gtest/gtest.h"
#include "paddle/framework/block_desc.h"
namespace paddle {
namespace framework {
TEST(ProgramDesc, copy_ctor) {
ProgramDescBind program;
auto* global_block = program.Block(0);
auto* x = global_block->Var("X");
x->SetType(VarDesc_VarType_LOD_TENSOR);
x->SetLoDLevel(0);
x->SetDataType(FP32);
x->SetShape({1000, 784});
auto* y = global_block->Var("Y");
y->SetType(VarDesc_VarType_LOD_TENSOR);
y->SetLoDLevel(0);
y->SetDataType(FP32);
y->SetShape({784, 100});
auto* op = global_block->AppendOp();
op->SetType("mul");
op->SetInput("X", {x->Name()});
op->SetInput("Y", {y->Name()});
auto* out = global_block->Var("Out");
out->SetType(VarDesc_VarType_LOD_TENSOR);
op->SetOutput("Y", {out->Name()});
ProgramDescBind program_copy(program);
auto* global_block_copy = program_copy.Block(0);
ASSERT_NE(global_block, global_block_copy);
auto assert_same_var = [&](const std::string& name, VarDescBind* var_before) {
ASSERT_TRUE(global_block_copy->HasVar(name));
auto* copy = global_block_copy->Var(name);
ASSERT_NE(copy, var_before);
ASSERT_EQ(copy->Name(), var_before->Name());
ASSERT_EQ(copy->GetType(), var_before->GetType());
ASSERT_EQ(copy->Shape(), var_before->Shape());
ASSERT_EQ(copy->Proto()->SerializeAsString(),
var_before->Proto()->SerializeAsString());
};
ASSERT_EQ(global_block->LocalVarNames(), global_block_copy->LocalVarNames());
ASSERT_EQ(3, global_block_copy->LocalVarNames().size());
assert_same_var("X", x);
assert_same_var("Y", y);
assert_same_var("Out", out);
for (size_t i = 0; i < global_block->OpSize(); ++i) {
auto op_origin = global_block->Op(i);
auto op_copy = global_block->Op(i);
ASSERT_EQ(op_origin->Type(), op_copy->Type());
ASSERT_EQ(op_origin->Inputs(), op_copy->Inputs());
ASSERT_EQ(op_origin->Outputs(), op_copy->Outputs());
ASSERT_EQ(op_copy->Proto()->SerializeAsString(),
op_origin->Proto()->SerializeAsString());
}
// Not check block's protostr are same it because the order of vars could be
// different and it is correct.
}
} // namespace framework
} // namespace paddle
\ No newline at end of file
...@@ -101,6 +101,10 @@ using namespace paddle::framework; // NOLINT ...@@ -101,6 +101,10 @@ using namespace paddle::framework; // NOLINT
void BindProgramDesc(py::module &m) { void BindProgramDesc(py::module &m) {
py::class_<ProgramDescBind>(m, "ProgramDesc", "") py::class_<ProgramDescBind>(m, "ProgramDesc", "")
.def(py::init<>()) .def(py::init<>())
.def("__init__",
[](ProgramDescBind &self, const ProgramDescBind &other) {
new (&self) ProgramDescBind(other);
})
.def("append_block", &ProgramDescBind::AppendBlock, .def("append_block", &ProgramDescBind::AppendBlock,
py::return_value_policy::reference) py::return_value_policy::reference)
.def("append_backward", .def("append_backward",
......
...@@ -364,6 +364,7 @@ class Block(object): ...@@ -364,6 +364,7 @@ class Block(object):
for op_idx in range(0, self.desc.op_size()): for op_idx in range(0, self.desc.op_size()):
ops_in_cpp.append(self.desc.op(op_idx)) ops_in_cpp.append(self.desc.op(op_idx))
if len(self.ops) != 0:
first_op_in_python = self.ops[0].desc first_op_in_python = self.ops[0].desc
last_op_in_python = self.ops[len(self.ops) - 1].desc last_op_in_python = self.ops[len(self.ops) - 1].desc
start_index = None start_index = None
...@@ -376,6 +377,9 @@ class Block(object): ...@@ -376,6 +377,9 @@ class Block(object):
assert start_index is not None assert start_index is not None
assert end_index is not None assert end_index is not None
assert start_index <= end_index assert start_index <= end_index
else:
start_index = 0
end_index = -1
# sync ops append to the head of cpp_ops # sync ops append to the head of cpp_ops
for index in range((start_index - 1 - 1), -1, -1): for index in range((start_index - 1 - 1), -1, -1):
...@@ -413,7 +417,15 @@ class Program(object): ...@@ -413,7 +417,15 @@ class Program(object):
proto = framework_pb2.ProgramDesc.FromString(str(protostr)) proto = framework_pb2.ProgramDesc.FromString(str(protostr))
return proto.__str__() return proto.__str__()
__repr__ = __str__ def clone(self):
p = Program()
p.desc = core.ProgramDesc(self.desc)
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
p.sync_with_cpp()
return p
def __repr__(self):
return str(self)
def global_block(self): def global_block(self):
return self.blocks[0] return self.blocks[0]
......
...@@ -34,6 +34,24 @@ class TestProgram(unittest.TestCase): ...@@ -34,6 +34,24 @@ class TestProgram(unittest.TestCase):
self.assertEqual(1, b.idx) self.assertEqual(1, b.idx)
self.assertEqual(0, b.parent_idx) self.assertEqual(0, b.parent_idx)
def test_program_clone(self):
prog = Program()
x = prog.global_block().create_var(
name='X', shape=[1000, 784], dtype='float32')
y = prog.global_block().create_var(
name='Y', shape=[784, 100], dtype='float32')
out = prog.global_block().create_var(name='Out', dtype='float32')
prog.global_block().append_op(
type="mul", inputs={'X': [x],
'Y': [y]}, outputs={'Out': [out]})
# FIXME(yuyang18): We manual compare the output string, since the order
# of variable could be changed.
print prog
print prog.clone()
def test_append_backward(self): def test_append_backward(self):
prog = Program.instance() prog = Program.instance()
block = prog.global_block() block = prog.global_block()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册