未验证 提交 6beeafe7 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Add more dispatch marco for users (#31058)

* add more dispatch marco

* add more dispatch marco

* add more tests

* revert unneeded change

* add timeout for test dispatch

* add float and complex test

* remove and marco
上级 d5323dab
......@@ -18,6 +18,8 @@ limitations under the License. */
namespace paddle {
///////// Basic Marco ///////////
#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \
case enum_type: { \
using HINT = type; \
......@@ -28,19 +30,139 @@ namespace paddle {
#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \
PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__)
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& dtype = TYPE; \
switch (dtype) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function not implemented for this type."); \
} \
///////// Floating Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Integral Dispatch Marco ///////////
#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Complex Dispatch Marco ///////////
#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating and Integral Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
///////// Floating, Integral and Complex Dispatch Marco ///////////
#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \
[&] { \
const auto& __dtype__ = TYPE; \
switch (__dtype__) { \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT32, float, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::FLOAT64, double, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT64, int64_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT8, int8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::UINT8, uint8_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT16, int16_t, \
__VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX64, \
::paddle::complex64, __VA_ARGS__) \
PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::COMPLEX128, \
::paddle::complex128, __VA_ARGS__) \
default: \
throw std::runtime_error("function " #NAME \
" not implemented for data type `" + \
::paddle::ToString(__dtype__) + "`"); \
} \
}()
// TODD(chenweihang): implement other DISPATH macros in next PR
// TODO(chenweihang): Add more Marcos in the future if needed
} // namespace paddle
......@@ -14,27 +14,90 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
struct complex128;
struct complex64;
struct float16;
struct bfloat16;
using float16 = paddle::platform::float16;
using bfloat16 = paddle::platform::bfloat16;
using complex64 = paddle::platform::complex64;
using complex128 = paddle::platform::complex128;
enum DataType {
BOOL,
INT8,
UINT8,
INT16,
INT32,
INT64,
FLOAT16,
BFLOAT16,
FLOAT32,
FLOAT64,
BFLOAT16,
COMPLEX128,
COMPLEX64,
FLOAT16,
INT64,
INT32,
INT16,
UINT8,
INT8,
BOOL,
COMPLEX128,
// TODO(JiabinYang) support more data types if needed.
};
inline std::string ToString(DataType dtype) {
switch (dtype) {
case DataType::BOOL:
return "bool";
case DataType::INT8:
return "int8_t";
case DataType::UINT8:
return "uint8_t";
case DataType::INT16:
return "int16_t";
case DataType::INT32:
return "int32_t";
case DataType::INT64:
return "int64_t";
case DataType::FLOAT16:
return "float16";
case DataType::BFLOAT16:
return "bfloat16";
case DataType::FLOAT32:
return "float";
case DataType::FLOAT64:
return "double";
case DataType::COMPLEX64:
return "complex64";
case DataType::COMPLEX128:
return "complex128";
default:
throw std::runtime_error("Unsupported paddle enum data type.");
}
}
#define PD_FOR_EACH_DATA_TYPE(_) \
_(bool, DataType::BOOL) \
_(int8_t, DataType::INT8) \
_(uint8_t, DataType::UINT8) \
_(int16_t, DataType::INT16) \
_(int, DataType::INT32) \
_(int64_t, DataType::INT64) \
_(float16, DataType::FLOAT16) \
_(bfloat16, DataType::BFLOAT16) \
_(float, DataType::FLOAT32) \
_(double, DataType::FLOAT64) \
_(complex64, DataType::COMPLEX64) \
_(complex128, DataType::COMPLEX128)
template <paddle::DataType T>
struct DataTypeToCPPType;
#define PD_SPECIALIZE_DataTypeToCPPType(cpp_type, data_type) \
template <> \
struct DataTypeToCPPType<data_type> { \
using type = cpp_type; \
};
PD_FOR_EACH_DATA_TYPE(PD_SPECIALIZE_DataTypeToCPPType)
#undef PD_SPECIALIZE_DataTypeToCPPType
} // namespace paddle
......@@ -300,10 +300,15 @@ void LoadCustomOperatorLib(const std::string& dso_name);
/////////////////////// Op register Macro /////////////////////////
#define PD_BUILD_OP(op_name) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##__COUNTER__##__ = \
#define PD_BUILD_OP_WITH_COUNTER(op_name, counter) \
static ::paddle::OpMetaInfoBuilder __op_meta_info_##counter##__ = \
::paddle::OpMetaInfoBuilder(op_name)
#define PD_BUILD_OP_INNER(op_name, counter) \
PD_BUILD_OP_WITH_COUNTER(op_name, counter)
#define PD_BUILD_OP(op_name) PD_BUILD_OP_INNER(op_name, __COUNTER__)
} // namespace paddle
///////////////////// C API ///////////////////
......
......@@ -17,11 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/custom_tensor_utils.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/complex128.h"
#include "paddle/fluid/platform/complex64.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/transform.h"
namespace paddle {
......
......@@ -30,6 +30,7 @@ endforeach()
set_tests_properties(test_jit_load PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_install PROPERTIES TIMEOUT 180)
set_tests_properties(test_setup_build PROPERTIES TIMEOUT 180)
set_tests_properties(test_dispatch PROPERTIES TIMEOUT 180)
set_tests_properties(test_simple_custom_op_setup PROPERTIES TIMEOUT 250)
set_tests_properties(test_simple_custom_op_jit PROPERTIES TIMEOUT 180)
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
template <typename data_t>
void assign_cpu_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
for (int i = 0; i < x_numel; ++i) {
out_data[i] = x_data[i];
}
}
std::vector<std::vector<int64_t>> InferShape(std::vector<int64_t> x_shape) {
return {x_shape};
}
std::vector<paddle::DataType> InferDType(paddle::DataType x_dtype) {
return {x_dtype};
}
std::vector<paddle::Tensor> DispatchTestInterger(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_INTEGRAL_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP("dispatch_test_integer")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestInterger))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestComplex(const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP("dispatch_test_complex")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestComplex))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestFloatAndInteger(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP("dispatch_test_float_and_integer")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndInteger))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestFloatAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP("dispatch_test_float_and_complex")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndComplex))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
std::vector<paddle::Tensor> DispatchTestFloatAndIntegerAndComplex(
const paddle::Tensor& x) {
auto out = paddle::Tensor(paddle::PlaceType::kCPU);
out.reshape(x.shape());
PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(
x.type(), "assign_cpu_kernel", ([&] {
assign_cpu_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(), x.size());
}));
return {out};
}
PD_BUILD_OP("dispatch_test_float_and_integer_and_complex")
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(DispatchTestFloatAndIntegerAndComplex))
.SetInferShapeFn(PD_INFER_SHAPE(InferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(InferDType));
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import paddle
import numpy as np
from paddle.utils.cpp_extension import load
from utils import paddle_includes, extra_compile_args
dispatch_op = load(
name='dispatch_op',
sources=['dispatch_test_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cflags=extra_compile_args) # add for Coverage CI
class TestJitDispatch(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
def run_dispatch_test(self, func, dtype):
np_x = np.ones([2, 2]).astype(dtype)
x = paddle.to_tensor(np_x)
out = func(x)
np_x = x.numpy()
np_out = out.numpy()
self.assertTrue(dtype in str(np_out.dtype))
self.assertTrue(
np.array_equal(np_x, np_out),
"custom op x: {},\n custom op out: {}".format(np_x, np_out))
def test_dispatch_integer(self):
dtypes = ["int32", "int64", "int8", "uint8", "int16"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_integer, dtype)
def test_dispatch_complex(self):
dtypes = ["complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_complex, dtype)
def test_dispatch_float_and_integer(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16"
]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_integer,
dtype)
def test_dispatch_float_and_complex(self):
dtypes = ["float32", "float64", "complex64", "complex128"]
for dtype in dtypes:
self.run_dispatch_test(dispatch_op.dispatch_test_float_and_complex,
dtype)
def test_dispatch_float_and_integer_and_complex(self):
dtypes = [
"float32", "float64", "int32", "int64", "int8", "uint8", "int16",
"complex64", "complex128"
]
for dtype in dtypes:
self.run_dispatch_test(
dispatch_op.dispatch_test_float_and_integer_and_complex, dtype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册