未验证 提交 1631836f 编写于 作者: W Wang Xin 提交者: GitHub

[PHI decoupling] remove framework/data_type.h from phi (#47776)

* remove framework/data_type.h from phi

* fix CI fail: map proto::VarType to phi::DataType

* refactor code to add more detailed comments
上级 7e914386
......@@ -14,6 +14,7 @@ limitations under the License. */
#pragma once
#include <iostream>
#include <map>
#include <string>
#include <typeindex>
......@@ -23,6 +24,23 @@ limitations under the License. */
namespace phi {
// Here we can't depend on the fluid proto::VarType, so we use the dtype enum
// value directly. See also `assign_value_sig.cc`.
// proto::VarType::INT16 -> 1 -> phi::DataType::INT16
// proto::VarType::INT32 -> 2 -> phi::DataType::INT32
// proto::VarType::INT64 -> 3 -> phi::DataType::INT64
// proto::VarType::FP16 -> 4 -> phi::DataType::FLOAT16
// proto::VarType::FP32 -> 5 -> phi::DataType::FLOAT32
// proto::VarType::FP64 -> 6 -> phi::DataType::FLOAT64
// proto::VarType::UINT8 -> 20 -> phi::DataType::UINT8
static std::map<int, phi::DataType> var_type_map{{1, phi::DataType::INT16},
{2, phi::DataType::INT32},
{3, phi::DataType::INT64},
{4, phi::DataType::FLOAT16},
{5, phi::DataType::FLOAT32},
{6, phi::DataType::FLOAT64},
{20, phi::DataType::UINT8}};
#define _PhiForEachDataTypeHelper_(callback, cpp_type, data_type) \
callback(cpp_type, data_type);
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function.h"
......@@ -141,15 +142,14 @@ void ArgMinMaxKernel(const Context& dev_ctx,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
phi::VisitDataTypeTiny(
phi::DataType::INT64,
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
phi::VisitDataTypeTiny(
var_type_map[dtype],
VisitDataArgMinMaxFunctor<Context, T, EnumArgMinMaxValue>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
......
......@@ -23,7 +23,7 @@
#include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
template <typename T, typename Context>
......@@ -51,7 +51,7 @@ void CumprodGradKernel(const Context& dev_ctx,
const T* out_data_deal;
Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr out_conj;
if (paddle::framework::IsComplex<T>::value) {
if (phi::IsComplexType(x.dtype())) {
x_conj = const_cast<Allocator&>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto* x_data_conj = reinterpret_cast<T*>(x_conj->ptr());
......
......@@ -18,8 +18,7 @@
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
......@@ -33,8 +32,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
auto data_type = var_type_map[dtype];
if (data_type == phi::DataType::INT32) {
PADDLE_ENFORCE_LE(
x.numel(),
INT_MAX,
......@@ -46,13 +45,13 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
}
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedTensorFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimFunctor<Context, T>(dev_ctx,
x,
......
......@@ -26,10 +26,10 @@ limitations under the License. */
#include <utility>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
#include "unsupported/Eigen/CXX11/Tensor"
......
......@@ -14,12 +14,12 @@ limitations under the License. */
#include <algorithm>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/blas/blas.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#include "paddle/phi/kernels/funcs/math_function_impl.h"
......
......@@ -16,7 +16,7 @@ limitations under the License. */
#include <memory>
#include <vector>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"
namespace phi {
......
......@@ -28,9 +28,8 @@ namespace cub = hipcub;
#endif
#include <limits>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
namespace { // NOLINT
......@@ -209,15 +208,14 @@ void ArgMinMaxOpCUDAKernel(const Context& dev_ctx,
int dtype,
DenseTensor* out) {
if (dtype < 0) {
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(
paddle::framework::proto::VarType::INT64),
phi::VisitDataTypeTiny(
phi::DataType::INT64,
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
return;
}
paddle::framework::VisitDataTypeTiny(
static_cast<paddle::framework::proto::VarType::Type>(dtype),
phi::VisitDataTypeTiny(
var_type_map[dtype],
VisitDataCudaArgMinMaxFunctor<Context, T, Reducer>(
dev_ctx, x, axis.to<int64_t>(), keepdims, flatten, out));
}
......
......@@ -24,7 +24,7 @@
#include "paddle/phi/kernels/funcs/elementwise_functor.h"
#include "paddle/phi/kernels/funcs/for_range.h"
// NOTE(@xiongkun): use of IsComplex<>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
......@@ -152,7 +152,7 @@ void CumprodGradKernel(const Context &dev_ctx,
const T *y_data_deal;
Allocator::AllocationPtr x_conj;
Allocator::AllocationPtr y_conj;
if (paddle::framework::IsComplex<T>::value) {
if (phi::IsComplexType(x.dtype())) {
x_conj = const_cast<Allocator &>(dev_ctx.GetAllocator())
.Allocate(numel * sizeof(T));
auto *x_data_conj = reinterpret_cast<T *>(x_conj->ptr());
......
......@@ -21,8 +21,6 @@
#include "paddle/phi/core/errors.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/fluid/framework/data_type.h"
namespace phi {
template <typename T, typename Context>
......@@ -35,8 +33,8 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
DenseTensor* out,
DenseTensor* index,
DenseTensor* counts) {
auto data_type = static_cast<paddle::framework::proto::VarType::Type>(dtype);
if (data_type == paddle::framework::proto::VarType::INT32) {
auto data_type = var_type_map[dtype];
if (data_type == phi::DataType::INT32) {
PADDLE_ENFORCE_LE(
x.numel() + 1,
INT_MAX,
......@@ -49,14 +47,14 @@ void UniqueConsecutiveKernel(const Context& dev_ctx,
// if 'axis' is not required, flatten the Tensor.
if (axis.empty()) {
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveFlattenedCUDAFunctor<Context, T>(
dev_ctx, x, out, return_inverse, return_counts, index, counts));
} else {
// 'axis' is required.
int valid_axis = axis[0];
paddle::framework::VisitDataTypeTiny(
phi::VisitDataTypeTiny(
data_type,
UniqueConsecutiveDimsCUDAFunctor<Context, T>(dev_ctx,
x,
......
......@@ -16,7 +16,6 @@
#include <cmath>
#include <string>
#include "paddle/fluid/framework/data_type.h"
#include "paddle/phi/backends/cpu/cpu_context.h"
#include "paddle/phi/backends/gpu/gpu_context.h"
#include "paddle/phi/common/data_type.h"
......
......@@ -17,6 +17,7 @@
#include "paddle/phi/backends/xpu/xpu_context.h"
#include "paddle/phi/core/ddim.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
namespace phi {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册