提交 602aa433 编写于 作者: T typhoonzero

cast data type

上级 a2de156d
...@@ -31,35 +31,14 @@ namespace detail { ...@@ -31,35 +31,14 @@ namespace detail {
using VarMsg = sendrecv::VariableMessage; using VarMsg = sendrecv::VariableMessage;
VarMsg::Type DataTypeToEnum(std::type_index type) {
if (typeid(platform::float16).hash_code() == type.hash_code()) {
return VarMsg::FP16;
} else if (typeid(const float).hash_code() == type.hash_code()) {
// CPPLint complains Using C-style cast. Use static_cast<float>() instead
// One fix to this is to replace float with const float because
// typeid(T) == typeid(const T)
// http://en.cppreference.com/w/cpp/language/typeid
return VarMsg::FP32;
} else if (typeid(const double).hash_code() == type.hash_code()) {
return VarMsg::FP64;
} else if (typeid(const int).hash_code() == type.hash_code()) {
return VarMsg::INT32;
} else if (typeid(const int64_t).hash_code() == type.hash_code()) {
return VarMsg::INT64;
} else if (typeid(const bool).hash_code() == type.hash_code()) {
return VarMsg::BOOL;
} else {
PADDLE_THROW("Not supported");
}
}
void GetTensorPayload(framework::Variable* var, void GetTensorPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request, const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) { void** payload, size_t* payload_size) {
auto tensor = var->Get<framework::LoDTensor>(); auto tensor = var->Get<framework::LoDTensor>();
// FIXME(wuyi): data types in send_recv.proto is not synced with // FIXME(wuyi): data types in send_recv.proto is copied from
// framework.proto // framework.proto
request->set_data_type(DataTypeToEnum(tensor.type())); request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(tensor.type())));
for (auto& dim : framework::vectorize(tensor.dims())) { for (auto& dim : framework::vectorize(tensor.dims())) {
request->add_dims(dim); request->add_dims(dim);
} }
...@@ -96,7 +75,8 @@ void GetSelectedRowsPayload(framework::Variable* var, ...@@ -96,7 +75,8 @@ void GetSelectedRowsPayload(framework::Variable* var,
const platform::DeviceContext& ctx, VarMsg* request, const platform::DeviceContext& ctx, VarMsg* request,
void** payload, size_t* payload_size) { void** payload, size_t* payload_size) {
auto* slr = var->GetMutable<framework::SelectedRows>(); auto* slr = var->GetMutable<framework::SelectedRows>();
request->set_data_type(DataTypeToEnum(slr->value().type())); request->set_data_type(
static_cast<VarMsg::Type>(framework::ToDataType(slr->value().type())));
request->set_lod_level(0); request->set_lod_level(0);
request->set_slr_height(slr->height()); request->set_slr_height(slr->height());
...@@ -170,7 +150,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ...@@ -170,7 +150,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var,
ProtoEncodeHelper e(static_cast<char*>(buf), 1024); ProtoEncodeHelper e(static_cast<char*>(buf), 1024);
e.WriteRawBytes(std::string(header.data(), header.size())); e.WriteRawBytes(std::string(header.data(), header.size()));
e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size);
// steal reference of tensor data // steal reference of tensor data
::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows
int num_slices = 2; // only SelectedRows have rows buffer int num_slices = 2; // only SelectedRows have rows buffer
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册