diff --git a/paddle/phi/api/ext/dispatch.h b/paddle/phi/api/ext/dispatch.h index 6b6d0ae7fe7230263454d0bf08da40e4a793549b..aa9cd0f53a4c642b97292f77567f2b6706857f00 100644 --- a/paddle/phi/api/ext/dispatch.h +++ b/paddle/phi/api/ext/dispatch.h @@ -14,327 +14,57 @@ limitations under the License. */ #pragma once -#include "paddle/phi/api/ext/exception.h" -#include "paddle/phi/common/data_type.h" +#include "paddle/phi/core/visit_type.h" namespace paddle { -///////// Basic Marco /////////// - -#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ - case enum_type: { \ - using HINT = type; \ - __VA_ARGS__(); \ - break; \ - } - -#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ - PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__) +// Note: Keep this file only for compatibility with custom operators ///////// Floating Dispatch Marco /////////// -#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \ + PD_VISIT_FLOATING_TYPES(TYPE, NAME, __VA_ARGS__) -#define PD_DISPATCH_FLOATING_AND_HALF_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT16, paddle::float16, __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_AND_HALF_TYPES(TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_HALF_TYPES(TYPE, NAME, __VA_ARGS__) ///////// Integral Dispatch Marco /////////// -#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_INTEGRAL_TYPES(TYPE, NAME, ...) \ + PD_VISIT_INTEGRAL_TYPES(TYPE, NAME, __VA_ARGS__) ///////// Complex Dispatch Marco /////////// -#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_COMPLEX_TYPES(TYPE, NAME, ...) \ + PD_VISIT_COMPLEX_TYPES(TYPE, NAME, __VA_ARGS__) ///////// Floating and Integral Dispatch Marco /////////// -#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, __VA_ARGS__) ///////// Floating and Complex Dispatch Marco /////////// -#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, __VA_ARGS__) ///////// Floating and Complex and other type Dispatch Marco /////////// -#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_1_TYPES( \ - SPECIFIED_TYPE, TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, \ - SPECIFIED_TYPE, \ - ::paddle::experimental::DataTypeToCppType::type, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_1_TYPE( \ + SPECIFIED_TYPE, TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_COMPLEX_AND_1_TYPE( \ + SPECIFIED_TYPE, TYPE, NAME, __VA_ARGS__) ///////// Floating and Complex and 2 other type Dispatch Marco /////////// -#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( \ - SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, \ - SPECIFIED_TYPE1, \ - ::paddle::experimental::DataTypeToCppType::type, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, \ - SPECIFIED_TYPE2, \ - ::paddle::experimental::DataTypeToCppType::type, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_AND_COMPLEX_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_COMPLEX_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, __VA_ARGS__) ///////// Floating, Integral and Complex Dispatch Marco /////////// -#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() - -// TODO(chenweihang): Add more Marcos in the future if needed - -#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::BOOL, bool, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT8, int8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::UINT8, uint8_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT16, int16_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT32, int32_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT64, int64_t, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::phi::DataType::BFLOAT16, \ - paddle::experimental::bfloat16, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::phi::DataType::FLOAT16, \ - paddle::experimental::float16, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::phi::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::phi::DataType::COMPLEX64, \ - paddle::experimental::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::phi::DataType::COMPLEX128, \ - paddle::experimental::complex128, \ - __VA_ARGS__) \ - default: \ - PADDLE_THROW(phi::errors::InvalidArgument( \ - "Invalid enum data type `%d`.", static_cast(__dtype__))); \ - } \ - }() - -#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \ - SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \ - [&] { \ - const auto& __dtype__ = TYPE; \ - switch (__dtype__) { \ - PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::BOOL, bool, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX64, \ - ::paddle::complex64, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE(NAME, \ - ::paddle::DataType::COMPLEX128, \ - ::paddle::complex128, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, \ - SPECIFIED_TYPE1, \ - ::paddle::experimental::DataTypeToCppType::type, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, \ - SPECIFIED_TYPE2, \ - ::paddle::experimental::DataTypeToCppType::type, \ - __VA_ARGS__) \ - PD_PRIVATE_CASE_TYPE( \ - NAME, \ - SPECIFIED_TYPE3, \ - ::paddle::experimental::DataTypeToCppType::type, \ - __VA_ARGS__) \ - default: \ - PD_THROW("function " #NAME " is not implemented for data type `", \ - __dtype__, \ - "`"); \ - } \ - }() +#define PD_DISPATCH_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + PD_VISIT_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, __VA_ARGS__) } // namespace paddle diff --git a/paddle/phi/api/lib/data_transform.cc b/paddle/phi/api/lib/data_transform.cc index 82d2e741e9de852823726f91a6f2d7370c8d0b0e..d4e92ded324da5ed5fead35ce22d5b31a84985b0 100644 --- a/paddle/phi/api/lib/data_transform.cc +++ b/paddle/phi/api/lib/data_transform.cc @@ -14,7 +14,6 @@ limitations under the License. */ #include "paddle/phi/api/lib/data_transform.h" -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/api/lib/kernel_dispatch.h" #include "paddle/phi/api/lib/utils/storage.h" #include "paddle/phi/backends/all_context.h" diff --git a/paddle/phi/core/visit_type.h b/paddle/phi/core/visit_type.h new file mode 100644 index 0000000000000000000000000000000000000000..bd972c8ceedc78f87bfe2cc3e1c51c3a0732cb45 --- /dev/null +++ b/paddle/phi/core/visit_type.h @@ -0,0 +1,338 @@ +/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/phi/api/ext/exception.h" +#include "paddle/phi/common/data_type.h" + +namespace phi { + +///////// Basic Marco /////////// + +#define PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, HINT, ...) \ + case enum_type: { \ + using HINT = type; \ + __VA_ARGS__(); \ + break; \ + } + +#define PD_PRIVATE_CASE_TYPE(NAME, enum_type, type, ...) \ + PD_PRIVATE_CASE_TYPE_USING_HINT(NAME, enum_type, type, data_t, __VA_ARGS__) + +///////// Floating Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +#define PD_VISIT_FLOATING_AND_HALF_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT16, paddle::float16, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Integral Dispatch Marco /////////// + +#define PD_VISIT_INTEGRAL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Complex Dispatch Marco /////////// + +#define PD_VISIT_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Floating and Integral Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_AND_INTEGRAL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Floating and Complex Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Floating and Complex and other type Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_AND_COMPLEX_AND_1_TYPE( \ + SPECIFIED_TYPE, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Floating and Complex and 2 other type Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_AND_COMPLEX_AND_2_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE1, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE2, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +///////// Floating, Integral and Complex Dispatch Marco /////////// + +#define PD_VISIT_FLOATING_AND_INTEGRAL_AND_COMPLEX_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::INT32, int, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::INT16, int16_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +#define PD_VISIT_ALL_TYPES(TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::BOOL, bool, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT8, int8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::UINT8, uint8_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT16, int16_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT32, int32_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::INT64, int64_t, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::phi::DataType::BFLOAT16, \ + paddle::experimental::bfloat16, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::phi::DataType::FLOAT16, \ + paddle::experimental::float16, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, ::phi::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::phi::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::phi::DataType::COMPLEX64, \ + paddle::experimental::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::phi::DataType::COMPLEX128, \ + paddle::experimental::complex128, \ + __VA_ARGS__) \ + default: \ + PADDLE_THROW(phi::errors::InvalidArgument( \ + "Invalid enum data type `%d`.", static_cast(__dtype__))); \ + } \ + }() + +#define PD_VISIT_BOOL_AND_FLOATING_AND_COMPLEX_AND_3_TYPES( \ + SPECIFIED_TYPE1, SPECIFIED_TYPE2, SPECIFIED_TYPE3, TYPE, NAME, ...) \ + [&] { \ + const auto& __dtype__ = TYPE; \ + switch (__dtype__) { \ + PD_PRIVATE_CASE_TYPE(NAME, ::paddle::DataType::BOOL, bool, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT32, float, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, ::paddle::DataType::FLOAT64, double, __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX64, \ + ::paddle::complex64, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE(NAME, \ + ::paddle::DataType::COMPLEX128, \ + ::paddle::complex128, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE1, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE2, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + PD_PRIVATE_CASE_TYPE( \ + NAME, \ + SPECIFIED_TYPE3, \ + ::paddle::experimental::DataTypeToCppType::type, \ + __VA_ARGS__) \ + default: \ + PD_THROW("function " #NAME " is not implemented for data type `", \ + __dtype__, \ + "`"); \ + } \ + }() + +} // namespace phi diff --git a/paddle/phi/kernels/cpu/cast_grad_kernel.cc b/paddle/phi/kernels/cpu/cast_grad_kernel.cc index c294c743bd4cf2ee05fd7f26349409e919b2d7a8..79f53cbce1a4a7079f0b0492a1cf0c7cef65fe28 100644 --- a/paddle/phi/kernels/cpu/cast_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cast_grad_kernel.cc @@ -13,7 +13,9 @@ // limitations under the License. #include "paddle/phi/kernels/cast_grad_kernel.h" + #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/cpu/cast_impl.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/cast_impl.h b/paddle/phi/kernels/cpu/cast_impl.h index d39ef24e7beb1d5e8665e41bae806c5ddc31a6e8..9648b584243f5b2aa65a5eee7e4fbeb7292f0284 100644 --- a/paddle/phi/kernels/cpu/cast_impl.h +++ b/paddle/phi/kernels/cpu/cast_impl.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/phi/api/ext/dispatch.h" + #include "paddle/phi/backends/cpu/cpu_context.h" // See Note [ Why still include the fluid headers? ] diff --git a/paddle/phi/kernels/cpu/cast_kernel.cc b/paddle/phi/kernels/cpu/cast_kernel.cc index b53c94eb4cae262ef46ea6edfe3fd9872e7cd1de..2132f0d5ae86cc6bf127f9fa4e30797f686e5f99 100644 --- a/paddle/phi/kernels/cpu/cast_kernel.cc +++ b/paddle/phi/kernels/cpu/cast_kernel.cc @@ -16,6 +16,7 @@ #include "paddle/phi/kernels/cpu/cast_impl.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" namespace phi { diff --git a/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc b/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc index d4a632b5e6ece09030a7071ee7919cb43a2015df..021fdac225330814e6bd1d25b9f5a061b9b75207 100644 --- a/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc +++ b/paddle/phi/kernels/cpu/cross_entropy_grad_kernel.cc @@ -16,13 +16,11 @@ limitations under the License. */ #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/eigen/common.h" -// TODO(chenweihang): move dispatch.h into phi/core -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { template @@ -200,7 +198,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, axis, logits_grad); } else { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( dtype, "CrossEntropyWithSoftmaxGradCPUKernel", ([&] { CrossEntropyWithSoftmaxGradCPUKernel(dev_ctx, label, diff --git a/paddle/phi/kernels/cpu/elementwise_kernel.cc b/paddle/phi/kernels/cpu/elementwise_kernel.cc index 4ca41de7bb64a9f444e4b57713ccb46adbf96e4e..a91ca1ee3244bdbdde2c9c248317e40d45b3dc17 100644 --- a/paddle/phi/kernels/cpu/elementwise_kernel.cc +++ b/paddle/phi/kernels/cpu/elementwise_kernel.cc @@ -13,7 +13,6 @@ // limitations under the License. #include "paddle/phi/kernels/cpu/elementwise.h" -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" diff --git a/paddle/phi/kernels/cpu/reduce.h b/paddle/phi/kernels/cpu/reduce.h index af67bdf5d624f33fd4ec06db425ec8312b490642..06a458832d19f8ad0e6a9d545a61a344be903650 100644 --- a/paddle/phi/kernels/cpu/reduce.h +++ b/paddle/phi/kernels/cpu/reduce.h @@ -16,8 +16,8 @@ #include -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/backends/cpu/cpu_context.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/cast_kernel.h" #include "paddle/phi/api/lib/utils/storage.h" diff --git a/paddle/phi/kernels/cpu/transpose_kernel.cc b/paddle/phi/kernels/cpu/transpose_kernel.cc index 5dc4866e1efc33bcd9a680dfe2eb2804e28a7588..a2f5aa2a29795f51aabfbe438247917d91697818 100644 --- a/paddle/phi/kernels/cpu/transpose_kernel.cc +++ b/paddle/phi/kernels/cpu/transpose_kernel.cc @@ -13,8 +13,9 @@ // limitations under the License. #include "paddle/phi/kernels/transpose_kernel.h" + #include -#include "paddle/phi/api/ext/dispatch.h" + #include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/core/kernel_registry.h" diff --git a/paddle/phi/kernels/funcs/reduce_function.h b/paddle/phi/kernels/funcs/reduce_function.h index 4eb6ba0310886e6cc3907dbd1691b11e6deaae8f..b414dfc5d6e849b86f584af1f36b7d39243a9f7b 100644 --- a/paddle/phi/kernels/funcs/reduce_function.h +++ b/paddle/phi/kernels/funcs/reduce_function.h @@ -35,7 +35,6 @@ namespace cub = hipcub; #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/core/dense_tensor.h" diff --git a/paddle/phi/kernels/gpu/cast_grad_kernel.cu b/paddle/phi/kernels/gpu/cast_grad_kernel.cu index 1c1d8cf2c06d44913cfc079a20d12bcd0f4396ab..f4b610301583c515f72a29f367de2d9525cd10ad 100644 --- a/paddle/phi/kernels/gpu/cast_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_grad_kernel.cu @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cast_grad_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/gpu/cast_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cast_impl.h b/paddle/phi/kernels/gpu/cast_impl.h index 8f6351e675cfa48bfa1e9d0af79a992839bcbd75..f73d396572541ba4388c5a3421bf37638938271c 100644 --- a/paddle/phi/kernels/gpu/cast_impl.h +++ b/paddle/phi/kernels/gpu/cast_impl.h @@ -13,7 +13,7 @@ // limitations under the License. #pragma once -#include "paddle/phi/api/ext/dispatch.h" + #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/kernels/funcs/elementwise_base.h" diff --git a/paddle/phi/kernels/gpu/cast_kernel.cu b/paddle/phi/kernels/gpu/cast_kernel.cu index 40a84648e4b163baa0c7c2593eb05efcb9494d1c..a879dc3bafd746450433f5a98281f5d8b35b801e 100644 --- a/paddle/phi/kernels/gpu/cast_kernel.cu +++ b/paddle/phi/kernels/gpu/cast_kernel.cu @@ -12,8 +12,10 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/cast_kernel.h" + +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/gpu/cast_impl.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu index 215b94c52b3950f68fcc084ad1942a612e79352b..c66daf4fe64e6107a646b6aa8bd4f656a93dba1e 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_grad_kernel.cu @@ -24,15 +24,13 @@ namespace cub = hipcub; #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" -// TODO(chenweihang): move dispatch.h into phi/core -#include "paddle/phi/api/ext/dispatch.h" - #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" @@ -267,7 +265,7 @@ void CrossEntropyWithSoftmaxGradKernel(const Context& dev_ctx, axis, logits_grad); } else { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( dtype, "CrossEntropyWithSoftmaxGradGPUKernel", ([&] { CrossEntropyWithSoftmaxGradGPUKernel(dev_ctx, label, diff --git a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu index 055706cffd41e50693cc20682f75a46b7f439d04..1908c78060483aa2e101d7ff1f9075a6e88b7124 100644 --- a/paddle/phi/kernels/gpu/cross_entropy_kernel.cu +++ b/paddle/phi/kernels/gpu/cross_entropy_kernel.cu @@ -24,15 +24,13 @@ namespace cub = hipcub; #include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/axis_utils.h" #include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/gpudnn/softmax_gpudnn.h" -// TODO(chenweihang): move dispatch.h into phi/core -#include "paddle/phi/api/ext/dispatch.h" - #include "paddle/fluid/operators/math/cross_entropy.h" #include "paddle/fluid/operators/math/softmax.h" #include "paddle/fluid/platform/device/gpu/gpu_device_function.h" @@ -1529,19 +1527,19 @@ void CrossEntropyWithSoftmaxKernel(const Context& dev_ctx, softmax, loss); } else { - PD_DISPATCH_INTEGRAL_TYPES( - dtype, "CrossEntropyWithSoftmaxCUDAKernel", ([&] { - CrossEntropyWithSoftmaxCUDAKernel(dev_ctx, - logits, - label, - soft_label, - use_softmax, - numeric_stable_mode, - ignore_index, - axis, - softmax, - loss); - })); + PD_VISIT_INTEGRAL_TYPES(dtype, "CrossEntropyWithSoftmaxCUDAKernel", ([&] { + CrossEntropyWithSoftmaxCUDAKernel( + dev_ctx, + logits, + label, + soft_label, + use_softmax, + numeric_stable_mode, + ignore_index, + axis, + softmax, + loss); + })); } } diff --git a/paddle/phi/kernels/gpu/reduce.h b/paddle/phi/kernels/gpu/reduce.h index a54669c6e9d42c31d5d0a1ad6ff3763eddb918e3..6fb81edd6bf47543677169ad69b3bba422218304 100644 --- a/paddle/phi/kernels/gpu/reduce.h +++ b/paddle/phi/kernels/gpu/reduce.h @@ -18,6 +18,7 @@ #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ defined(PADDLE_WITH_XPU_KP) +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/reduce_function.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/reduce_grad.h b/paddle/phi/kernels/gpu/reduce_grad.h index 1e39a08e9cbaf2b5246a7b7a969b6d16dada5f1d..e1f7419fb7a0173fd573b5bed49687d164540778 100644 --- a/paddle/phi/kernels/gpu/reduce_grad.h +++ b/paddle/phi/kernels/gpu/reduce_grad.h @@ -23,7 +23,7 @@ #include #include -#include "paddle/phi/api/ext/dispatch.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/broadcast_function.h" namespace phi { diff --git a/paddle/phi/kernels/gpu/transpose_kernel.cu b/paddle/phi/kernels/gpu/transpose_kernel.cu index 9ea2af292ccf161653b5674bd231aed584b84632..203f10e4ddd47230d7264999cd915abe47968acb 100644 --- a/paddle/phi/kernels/gpu/transpose_kernel.cu +++ b/paddle/phi/kernels/gpu/transpose_kernel.cu @@ -14,7 +14,6 @@ #include -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/kernels/transpose_kernel.h" diff --git a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc index 80693c90d1e7ff1994c1a39c4a82f647cd201eab..216685f0f719184c40dd7abe321e08efb665ca62 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_grad_kernel.cc @@ -13,13 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -191,7 +190,7 @@ void Conv3dGradKernel(const Context& dev_ctx, const bool subm, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "Conv3dGradCPUKernel", ([&] { Conv3dGradCPUKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc index a1c8cf014c7fb99366a753e24ae147fabb6d9b7b..c920f3c46128737425614ced21331215fd244e08 100644 --- a/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/convolution_kernel.cc @@ -15,10 +15,9 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/cpu/convolution.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -159,7 +158,7 @@ void Conv3dKernel(const Context& dev_ctx, const bool subm, SparseCooTensor* out, DenseTensor* rulebook) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "Conv3dCPUKernel", ([&] { Conv3dCPUKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc index a07a7fb2ecf4435722b4b97a3c9b1e90b7ea2207..c10a240c684302d6d76d86935816794185f1c5e6 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_mask_kernel.cc @@ -16,13 +16,12 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -78,7 +77,7 @@ void SparseMaskKernel(const Context& dev_ctx, const DenseTensor& x, const SparseCooTensor& mask, SparseCooTensor* out) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( mask.non_zero_indices().dtype(), "SparseMaskCPUKernel", ([&] { SparseMaskCPUKernel(dev_ctx, x, mask, out); })); @@ -145,7 +144,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& mask_indices, DenseTensor* out) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "SparseMaskHelperCPUKernel", ([&] { SparseMaskHelperCPUKernel(dev_ctx, x, mask_indices, out); })); diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc index 30221975e7756c18c4327e2a60cc6568ad96bdd9..78b6354f44f9e8b763e5e6e70c57b29ccfcbeb1a 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_grad_kernel.cc @@ -14,13 +14,12 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_pool_grad_kernel.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -82,7 +81,7 @@ void MaxPoolGradKernel(const Context& dev_ctx, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, SparseCooTensor* x_grad) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolGradCPUKernel", ([&] { MaxPoolGradCPUKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); diff --git a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc index ed6e0200587e877386a7453a81597b5b9c9971fb..28211a1cda34735bbb292dce4d915803af95af2a 100644 --- a/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc +++ b/paddle/phi/kernels/sparse/cpu/sparse_pool_kernel.cc @@ -15,12 +15,11 @@ limitations under the License. */ #include "paddle/phi/kernels/sparse/sparse_pool_kernel.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/sparse/cpu/convolution.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -106,7 +105,7 @@ void MaxPoolKernel(const Context& dev_ctx, const std::vector& strides, SparseCooTensor* out, DenseTensor* rulebook) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolCPUKernel", ([&] { MaxPoolCPUKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/gpu/convolution.cu.h b/paddle/phi/kernels/sparse/gpu/convolution.cu.h index 5662a4fac71c56a88af9a289bea4d3bf92eb9102..1bceb767b670857fabc2577b161085355af43131 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution.cu.h +++ b/paddle/phi/kernels/sparse/gpu/convolution.cu.h @@ -338,7 +338,7 @@ int ProductRuleBook(const Context& dev_ctx, SparseCooTensor* out, std::vector* h_counter, std::vector* h_offsets) { - // TODO(zhangkaihuo): use PD_DISPATCH_INTEGRAL_TYPES for secondary dispatch + // TODO(zhangkaihuo): use PD_VISIT_INTEGRAL_TYPES for secondary dispatch auto indices_dtype = paddle::experimental::CppTypeToDataType::Type(); const int64_t non_zero_num = x.nnz(); const auto& non_zero_indices = x.non_zero_indices(); diff --git a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu index 2b61be72896462e70638be9abba710687ac1e4ad..6c37f759923c33e96a3e164ec2ac0a704d670ee3 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_grad_kernel.cu @@ -18,14 +18,13 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/sparse/convolution_grad_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -249,7 +248,7 @@ void Conv3dGradKernel(const Context& dev_ctx, const bool subm, SparseCooTensor* x_grad, DenseTensor* kernel_grad) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "Conv3dGradGPUKernel", ([&] { Conv3dGradGPUKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu index 2d212eadffac146ae71f3cde07da20f9858770fd..83f19ce5785df4ba11513e9fe2d0220505ca0f6d 100644 --- a/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/convolution_kernel.cu @@ -15,12 +15,11 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/sparse/convolution_kernel.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -177,7 +176,7 @@ void Conv3dKernel(const Context& dev_ctx, const bool subm, SparseCooTensor* out, DenseTensor* rulebook) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "Conv3dGPUKernel", ([&] { Conv3dGPUKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu index 96ab56697b9b0ed393d79334ac3d248643cc7491..dff1cc2318f132e1ecc41d73c01346b57ce70b9d 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_mask_kernel.cu @@ -19,14 +19,13 @@ limitations under the License. */ #include "paddle/phi/core/ddim.h" #include "paddle/phi/core/enforce.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/sparse/common_shape.h" #include "paddle/phi/kernels/sparse/sparse_mask_kernel.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -118,7 +117,7 @@ void SparseMaskKernel(const Context& dev_ctx, const DenseTensor& x, const SparseCooTensor& mask, SparseCooTensor* out) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( mask.non_zero_indices().dtype(), "SparseMaskGPUKernel", ([&] { SparseMaskGPUKernel(dev_ctx, x, mask, out); })); @@ -265,7 +264,7 @@ void SparseMaskHelperKernel(const Context& dev_ctx, const SparseCooTensor& x, const DenseTensor& mask_indices, DenseTensor* out) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "SparseMaskHelperGPUKernel", ([&] { SparseMaskHelperGPUKernel(dev_ctx, x, mask_indices, out); })); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu index 8657e7319d8ca518eb998ec270aed917faf09e3b..bd862a44afeebba0e2853d2221c6c9c55290fbbd 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_grad_kernel.cu @@ -18,14 +18,13 @@ limitations under the License. */ #include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/copy_kernel.h" #include "paddle/phi/kernels/empty_kernel.h" #include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -129,7 +128,7 @@ void MaxPoolGradKernel(const Context& dev_ctx, const SparseCooTensor& out_grad, const std::vector& kernel_sizes, SparseCooTensor* x_grad) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolGradGPUKernel", ([&] { MaxPoolGradGPUKernel( dev_ctx, x, rulebook, out, out_grad, kernel_sizes, x_grad); diff --git a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu index a59cd3c7a5a78054014863c8ec66efb0a903761c..b76b61f83bfc93424a3eab832647eebe184b02dd 100644 --- a/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu +++ b/paddle/phi/kernels/sparse/gpu/sparse_pool_kernel.cu @@ -16,12 +16,11 @@ limitations under the License. */ #include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/tensor_meta.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/pooling.h" #include "paddle/phi/kernels/funcs/sparse/convolution.h" #include "paddle/phi/kernels/sparse/gpu/convolution.cu.h" -#include "paddle/phi/api/ext/dispatch.h" - namespace phi { namespace sparse { @@ -136,7 +135,7 @@ void MaxPoolKernel(const Context& dev_ctx, const std::vector& strides, SparseCooTensor* out, DenseTensor* rulebook) { - PD_DISPATCH_INTEGRAL_TYPES( + PD_VISIT_INTEGRAL_TYPES( x.non_zero_indices().dtype(), "MaxPoolGPUKernel", ([&] { MaxPoolGPUKernel(dev_ctx, x, diff --git a/paddle/phi/kernels/transfer_layout_kernel.cc b/paddle/phi/kernels/transfer_layout_kernel.cc index 60df877355b8268efafddfdc2b452617cdadf9df..f7ecf379fdfa9e0b491f20a4b97b8311fe66e376 100644 --- a/paddle/phi/kernels/transfer_layout_kernel.cc +++ b/paddle/phi/kernels/transfer_layout_kernel.cc @@ -14,9 +14,9 @@ limitations under the License. */ #include "paddle/phi/kernels/transfer_layout_kernel.h" -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/backends/all_context.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" #include "paddle/phi/kernels/funcs/math_function.h" namespace phi { diff --git a/paddle/phi/kernels/xpu/full_kernel.cc b/paddle/phi/kernels/xpu/full_kernel.cc index 6668ae39cbdbe8d0f9385306143b2ba94ea59fb0..978bdb5129c04e53f2628aac7df6c1b4386ed0a6 100644 --- a/paddle/phi/kernels/xpu/full_kernel.cc +++ b/paddle/phi/kernels/xpu/full_kernel.cc @@ -14,13 +14,13 @@ #include "paddle/phi/kernels/full_kernel.h" -#include "paddle/phi/api/ext/dispatch.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/bfloat16.h" #include "paddle/phi/common/complex.h" #include "paddle/phi/common/float16.h" #include "paddle/phi/common/scalar.h" #include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/visit_type.h" // See Note [ Why still include the fluid headers? ] #include "paddle/fluid/memory/memcpy.h"