未验证 提交 77ee8fb2 编写于 作者: K kavyasrinet 提交者: GitHub

Exposing Channel to be used as a Variable and integrating with Fluid (#8486)

* Adding set_capacity method support

* Adding Python for make_channel

* Updating notest_concurrency

* Write python for make_channel method

* Write python for make_channel method

* Fix make_channel and test

* Placeholder ops for channel send, recv and close

* Adding ToTypeIndex method to var_type.h

* Add var_type.h to channel:

* Added POD_Type to the method

* Add CHANNEL to executor

* Updated get and set DataType to accomodate Channels

* Updating get and set to incorporate channels

* Adding CHANNEL as supported VarType in protobuf

* Removing unecessary import

* Fixing VarDesc to adapt to Channel as VarType

* Add channel.h to executor

* Remove innclude from channel

* Updated var_type to support Channel as  var type

* Adding get_channel to pybind

* Added ChannelHolder

* Adding make_channel as an op

* Adding ChannelHolder in channel

* Fixing typo

* Commenting out operators in concurrency

* Removing totypeid right now since we don't need it.

* Reverting python changes

* Fixing typo in framework.py

* Modify comments for ReaderHolder
上级 88c22e9d
...@@ -15,6 +15,8 @@ limitations under the License. */ ...@@ -15,6 +15,8 @@ limitations under the License. */
#pragma once #pragma once
#include <stddef.h> // for size_t #include <stddef.h> // for size_t
#include <typeindex>
#include "paddle/fluid/platform/enforce.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) { ...@@ -51,6 +53,77 @@ void CloseChannel(Channel<T>* ch) {
ch->Close(); ch->Close();
} }
/*
* The ChannelHolder class serves two main purposes:
* 1. It acts as a unified wrapper for the different kinds of
* channels, i.e. Buffered and Unbuffered channels. This is
* similar to the ReaderHolder class.
* 2. It also helps us in TypeHiding. This is similar to the
* PlaceHolder implementations in variable.h and tensor.h.
*/
class ChannelHolder {
public:
template <typename T>
void Reset(size_t buffer_size) {
holder_.reset(new PlaceholderImpl<T>(buffer_size));
}
template <typename T>
bool Send(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
// Static cast should be safe because we have ensured that types are same
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Send(data) : false;
}
template <typename T>
bool Receive(T* data) {
if (!IsInitialized()) return false;
PADDLE_ENFORCE_EQ(holder_->Type(), std::type_index(typeid(T)));
Channel<T>* channel = static_cast<Channel<T>*>(holder_->Ptr());
return channel != nullptr ? channel->Receive(data) : false;
}
void close() {
if (IsInitialized()) holder_->Close();
}
inline bool IsInitialized() const { return holder_ != nullptr; }
private:
/**
* @note Placeholder hides type T, so it doesn't appear as a template
* parameter of ChannelHolder.
*/
struct Placeholder {
virtual ~Placeholder() {}
virtual const std::type_index Type() const = 0;
virtual void* Ptr() const = 0;
virtual void Close() const = 0;
std::type_info type_;
};
template <typename T>
struct PlaceholderImpl : public Placeholder {
PlaceholderImpl(size_t buffer_size) : type_(std::type_index(typeid(T))) {
channel_.reset(MakeChannel<T>(buffer_size));
}
virtual const std::type_index Type() const { return type_; }
virtual void* Ptr() const { return static_cast<void*>(channel_.get()); }
virtual void Close() {
if (channel_) channel_->Close();
}
std::unique_ptr<Channel<T>*> channel_;
const std::type_index type_;
};
// Pointer to a PlaceholderImpl object
std::unique_ptr<Placeholder> holder_;
};
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <set> #include <set>
#include "gflags/gflags.h" #include "gflags/gflags.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
...@@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) { ...@@ -55,13 +56,15 @@ static void CreateTensor(Variable* var, proto::VarType::Type var_type) {
var->GetMutable<platform::PlaceList>(); var->GetMutable<platform::PlaceList>();
} else if (var_type == proto::VarType::READER) { } else if (var_type == proto::VarType::READER) {
var->GetMutable<ReaderHolder>(); var->GetMutable<ReaderHolder>();
} else if (var_type == proto::VarType::CHANNEL) {
var->GetMutable<ChannelHolder>();
} else if (var_type == proto::VarType::NCCL_COM) { } else if (var_type == proto::VarType::NCCL_COM) {
// GetMutable will be called in ncclInit // GetMutable will be called in ncclInit
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, " "[LOD_TENSOR, SELECTED_ROWS, FEED_MINIBATCH, FETCH_LIST, "
"LOD_RANK_TABLE, PLACE_LIST, READER, NCCL_COM]", "LOD_RANK_TABLE, PLACE_LIST, READER, CHANNEL, NCCL_COM]",
var_type); var_type);
} }
} }
......
...@@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const { ...@@ -88,7 +88,13 @@ std::vector<std::vector<int64_t>> VarDesc::GetShapes() const {
} }
void VarDesc::SetDataType(proto::VarType::Type data_type) { void VarDesc::SetDataType(proto::VarType::Type data_type) {
mutable_tensor_desc()->set_data_type(data_type); switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
mutable_channel_desc()->set_data_type(data_type);
break;
default:
mutable_tensor_desc()->set_data_type(data_type);
}
} }
void VarDesc::SetDataTypes( void VarDesc::SetDataTypes(
...@@ -109,7 +115,13 @@ void VarDesc::SetDataTypes( ...@@ -109,7 +115,13 @@ void VarDesc::SetDataTypes(
} }
proto::VarType::Type VarDesc::GetDataType() const { proto::VarType::Type VarDesc::GetDataType() const {
return tensor_desc().data_type(); switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return channel_desc().data_type();
break;
default:
return tensor_desc().data_type();
}
} }
std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const { std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
...@@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const { ...@@ -122,6 +134,17 @@ std::vector<proto::VarType::Type> VarDesc::GetDataTypes() const {
return res; return res;
} }
void VarDesc::SetCapacity(int64_t capacity) {
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
desc_.mutable_type()->mutable_channel()->set_capacity(capacity);
break;
default:
PADDLE_THROW("Setting 'capacity' is not supported by the type of var %s.",
this->Name());
}
}
void VarDesc::SetLoDLevel(int32_t lod_level) { void VarDesc::SetLoDLevel(int32_t lod_level) {
switch (desc_.type().type()) { switch (desc_.type().type()) {
case proto::VarType::LOD_TENSOR: case proto::VarType::LOD_TENSOR:
...@@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const { ...@@ -191,6 +214,19 @@ std::vector<int32_t> VarDesc::GetLoDLevels() const {
} }
} }
const proto::VarType::ChannelDesc &VarDesc::channel_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.type().channel();
default:
PADDLE_THROW(
"Getting 'channel_desc' is not supported by the type of var %s.",
this->Name());
}
}
const proto::VarType::TensorDesc &VarDesc::tensor_desc() const { const proto::VarType::TensorDesc &VarDesc::tensor_desc() const {
PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var's type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
...@@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const { ...@@ -226,6 +262,20 @@ std::vector<proto::VarType::TensorDesc> VarDesc::tensor_descs() const {
} }
} }
proto::VarType::ChannelDesc *VarDesc::mutable_channel_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
switch (desc_.type().type()) {
case proto::VarType::CHANNEL:
return desc_.mutable_type()->mutable_channel();
default:
PADDLE_THROW(
"Getting 'mutable_channel_desc' is not supported by the type of var "
"%s.",
this->Name());
}
}
proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() { proto::VarType::TensorDesc *VarDesc::mutable_tensor_desc() {
PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.has_type(), "The var type hasn't been set.");
PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set."); PADDLE_ENFORCE(desc_.type().has_type(), "The var type hasn't been set.");
......
...@@ -85,6 +85,8 @@ class VarDesc { ...@@ -85,6 +85,8 @@ class VarDesc {
void SetDataTypes( void SetDataTypes(
const std::vector<proto::VarType::Type> &multiple_data_type); const std::vector<proto::VarType::Type> &multiple_data_type);
void SetCapacity(int64_t capacity);
proto::VarType::Type GetDataType() const; proto::VarType::Type GetDataType() const;
std::vector<proto::VarType::Type> GetDataTypes() const; std::vector<proto::VarType::Type> GetDataTypes() const;
...@@ -106,8 +108,10 @@ class VarDesc { ...@@ -106,8 +108,10 @@ class VarDesc {
void SetPersistable(bool persistable) { desc_.set_persistable(persistable); } void SetPersistable(bool persistable) { desc_.set_persistable(persistable); }
private: private:
const proto::VarType::ChannelDesc &channel_desc() const;
const proto::VarType::TensorDesc &tensor_desc() const; const proto::VarType::TensorDesc &tensor_desc() const;
std::vector<proto::VarType::TensorDesc> tensor_descs() const; std::vector<proto::VarType::TensorDesc> tensor_descs() const;
proto::VarType::ChannelDesc *mutable_channel_desc();
proto::VarType::TensorDesc *mutable_tensor_desc(); proto::VarType::TensorDesc *mutable_tensor_desc();
std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs(); std::vector<proto::VarType::TensorDesc *> mutable_tensor_descs();
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) { ...@@ -34,6 +35,8 @@ inline proto::VarType::Type ToVarType(std::type_index type) {
return proto::VarType_Type_SELECTED_ROWS; return proto::VarType_Type_SELECTED_ROWS;
} else if (type.hash_code() == typeid(ReaderHolder).hash_code()) { } else if (type.hash_code() == typeid(ReaderHolder).hash_code()) {
return proto::VarType_Type_READER; return proto::VarType_Type_READER;
} else if (type.hash_code() == typeid(ChannelHolder).hash_code()) {
return proto::VarType_Type_CHANNEL;
} else { } else {
PADDLE_THROW("ToVarType:Unsupported type %s", type.name()); PADDLE_THROW("ToVarType:Unsupported type %s", type.name());
} }
...@@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) { ...@@ -57,6 +60,9 @@ inline void VisitVarType(const framework::Variable& var, Visitor visitor) {
case proto::VarType_Type_READER: case proto::VarType_Type_READER:
visitor(var.Get<ReaderHolder>()); visitor(var.Get<ReaderHolder>());
return; return;
case proto::VarType_Type_CHANNEL:
visitor(var.Get<ChannelHolder>());
return;
default: default:
PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type())); PADDLE_THROW("Not supported visit type, %d", ToVarType(var.Type()));
} }
......
...@@ -216,6 +216,7 @@ void BindVarDsec(py::module &m) { ...@@ -216,6 +216,7 @@ void BindVarDsec(py::module &m) {
.def("set_shapes", &VarDesc::SetShapes) .def("set_shapes", &VarDesc::SetShapes)
.def("set_dtype", &VarDesc::SetDataType) .def("set_dtype", &VarDesc::SetDataType)
.def("set_dtypes", &VarDesc::SetDataTypes) .def("set_dtypes", &VarDesc::SetDataTypes)
.def("set_capacity", &VarDesc::SetCapacity)
.def("shape", &VarDesc::GetShape, py::return_value_policy::reference) .def("shape", &VarDesc::GetShape, py::return_value_policy::reference)
.def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference) .def("shapes", &VarDesc::GetShapes, py::return_value_policy::reference)
.def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference) .def("dtype", &VarDesc::GetDataType, py::return_value_policy::reference)
...@@ -246,6 +247,7 @@ void BindVarDsec(py::module &m) { ...@@ -246,6 +247,7 @@ void BindVarDsec(py::module &m) {
.value("STEP_SCOPES", proto::VarType::STEP_SCOPES) .value("STEP_SCOPES", proto::VarType::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarType::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY) .value("LOD_TENSOR_ARRAY", proto::VarType::LOD_TENSOR_ARRAY)
.value("CHANNEL", proto::VarType::CHANNEL)
.value("PLACE_LIST", proto::VarType::PLACE_LIST) .value("PLACE_LIST", proto::VarType::PLACE_LIST)
.value("READER", proto::VarType::READER) .value("READER", proto::VarType::READER)
.value("NCCL_COM", proto::VarType::NCCL_COM); .value("NCCL_COM", proto::VarType::NCCL_COM);
......
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <mutex> // for call_once #include <mutex> // for call_once
#include <unordered_map> #include <unordered_map>
#include "paddle/fluid/framework/backward.h" #include "paddle/fluid/framework/backward.h"
#include "paddle/fluid/framework/channel.h"
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册