未验证 提交 7353e9e9 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix compiling on XPU related to MPTypeTrait. (#54924)

* Fix compiling on XPU related to MPTypeTrait.

* Unify the use of MPTypeTrait.

* Fix compiling error.
上级 d9fa8fde
...@@ -12,10 +12,10 @@ limitations under the License. */ ...@@ -12,10 +12,10 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h" #include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/bfloat16.h" #include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/activation_functor.h" #include "paddle/phi/kernels/funcs/activation_functor.h"
namespace paddle { namespace paddle {
...@@ -23,7 +23,7 @@ namespace operators { ...@@ -23,7 +23,7 @@ namespace operators {
template <typename T> template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> { struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
float threshold; float threshold;
...@@ -44,7 +44,7 @@ struct CudaSoftReluFunctor : public BaseActivationFunctor<T> { ...@@ -44,7 +44,7 @@ struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
template <typename T> template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> { struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f); MPType one = static_cast<MPType>(1.0f);
float threshold; float threshold;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
template <>
class MPTypeTrait<platform::bfloat16> {
public:
using Type = float;
};
} // namespace details
} // namespace operators
} // namespace paddle
...@@ -19,12 +19,12 @@ limitations under the License. */ ...@@ -19,12 +19,12 @@ limitations under the License. */
#include <curand_kernel.h> #include <curand_kernel.h>
#include "paddle/fluid/memory/memory.h" #include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h" #include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h" #include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h" #include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h" #include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
......
...@@ -15,6 +15,7 @@ limitations under the License. */ ...@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once #pragma once
#include "paddle/fluid/operators/fused/fused_dropout_common.h" #include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -45,8 +46,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -45,8 +46,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
OutType *dst, OutType *dst,
MaskType *mask, MaskType *mask,
const bool is_test, const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val, typename phi::dtype::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val, typename phi::dtype::MPTypeTrait<T>::Type *var_val,
Functor act_func, Functor act_func,
const float quant_last_in_scale = 1.0, const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr, const float *dequant_out_scale_data = nullptr,
...@@ -61,7 +62,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread( ...@@ -61,7 +62,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
using StoreOutType = phi::AlignedVector<OutType, VecSize>; using StoreOutType = phi::AlignedVector<OutType, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>; using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
using U = typename details::MPTypeTrait<T>::Type; using U = typename phi::dtype::MPTypeTrait<T>::Type;
LoadInType src_vec; LoadInType src_vec;
LoadT residual_vec; LoadT residual_vec;
......
...@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thrust/random.h> #include <thrust/random.h>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/generator.h" #include "paddle/phi/core/generator.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h" #include "paddle/phi/kernels/funcs/index_impl.cu.h"
...@@ -37,7 +38,7 @@ struct GaussianGenerator { ...@@ -37,7 +38,7 @@ struct GaussianGenerator {
__host__ __device__ T operator()(const unsigned int n) const { __host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng; thrust::minstd_rand rng;
rng.seed(seed_); rng.seed(seed_);
using MT = typename details::MPTypeTrait<T>::Type; using MT = typename phi::dtype::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(static_cast<MT>(mean_), thrust::normal_distribution<MT> dist(static_cast<MT>(mean_),
static_cast<MT>(std_)); static_cast<MT>(std_));
unsigned int new_n = n + offset_; unsigned int new_n = n + offset_;
......
...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <thrust/device_vector.h> #include <thrust/device_vector.h>
#include <thrust/host_vector.h> #include <thrust/host_vector.h>
#include <thrust/random.h> #include <thrust/random.h>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h" #include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/generator.h" #include "paddle/phi/core/generator.h"
#include "paddle/phi/core/hostdevice.h" #include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
......
...@@ -14,7 +14,7 @@ limitations under the License. */ ...@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h" #include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h" #include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h" #include "paddle/phi/kernels/funcs/math_cuda_utils.h"
...@@ -34,7 +34,7 @@ namespace paddle { ...@@ -34,7 +34,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type; using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
__device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); } __device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); }
__device__ __forceinline__ double Sqrt(double x) { return sqrt(x); } __device__ __forceinline__ double Sqrt(double x) { return sqrt(x); }
......
...@@ -14,9 +14,9 @@ limitations under the License. */ ...@@ -14,9 +14,9 @@ limitations under the License. */
#include <algorithm> #include <algorithm>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h" #include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h" #include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -77,7 +77,7 @@ class SGDOpKernel<phi::GPUContext, T> : public framework::OpKernel<T> { ...@@ -77,7 +77,7 @@ class SGDOpKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
ctx.InputNames("Param").front(), ctx.InputNames("Param").front(),
paddle::framework::ToTypeName(param_var->Type()))); paddle::framework::ToTypeName(param_var->Type())));
using MPDType = typename details::MPTypeTrait<T>::Type; using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
auto* param = ctx.Input<phi::DenseTensor>("Param"); auto* param = ctx.Input<phi::DenseTensor>("Param");
auto* param_out = ctx.Output<phi::DenseTensor>("ParamOut"); auto* param_out = ctx.Output<phi::DenseTensor>("ParamOut");
......
...@@ -21,9 +21,9 @@ ...@@ -21,9 +21,9 @@
#include "paddle/fluid/framework/convert_utils.h" #include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h" #include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/common/amp_type_traits.h"
#ifdef __NVCC__ #ifdef __NVCC__
#include "cub/cub.cuh" #include "cub/cub.cuh"
...@@ -37,7 +37,7 @@ namespace paddle { ...@@ -37,7 +37,7 @@ namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type; using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
enum class RegularizationType { enum class RegularizationType {
kNONE = 0, kNONE = 0,
......
...@@ -832,7 +832,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> { ...@@ -832,7 +832,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
pt_out_dtype = d_out->dtype(); pt_out_dtype = d_out->dtype();
} }
using MPType = typename kps::details::MPTypeTrait<T>::Type; using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx, phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx,
pt_d_out.get(), pt_d_out.get(),
pt_d_x.get(), pt_d_x.get(),
......
...@@ -265,7 +265,7 @@ static void CheckNumericsCpuImpl(const T* value_ptr, ...@@ -265,7 +265,7 @@ static void CheckNumericsCpuImpl(const T* value_ptr,
} else if (std::isinf(value)) { } else if (std::isinf(value)) {
thread_num_inf[tid] += 1; thread_num_inf[tid] += 1;
} }
if (value == 0) { if (value == static_cast<MT>(0)) {
thread_num_zero[tid] += 1; thread_num_zero[tid] += 1;
} }
} }
......
...@@ -13,7 +13,8 @@ ...@@ -13,7 +13,8 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "xpu/kernel/cluster_header.h" #include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h" #include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h" #include "xpu/kernel/math.h"
...@@ -27,18 +28,6 @@ namespace details { ...@@ -27,18 +28,6 @@ namespace details {
// kLocalMode: thread reduce, each thread gets an output; // kLocalMode: thread reduce, each thread gets an output;
enum ReduceMode { kGlobalMode, kLocalMode }; enum ReduceMode { kGlobalMode, kLocalMode };
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<phi::dtype::float16> {
public:
using Type = float;
};
static inline __device__ void sync_all() { static inline __device__ void sync_all() {
__asm__ __volatile__( __asm__ __volatile__(
"sync_local\t\n" "sync_local\t\n"
......
...@@ -17,8 +17,8 @@ limitations under the License. */ ...@@ -17,8 +17,8 @@ limitations under the License. */
#include <random> #include <random>
#include <vector> #include <vector>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h" #include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h" #include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/functors.h" #include "paddle/phi/kernels/funcs/functors.h"
#include "test/cpp/fluid/fused/fused_dropout_test.h" #include "test/cpp/fluid/fused/fused_dropout_test.h"
...@@ -30,7 +30,6 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT); ...@@ -30,7 +30,6 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT);
namespace framework = paddle::framework; namespace framework = paddle::framework;
namespace platform = paddle::platform; namespace platform = paddle::platform;
namespace details = paddle::operators::details;
/** /**
* @brief the unittest of fused_dropout_act_bias * @brief the unittest of fused_dropout_act_bias
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册