未验证 提交 eefe5feb 编写于 作者: L Leo Chen 提交者: GitHub

[pten] fit get all register op kernels (#39288)

* upgrade _get_all_register_op_kernels

* add ut

* support xpu/npu

* fix device id

* enhance TransToFluidPlace

* fix compile
上级 92253f11
...@@ -535,9 +535,9 @@ class OperatorWithKernel : public OperatorBase { ...@@ -535,9 +535,9 @@ class OperatorWithKernel : public OperatorBase {
bool SupportGPU() const override { bool SupportGPU() const override {
auto pten_kernels = pten::KernelFactory::Instance().SelectKernelMap( auto pten_kernels = pten::KernelFactory::Instance().SelectKernelMap(
pten::TransToPtenKernelName(type_)); pten::TransToPtenKernelName(type_));
auto has_pten_kernel = std::any_of( auto has_pten_kernel =
pten_kernels.begin(), pten_kernels.end(), std::any_of(pten_kernels.begin(), pten_kernels.end(),
[](pten::KernelFactory::KernelKeyMap::const_reference kern_pair) { [](pten::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == pten::Backend::GPU; return kern_pair.first.backend() == pten::Backend::GPU;
}); });
if (has_pten_kernel) { if (has_pten_kernel) {
......
...@@ -60,7 +60,8 @@ OpKernelType TransPtenKernelKeyToOpKernelType( ...@@ -60,7 +60,8 @@ OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key) { const pten::KernelKey& kernel_key) {
proto::VarType::Type data_type = proto::VarType::Type data_type =
pten::TransToProtoVarType(kernel_key.dtype()); pten::TransToProtoVarType(kernel_key.dtype());
platform::Place place = pten::TransToFluidPlace(kernel_key.backend()); // no need to set current device id here
platform::Place place = pten::TransToFluidPlace(kernel_key.backend(), false);
DataLayout data_layout = kernel_key.layout(); DataLayout data_layout = kernel_key.layout();
LibraryType library_type = LibraryType::kPlain; LibraryType library_type = LibraryType::kPlain;
if (kernel_key.backend() == pten::Backend::MKLDNN) { if (kernel_key.backend() == pten::Backend::MKLDNN) {
......
...@@ -75,6 +75,7 @@ limitations under the License. */ ...@@ -75,6 +75,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h" #include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/lod_utils.h" #include "paddle/pten/core/lod_utils.h"
#ifndef PADDLE_ON_INFERENCE #ifndef PADDLE_ON_INFERENCE
#include "paddle/fluid/pybind/eager.h" #include "paddle/fluid/pybind/eager.h"
...@@ -715,21 +716,61 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -715,21 +716,61 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("_get_use_default_grad_op_desc_maker_ops", m.def("_get_use_default_grad_op_desc_maker_ops",
[] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); }); [] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); });
m.def("_get_all_register_op_kernels", [] { m.def(
auto &all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels(); "_get_all_register_op_kernels",
std::unordered_map<std::string, std::vector<std::string>> all_kernels_info; [](const std::string &lib) {
std::unordered_map<std::string, std::vector<std::string>>
all_kernels_info;
if (lib == "fluid" || lib == "all") {
auto &all_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
for (auto &kernel_pair : all_kernels) { for (auto &kernel_pair : all_kernels) {
auto op_type = kernel_pair.first; auto op_type = kernel_pair.first;
std::vector<std::string> kernel_types; std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) { for (auto &info_pair : kernel_pair.second) {
paddle::framework::OpKernelType kernel_type = info_pair.first; paddle::framework::OpKernelType kernel_type = info_pair.first;
kernel_types.push_back( kernel_types.emplace_back(
paddle::framework::KernelTypeToString(kernel_type)); paddle::framework::KernelTypeToString(kernel_type));
} }
all_kernels_info.emplace(op_type, kernel_types); all_kernels_info.emplace(op_type, kernel_types);
} }
}
if (lib == "pten" || lib == "all") {
auto pten_kernels = pten::KernelFactory::Instance().kernels();
for (auto &kernel_pair : pten_kernels) {
auto op_type = pten::TransToFluidOpName(kernel_pair.first);
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
framework::OpKernelType kernel_type =
framework::TransPtenKernelKeyToOpKernelType(info_pair.first);
auto kernel_type_str = framework::KernelTypeToString(kernel_type);
if (all_kernels_info.count(op_type)) {
if (std::find(all_kernels_info[op_type].begin(),
all_kernels_info[op_type].end(),
kernel_type_str) ==
all_kernels_info[op_type].end()) {
all_kernels_info[op_type].emplace_back(kernel_type_str);
}
} else {
kernel_types.emplace_back(kernel_type_str);
}
}
if (!kernel_types.empty()) {
all_kernels_info.emplace(op_type, kernel_types);
}
}
}
return all_kernels_info; return all_kernels_info;
}); },
py::arg("lib") = "all",
R"DOC(
Return the registered kernels in paddle.
Args:
lib[string]: the libarary, could be 'pten', 'fluid' and 'all'.
)DOC");
// NOTE(zjl): ctest would load environment variables at the beginning even // NOTE(zjl): ctest would load environment variables at the beginning even
// though we have not `import paddle.fluid as fluid`. So we add this API // though we have not `import paddle.fluid as fluid`. So we add this API
...@@ -973,7 +1014,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -973,7 +1014,8 @@ PYBIND11_MODULE(core_noavx, m) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, -1), true, CheckLoD(new_offset_lod, -1), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The provided recursive_sequence_lengths info is invalid, " "The provided recursive_sequence_lengths info is "
"invalid, "
"the LoD converted by recursive_sequence_lengths is %s", "the LoD converted by recursive_sequence_lengths is %s",
new_lod)); new_lod));
new (&instance) framework::Tensor(new_offset_lod); new (&instance) framework::Tensor(new_offset_lod);
...@@ -1035,7 +1077,8 @@ PYBIND11_MODULE(core_noavx, m) { ...@@ -1035,7 +1077,8 @@ PYBIND11_MODULE(core_noavx, m) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, vectorize(self.dims()).front()), true, CheckLoD(new_offset_lod, vectorize(self.dims()).front()), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The provided recursive_sequence_lengths info is invalid, " "The provided recursive_sequence_lengths info is "
"invalid, "
"the LoD converted by recursive_sequence_lengths is " "the LoD converted by recursive_sequence_lengths is "
"%s", "%s",
new_lod)); new_lod));
...@@ -2110,17 +2153,17 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2110,17 +2153,17 @@ All parameter, weight, gradient are variables in Paddle.
.def("__str__", string::to_string<const platform::Place &>); .def("__str__", string::to_string<const platform::Place &>);
py::class_<OperatorBase>(m, "Operator") py::class_<OperatorBase>(m, "Operator")
.def_static( .def_static("create",
"create",
[](py::bytes protobin) { [](py::bytes protobin) {
proto::OpDesc desc; proto::OpDesc desc;
PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), true, PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin),
true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Cannot parse user input to OpDesc")); "Cannot parse user input to OpDesc"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(desc.IsInitialized(), true,
desc.IsInitialized(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The provided OpDesc is not initialized, the reason is: %s", "The provided OpDesc is not "
"initialized, the reason is: %s",
desc.InitializationErrorString())); desc.InitializationErrorString()));
return OpRegistry::CreateOp(desc); return OpRegistry::CreateOp(desc);
}) })
...@@ -2705,8 +2748,8 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -2705,8 +2748,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("register_pass", [](const std::string &pass_type, py::object callable) { m.def("register_pass", [](const std::string &pass_type, py::object callable) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
framework::ir::PassRegistry::Instance().Has(pass_type), false, framework::ir::PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists( platform::errors::AlreadyExists("Pass '%s' is registered more than "
"Pass '%s' is registered more than once. Please use another name.", "once. Please use another name.",
pass_type)); pass_type));
callable.inc_ref(); callable.inc_ref();
framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type, framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type,
......
...@@ -3,6 +3,10 @@ if(WITH_GPU) ...@@ -3,6 +3,10 @@ if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_ROCM) elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_XPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place xpu_info)
elseif(WITH_ASCEND_CL)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place npu_info)
else() else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place) cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place)
endif() endif()
......
...@@ -18,6 +18,8 @@ limitations under the License. */ ...@@ -18,6 +18,8 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ] // See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_info.h" #include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
namespace pten { namespace pten {
...@@ -66,15 +68,18 @@ paddle::experimental::DataType TransToPtenDataType( ...@@ -66,15 +68,18 @@ paddle::experimental::DataType TransToPtenDataType(
} }
} }
paddle::platform::Place TransToFluidPlace(const Backend& backend) { paddle::platform::Place TransToFluidPlace(const Backend& backend,
// TODO(chenweihang): add other trans cases later bool set_device_id) {
// NOTE(zhiqiu): GetCurrentDeviceId not always success, and device id is not
// always needed.
// So, add set_device_id parameter here.
switch (backend) { switch (backend) {
case pten::Backend::CPU: case pten::Backend::CPU:
return paddle::platform::CPUPlace(); return paddle::platform::CPUPlace();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case pten::Backend::GPU: case pten::Backend::GPU:
return paddle::platform::CUDAPlace( return paddle::platform::CUDAPlace(
paddle::platform::GetCurrentDeviceId()); set_device_id ? paddle::platform::GetCurrentDeviceId() : 0);
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
case pten::Backend::MKLDNN: case pten::Backend::MKLDNN:
...@@ -83,7 +88,17 @@ paddle::platform::Place TransToFluidPlace(const Backend& backend) { ...@@ -83,7 +88,17 @@ paddle::platform::Place TransToFluidPlace(const Backend& backend) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case pten::Backend::CUDNN: case pten::Backend::CUDNN:
return paddle::platform::CUDAPlace( return paddle::platform::CUDAPlace(
paddle::platform::GetCurrentDeviceId()); set_device_id ? paddle::platform::GetCurrentDeviceId() : 0);
#endif
#if defined(PADDLE_WITH_XPU)
case pten::Backend::XPU:
return paddle::platform::XPUPlace(
set_device_id ? paddle::platform::GetXPUCurrentDeviceId() : 0);
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
case pten::Backend::NPU:
return paddle::platform::NPUPlace(
set_device_id ? paddle::platform::GetCurrentNPUDeviceId() : 0);
#endif #endif
default: default:
PADDLE_THROW(paddle::platform::errors::Unimplemented( PADDLE_THROW(paddle::platform::errors::Unimplemented(
...@@ -228,4 +243,16 @@ const std::string& TransToPtenKernelName(const std::string& fluid_op_name) { ...@@ -228,4 +243,16 @@ const std::string& TransToPtenKernelName(const std::string& fluid_op_name) {
return fluid_op_name; return fluid_op_name;
} }
const std::string& TransToFluidOpName(const std::string& pten_kernel_name) {
auto it = std::find_if(kernel_alias_name_map.begin(),
kernel_alias_name_map.end(),
[&pten_kernel_name](const auto& pair) {
return pair.second == pten_kernel_name;
});
if (it != kernel_alias_name_map.end()) {
return it->first;
}
return pten_kernel_name;
}
} // namespace pten } // namespace pten
...@@ -28,12 +28,14 @@ limitations under the License. */ ...@@ -28,12 +28,14 @@ limitations under the License. */
namespace pten { namespace pten {
const std::string& TransToPtenKernelName(const std::string& fluid_op_name); const std::string& TransToPtenKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& pten_kernel_name);
Backend TransToPtenBackend(const pten::Place& place); Backend TransToPtenBackend(const pten::Place& place);
DataType TransToPtenDataType( DataType TransToPtenDataType(
const paddle::framework::proto::VarType::Type& dtype); const paddle::framework::proto::VarType::Type& dtype);
pten::Place TransToFluidPlace(const Backend& backend); paddle::platform::Place TransToFluidPlace(const Backend& backend,
bool set_device_id = true);
paddle::framework::proto::VarType::Type TransToProtoVarType( paddle::framework::proto::VarType::Type TransToProtoVarType(
const DataType& dtype); const DataType& dtype);
......
...@@ -50,11 +50,11 @@ Kernel KernelFactory::SelectKernel(const std::string& kernel_name, ...@@ -50,11 +50,11 @@ Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
return kernel_iter->second; return kernel_iter->second;
} }
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash> KernelKeyMap KernelFactory::SelectKernelMap(
KernelFactory::SelectKernelMap(const std::string& kernel_name) const { const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name); auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) { if (iter == kernels_.end()) {
return paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>(); return KernelKeyMap();
} }
return iter->second; return iter->second;
} }
......
...@@ -196,6 +196,10 @@ class Kernel { ...@@ -196,6 +196,10 @@ class Kernel {
KernelArgsDef args_def_; KernelArgsDef args_def_;
}; };
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
/** /**
* Note: Each Computation need a basic kernel map that named by kernel_name. * Note: Each Computation need a basic kernel map that named by kernel_name.
* Such as for scale op, KernelMap contains a `scale` kernel map, * Such as for scale op, KernelMap contains a `scale` kernel map,
...@@ -204,11 +208,6 @@ class Kernel { ...@@ -204,11 +208,6 @@ class Kernel {
*/ */
class KernelFactory { class KernelFactory {
public: public:
using KernelKeyMap =
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
static KernelFactory& Instance(); static KernelFactory& Instance();
KernelNameMap& kernels() { return kernels_; } KernelNameMap& kernels() { return kernels_; }
......
...@@ -136,6 +136,13 @@ struct KernelRegistrar { ...@@ -136,6 +136,13 @@ struct KernelRegistrar {
for (size_t dtype = static_cast<size_t>(DataType::BOOL); for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES); dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) { dtype++) {
// NOTE(zhiqiu): why skip these types, because fluid kernel has no kernel
// of these type.
if (dtype == static_cast<size_t>(DataType::UINT32) ||
dtype == static_cast<size_t>(DataType::UINT64) ||
dtype == static_cast<size_t>(DataType::UINT16)) {
continue;
}
ConstructKernel(kernel_name_cstr, ConstructKernel(kernel_name_cstr,
backend, backend,
layout, layout,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid import core
from paddle import compat as cpt
class TestGetAllRegisteredOpKernels(unittest.TestCase):
# reshape kernel is in fluid while not in pten
def test_pten_kernels(self):
self.assertTrue(core._get_all_register_op_kernels('pten')['sign'])
with self.assertRaises(KeyError):
core._get_all_register_op_kernels('pten')['reshape']
# sign kernel is removed from fluid and added into pten
def test_fluid_kernels(self):
self.assertTrue(core._get_all_register_op_kernels('fluid')['reshape'])
with self.assertRaises(KeyError):
core._get_all_register_op_kernels('fluid')['sign']
def test_all_kernels(self):
self.assertTrue(core._get_all_register_op_kernels('all')['reshape'])
self.assertTrue(core._get_all_register_op_kernels('all')['sign'])
self.assertTrue(core._get_all_register_op_kernels()['reshape'])
self.assertTrue(core._get_all_register_op_kernels()['sign'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册