diff --git a/paddle/framework/CMakeLists.txt b/paddle/framework/CMakeLists.txt index 5b0c18cc6c69f683d12ac6fa47ce1b8c7d1fc038..4aaa43d79612111856dd4dfc954ca2bfd8f4fa63 100644 --- a/paddle/framework/CMakeLists.txt +++ b/paddle/framework/CMakeLists.txt @@ -19,6 +19,7 @@ cc_test(scope_test SRCS scope_test.cc DEPS scope) proto_library(framework_proto SRCS framework.proto) cc_library(attribute SRCS attribute.cc DEPS framework_proto) +cc_library(proto_desc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc DEPS attribute) cc_library(op_proto_maker SRCS op_proto_maker.cc DEPS framework_proto attribute) cc_test(op_proto_maker_test SRCS op_proto_maker_test.cc DEPS op_proto_maker) cc_library(op_info SRCS op_info.cc DEPS attribute framework_proto) diff --git a/paddle/framework/block_desc.cc b/paddle/framework/block_desc.cc new file mode 100644 index 0000000000000000000000000000000000000000..9570aedfdda332b797a8f348e0f6cf81bb2aee2f --- /dev/null +++ b/paddle/framework/block_desc.cc @@ -0,0 +1,89 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/block_desc.h" +#include "paddle/framework/program_desc.h" + +namespace paddle { +namespace framework { + +VarDescBind *BlockDescBind::NewVar(const std::string &name) { + need_update_ = true; + auto it = vars_.find(name); + PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); + auto var = new VarDescBind(name); + vars_[name].reset(var); + return var; +} + +VarDescBind *BlockDescBind::Var(const std::string &name) const { + auto it = vars_.find(name); + PADDLE_ENFORCE(it != vars_.end(), + "Can not find variable %s in current block.", name); + return it->second.get(); +} + +std::vector BlockDescBind::AllVars() const { + std::vector res; + for (const auto &p : vars_) { + res.push_back(p.second.get()); + } + return res; +} + +OpDescBind *BlockDescBind::AppendOp() { + need_update_ = true; + ops_.emplace_back(new OpDescBind()); + return ops_.back().get(); +} + +OpDescBind *BlockDescBind::PrependOp() { + need_update_ = true; + ops_.emplace_front(new OpDescBind()); + return ops_.front().get(); +} + +std::vector BlockDescBind::AllOps() const { + std::vector res; + for (const auto &op : ops_) { + res.push_back(op.get()); + } + return res; +} + +void BlockDescBind::Sync() { + if (need_update_) { + auto &op_field = *this->desc_->mutable_ops(); + op_field.Clear(); + op_field.Reserve(static_cast(ops_.size())); + for (auto &op_desc : ops_) { + op_field.AddAllocated(op_desc->Proto()); + } + need_update_ = false; + } +} + +BlockDescBind *BlockDescBind::ParentBlock() const { + if (this->desc_->parent_idx() == -1) { + return nullptr; + } + return prog_->Block(static_cast(this->desc_->parent_idx())); +} + +void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { + BlockDesc *desc = block.RawPtr(); + this->attrs_[name] = desc; +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/block_desc.h b/paddle/framework/block_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..1a1135bab44cd27bb7d784c3b486188aa40635e4 --- /dev/null +++ b/paddle/framework/block_desc.h @@ -0,0 +1,71 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include +#include "paddle/framework/op_desc.h" +#include "paddle/framework/var_desc.h" + +namespace paddle { +namespace framework { + +class ProgramDescBind; + +// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize +// read/write speed. Only when we want the protobuf message, the local changes +// will be synchronized (by `Sync` method). + +class BlockDescBind { + public: + BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) + : prog_(prog), desc_(desc), need_update_(false) {} + + BlockDescBind(const BlockDescBind &o) = delete; + BlockDescBind &operator=(const BlockDescBind &o) = delete; + + int32_t ID() const { return desc_->idx(); } + + int32_t Parent() const { return desc_->parent_idx(); } + + VarDescBind *NewVar(const std::string &name_bytes); + + VarDescBind *Var(const std::string &name_bytes) const; + + std::vector AllVars() const; + + BlockDescBind *ParentBlock() const; + + OpDescBind *AppendOp(); + + OpDescBind *PrependOp(); + + std::vector AllOps() const; + + void Sync(); + + BlockDesc *RawPtr() { return desc_; } + + private: + ProgramDescBind *prog_; // not_own + BlockDesc *desc_; // not_own + bool need_update_; + + std::deque> ops_; + std::unordered_map> vars_; +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_desc.cc b/paddle/framework/op_desc.cc new file mode 100644 index 0000000000000000000000000000000000000000..99b5a9c37700adce56f9a83af3792ef113a873ff --- /dev/null +++ b/paddle/framework/op_desc.cc @@ -0,0 +1,133 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_desc.h" +#include "paddle/framework/block_desc.h" + +namespace paddle { +namespace framework { + +OpDesc *OpDescBind::Proto() { + Sync(); + return &op_desc_; +} + +const std::vector &OpDescBind::Input( + const std::string &name) const { + auto it = inputs_.find(name); + PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", name, + Type()); + return it->second; +} + +std::vector OpDescBind::InputNames() const { + std::vector retv; + retv.reserve(this->inputs_.size()); + for (auto &ipt : this->inputs_) { + retv.push_back(ipt.first); + } + return retv; +} + +void OpDescBind::SetInput(const std::string ¶m_name, + const std::vector &args) { + need_update_ = true; + inputs_[param_name] = args; +} + +const std::vector &OpDescBind::Output( + const std::string &name) const { + auto it = outputs_.find(name); + PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s", + name, Type()); + return it->second; +} + +std::vector OpDescBind::OutputNames() const { + std::vector retv; + retv.reserve(this->outputs_.size()); + for (auto &ipt : this->outputs_) { + retv.push_back(ipt.first); + } + return retv; +} + +void OpDescBind::SetOutput(const std::string ¶m_name, + const std::vector &args) { + need_update_ = true; + this->outputs_[param_name] = args; +} + +AttrType OpDescBind::GetAttrType(const std::string &name) const { + auto it = attrs_.find(name); + PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); + return static_cast(it->second.which() - 1); +} + +std::vector OpDescBind::AttrNames() const { + std::vector retv; + retv.reserve(attrs_.size()); + for (auto &attr : attrs_) { + retv.push_back(attr.first); + } + return retv; +} + +void OpDescBind::SetAttr(const std::string &name, const Attribute &v) { + this->attrs_[name] = v; + need_update_ = true; +} + +Attribute OpDescBind::GetAttr(const std::string &name) const { + auto it = attrs_.find(name); + PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); + return it->second; +} + +int OpDescBind::GetBlockAttr(const std::string &name) const { + auto it = attrs_.find(name); + PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); + return boost::get(it->second)->idx(); +} + +void OpDescBind::Sync() { + if (need_update_) { + this->op_desc_.mutable_inputs()->Clear(); + for (auto &ipt : inputs_) { + auto *input = op_desc_.add_inputs(); + input->set_parameter(ipt.first); + VectorToRepeated(ipt.second, input->mutable_arguments()); + } + + this->op_desc_.mutable_outputs()->Clear(); + for (auto &opt : outputs_) { + auto *output = op_desc_.add_outputs(); + output->set_parameter(opt.first); + VectorToRepeated(opt.second, output->mutable_arguments()); + } + + this->op_desc_.mutable_attrs()->Clear(); + for (auto &attr : attrs_) { + auto *attr_desc = op_desc_.add_attrs(); + attr_desc->set_name(attr.first); + attr_desc->set_type( + static_cast(attr.second.which() - 1)); + boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second); + } + + need_update_ = false; + } +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/op_desc.h b/paddle/framework/op_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..ffc8ac61abfb74e4716f10c457d0fbc18b2e2ab8 --- /dev/null +++ b/paddle/framework/op_desc.h @@ -0,0 +1,106 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include +#include "paddle/framework/attribute.h" +#include "paddle/framework/var_desc.h" + +namespace paddle { +namespace framework { + +class BlockDescBind; + +class OpDescBind { + public: + OpDesc *Proto(); + + std::string Type() const { return op_desc_.type(); } + + void SetType(const std::string &type) { op_desc_.set_type(type); } + + const std::vector &Input(const std::string &name) const; + + std::vector InputNames() const; + + void SetInput(const std::string ¶m_name, + const std::vector &args); + + const std::vector &Output(const std::string &name) const; + + std::vector OutputNames() const; + + void SetOutput(const std::string ¶m_name, + const std::vector &args); + + std::string DebugString() { return this->Proto()->DebugString(); } + + bool HasAttr(const std::string &name) const { + return attrs_.find(name) != attrs_.end(); + } + + AttrType GetAttrType(const std::string &name) const; + + std::vector AttrNames() const; + + void SetAttr(const std::string &name, const Attribute &v); + + void SetBlockAttr(const std::string &name, BlockDescBind &block); + + Attribute GetAttr(const std::string &name) const; + + int GetBlockAttr(const std::string &name) const; + + private: + struct SetAttrDescVisitor : public boost::static_visitor { + explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} + mutable OpDesc::Attr *attr_; + void operator()(int v) const { attr_->set_i(v); } + void operator()(float v) const { attr_->set_f(v); } + void operator()(const std::string &v) const { attr_->set_s(v); } + void operator()(bool b) const { attr_->set_b(b); } + + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_ints()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_floats()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_strings()); + } + void operator()(const std::vector &v) const { + VectorToRepeated(v, attr_->mutable_bools()); + } + void operator()(BlockDesc *desc) const { + attr_->set_block_idx(desc->idx()); + } + void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } + }; + + void Sync(); + + OpDesc op_desc_; + std::unordered_map> inputs_; + std::unordered_map> outputs_; + std::unordered_map attrs_; + + // need_update_ indicate there some local changes not be synchronized. If + // local changes should be synchronized, need_update_ should be set to true. + bool need_update_{false}; +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/program_desc.cc b/paddle/framework/program_desc.cc new file mode 100644 index 0000000000000000000000000000000000000000..e89f9a46d587b6378aa3be92306c5680093e1926 --- /dev/null +++ b/paddle/framework/program_desc.cc @@ -0,0 +1,60 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/program_desc.h" +#include "paddle/framework/block_desc.h" + +namespace paddle { +namespace framework { + +using ProgDescMap = + std::unordered_map>; +static ProgDescMap *g_bind_map = nullptr; + +ProgramDescBind &ProgramDescBind::Instance(ProgramDesc *prog) { + if (g_bind_map == nullptr) { + g_bind_map = new ProgDescMap(); + } + auto &map = *g_bind_map; + auto &ptr = map[prog]; + + if (ptr == nullptr) { + ptr.reset(new ProgramDescBind(prog)); + } + return *ptr; +} + +BlockDescBind *ProgramDescBind::AppendBlock(const BlockDescBind &parent) { + auto *b = prog_->add_blocks(); + b->set_parent_idx(parent.ID()); + b->set_idx(prog_->blocks_size() - 1); + blocks_.emplace_back(new BlockDescBind(this, b)); + return blocks_.back().get(); +} + +ProgramDesc *ProgramDescBind::Proto() { + for (auto &block : blocks_) { + block->Sync(); + } + return prog_; +} + +ProgramDescBind::ProgramDescBind(ProgramDesc *prog) { + prog_ = prog; + for (auto &block : *prog->mutable_blocks()) { + blocks_.emplace_back(new BlockDescBind(this, &block)); + } +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/program_desc.h b/paddle/framework/program_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..06ffcd4b15078f62ea8b7a3714e73de799530785 --- /dev/null +++ b/paddle/framework/program_desc.h @@ -0,0 +1,51 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/framework.pb.h" + +namespace paddle { +namespace framework { + +class BlockDescBind; + +class ProgramDescBind { + public: + static ProgramDescBind &Instance(ProgramDesc *prog); + + ProgramDescBind(const ProgramDescBind &o) = delete; + ProgramDescBind &operator=(const ProgramDescBind &o) = delete; + + BlockDescBind *AppendBlock(const BlockDescBind &parent); + + BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } + + std::string DebugString() { return Proto()->DebugString(); } + + size_t Size() const { return blocks_.size(); } + + ProgramDesc *Proto(); + + private: + explicit ProgramDescBind(ProgramDesc *prog); + + // Not owned + ProgramDesc *prog_; + + std::vector> blocks_; +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/var_desc.cc b/paddle/framework/var_desc.cc new file mode 100644 index 0000000000000000000000000000000000000000..13b9c5f3cdf98e6d22f4217fa1cf9a48910a78d8 --- /dev/null +++ b/paddle/framework/var_desc.cc @@ -0,0 +1,36 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/var_desc.h" + +namespace paddle { +namespace framework { + +void VarDescBind::SetShape(const std::vector &dims) { + VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); +} + +void VarDescBind::SetDataType(DataType data_type) { + desc_.mutable_lod_tensor()->set_data_type(data_type); +} + +std::vector VarDescBind::Shape() const { + return RepeatedToVector(desc_.lod_tensor().dims()); +} + +DataType VarDescBind::GetDataType() const { + return desc_.lod_tensor().data_type(); +} +} // namespace framework +} // namespace paddle diff --git a/paddle/framework/var_desc.h b/paddle/framework/var_desc.h new file mode 100644 index 0000000000000000000000000000000000000000..4763bf09d004539ab24e4aad3bf429667f1fcc73 --- /dev/null +++ b/paddle/framework/var_desc.h @@ -0,0 +1,73 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + +http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include +#include "paddle/framework/framework.pb.h" + +namespace paddle { +namespace framework { + +// convert between std::vector and protobuf repeated. +template +inline std::vector RepeatedToVector( + const google::protobuf::RepeatedField &repeated_field) { + std::vector ret; + ret.reserve(repeated_field.size()); + std::copy(repeated_field.begin(), repeated_field.end(), + std::back_inserter(ret)); + return ret; +} + +template +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { + repeated_field->Reserve(vec.size()); + for (const auto &elem : vec) { + *repeated_field->Add() = elem; + } +} + +// Specialize vector. +template +inline void VectorToRepeated(const std::vector &vec, + RepeatedField *repeated_field) { + repeated_field->Reserve(vec.size()); + for (auto elem : vec) { + *repeated_field->Add() = elem; + } +} + +class VarDescBind { + public: + explicit VarDescBind(const std::string &name) { desc_.set_name(name); } + + VarDesc *Proto() { return &desc_; } + + std::string Name() const { return desc_.name(); } + + void SetShape(const std::vector &dims); + + void SetDataType(DataType data_type); + + std::vector Shape() const; + + DataType GetDataType() const; + + private: + VarDesc desc_; +}; +} // namespace framework +} // namespace paddle diff --git a/paddle/pybind/CMakeLists.txt b/paddle/pybind/CMakeLists.txt index 326cc4a75bd5cc29f79de88a3e0802d17c812ecd..18ecbd1aa34c82d63ae7f8ec1bd8f81b35eee30b 100644 --- a/paddle/pybind/CMakeLists.txt +++ b/paddle/pybind/CMakeLists.txt @@ -1,6 +1,6 @@ if(WITH_PYTHON) cc_library(paddle_pybind SHARED SRCS pybind.cc exception.cc protobuf.cc - DEPS pybind python backward + DEPS pybind python backward proto_desc ${GLOB_OP_LIB}) endif(WITH_PYTHON) diff --git a/paddle/pybind/protobuf.cc b/paddle/pybind/protobuf.cc index 1a29621bdf13030c8781dab4acccca08d7250dbe..218821b35bb6947181fedc56e002ad0285f6307d 100644 --- a/paddle/pybind/protobuf.cc +++ b/paddle/pybind/protobuf.cc @@ -15,7 +15,10 @@ limitations under the License. */ #include "paddle/pybind/protobuf.h" #include #include -#include "paddle/framework/attribute.h" +#include "paddle/framework/block_desc.h" +#include "paddle/framework/op_desc.h" +#include "paddle/framework/program_desc.h" +#include "paddle/framework/var_desc.h" // Cast boost::variant for PyBind. // Copy from @@ -93,383 +96,6 @@ namespace pybind { using namespace paddle::framework; // NOLINT -// convert between std::vector and protobuf repeated. -template -inline std::vector RepeatedToVector( - const google::protobuf::RepeatedField &repeated_field) { - std::vector ret; - ret.reserve(repeated_field.size()); - std::copy(repeated_field.begin(), repeated_field.end(), - std::back_inserter(ret)); - return ret; -} - -template -inline void VectorToRepeated(const std::vector &vec, - RepeatedField *repeated_field) { - repeated_field->Reserve(vec.size()); - for (const auto &elem : vec) { - *repeated_field->Add() = elem; - } -} - -// Specialize vector. -template -inline void VectorToRepeated(const std::vector &vec, - RepeatedField *repeated_field) { - repeated_field->Reserve(vec.size()); - for (auto elem : vec) { - *repeated_field->Add() = elem; - } -} - -class ProgramDescBind; -class OpDescBind; -class BlockDescBind; -class VarDescBind; - -// Each Protobuf Message, we provide a XXXBind class. In that class, we optimize -// read/write speed. Only when we want the protobuf message, the local changes -// will be synchronized (by `Sync` method). -class VarDescBind { - public: - explicit VarDescBind(const std::string &name) { desc_.set_name(name); } - - VarDesc *Proto() { return &desc_; } - - py::bytes Name() const { return desc_.name(); } - - void SetShape(const std::vector &dims) { - VectorToRepeated(dims, desc_.mutable_lod_tensor()->mutable_dims()); - } - - void SetDataType(framework::DataType data_type) { - desc_.mutable_lod_tensor()->set_data_type(data_type); - } - - std::vector Shape() const { - return RepeatedToVector(desc_.lod_tensor().dims()); - } - - framework::DataType DataType() const { - return desc_.lod_tensor().data_type(); - } - - private: - VarDesc desc_; -}; - -class OpDescBind { - public: - OpDesc *Proto() { - Sync(); - return &op_desc_; - } - - std::string Type() const { return op_desc_.type(); } - - void SetType(const std::string &type) { op_desc_.set_type(type); } - - const std::vector &Input(const std::string &name) const { - auto it = inputs_.find(name); - PADDLE_ENFORCE(it != inputs_.end(), "Input %s cannot be found in Op %s", - name, Type()); - return it->second; - } - - std::vector InputNames() const { - std::vector retv; - retv.reserve(this->inputs_.size()); - for (auto &ipt : this->inputs_) { - retv.push_back(ipt.first); - } - return retv; - } - - void SetInput(const std::string ¶m_name, - const std::vector &args) { - need_update_ = true; - inputs_[param_name] = args; - } - - const std::vector &Output(const std::string &name) const { - auto it = outputs_.find(name); - PADDLE_ENFORCE(it != outputs_.end(), "Output %s cannot be found in Op %s", - name, Type()); - return it->second; - } - - std::vector OutputNames() const { - std::vector retv; - retv.reserve(this->outputs_.size()); - for (auto &ipt : this->outputs_) { - retv.push_back(ipt.first); - } - return retv; - } - - void SetOutput(const std::string ¶m_name, - const std::vector &args) { - need_update_ = true; - this->outputs_[param_name] = args; - } - - std::string DebugString() { return this->Proto()->DebugString(); } - - bool HasAttr(const std::string &name) const { - return attrs_.find(name) != attrs_.end(); - } - - framework::AttrType GetAttrType(const std::string &name) const { - auto it = attrs_.find(name); - PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); - return static_cast(it->second.which() - 1); - } - - std::vector AttrNames() const { - std::vector retv; - retv.reserve(attrs_.size()); - for (auto &attr : attrs_) { - retv.push_back(attr.first); - } - return retv; - } - - void SetAttr(const std::string &name, const Attribute &v) { - this->attrs_[name] = v; - need_update_ = true; - } - - void SetBlockAttr(const std::string &name, BlockDescBind &block); - - Attribute GetAttr(const std::string &name) const { - auto it = attrs_.find(name); - PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); - return it->second; - } - - int GetBlockAttr(const std::string &name) const { - auto it = attrs_.find(name); - PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name); - return boost::get(it->second)->idx(); - } - - private: - struct SetAttrDescVisitor : public boost::static_visitor { - explicit SetAttrDescVisitor(OpDesc::Attr *attr) : attr_(attr) {} - mutable OpDesc::Attr *attr_; - void operator()(int v) const { attr_->set_i(v); } - void operator()(float v) const { attr_->set_f(v); } - void operator()(const std::string &v) const { attr_->set_s(v); } - void operator()(bool b) const { attr_->set_b(b); } - - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_ints()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_floats()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_strings()); - } - void operator()(const std::vector &v) const { - VectorToRepeated(v, attr_->mutable_bools()); - } - void operator()(BlockDesc *desc) const { - attr_->set_block_idx(desc->idx()); - } - void operator()(boost::blank) const { PADDLE_THROW("Unexpected branch"); } - }; - - void Sync() { - if (need_update_) { - this->op_desc_.mutable_inputs()->Clear(); - for (auto &ipt : inputs_) { - auto *input = op_desc_.add_inputs(); - input->set_parameter(ipt.first); - VectorToRepeated(ipt.second, input->mutable_arguments()); - } - - this->op_desc_.mutable_outputs()->Clear(); - for (auto &opt : outputs_) { - auto *output = op_desc_.add_outputs(); - output->set_parameter(opt.first); - VectorToRepeated(opt.second, output->mutable_arguments()); - } - - this->op_desc_.mutable_attrs()->Clear(); - for (auto &attr : attrs_) { - auto *attr_desc = op_desc_.add_attrs(); - attr_desc->set_name(attr.first); - attr_desc->set_type( - static_cast(attr.second.which() - 1)); - boost::apply_visitor(SetAttrDescVisitor(attr_desc), attr.second); - } - - need_update_ = false; - } - } - - OpDesc op_desc_; - std::unordered_map> inputs_; - std::unordered_map> outputs_; - std::unordered_map attrs_; - - // need_update_ indicate there some local changes not be synchronized. If - // local changes should be synchronized, need_update_ should be set to true. - bool need_update_{false}; -}; - -class BlockDescBind { - public: - BlockDescBind(ProgramDescBind *prog, BlockDesc *desc) - : prog_(prog), desc_(desc), need_update_(false) {} - - BlockDescBind(const BlockDescBind &o) = delete; - BlockDescBind &operator=(const BlockDescBind &o) = delete; - - int32_t ID() const { return desc_->idx(); } - - int32_t Parent() const { return desc_->parent_idx(); } - - VarDescBind *NewVar(py::bytes name_bytes) { - std::string name = name_bytes; - need_update_ = true; - auto it = vars_.find(name); - PADDLE_ENFORCE(it == vars_.end(), "Duplicated variable %s", name); - auto var = new VarDescBind(name); - vars_[name].reset(var); - return var; - } - - VarDescBind *Var(py::bytes name_bytes) const { - std::string name = name_bytes; - auto it = vars_.find(name); - PADDLE_ENFORCE(it != vars_.end(), - "Can not find variable %s in current block.", name); - return it->second.get(); - } - - std::vector AllVars() const { - std::vector res; - for (const auto &p : vars_) { - res.push_back(p.second.get()); - } - return res; - } - - BlockDescBind *ParentBlock() const; - - OpDescBind *AppendOp() { - need_update_ = true; - ops_.emplace_back(new OpDescBind()); - return ops_.back().get(); - } - - OpDescBind *PrependOp() { - need_update_ = true; - ops_.emplace_front(new OpDescBind()); - return ops_.front().get(); - } - - std::vector AllOps() const { - std::vector res; - for (const auto &op : ops_) { - res.push_back(op.get()); - } - return res; - } - - void Sync() { - if (need_update_) { - auto &op_field = *this->desc_->mutable_ops(); - op_field.Clear(); - op_field.Reserve(static_cast(ops_.size())); - for (auto &op_desc : ops_) { - op_field.AddAllocated(op_desc->Proto()); - } - need_update_ = false; - } - } - - BlockDesc *RawPtr() { return desc_; } - - private: - ProgramDescBind *prog_; // not_own - BlockDesc *desc_; // not_own - bool need_update_; - - std::deque> ops_; - std::unordered_map> vars_; -}; - -using ProgDescMap = - std::unordered_map>; -static ProgDescMap *g_bind_map = nullptr; - -class ProgramDescBind { - public: - static ProgramDescBind &Instance(ProgramDesc *prog) { - if (g_bind_map == nullptr) { - g_bind_map = new ProgDescMap(); - } - auto &map = *g_bind_map; - auto &ptr = map[prog]; - - if (ptr == nullptr) { - ptr.reset(new ProgramDescBind(prog)); - } - return *ptr; - } - ProgramDescBind(const ProgramDescBind &o) = delete; - ProgramDescBind &operator=(const ProgramDescBind &o) = delete; - - BlockDescBind *AppendBlock(const BlockDescBind &parent) { - auto *b = prog_->add_blocks(); - b->set_parent_idx(parent.ID()); - b->set_idx(prog_->blocks_size() - 1); - blocks_.emplace_back(new BlockDescBind(this, b)); - return blocks_.back().get(); - } - - BlockDescBind *Block(size_t idx) { return blocks_[idx].get(); } - - std::string DebugString() { return Proto()->DebugString(); } - - size_t Size() const { return blocks_.size(); } - - ProgramDesc *Proto() { - for (auto &block : blocks_) { - block->Sync(); - } - return prog_; - } - - private: - explicit ProgramDescBind(ProgramDesc *prog) : prog_(prog) { - for (auto &block : *prog->mutable_blocks()) { - blocks_.emplace_back(new BlockDescBind(this, &block)); - } - } - - // Not owned - ProgramDesc *prog_; - - std::vector> blocks_; -}; - -BlockDescBind *BlockDescBind::ParentBlock() const { - if (this->desc_->parent_idx() == -1) { - return nullptr; - } - return prog_->Block(static_cast(this->desc_->parent_idx())); -} - -void OpDescBind::SetBlockAttr(const std::string &name, BlockDescBind &block) { - BlockDesc *desc = block.RawPtr(); - this->attrs_[name] = desc; -} - // Bind Methods void BindProgramDesc(py::module &m) { py::class_(m, "ProgramDesc", "") @@ -503,9 +129,18 @@ void BindBlockDesc(py::module &m) { py::return_value_policy::reference) .def("prepend_op", &BlockDescBind::PrependOp, py::return_value_policy::reference) - .def("new_var", &BlockDescBind::NewVar, + .def("new_var", + [](BlockDescBind &self, py::bytes byte_name) { + std::string name = byte_name; + return self.NewVar(name); + }, + py::return_value_policy::reference) + .def("var", + [](BlockDescBind &self, py::bytes byte_name) { + std::string name = byte_name; + return self.Var(name); + }, py::return_value_policy::reference) - .def("var", &BlockDescBind::Var, py::return_value_policy::reference) .def("all_vars", &BlockDescBind::AllVars, py::return_value_policy::reference) .def("all_ops", &BlockDescBind::AllOps, @@ -513,7 +148,7 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { - py::enum_(m, "DataType", "") + py::enum_(m, "DataType", "") .value("BOOL", DataType::BOOL) .value("INT16", DataType::INT16) .value("INT32", DataType::INT32) @@ -523,15 +158,20 @@ void BindVarDsec(py::module &m) { .value("FP64", DataType::FP64); py::class_(m, "VarDesc", "") - .def("name", &VarDescBind::Name, py::return_value_policy::reference) + .def("name", + [](const VarDescBind &self) { + py::bytes name = self.Name(); + return name; + }, + py::return_value_policy::reference) .def("set_shape", &VarDescBind::SetShape) .def("set_data_type", &VarDescBind::SetDataType) .def("shape", &VarDescBind::Shape, py::return_value_policy::reference) - .def("data_type", &VarDescBind::DataType); + .def("data_type", &VarDescBind::GetDataType); } void BindOpDesc(py::module &m) { - py::enum_(m, "AttrType", "") + py::enum_(m, "AttrType", "") .value("INT", AttrType::INT) .value("INTS", AttrType::INTS) .value("FLOAT", AttrType::FLOAT) diff --git a/paddle/pybind/protobuf.h b/paddle/pybind/protobuf.h index 2721c128d1290ee0b1246d877d9e5ea9c4ae24ec..089183accc08c3c486a7ae78ccfe060853ec54f5 100644 --- a/paddle/pybind/protobuf.h +++ b/paddle/pybind/protobuf.h @@ -17,7 +17,6 @@ limitations under the License. */ #include #include #include -#include "paddle/framework/op_registry.h" #include "pybind11/numpy.h" #include "pybind11/pybind11.h" #include "pybind11/stl.h"