diff --git a/paddle/fluid/framework/parallel_executor.cc b/paddle/fluid/framework/parallel_executor.cc index fce1bf47244316bd099d227789afcb4cea20a469..991a0c8238cff60a51fb9a753a713d139064ff58 100644 --- a/paddle/fluid/framework/parallel_executor.cc +++ b/paddle/fluid/framework/parallel_executor.cc @@ -20,6 +20,7 @@ limitations under the License. */ #include "op_registry.h" #include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/operators/math/concat.h" +#include "paddle/fluid/platform/nccl_helper.h" namespace paddle { namespace framework { @@ -299,19 +300,6 @@ class ParallelExecutorPrivate { std::unique_ptr exception_; }; -// TODO(yy): Move this function somewhere -ncclDataType_t ToNCCLDataType(std::type_index type) { - if (type == typeid(float)) { // NOLINT - return ncclFloat; - } else if (type == typeid(double)) { // NOLINT - return ncclDouble; - } else if (type == typeid(int)) { // NOLINT - return ncclInt; - } else { - PADDLE_THROW("Not supported"); - } -} - static std::mutex g_nccl_mtx_; struct NCCLAllReduceOpHandle : public OpHandle { @@ -356,7 +344,7 @@ struct NCCLAllReduceOpHandle : public OpHandle { } if (dtype == -1) { - dtype = ToNCCLDataType(lod_tensor.type()); + dtype = platform::ToNCCLDataType(lod_tensor.type()); } if (numel == 0) { @@ -629,7 +617,7 @@ void ParallelExecutor::BCastParamsToGPUs( if (var_desc->GetType() == proto::VarType::LOD_TENSOR) { auto &main_tensor = main_scope->FindVar(var_desc->Name())->Get(); - ncclDataType_t data_type = ToNCCLDataType(main_tensor.type()); + ncclDataType_t data_type = platform::ToNCCLDataType(main_tensor.type()); auto &dims = main_tensor.dims(); size_t numel = main_tensor.numel(); diff --git a/paddle/fluid/platform/nccl_helper.h b/paddle/fluid/platform/nccl_helper.h new file mode 100644 index 0000000000000000000000000000000000000000..e20f99bc6bc30298dc0ab6bb37adb6b855b6b75e --- /dev/null +++ b/paddle/fluid/platform/nccl_helper.h @@ -0,0 +1,37 @@ +// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include "paddle/fluid/platform/dynload/nccl.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { + +inline ncclDataType_t ToNCCLDataType(std::type_index type) { + if (type == typeid(float)) { // NOLINT + return ncclFloat; + } else if (type == typeid(double)) { // NOLINT + return ncclDouble; + } else if (type == typeid(int)) { // NOLINT + return ncclInt; + } else { + PADDLE_THROW("Not supported"); + } +} + +} // namespace platform +} // namespace paddle