diff --git a/paddle/fluid/framework/details/reduce_op_handle.cc b/paddle/fluid/framework/details/reduce_op_handle.cc index 85d8abc91004b6a98a1a4220cd5425ad2baa6fa1..7a5f7de57ef20b4b909894ff8d742a65ea05874d 100644 --- a/paddle/fluid/framework/details/reduce_op_handle.cc +++ b/paddle/fluid/framework/details/reduce_op_handle.cc @@ -218,18 +218,18 @@ void ReduceOpHandle::RunImpl() { } #if defined PADDLE_WITH_CUDA && defined PADDLE_WITH_DISTRIBUTE - if (framework::IsType(in_selected_rows[0]->value().type())) { + if (in_selected_rows[0]->value().type() == + framework::proto::VarType::FP32) { GatherSelectedRows( in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p, out_var->GetMutable()); - } else if (framework::IsType( - in_selected_rows[0]->value().type())) { + } else if (in_selected_rows[0]->value().type() == + framework::proto::VarType::FP64) { GatherSelectedRows( in_selected_rows, in_places, dev_ctxes_, out_var_handle, t_out_p, out_var->GetMutable()); } else { - PADDLE_ENFORCE(false, - "only support double or float when gahter SelectedRows"); + PADDLE_THROW("only support double or float when gather SelectedRows"); } #endif }); diff --git a/paddle/fluid/operators/distributed/grpc_serde.cc b/paddle/fluid/operators/distributed/grpc_serde.cc index 31fac2133cf159719474207407c52bb96e80e131..94bf0a113be62613e081b7bf05763e48a0a6124c 100644 --- a/paddle/fluid/operators/distributed/grpc_serde.cc +++ b/paddle/fluid/operators/distributed/grpc_serde.cc @@ -122,8 +122,7 @@ void SerializeToByteBuffer(const std::string& name, framework::Variable* var, if (var->IsType()) { auto* slr = var->GetMutable(); ProtoEncodeHelper e2(static_cast(buf), 128); - size_t rows_memory_size = - slr->rows().size() * framework::SizeOfType(typeid(int64_t)); + size_t rows_memory_size = slr->rows().size() * sizeof(int64_t); e2.WriteVarlengthBeginning(VarMsg::kRowsFieldNumber, rows_memory_size); slices[2] = ::grpc::Slice(e2.size()); memcpy(const_cast(slices[2].begin()), e2.data(), e2.size()); diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.cc b/paddle/fluid/operators/distributed/sendrecvop_utils.cc index 6ba883ba01f73bd941e7b1b58afb777e2588764b..5fd42e884ab4704a1bf63ae7e2b12d7c11d585e8 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.cc +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.cc @@ -61,8 +61,7 @@ TensorPayload GetTensorPayload(framework::Variable* var, auto tensor = var->Get(); // FIXME(wuyi): data types in send_recv.proto is copied from // framework.proto - request->set_data_type( - static_cast(framework::ToDataType(tensor.type()))); + request->set_data_type(static_cast(tensor.type())); for (auto& dim : framework::vectorize(tensor.dims())) { request->add_dims(dim); } @@ -83,8 +82,7 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, const platform::DeviceContext& ctx, VarMsg* request) { auto* slr = var->GetMutable(); - request->set_data_type( - static_cast(framework::ToDataType(slr->value().type()))); + request->set_data_type(static_cast(slr->value().type())); request->set_lod_level(0); request->set_slr_height(slr->height()); diff --git a/paddle/fluid/operators/distributed/sendrecvop_utils.h b/paddle/fluid/operators/distributed/sendrecvop_utils.h index 523e56fe3e4eaf9d4faa8f0bf91d0b6075e0b0e5..710c83916601756aa81d98c2a829be348ebd2daa 100644 --- a/paddle/fluid/operators/distributed/sendrecvop_utils.h +++ b/paddle/fluid/operators/distributed/sendrecvop_utils.h @@ -58,18 +58,19 @@ TensorPayload GetSelectedRowsPayload(framework::Variable* var, const platform::DeviceContext& ctx, VarMsg* request); -inline std::type_index ToTypeIndex(sendrecv::VariableMessage::Type type) { +inline framework::proto::VarType::Type ToVarType( + sendrecv::VariableMessage::Type type) { switch (type) { case sendrecv::VariableMessage::FP32: - return typeid(float); // NOLINT + return framework::proto::VarType::FP32; // NOLINT case sendrecv::VariableMessage::FP64: - return typeid(double); // NOLINT + return framework::proto::VarType::FP64; // NOLINT case sendrecv::VariableMessage::INT32: - return typeid(int); // NOLINT + return framework::proto::VarType::INT32; // NOLINT case sendrecv::VariableMessage::INT64: - return typeid(int64_t); // NOLINT + return framework::proto::VarType::INT64; // NOLINT case sendrecv::VariableMessage::BOOL: - return typeid(bool); // NOLINT + return framework::proto::VarType::BOOL; // NOLINT default: PADDLE_THROW("Not support type %d", type); } diff --git a/paddle/fluid/operators/distributed/variable_response.cc b/paddle/fluid/operators/distributed/variable_response.cc index 5b2be04e6a1656f79a8e2b3a6fa19c847e81b5cb..921c96b5839a37a4024b17b9ff71442f31b5b8f0 100644 --- a/paddle/fluid/operators/distributed/variable_response.cc +++ b/paddle/fluid/operators/distributed/variable_response.cc @@ -114,7 +114,7 @@ bool VariableResponse::CopyLodTensorData( tensor->set_lod(lod); void* tensor_data = - tensor->mutable_data(ctx.GetPlace(), ToTypeIndex(meta_.data_type())); + tensor->mutable_data(ctx.GetPlace(), ToVarType(meta_.data_type())); VLOG(6) << "Tensor.memory_size = " << tensor->memory_size() << ", Buffer Size = " << length; @@ -139,13 +139,13 @@ bool VariableResponse::CopySelectRowsTensorData( slr->set_height(meta_.slr_height()); auto* tensor = slr->mutable_value(); tensor->Resize(dims); - PADDLE_ENFORCE_EQ(static_cast(tensor->numel()), - length / framework::SizeOfType( - paddle::operators::distributed::ToTypeIndex( - meta_.data_type()))); + PADDLE_ENFORCE_EQ( + static_cast(tensor->numel()), + length / framework::SizeOfType(paddle::operators::distributed::ToVarType( + meta_.data_type()))); void* tensor_data = tensor->mutable_data( ctx.GetPlace(), - paddle::operators::distributed::ToTypeIndex(meta_.data_type())); + paddle::operators::distributed::ToVarType(meta_.data_type())); if (!ReadRaw(input, ctx, tensor->place(), tensor_data, length)) { return false; @@ -159,8 +159,7 @@ bool VariableResponse::CopySelectRowsData( const platform::DeviceContext& ctx, int length) { auto* slr = GetVar()->GetMutable(); slr->mutable_rows()->clear(); - slr->mutable_rows()->resize(length / - framework::SizeOfType(typeid(int64_t))); // int64 + slr->mutable_rows()->resize(length / sizeof(int64_t)); // int64 int64_t* rows_data = slr->mutable_rows()->data(); // copy rows CPU data, GPU data will be copied lazily. diff --git a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc index 252a63cb605f65e8572281a05e884fb8b020a820..da0185b8c492eeb694902b46c871c44cd060d438 100644 --- a/paddle/fluid/operators/distributed_ops/merge_ids_op.cc +++ b/paddle/fluid/operators/distributed_ops/merge_ids_op.cc @@ -108,9 +108,7 @@ class MergeIdsOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.MultiInput("X").front()->type()), - ctx.GetPlace()); + ctx.MultiInput("X").front()->type(), ctx.GetPlace()); } }; diff --git a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc index 98b0af7688b928f21573247b327bee1d22a73f17..7e16e6ff66b603634aa7cd26f71a4f2d3159c4e4 100644 --- a/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc +++ b/paddle/fluid/operators/distributed_ops/ref_by_trainer_id_op.cc @@ -42,9 +42,7 @@ class RefByTrainerIdOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::ToDataType( - ctx.MultiInput("X")[0]->type()), - ctx.GetPlace()); + ctx.MultiInput("X")[0]->type(), ctx.GetPlace()); } };