未验证 提交 d463f8ee 编写于 作者: iSerendipity's avatar iSerendipity 提交者: GitHub

Revert "【Hackathon No.52】为 Paddle dist 算子实现 float16 数据类型支持 (#50915)" (#53527)

This reverts commit 9c406531.
上级 0d9a23b4
......@@ -98,11 +98,6 @@ PD_REGISTER_KERNEL(
dist_grad, CPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(dist_grad,
GPU,
ALL_LAYOUT,
phi::DistGradKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(
dist_grad, GPU, ALL_LAYOUT, phi::DistGradKernel, float, double) {}
#endif
......@@ -23,9 +23,6 @@ limitations under the License. */
#include <algorithm>
#include "paddle/phi/backends/gpu/gpu_device_function.h"
#include "paddle/phi/common/data_type.h"
namespace phi {
namespace funcs {
......@@ -173,7 +170,11 @@ struct KeyValuePair<half> {
template <typename T>
__inline__ __device__ T WarpReduceSum(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val += phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask);
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val += __shfl_xor_sync(lane_mask, val, mask, warpSize);
#else
val += __shfl_xor(val, mask, warpSize);
#endif
return val;
}
......@@ -242,8 +243,11 @@ __inline__ __device__ T BlockReduceSumV2(T *val) {
template <typename T>
__inline__ __device__ T WarpReduceMax(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val = std::max(
val, phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask));
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = max(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = max(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}
......@@ -261,8 +265,11 @@ __inline__ __device__ T WarpReduceMaxV2(T *val) {
template <typename T>
__inline__ __device__ T WarpReduceMin(T val, unsigned lane_mask) {
for (int mask = HALF_WARP; mask > 0; mask >>= 1)
val = std::min(
val, phi::backends::gpu::CudaShuffleXorSync(lane_mask, val, mask));
#if defined(PADDLE_WITH_CUDA) && (__CUDA_ARCH__ >= 350 && CUDA_VERSION >= 9000)
val = min(val, __shfl_xor_sync(lane_mask, val, mask, warpSize));
#else
val = min(val, __shfl_xor(val, mask, warpSize));
#endif
return val;
}
......@@ -303,7 +310,7 @@ __inline__ __device__ T BlockReduceMax(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : std::numeric_limits<T>::min();
val = (lane < block_span) ? shared[lane] : -1e10f;
val = WarpReduceMax(val, mask);
return val;
......@@ -351,7 +358,7 @@ __inline__ __device__ T BlockReduceMin(T val, unsigned mask) {
// align block_span to warpSize
int block_span = (blockDim.x + warpSize - 1) >> 5;
val = (lane < block_span) ? shared[lane] : std::numeric_limits<T>::max();
val = (lane < block_span) ? shared[lane] : 1e10f;
val = WarpReduceMin(val, mask);
return val;
......
......@@ -12,12 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <algorithm>
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/dist_kernel.h"
#include "paddle/phi/kernels/elementwise_subtract_kernel.h"
#include "paddle/phi/kernels/funcs/math_cuda_utils.h"
#include "paddle/phi/kernels/gpu/reduce.h"
......@@ -27,56 +24,47 @@ namespace phi {
#define FULL_MASK 0xffffffff
template <typename Tx, typename Ty = Tx>
template <typename T>
struct ZeroOrderFunctor {
HOSTDEVICE explicit inline ZeroOrderFunctor() {}
HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(x != y);
public:
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>((x - y) != 0);
}
};
template <typename Tx, typename Ty = Tx>
template <typename T>
struct OtherOrderFunctor {
HOSTDEVICE explicit inline OtherOrderFunctor(const Ty& _p_order)
: p_order(_p_order) {}
HOSTDEVICE inline Ty operator()(const Tx& x, const Tx& y) const {
return static_cast<Ty>(
pow(abs(static_cast<Ty>(x) - static_cast<Ty>(y)), p_order));
explicit OtherOrderFunctor(const T& p_order) : p_order_(p_order) {}
__device__ T operator()(const T& x, const T& y) const {
return static_cast<T>(pow(abs(x - y), p_order_));
}
private:
Ty p_order;
T p_order_;
};
template <typename Tx, typename Ty = Tx>
template <typename T>
struct PowFunctor {
HOSTDEVICE explicit inline PowFunctor(const Ty& _p_order)
: p_order(_p_order) {}
HOSTDEVICE inline Tx operator()(const Tx x) const {
return static_cast<Tx>(pow(static_cast<Ty>(x), p_order));
explicit PowFunctor(const T& p_order) : p_order_(p_order) {}
HOSTDEVICE inline T operator()(const T x) const {
return static_cast<T>(pow(x, p_order_));
}
private:
Ty p_order;
T p_order_;
};
template <typename T, typename Functor>
__global__ void ReduceSumWithSubtract(
const T* x, const T* y, T* out, int64_t N, Functor func) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
MT sum_val(0.0);
T sum_val = 0;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
sum_val += static_cast<MT>(func(x[i], y[i]));
sum_val += func(x[i], y[i]);
}
__syncthreads();
sum_val = phi::funcs::BlockReduceSum<MT>(sum_val, FULL_MASK);
sum_val = phi::funcs::BlockReduceSum<T>(sum_val, FULL_MASK);
if (threadIdx.x == 0) {
out[blockIdx.x] = static_cast<T>(sum_val);
out[blockIdx.x] = sum_val;
}
}
......@@ -85,10 +73,10 @@ __global__ void ReduceMaxWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T max_val = std::numeric_limits<T>::min();
T max_val = -1e10f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
max_val = std::max(max_val, abs(x[i] - y[i]));
max_val = max(max_val, abs(x[i] - y[i]));
}
__syncthreads();
......@@ -103,10 +91,10 @@ __global__ void ReduceMinWithSubtract(const T* x,
const T* y,
T* out,
int64_t N) {
T min_val = std::numeric_limits<T>::max();
T min_val = 1e10f;
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < N;
i += blockDim.x * gridDim.x) {
min_val = std::min(min_val, abs(x[i] - y[i]));
min_val = min(min_val, abs(x[i] - y[i]));
}
__syncthreads();
......@@ -122,7 +110,6 @@ void DistKernel(const Context& dev_ctx,
const DenseTensor& y,
float p,
DenseTensor* out) {
using MT = typename phi::dtype::MPTypeTrait<T>::Type;
DenseTensor intermediate;
const T* x_ptr = x.data<T>();
const T* y_ptr = y.data<T>();
......@@ -144,8 +131,9 @@ void DistKernel(const Context& dev_ctx,
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, ZeroOrderFunctor<T>());
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
} else if (p == INFINITY) {
ReduceMaxWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
......@@ -162,19 +150,19 @@ void DistKernel(const Context& dev_ctx,
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
} else {
MT p_order = static_cast<MT>(p);
T p_order = static_cast<T>(p);
ReduceSumWithSubtract<T>
<<<config.block_per_grid.x, config.thread_per_block.x, 0, stream>>>(
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T, MT>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<MT>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<MT>(), reduce_axis);
x_ptr, y_ptr, i_ptr, n, OtherOrderFunctor<T>(p_order));
phi::funcs::ReduceKernel<T, T, kps::AddFunctor, kps::IdentityFunctor<T>>(
dev_ctx, intermediate, out, kps::IdentityFunctor<T>(), reduce_axis);
const DenseTensor* tmp_norm = out;
std::vector<const DenseTensor*> ins = {tmp_norm};
std::vector<DenseTensor*> outs = {out};
MT p_order_ = static_cast<MT>(static_cast<MT>(1.) / p_order);
T p_order_ = static_cast<T>(1. / p_order);
phi::funcs::ElementwiseKernel<T>(
dev_ctx, ins, &outs, PowFunctor<T, MT>(p_order_));
dev_ctx, ins, &outs, PowFunctor<T>(p_order_));
}
} else {
......@@ -185,10 +173,4 @@ void DistKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(dist,
GPU,
ALL_LAYOUT,
phi::DistKernel,
phi::dtype::float16,
float,
double) {}
PD_REGISTER_KERNEL(dist, GPU, ALL_LAYOUT, phi::DistKernel, float, double) {}
......@@ -158,46 +158,6 @@ class TestDistOpCase5(TestDistOp):
self.p = 1.5
class TestDistFP16Op(OpTest):
def init_data_type(self):
self.data_type = 'float16'
class TestDistFP16OpCase1(TestDistFP16Op):
def init_case(self):
self.x_shape = (3, 5, 5, 6)
self.y_shape = (5, 5, 6)
self.p = 1.0
class TestDistFP16OpCase2(TestDistFP16Op):
def init_case(self):
self.x_shape = (10, 10)
self.y_shape = (4, 10, 10)
self.p = 2.0
class TestDistFP16OpCase3(TestDistFP16Op):
def init_case(self):
self.x_shape = (15, 10)
self.y_shape = (15, 10)
self.p = float("inf")
class TestDistFP16OpCase4(TestDistFP16Op):
def init_case(self):
self.x_shape = (2, 3, 4, 5, 8)
self.y_shape = (3, 1, 5, 8)
self.p = float("-inf")
class TestDistFP16OpCase5(TestDistFP16Op):
def init_case(self):
self.x_shape = (4, 1, 4, 8)
self.y_shape = (2, 2, 1, 4, 4, 8)
self.p = 1.5
class TestDistAPI(unittest.TestCase):
def init_data_type(self):
self.data_type = (
......
......@@ -675,8 +675,8 @@ def dist(x, y, p=2, name=None):
||z||_{p}=(\sum_{i=1}^{m}|z_i|^p)^{\\frac{1}{p}}
Args:
x (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float16, float32 or float64.
x (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
y (Tensor): 1-D to 6-D Tensor, its data type is float32 or float64.
p (float, optional): The norm to be computed, its data type is float32 or float64. Default: 2.
name (str, optional): The default value is `None`. Normally there is no need for
user to set this property. For more information, please refer to :ref:`api_guide_Name`.
......@@ -706,12 +706,8 @@ def dist(x, y, p=2, name=None):
if in_dygraph_mode():
return _C_ops.dist(x, y, p)
check_variable_and_dtype(
x, 'dtype', ['float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(
y, 'dtype', ['float16', 'float32', 'float64'], 'dist'
)
check_variable_and_dtype(x, 'dtype', ['float32', 'float64'], 'dist')
check_variable_and_dtype(y, 'dtype', ['float32', 'float64'], 'dist')
check_type(p, 'p', (float, int), 'dist')
helper = LayerHelper("dist", **locals())
out = helper.create_variable_for_type_inference(x.dtype)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册