diff --git a/paddle/fluid/framework/channel.h b/paddle/fluid/framework/channel.h index 5acf4fb39bbeb6bd45d215c962f10f0333578c02..8ca1f2aa47b9d9ecc4da3f8d0d917cc07644bf08 100644 --- a/paddle/fluid/framework/channel.h +++ b/paddle/fluid/framework/channel.h @@ -15,6 +15,8 @@ limitations under the License. */ #pragma once #include // for size_t +#include +#include "paddle/fluid/platform/enforce.h" namespace paddle { namespace framework { @@ -51,6 +53,77 @@ void CloseChannel(Channel* ch) { ch->Close(); } +/* + * The ChannelHolder class serves two main purposes: + * 1. It acts as a unified wrapper for the different kinds of + * channels, i.e. Buffered and Unbuffered channels. This is + * similar to the ReaderHolder class. + * 2. It also helps us in TypeHiding. This is similar to the + * PlaceHolder implementations in variable.h and tensor.h. + */ +class ChannelHolder { + public: + template + void Reset(size_t buffer_size) { + holder_.reset(new PlaceholderImpl(buffer_size)); + } + + template + bool Send(T* data) { + if (!IsInitialized()) return false; + PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); + // Static cast should be safe because we have ensured that types are same + Channel* channel = static_cast*>(holder_->Ptr()); + return channel != nullptr ? channel->Send(data) : false; + } + + template + bool Receive(T* data) { + if (!IsInitialized()) return false; + PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T))); + Channel* channel = static_cast*>(holder_->Ptr()); + return channel != nullptr ? channel->Receive(data) : false; + } + + void close() { + if (IsInitialized()) holder_->Close(); + } + + inline bool IsInitialized() const { return holder_ != nullptr; } + + private: + /** + * @note Placeholder hides type T, so it doesn't appear as a template + * parameter of ChannelHolder. + */ + struct Placeholder { + virtual ~Placeholder() {} + virtual const std::type_index Type() const = 0; + virtual void* Ptr() const = 0; + virtual void Close() const = 0; + std::type_info type_; + }; + + template + struct PlaceholderImpl : public Placeholder { + PlaceholderImpl(size_t buffer_size) : type_(std::type_index(typeid(T))) { + channel_.reset(MakeChannel(buffer_size)); + } + + virtual const std::type_index Type() const { return type_; } + virtual void* Ptr() const { return static_cast(channel_.get()); } + virtual void Close() { + if (channel_) channel_->Close(); + } + + std::unique_ptr*> channel_; + const std::type_index type_; + }; + + // Pointer to a PlaceholderImpl object + std::unique_ptr holder_; +}; + } // namespace framework } // namespace paddle diff --git a/paddle/fluid/framework/executor.cc b/paddle/fluid/framework/executor.cc index 6fd19e804afd674f292ffe2112988bf9d166f12a..0d2691e8115ad6de46dcd4fcd5b7fd79ed60ecb9 100644 --- a/paddle/fluid/framework/executor.cc +++ b/paddle/fluid/framework/executor.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include #include "gflags/gflags.h" +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/lod_rank_table.h" @@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { var->GetMutable(); } else if (var_type == proto::VarType::READER) { var->GetMutable(); + } else if (var_type == proto::VarType::CHANNEL) { + var->GetMutable(); } else if (var_type == proto::VarType::NCCL_COM) { // GetMutable will be called in ncclInit } else { PADDLE_THROW( "Variable type %d is not in " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " - "LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]", + "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]", var_type); } } diff --git a/paddle/fluid/framework/var_desc.cc b/paddle/fluid/framework/var_desc.cc index 7e3f002b53351ba5892aaa50482b21a83db94069..1aa0ae0f7c1946d91736ab61236a65a45c203fe3 100644 --- a/paddle/fluid/framework/var_desc.cc +++ b/paddle/fluid/framework/var_desc.cc @@ -88,7 +88,13 @@ std::vector> VarDesc::GetShapes() const { } void VarDesc::SetDataType(proto::VarType::Type data_type) { - mutable_tensor_desc()->set_data_type(data_type); + switch (desc_.type().type()) { + case proto::VarType::CHANNEL: + mutable_channel_desc()->set_data_type(data_type); + break; + default: + mutable_tensor_desc()->set_data_type(data_type); + } } void VarDesc::SetDataTypes( @@ -109,7 +115,13 @@ void VarDesc::SetDataTypes( } proto::VarType::Type VarDesc::GetDataType() const { - return tensor_desc().data_type(); + switch (desc_.type().type()) { + case proto::VarType::CHANNEL: + return channel_desc().data_type(); + break; + default: + return tensor_desc().data_type(); + } } std::vector VarDesc::GetDataTypes() const { @@ -122,6 +134,17 @@ std::vector VarDesc::GetDataTypes() const { return res; } +void VarDesc::SetCapacity(int64_t capacity) { + switch (desc_.type().type()) { + case proto::VarType::CHANNEL: + desc_.mutable_type()->mutable_channel()->set_capacity(capacity); + break; + default: + PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.", + this->Name()); + } +} + void VarDesc::SetLoDLevel(int32_t lod_level) { switch (desc_.type().type()) { case proto::VarType::LOD_TENSOR: @@ -191,6 +214,19 @@ std::vector VarDesc::GetLoDLevels() const { } } +const proto::VarType::ChannelDesc &VarDesc::channel_desc() const { + PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + switch (desc_.type().type()) { + case proto::VarType::CHANNEL: + return desc_.type().channel(); + default: + PADDLE_THROW( + "Getting 'channel_desc' is not supported by the type of var %s.", + this->Name()); + } +} + const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); @@ -226,6 +262,20 @@ std::vector VarDesc::tensor_descs() const { } } +proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() { + PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); + PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); + switch (desc_.type().type()) { + case proto::VarType::CHANNEL: + return desc_.mutable_type()->mutable_channel(); + default: + PADDLE_THROW( + "Getting 'mutable_channel_desc' is not supported by the type of var " + "%s.", + this->Name()); + } +} + proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); diff --git a/paddle/fluid/framework/var_desc.h b/paddle/fluid/framework/var_desc.h index 19b8d890c1f8a80134744d0180a82c46cc6be6c2..f62415fda67a506763494886eb499fbb09c5caa6 100644 --- a/paddle/fluid/framework/var_desc.h +++ b/paddle/fluid/framework/var_desc.h @@ -85,6 +85,8 @@ class VarDesc { void SetDataTypes( const std::vector &multiple_data_type); + void SetCapacity(int64_t capacity); + proto::VarType::Type GetDataType() const; std::vector GetDataTypes() const; @@ -106,8 +108,10 @@ class VarDesc { void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } private: + const proto::VarType::ChannelDesc &channel_desc() const; const proto::VarType::TensorDesc &tensor_desc() const; std::vector tensor_descs() const; + proto::VarType::ChannelDesc *mutable_channel_desc(); proto::VarType::TensorDesc *mutable_tensor_desc(); std::vector mutable_tensor_descs(); diff --git a/paddle/fluid/framework/var_type.h b/paddle/fluid/framework/var_type.h index 960ebff9d7d8a522cf37c6c413e4caa1655ea86e..2b646d78f0b23ec3e065c891826856c2341d4ac1 100644 --- a/paddle/fluid/framework/var_type.h +++ b/paddle/fluid/framework/var_type.h @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #pragma once +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" @@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) { return proto::VarType_Type_SELECTED_ROWS; } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { return proto::VarType_Type_READER; + } else if (type.hash_code() == typeid(ChannelHolder).hash_code()) { + return proto::VarType_Type_CHANNEL; } else { PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); } @@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) { case proto::VarType_Type_READER: visitor(var.Get()); return; + case proto::VarType_Type_CHANNEL: + visitor(var.Get()); + return; default: PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); } diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 4e04151c6ad26f37bde6bd0058505c767ef2d7f1..1a9d7c421b741187390e0ea3d837e8ef1cce70e8 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -216,6 +216,7 @@ void BindVarDsec(py::module &m) { .def("set_shapes", &VarDesc::SetShapes) .def("set_dtype", &VarDesc::SetDataType) .def("set_dtypes", &VarDesc::SetDataTypes) + .def("set_capacity", &VarDesc::SetCapacity) .def("shape", &VarDesc::GetShape, py::return_value_policy::reference) .def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) @@ -246,6 +247,7 @@ void BindVarDsec(py::module &m) { .value("STEP_SCOPES", proto::VarType::STEP_SCOPES) .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE) .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY) + .value("CHANNEL", proto::VarType::CHANNEL) .value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("READER", proto::VarType::READER) .value("NCCL_COM", proto::VarType::NCCL_COM); diff --git a/paddle/fluid/pybind/pybind.cc b/paddle/fluid/pybind/pybind.cc index 56c1a935d98c4faf808088944f2a3e0808f2ca46..abe2b114492007ec19f2fcdb09aa173c88badbf5 100644 --- a/paddle/fluid/pybind/pybind.cc +++ b/paddle/fluid/pybind/pybind.cc @@ -17,6 +17,7 @@ limitations under the License. */ #include // for call_once #include #include "paddle/fluid/framework/backward.h" +#include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/framework.pb.h"