未验证 提交 eefe5feb 编写于 作者: L Leo Chen 提交者: GitHub

[pten] fit get all register op kernels (#39288)

* upgrade _get_all_register_op_kernels

* add ut

* support xpu/npu

* fix device id

* enhance TransToFluidPlace

* fix compile
上级 92253f11
......@@ -535,9 +535,9 @@ class OperatorWithKernel : public OperatorBase {
bool SupportGPU() const override {
auto pten_kernels = pten::KernelFactory::Instance().SelectKernelMap(
pten::TransToPtenKernelName(type_));
auto has_pten_kernel = std::any_of(
pten_kernels.begin(), pten_kernels.end(),
[](pten::KernelFactory::KernelKeyMap::const_reference kern_pair) {
auto has_pten_kernel =
std::any_of(pten_kernels.begin(), pten_kernels.end(),
[](pten::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == pten::Backend::GPU;
});
if (has_pten_kernel) {
......
......@@ -60,7 +60,8 @@ OpKernelType TransPtenKernelKeyToOpKernelType(
const pten::KernelKey& kernel_key) {
proto::VarType::Type data_type =
pten::TransToProtoVarType(kernel_key.dtype());
platform::Place place = pten::TransToFluidPlace(kernel_key.backend());
// no need to set current device id here
platform::Place place = pten::TransToFluidPlace(kernel_key.backend(), false);
DataLayout data_layout = kernel_key.layout();
LibraryType library_type = LibraryType::kPlain;
if (kernel_key.backend() == pten::Backend::MKLDNN) {
......
......@@ -75,6 +75,7 @@ limitations under the License. */
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
#include "paddle/fluid/pybind/cuda_streams_py.h"
#include "paddle/pten/core/compat/convert_utils.h"
#include "paddle/pten/core/lod_utils.h"
#ifndef PADDLE_ON_INFERENCE
#include "paddle/fluid/pybind/eager.h"
......@@ -715,21 +716,61 @@ PYBIND11_MODULE(core_noavx, m) {
m.def("_get_use_default_grad_op_desc_maker_ops",
[] { return OpInfoMap::Instance().GetUseDefaultGradOpDescMakerOps(); });
m.def("_get_all_register_op_kernels", [] {
auto &all_kernels = paddle::framework::OperatorWithKernel::AllOpKernels();
std::unordered_map<std::string, std::vector<std::string>> all_kernels_info;
m.def(
"_get_all_register_op_kernels",
[](const std::string &lib) {
std::unordered_map<std::string, std::vector<std::string>>
all_kernels_info;
if (lib == "fluid" || lib == "all") {
auto &all_kernels =
paddle::framework::OperatorWithKernel::AllOpKernels();
for (auto &kernel_pair : all_kernels) {
auto op_type = kernel_pair.first;
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
paddle::framework::OpKernelType kernel_type = info_pair.first;
kernel_types.push_back(
kernel_types.emplace_back(
paddle::framework::KernelTypeToString(kernel_type));
}
all_kernels_info.emplace(op_type, kernel_types);
}
}
if (lib == "pten" || lib == "all") {
auto pten_kernels = pten::KernelFactory::Instance().kernels();
for (auto &kernel_pair : pten_kernels) {
auto op_type = pten::TransToFluidOpName(kernel_pair.first);
std::vector<std::string> kernel_types;
for (auto &info_pair : kernel_pair.second) {
framework::OpKernelType kernel_type =
framework::TransPtenKernelKeyToOpKernelType(info_pair.first);
auto kernel_type_str = framework::KernelTypeToString(kernel_type);
if (all_kernels_info.count(op_type)) {
if (std::find(all_kernels_info[op_type].begin(),
all_kernels_info[op_type].end(),
kernel_type_str) ==
all_kernels_info[op_type].end()) {
all_kernels_info[op_type].emplace_back(kernel_type_str);
}
} else {
kernel_types.emplace_back(kernel_type_str);
}
}
if (!kernel_types.empty()) {
all_kernels_info.emplace(op_type, kernel_types);
}
}
}
return all_kernels_info;
});
},
py::arg("lib") = "all",
R"DOC(
Return the registered kernels in paddle.
Args:
lib[string]: the libarary, could be 'pten', 'fluid' and 'all'.
)DOC");
// NOTE(zjl): ctest would load environment variables at the beginning even
// though we have not `import paddle.fluid as fluid`. So we add this API
......@@ -973,7 +1014,8 @@ PYBIND11_MODULE(core_noavx, m) {
PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, -1), true,
platform::errors::InvalidArgument(
"The provided recursive_sequence_lengths info is invalid, "
"The provided recursive_sequence_lengths info is "
"invalid, "
"the LoD converted by recursive_sequence_lengths is %s",
new_lod));
new (&instance) framework::Tensor(new_offset_lod);
......@@ -1035,7 +1077,8 @@ PYBIND11_MODULE(core_noavx, m) {
PADDLE_ENFORCE_EQ(
CheckLoD(new_offset_lod, vectorize(self.dims()).front()), true,
platform::errors::InvalidArgument(
"The provided recursive_sequence_lengths info is invalid, "
"The provided recursive_sequence_lengths info is "
"invalid, "
"the LoD converted by recursive_sequence_lengths is "
"%s",
new_lod));
......@@ -2110,17 +2153,17 @@ All parameter, weight, gradient are variables in Paddle.
.def("__str__", string::to_string<const platform::Place &>);
py::class_<OperatorBase>(m, "Operator")
.def_static(
"create",
.def_static("create",
[](py::bytes protobin) {
proto::OpDesc desc;
PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin), true,
PADDLE_ENFORCE_EQ(desc.ParsePartialFromString(protobin),
true,
platform::errors::InvalidArgument(
"Cannot parse user input to OpDesc"));
PADDLE_ENFORCE_EQ(
desc.IsInitialized(), true,
PADDLE_ENFORCE_EQ(desc.IsInitialized(), true,
platform::errors::InvalidArgument(
"The provided OpDesc is not initialized, the reason is: %s",
"The provided OpDesc is not "
"initialized, the reason is: %s",
desc.InitializationErrorString()));
return OpRegistry::CreateOp(desc);
})
......@@ -2705,8 +2748,8 @@ All parameter, weight, gradient are variables in Paddle.
m.def("register_pass", [](const std::string &pass_type, py::object callable) {
PADDLE_ENFORCE_EQ(
framework::ir::PassRegistry::Instance().Has(pass_type), false,
platform::errors::AlreadyExists(
"Pass '%s' is registered more than once. Please use another name.",
platform::errors::AlreadyExists("Pass '%s' is registered more than "
"once. Please use another name.",
pass_type));
callable.inc_ref();
framework::ir::PassRegistry::Instance().Insert(pass_type, [pass_type,
......
......@@ -3,6 +3,10 @@ if(WITH_GPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_ROCM)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place gpu_info)
elseif(WITH_XPU)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place xpu_info)
elseif(WITH_ASCEND_CL)
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place npu_info)
else()
cc_library(convert_utils SRCS convert_utils.cc DEPS data_type place)
endif()
......
......@@ -18,6 +18,8 @@ limitations under the License. */
// See Note [ Why still include the fluid headers? ]
#include "paddle/fluid/platform/device/gpu/gpu_info.h"
#include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/xpu/xpu_info.h"
namespace pten {
......@@ -66,15 +68,18 @@ paddle::experimental::DataType TransToPtenDataType(
}
}
paddle::platform::Place TransToFluidPlace(const Backend& backend) {
// TODO(chenweihang): add other trans cases later
paddle::platform::Place TransToFluidPlace(const Backend& backend,
bool set_device_id) {
// NOTE(zhiqiu): GetCurrentDeviceId not always success, and device id is not
// always needed.
// So, add set_device_id parameter here.
switch (backend) {
case pten::Backend::CPU:
return paddle::platform::CPUPlace();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case pten::Backend::GPU:
return paddle::platform::CUDAPlace(
paddle::platform::GetCurrentDeviceId());
set_device_id ? paddle::platform::GetCurrentDeviceId() : 0);
#endif
#ifdef PADDLE_WITH_MKLDNN
case pten::Backend::MKLDNN:
......@@ -83,7 +88,17 @@ paddle::platform::Place TransToFluidPlace(const Backend& backend) {
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
case pten::Backend::CUDNN:
return paddle::platform::CUDAPlace(
paddle::platform::GetCurrentDeviceId());
set_device_id ? paddle::platform::GetCurrentDeviceId() : 0);
#endif
#if defined(PADDLE_WITH_XPU)
case pten::Backend::XPU:
return paddle::platform::XPUPlace(
set_device_id ? paddle::platform::GetXPUCurrentDeviceId() : 0);
#endif
#if defined(PADDLE_WITH_ASCEND_CL)
case pten::Backend::NPU:
return paddle::platform::NPUPlace(
set_device_id ? paddle::platform::GetCurrentNPUDeviceId() : 0);
#endif
default:
PADDLE_THROW(paddle::platform::errors::Unimplemented(
......@@ -228,4 +243,16 @@ const std::string& TransToPtenKernelName(const std::string& fluid_op_name) {
return fluid_op_name;
}
const std::string& TransToFluidOpName(const std::string& pten_kernel_name) {
auto it = std::find_if(kernel_alias_name_map.begin(),
kernel_alias_name_map.end(),
[&pten_kernel_name](const auto& pair) {
return pair.second == pten_kernel_name;
});
if (it != kernel_alias_name_map.end()) {
return it->first;
}
return pten_kernel_name;
}
} // namespace pten
......@@ -28,12 +28,14 @@ limitations under the License. */
namespace pten {
const std::string& TransToPtenKernelName(const std::string& fluid_op_name);
const std::string& TransToFluidOpName(const std::string& pten_kernel_name);
Backend TransToPtenBackend(const pten::Place& place);
DataType TransToPtenDataType(
const paddle::framework::proto::VarType::Type& dtype);
pten::Place TransToFluidPlace(const Backend& backend);
paddle::platform::Place TransToFluidPlace(const Backend& backend,
bool set_device_id = true);
paddle::framework::proto::VarType::Type TransToProtoVarType(
const DataType& dtype);
......
......@@ -50,11 +50,11 @@ Kernel KernelFactory::SelectKernel(const std::string& kernel_name,
return kernel_iter->second;
}
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>
KernelFactory::SelectKernelMap(const std::string& kernel_name) const {
KernelKeyMap KernelFactory::SelectKernelMap(
const std::string& kernel_name) const {
auto iter = kernels_.find(kernel_name);
if (iter == kernels_.end()) {
return paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>();
return KernelKeyMap();
}
return iter->second;
}
......
......@@ -196,6 +196,10 @@ class Kernel {
KernelArgsDef args_def_;
};
using KernelKeyMap = paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
/**
* Note: Each Computation need a basic kernel map that named by kernel_name.
* Such as for scale op, KernelMap contains a `scale` kernel map,
......@@ -204,11 +208,6 @@ class Kernel {
*/
class KernelFactory {
public:
using KernelKeyMap =
paddle::flat_hash_map<KernelKey, Kernel, KernelKey::Hash>;
using KernelNameMap = paddle::flat_hash_map<std::string, KernelKeyMap>;
static KernelFactory& Instance();
KernelNameMap& kernels() { return kernels_; }
......
......@@ -136,6 +136,13 @@ struct KernelRegistrar {
for (size_t dtype = static_cast<size_t>(DataType::BOOL);
dtype != static_cast<size_t>(DataType::NUM_DATA_TYPES);
dtype++) {
// NOTE(zhiqiu): why skip these types, because fluid kernel has no kernel
// of these type.
if (dtype == static_cast<size_t>(DataType::UINT32) ||
dtype == static_cast<size_t>(DataType::UINT64) ||
dtype == static_cast<size_t>(DataType::UINT16)) {
continue;
}
ConstructKernel(kernel_name_cstr,
backend,
layout,
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
from paddle.fluid import core
from paddle import compat as cpt
class TestGetAllRegisteredOpKernels(unittest.TestCase):
# reshape kernel is in fluid while not in pten
def test_pten_kernels(self):
self.assertTrue(core._get_all_register_op_kernels('pten')['sign'])
with self.assertRaises(KeyError):
core._get_all_register_op_kernels('pten')['reshape']
# sign kernel is removed from fluid and added into pten
def test_fluid_kernels(self):
self.assertTrue(core._get_all_register_op_kernels('fluid')['reshape'])
with self.assertRaises(KeyError):
core._get_all_register_op_kernels('fluid')['sign']
def test_all_kernels(self):
self.assertTrue(core._get_all_register_op_kernels('all')['reshape'])
self.assertTrue(core._get_all_register_op_kernels('all')['sign'])
self.assertTrue(core._get_all_register_op_kernels()['reshape'])
self.assertTrue(core._get_all_register_op_kernels()['sign'])
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册