未验证 提交 1631836f 编写于 作者: W Wang Xin 提交者: GitHub

[PHI decoupling] remove framework/data_type.h from phi (#47776)

* remove framework/data_type.h from phi

* fix CI fail: map proto::VarType to phi::DataType

* refactor code to add more detailed comments
上级 7e914386
...@@ -14,6 +14,7 @@ limitations under the License. */ ...@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once #pragma once
#include <iostream> #include <iostream>
#include <map>
#include <string> #include <string>
#include <typeindex> #include <typeindex>
...@@ -23,6 +24,23 @@ limitations under the License. */ ...@@ -23,6 +24,23 @@ limitations under the License. */
namespace phi { namespace phi {
// Here we can't depend on the fluid proto::VarType, so we use the dtype enum
// value directly. See also `assign_value_sig.cc`.
// proto::VarType::INT16 -> 1 -> phi::DataType::INT16
// proto::VarType::INT32 -> 2 -> phi::DataType::INT32
// proto::VarType::INT64 -> 3 -> phi::DataType::INT64
// proto::VarType::FP16 -> 4 -> phi::DataType::FLOAT16
// proto::VarType::FP32 -> 5 -> phi::DataType::FLOAT32
// proto::VarType::FP64 -> 6 -> phi::DataType::FLOAT64
// proto::VarType::UINT8 -> 20 -> phi::DataType::UINT8
static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{2, phi::DataType::INT32},
{3, phi::DataType::INT64},
{4, phi::DataType::FLOAT16},
{5, phi::DataType::FLOAT32},
{6, phi::DataType::FLOAT64},
{20, phi::DataType::UINT8}};
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \ #define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type); callback(cpp_type, data_type);
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
...@@ -141,15 +142,14 @@ void ArgMinMaxKernel(const Context& dev_ctx, ...@@ -141,15 +142,14 @@ void ArgMinMaxKernel(const Context& dev_ctx,
int dtype, int dtype,
DenseTensor* out) { DenseTensor* out) {
if (dtype < 0) { if (dtype < 0) {
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>( phi::DataType::INT64,
paddle::framework::proto::VarType::INT64),
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>( VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return; return;
} }
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype), var_type_map[dtype],
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>( VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
} }
......
...@@ -23,7 +23,7 @@ ...@@ -23,7 +23,7 @@
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<> // NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
...@@ -51,7 +51,7 @@ void CumprodGradKernel(const Context& dev_ctx, ...@@ -51,7 +51,7 @@ void CumprodGradKernel(const Context& dev_ctx,
const T* out_data_deal; const T* out_data_deal;
Allocator::AllocationPtr x_conj; Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr out_conj; Allocator::AllocationPtr out_conj;
if (paddle::framework::IsComplex<T>::value) { if (phi::IsComplexType(x.dtype())) {
x_conj = const_cast<Allocator&>(dev_ctx.GetAllocator()) x_conj = const_cast<Allocator&>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T)); .Allocate(numel * sizeof(T));
auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr()); auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr());
......
...@@ -18,8 +18,7 @@ ...@@ -18,8 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi { namespace phi {
...@@ -33,8 +32,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, ...@@ -33,8 +32,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* index, DenseTensor* index,
DenseTensor* counts) { DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype); auto data_type = var_type_map[dtype];
if (data_type == paddle::framework::proto::VarType::INT32) { if (data_type == phi::DataType::INT32) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
x.numel(), x.numel(),
INT_MAX, INT_MAX,
...@@ -46,13 +45,13 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, ...@@ -46,13 +45,13 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
} }
if (axis.empty()) { if (axis.empty()) {
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
data_type, data_type,
UniqueConsecutiveFlattenedTensorFunctor<Context, T>( UniqueConsecutiveFlattenedTensorFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts)); dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else { } else {
int valid_axis = axis[0]; int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
data_type, data_type,
UniqueConsecutiveDimFunctor<Context, T>(dev_ctx, UniqueConsecutiveDimFunctor<Context, T>(dev_ctx,
x, x,
......
...@@ -26,10 +26,10 @@ limitations under the License. */ ...@@ -26,10 +26,10 @@ limitations under the License. */
#include <utility> #include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h" #include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "unsupported/Eigen/CXX11/Tensor" #include "unsupported/Eigen/CXX11/Tensor"
......
...@@ -14,12 +14,12 @@ limitations under the License. */ ...@@ -14,12 +14,12 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h" #include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h" #include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h" #include "paddle/phi/kernels/funcs/math_function_impl.h"
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace phi { namespace phi {
......
...@@ -28,9 +28,8 @@ namespace cub = hipcub; ...@@ -28,9 +28,8 @@ namespace cub = hipcub;
#endif #endif
#include <limits> #include <limits>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
namespace { // NOLINT namespace { // NOLINT
...@@ -209,15 +208,14 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx, ...@@ -209,15 +208,14 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
int dtype, int dtype,
DenseTensor* out) { DenseTensor* out) {
if (dtype < 0) { if (dtype < 0) {
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>( phi::DataType::INT64,
paddle::framework::proto::VarType::INT64),
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>( VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return; return;
} }
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype), var_type_map[dtype],
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>( VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out)); dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
} }
......
...@@ -24,7 +24,7 @@ ...@@ -24,7 +24,7 @@
#include "paddle/phi/kernels/funcs/elementwise_functor.h" #include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h" #include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<> // NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h" #include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
...@@ -152,7 +152,7 @@ void CumprodGradKernel(const Context &dev_ctx, ...@@ -152,7 +152,7 @@ void CumprodGradKernel(const Context &dev_ctx,
const T *y_data_deal; const T *y_data_deal;
Allocator::AllocationPtr x_conj; Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr y_conj; Allocator::AllocationPtr y_conj;
if (paddle::framework::IsComplex<T>::value) { if (phi::IsComplexType(x.dtype())) {
x_conj = const_cast<Allocator &>(dev_ctx.GetAllocator()) x_conj = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T)); .Allocate(numel * sizeof(T));
auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr()); auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr());
......
...@@ -21,8 +21,6 @@ ...@@ -21,8 +21,6 @@
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi { namespace phi {
template <typename T, typename Context> template <typename T, typename Context>
...@@ -35,8 +33,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, ...@@ -35,8 +33,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
DenseTensor* out, DenseTensor* out,
DenseTensor* index, DenseTensor* index,
DenseTensor* counts) { DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype); auto data_type = var_type_map[dtype];
if (data_type == paddle::framework::proto::VarType::INT32) { if (data_type == phi::DataType::INT32) {
PADDLE_ENFORCE_LE( PADDLE_ENFORCE_LE(
x.numel() + 1, x.numel() + 1,
INT_MAX, INT_MAX,
...@@ -49,14 +47,14 @@ void UniqueConsecutiveKernel(const Context& dev_ctx, ...@@ -49,14 +47,14 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
// if 'axis' is not required, flatten the Tensor. // if 'axis' is not required, flatten the Tensor.
if (axis.empty()) { if (axis.empty()) {
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
data_type, data_type,
UniqueConsecutiveFlattenedCUDAFunctor<Context, T>( UniqueConsecutiveFlattenedCUDAFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts)); dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else { } else {
// 'axis' is required. // 'axis' is required.
int valid_axis = axis[0]; int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny( phi::VisitDataTypeTiny(
data_type, data_type,
UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx, UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx,
x, x,
......
...@@ -16,7 +16,6 @@ ...@@ -16,7 +16,6 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h" #include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h" #include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h" #include "paddle/phi/common/data_type.h"
......
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi { namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册