未验证 提交 7353e9e9 编写于 作者: Y Yiqun Liu 提交者: GitHub

Fix compiling on XPU related to MPTypeTrait. (#54924)

* Fix compiling on XPU related to MPTypeTrait.

* Unify the use of MPTypeTrait.

* Fix compiling error.
上级 d9fa8fde
......@@ -12,10 +12,10 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_impl.cu.h"
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
namespace paddle {
......@@ -23,7 +23,7 @@ namespace operators {
template <typename T>
struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
......@@ -44,7 +44,7 @@ struct CudaSoftReluFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct CudaSoftReluGradFunctor : public BaseActivationFunctor<T> {
using MPType = typename details::MPTypeTrait<T>::Type;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
MPType one = static_cast<MPType>(1.0f);
float threshold;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
namespace details {
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<platform::float16> {
public:
using Type = float;
};
template <>
class MPTypeTrait<platform::bfloat16> {
public:
using Type = float;
};
} // namespace details
} // namespace operators
} // namespace paddle
......@@ -19,12 +19,12 @@ limitations under the License. */
#include <curand_kernel.h>
#include "paddle/fluid/memory/memory.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/quant_dequant_kernel.h"
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "paddle/phi/kernels/funcs/layer_norm_impl.cu.h"
......
......@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/operators/fused/fused_dropout_common.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace paddle {
namespace operators {
......@@ -45,8 +46,8 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
OutType *dst,
MaskType *mask,
const bool is_test,
typename details::MPTypeTrait<T>::Type *mean_val,
typename details::MPTypeTrait<T>::Type *var_val,
typename phi::dtype::MPTypeTrait<T>::Type *mean_val,
typename phi::dtype::MPTypeTrait<T>::Type *var_val,
Functor act_func,
const float quant_last_in_scale = 1.0,
const float *dequant_out_scale_data = nullptr,
......@@ -61,7 +62,7 @@ __forceinline__ __device__ void FusedResidualDropoutBiasOneThread(
using StoreOutType = phi::AlignedVector<OutType, VecSize>;
using MaskStoreT = phi::AlignedVector<MaskType, VecSize>;
using U = typename details::MPTypeTrait<T>::Type;
using U = typename phi::dtype::MPTypeTrait<T>::Type;
LoadInType src_vec;
LoadT residual_vec;
......
......@@ -11,11 +11,12 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thrust/random.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/kernels/funcs/index_impl.cu.h"
......@@ -37,7 +38,7 @@ struct GaussianGenerator {
__host__ __device__ T operator()(const unsigned int n) const {
thrust::minstd_rand rng;
rng.seed(seed_);
using MT = typename details::MPTypeTrait<T>::Type;
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
thrust::normal_distribution<MT> dist(static_cast<MT>(mean_),
static_cast<MT>(std_));
unsigned int new_n = n + offset_;
......
......@@ -11,15 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <thrust/device_vector.h>
#include <thrust/host_vector.h>
#include <thrust/random.h>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/generator.h"
#include "paddle/phi/core/hostdevice.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
......
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/operators/optimizers/lars_momentum_op.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/kernels/funcs/aligned_vector.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
......@@ -34,7 +34,7 @@ namespace paddle {
namespace operators {
template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
__device__ __forceinline__ float Sqrt(float x) { return sqrtf(x); }
__device__ __forceinline__ double Sqrt(double x) { return sqrt(x); }
......
......@@ -14,9 +14,9 @@ limitations under the License. */
#include <algorithm>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/phi/backends/gpu/gpu_primitives.h"
#include "paddle/phi/common/amp_type_traits.h"
namespace paddle {
namespace operators {
......@@ -77,7 +77,7 @@ class SGDOpKernel<phi::GPUContext, T> : public framework::OpKernel<T> {
ctx.InputNames("Param").front(),
paddle::framework::ToTypeName(param_var->Type())));
using MPDType = typename details::MPTypeTrait<T>::Type;
using MPDType = typename phi::dtype::MPTypeTrait<T>::Type;
auto* param = ctx.Input<phi::DenseTensor>("Param");
auto* param_out = ctx.Output<phi::DenseTensor>("ParamOut");
......
......@@ -21,9 +21,9 @@
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/platform/for_range.h"
#include "paddle/phi/common/amp_type_traits.h"
#ifdef __NVCC__
#include "cub/cub.cuh"
......@@ -37,7 +37,7 @@ namespace paddle {
namespace operators {
template <typename T>
using MultiPrecisionType = typename details::MPTypeTrait<T>::Type;
using MultiPrecisionType = typename phi::dtype::MPTypeTrait<T>::Type;
enum class RegularizationType {
kNONE = 0,
......
......@@ -832,7 +832,7 @@ class ReduceCudaGradKernel : public framework::OpKernel<T> {
pt_out_dtype = d_out->dtype();
}
using MPType = typename kps::details::MPTypeTrait<T>::Type;
using MPType = typename phi::dtype::MPTypeTrait<T>::Type;
phi::ReduceGrad<TransformOp<T, MPType>>(dev_ctx,
pt_d_out.get(),
pt_d_x.get(),
......
......@@ -265,7 +265,7 @@ static void CheckNumericsCpuImpl(const T* value_ptr,
} else if (std::isinf(value)) {
thread_num_inf[tid] += 1;
}
if (value == 0) {
if (value == static_cast<MT>(0)) {
thread_num_zero[tid] += 1;
}
}
......
......@@ -13,7 +13,8 @@
// limitations under the License.
#pragma once
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "xpu/kernel/cluster_header.h"
#include "xpu/kernel/debug.h"
#include "xpu/kernel/math.h"
......@@ -27,18 +28,6 @@ namespace details {
// kLocalMode: thread reduce, each thread gets an output;
enum ReduceMode { kGlobalMode, kLocalMode };
template <typename T>
class MPTypeTrait {
public:
using Type = T;
};
template <>
class MPTypeTrait<phi::dtype::float16> {
public:
using Type = float;
};
static inline __device__ void sync_all() {
__asm__ __volatile__(
"sync_local\t\n"
......
......@@ -17,8 +17,8 @@ limitations under the License. */
#include <random>
#include <vector>
#include "paddle/fluid/operators/amp/fp16_type_traits.h"
#include "paddle/fluid/operators/fused/fused_dropout_act_bias.h"
#include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/functors.h"
#include "test/cpp/fluid/fused/fused_dropout_test.h"
......@@ -30,7 +30,6 @@ PD_DECLARE_KERNEL(dropout_grad, GPU, ALL_LAYOUT);
namespace framework = paddle::framework;
namespace platform = paddle::platform;
namespace details = paddle::operators::details;
/**
* @brief the unittest of fused_dropout_act_bias
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册