From 602aa433222d2958b00851e15d2c52ec23a09bb2 Mon Sep 17 00:00:00 2001 From: typhoonzero Date: Thu, 10 May 2018 11:14:28 +0800 Subject: [PATCH] cast data type --- .../operators/detail/sendrecvop_utils.cc | 31 +++---------------- 1 file changed, 5 insertions(+), 26 deletions(-) diff --git a/paddle/fluid/operators/detail/sendrecvop_utils.cc b/paddle/fluid/operators/detail/sendrecvop_utils.cc index e77c38f59a8..1a8a1af20fa 100644 --- a/paddle/fluid/operators/detail/sendrecvop_utils.cc +++ b/paddle/fluid/operators/detail/sendrecvop_utils.cc @@ -31,35 +31,14 @@ namespace detail { using VarMsg = sendrecv::VariableMessage; -VarMsg::Type DataTypeToEnum(std::type_index type) { - if (typeid(platform::float16).hash_code() == type.hash_code()) { - return VarMsg::FP16; - } else if (typeid(const float).hash_code() == type.hash_code()) { - // CPPLint complains Using C-style cast. Use static_cast() instead - // One fix to this is to replace float with const float because - // typeid(T) == typeid(const T) - // http://en.cppreference.com/w/cpp/language/typeid - return VarMsg::FP32; - } else if (typeid(const double).hash_code() == type.hash_code()) { - return VarMsg::FP64; - } else if (typeid(const int).hash_code() == type.hash_code()) { - return VarMsg::INT32; - } else if (typeid(const int64_t).hash_code() == type.hash_code()) { - return VarMsg::INT64; - } else if (typeid(const bool).hash_code() == type.hash_code()) { - return VarMsg::BOOL; - } else { - PADDLE_THROW("Not supported"); - } -} - void GetTensorPayload(framework::Variable* var, const platform::DeviceContext& ctx, VarMsg* request, void** payload, size_t* payload_size) { auto tensor = var->Get(); - // FIXME(wuyi): data types in send_recv.proto is not synced with + // FIXME(wuyi): data types in send_recv.proto is copied from // framework.proto - request->set_data_type(DataTypeToEnum(tensor.type())); + request->set_data_type( + static_cast(framework::ToDataType(tensor.type()))); for (auto& dim : framework::vectorize(tensor.dims())) { request->add_dims(dim); } @@ -96,7 +75,8 @@ void GetSelectedRowsPayload(framework::Variable* var, const platform::DeviceContext& ctx, VarMsg* request, void** payload, size_t* payload_size) { auto* slr = var->GetMutable(); - request->set_data_type(DataTypeToEnum(slr->value().type())); + request->set_data_type( + static_cast(framework::ToDataType(slr->value().type()))); request->set_lod_level(0); request->set_slr_height(slr->height()); @@ -170,7 +150,6 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, ProtoEncodeHelper e(static_cast(buf), 1024); e.WriteRawBytes(std::string(header.data(), header.size())); e.WriteVarlengthBeginning(VarMsg::kSerializedFieldNumber, payload_size); - // steal reference of tensor data ::grpc::Slice slices[4]; // metadata, tensor, rows meta, rows int num_slices = 2; // only SelectedRows have rows buffer -- GitLab