diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu index 9f336f2dab663e8624ec95103ea2c27f80d74b1e..6c2c685601b9412acf76e7603376ef7623b877ba 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cu @@ -15,110 +15,216 @@ */ #include +#include #include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh" #include "runtime/device/gpu/cuda_common.h" -template +// Basic function +template struct GreaterFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; } }; -template +template struct LessFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } + __device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; } }; -template +template struct MinimumFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; } }; -template +template struct MaximumFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; } }; -template +template struct PowerFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); } }; template <> -struct PowerFunc { - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { +struct PowerFunc { + __device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return __float2half(pow(__half2float(lhs), __half2float(rhs))); } }; -template +template <> +struct PowerFunc { + __device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) { + float2 base = __half22float2(lhs); + float2 index = __half22float2(rhs); + base.x = pow(base.x, index.x); + base.y = pow(base.y, index.y); + return __float22half2_rn(base); + } +}; + +template struct RealDivFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } }; -template +template struct DivFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); } }; -template +template struct MulFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); } }; -template +template struct SubFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs - rhs); } }; -template +template struct AddFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); } }; -template +// convert to float to fix accuracy issue +template struct FloorDivFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { - return floor(static_cast(lhs) / static_cast(rhs)); + __device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { + return floorf(static_cast(lhs) / static_cast(rhs)); } }; template <> -struct FloorDivFunc { - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { - return __float2half(floor(__half2float(lhs) / __half2float(rhs))); +struct FloorDivFunc { + __device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) { + return floorf(__half2float(lhs) / __half2float(rhs)); } }; template <> -struct FloorDivFunc { - // invalid branch - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +struct FloorDivFunc { + __device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) { + float2 l = __half22float2(lhs); + float2 r = __half22float2(rhs); + l.x = floorf(l.x / r.x); + l.y = floorf(l.y / r.y); + return __float22half2_rn(l); + } }; -template +template struct AbsGradFunc { - __device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { + __device__ __forceinline__ T operator()(const T &lhs, const T &rhs) { T zero = 0.0; return lhs < zero ? -rhs : rhs; } }; template <> -struct PowerFunc { - // invalid branch - __device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; } +struct AbsGradFunc { + __device__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) { + half2 zero(0.0, 0.0); + return lhs < zero ? -rhs : rhs; + } }; +// Element-wise Comparation +template +__global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + y[pos] = Func()(x0[pos], x1[pos]); + } +} + +template +void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) { + switch (op) { + case BROADCAST_TYPE_GREATER: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_LESS: + return ElewiseCmpKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + default: + break; + } +} + +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, bool *y, + cudaStream_t stream); +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, bool *y, + cudaStream_t stream); +template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y, + cudaStream_t stream); + +// Element-wise ArithMetic +template +__global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { + y[pos] = Func()(x0[pos], x1[pos]); + } +} + +template +void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) { + switch (op) { + case BROADCAST_TYPE_MINIMUM: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_MAXIMUM: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_POWER: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_REALDIV: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_MUL: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_SUB: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_ADD: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_FLOORDIV: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_ABSGRAD: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + case BROADCAST_TYPE_DIV: + return ElewiseArithKernel><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y); + default: + break; + } +} + +template +void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) { + return ElewiseArithKernel(nums, op, x0, x1, y, stream); +} + +template <> +void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y, + cudaStream_t stream) { + if (nums % 2 == 0) { + ElewiseArithKernel(nums / 2, op, reinterpret_cast(x0), reinterpret_cast(x1), + reinterpret_cast(y), stream); + } else { + return ElewiseArithKernel(nums, op, x0, x1, y, stream); + } +} + +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y, + cudaStream_t stream); +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y, + cudaStream_t stream); +template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y, + cudaStream_t stream); + +// Broadcast comparation __device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; } -template -__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3, - const int &l4, const int &l5, const int &l6, const int &r0, - const int &r1, const int &r2, const int &r3, const int &r4, - const int &r5, const int &r6, const int &d0, const int &d1, - const int &d2, const int &d3, const int &d4, const int &d5, - const int &d6, const T *input0, const T *input1, S *output) { +template +__global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, + const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, + const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, + const int d4, const int d5, const int d6, const T *x0, const T *x1, bool *y) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; pos += blockDim.x * gridDim.x) { int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; @@ -143,115 +249,152 @@ __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, r_index += Index(m, r4) * r5 * r6; r_index += Index(n, r5) * r6; r_index += Index(o, r6); - output[pos] = Func()(input0[l_index], input1[r_index]); + y[pos] = Func()(x0[l_index], x1[r_index]); } } -template -__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, - const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, - const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, - const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0, - const T *input1, S *output) { +template +void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, + enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) { + int size = 1; + for (auto d : y_dims) { + size *= d; + } + switch (op) { case BROADCAST_TYPE_GREATER: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_LESS: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, - d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_MINIMUM: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_MAXIMUM: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_POWER: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_REALDIV: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_MUL: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, - d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_SUB: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, - d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_ADD: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, - d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_FLOORDIV: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_ABSGRAD: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, - d2, d3, d4, d5, d6, input0, input1, output); - case BROADCAST_TYPE_DIV: - return BroadcastOperator>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2, - d3, d4, d5, d6, input0, input1, output); + return BroadcastCmpKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + default: + break; } } -template -void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const T *input0, const T *input1, - S *output, cudaStream_t stream) { - int size = 1; - for (auto d : output_shape) { - size *= d; - } - BroadcastKernel<<>>( - lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3], lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0], - rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4], rhs_shape[5], rhs_shape[6], output_shape[0], - output_shape[1], output_shape[2], output_shape[3], output_shape[4], output_shape[5], output_shape[6], op, input0, - input1, output); -} +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, + bool *y, cudaStream_t stream); +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, + bool *y, cudaStream_t stream); +template void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, + bool *y, cudaStream_t stream); -template -__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) { - for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) { - output[pos] = Func()(input0[pos], input1[pos]); +// Broadcast Arithmetic +template +__global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5, + const int l6, const int r0, const int r1, const int r2, const int r3, const int r4, + const int r5, const int r6, const int d0, const int d1, const int d2, const int d3, + const int d4, const int d5, const int d6, const T *x0, const T *x1, T *y) { + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6; + pos += blockDim.x * gridDim.x) { + int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0; + int j = pos / (d2 * d3 * d4 * d5 * d6) % d1; + int k = pos / (d3 * d4 * d5 * d6) % d2; + int l = pos / (d4 * d5 * d6) % d3; + int m = pos / (d5 * d6) % d4; + int n = pos / d6 % d5; + int o = pos % d6; + + int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6; + l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6; + l_index += Index(k, l2) * l3 * l4 * l5 * l6; + l_index += Index(l, l3) * l4 * l5 * l6; + l_index += Index(m, l4) * l5 * l6; + l_index += Index(n, l5) * l6; + l_index += Index(o, l6); + int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6; + r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6; + r_index += Index(k, r2) * r3 * r4 * r5 * r6; + r_index += Index(l, r3) * r4 * r5 * r6; + r_index += Index(m, r4) * r5 * r6; + r_index += Index(n, r5) * r6; + r_index += Index(o, r6); + y[pos] = Func()(x0[l_index], x1[r_index]); } } -template -__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1, - S *output) { +template +void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, + enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) { + int size = 1; + for (auto d : y_dims) { + size *= d; + } switch (op) { - case BROADCAST_TYPE_GREATER: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_LESS: - return NoBroadcastOperator>(nums, input0, input1, output); - case BROADCAST_TYPE_MINIMUM: - return NoBroadcastOperator>(nums, input0, input1, output); case BROADCAST_TYPE_MAXIMUM: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + case BROADCAST_TYPE_MINIMUM: + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_POWER: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_REALDIV: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_MUL: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_SUB: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_ADD: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_FLOORDIV: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_ABSGRAD: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); case BROADCAST_TYPE_DIV: - return NoBroadcastOperator>(nums, input0, input1, output); + return BroadcastArithKernel><<<(size + 255) / 256, 256, 0, stream>>>( + x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1], + x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3], + y_dims[4], y_dims[5], y_dims[6], x0, x1, y); + default: + break; } } -template -void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output, - cudaStream_t stream) { - NoBroadcastKernel<<>>(nums, op, input0, input1, output); -} +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const float *x0, const float *x1, + float *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const half *x0, const half *x1, + half *y, cudaStream_t stream); +template void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, + const std::vector &y_dims, enum BroadcastOpType op, const int *x0, const int *x1, + int *y, cudaStream_t stream); +// BroadcastTo template __global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1, const int o2, const int o3, const T *input_addr, T *output_addr) { @@ -274,36 +417,6 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con output_addr); } -template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const float *input0, - const float *input1, bool *output, cudaStream_t stream); -template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const float *input0, - const float *input1, float *output, cudaStream_t stream); -template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const half *input0, - const half *input1, bool *output, cudaStream_t stream); -template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const half *input0, - const half *input1, half *output, cudaStream_t stream); -template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const int *input0, - const int *input1, int *output, cudaStream_t stream); -template void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const int *input0, - const int *input1, bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, - bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1, - float *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, - bool *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1, - half *output, cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output, - cudaStream_t stream); -template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, bool *output, - cudaStream_t stream); template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, const int &o2, const int &o3, const float *input_addr, float *output_addr, cudaStream_t stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh index c00c12cadb76f7880a6e45953b268003508ec98d..9f0a5ba984142f0087ddd45e0ea0e7c796b80b2c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh @@ -36,17 +36,21 @@ enum BroadcastOpType { BROADCAST_TYPE_INVALID = 0xffffffff, }; -template -void Broadcast(const std::vector &lhs_shape, const std::vector &rhs_shape, - const std::vector &output_shape, enum BroadcastOpType op, const T *input0, const T *input1, - S *output, cudaStream_t stream); +template +void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream); + +template +void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream); -template -void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output, - cudaStream_t stream); +template +void BroadcastCmp(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, + enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream); + +template +void BroadcastArith(const std::vector &x0_dims, const std::vector &x1_dims, const std::vector &y_dims, + enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream); template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1, const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream); - #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h index e7471d8cb151b59f227b8d60b4c52cc99e619087..04002b2fc0b3d96fd89061d764302503a2aac9b5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/addn_gpu_kernel.h @@ -58,8 +58,8 @@ class AddNGpuFwdKernel : public GpuKernel { for (size_t i = 0; i < IntToSize(num_input_); i++) { T *input_addr = GetDeviceAddress(inputs, i); if (cudnn_data_type_ == CUDNN_DATA_INT32) { - NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr, - reinterpret_cast(stream_ptr)); + ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr, + reinterpret_cast(stream_ptr)); } else { CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr, &(i > 0 ? alpha : beta), input_descriptor_, output_addr), diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc index 2fcbc9ea6d26ac0717a66a25847158a665b16c86..6109f72e8e6d132f1f26d5efeec0486bcae600ab 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.cc @@ -19,119 +19,119 @@ namespace mindspore { namespace kernel { // fp32 -MS_REG_GPU_KERNEL_TWO( +MS_REG_GPU_KERNEL_ONE( Greater, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, float, bool) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Maximum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Minimum, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( RealDiv, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( AbsGrad, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, float) +MS_REG_GPU_KERNEL_ONE( Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32), - BroadcastOpGpuKernel, float, float) + BroadcastOpGpuKernel, float) // fp16 -MS_REG_GPU_KERNEL_TWO( +MS_REG_GPU_KERNEL_ONE( Greater, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, half, bool) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, half, bool) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Maximum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Minimum, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( RealDiv, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( AbsGrad, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, half) +MS_REG_GPU_KERNEL_ONE( Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16), - BroadcastOpGpuKernel, half, half) + BroadcastOpGpuKernel, half) // int32 -MS_REG_GPU_KERNEL_TWO( +MS_REG_GPU_KERNEL_ONE( Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool), - BroadcastOpGpuKernel, int, bool) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) -MS_REG_GPU_KERNEL_TWO( + BroadcastOpGpuKernel, int) +MS_REG_GPU_KERNEL_ONE( Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32), - BroadcastOpGpuKernel, int, int) + BroadcastOpGpuKernel, int) } // namespace kernel } // namespace mindspore diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h index cde22769ea0ec5251db0a295de0cc521f6e193c6..b739969a3f23c43994001a4160a36ce48bdd05fc 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/math/broadcast_gpu_kernel.h @@ -28,11 +28,16 @@ namespace mindspore { namespace kernel { constexpr int MAX_DIMS = 7; -template +template class BroadcastOpGpuKernel : public GpuKernel { public: BroadcastOpGpuKernel() - : op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {} + : op_type_(BROADCAST_TYPE_INVALID), + need_broadcast_(false), + is_comp_op_(false), + input1_num_(1), + input2_num_(1), + output_num_(1) {} ~BroadcastOpGpuKernel() override = default; const std::vector &GetInputSizeList() const override { return input_size_list_; } @@ -43,13 +48,23 @@ class BroadcastOpGpuKernel : public GpuKernel { const std::vector &outputs, void *stream_ptr) override { T *lhs = GetDeviceAddress(inputs, 0); T *rhs = GetDeviceAddress(inputs, 1); - S *output = GetDeviceAddress(outputs, 0); - if (need_broadcast_) { - Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, - reinterpret_cast(stream_ptr)); + if (is_comp_op_) { + bool *output = GetDeviceAddress(outputs, 0); + if (need_broadcast_) { + BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, + reinterpret_cast(stream_ptr)); + } else { + ElewiseCmp(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + } } else { - NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + T *output = GetDeviceAddress(outputs, 0); + if (need_broadcast_) { + BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output, + reinterpret_cast(stream_ptr)); + } else { + ElewiseArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast(stream_ptr)); + } } return true; @@ -91,26 +106,42 @@ class BroadcastOpGpuKernel : public GpuKernel { void InitSizeLists() override { input_size_list_.push_back(input1_num_ * sizeof(T)); input_size_list_.push_back(input2_num_ * sizeof(T)); - output_size_list_.push_back(output_num_ * sizeof(S)); + + auto unit_size = is_comp_op_ ? sizeof(bool) : sizeof(T); + output_size_list_.push_back(output_num_ * unit_size); } private: void GetOpType(const CNodePtr &kernel_node) { std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node); - static std::map kBroadcastTypeMap = { - {"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM}, - {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV}, - {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD}, - {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, {"Div", BROADCAST_TYPE_DIV}, + static std::map kBroadcastCmpTypeMap = { + {"Greater", BROADCAST_TYPE_GREATER}, + {"Less", BROADCAST_TYPE_LESS}, }; - auto iter = kBroadcastTypeMap.find(kernel_name); - if (iter == kBroadcastTypeMap.end()) { - MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; - } else { + auto iter = kBroadcastCmpTypeMap.find(kernel_name); + if (iter != kBroadcastCmpTypeMap.end()) { + op_type_ = iter->second; + is_comp_op_ = true; + return; + } + + static std::map kBroadcastArithmetricTypeMap = { + {"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, + {"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, + {"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, + {"Div", BROADCAST_TYPE_DIV}, + }; + + iter = kBroadcastArithmetricTypeMap.find(kernel_name); + if (iter != kBroadcastArithmetricTypeMap.end()) { op_type_ = iter->second; + is_comp_op_ = false; + return; } + + MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported."; } bool IsBroadcast(const std::vector &lhs, const std::vector &rhs) { @@ -127,6 +158,7 @@ class BroadcastOpGpuKernel : public GpuKernel { BroadcastOpType op_type_; bool need_broadcast_; + bool is_comp_op_; int input1_num_; int input2_num_; int output_num_; @@ -137,7 +169,7 @@ class BroadcastOpGpuKernel : public GpuKernel { std::vector input_size_list_; std::vector output_size_list_; std::vector workspace_size_list_; -}; +}; // namespace kernel } // namespace kernel } // namespace mindspore diff --git a/tests/st/ops/gpu/test_broadcast_op.py b/tests/st/ops/gpu/test_broadcast_op.py index 202517729a32b0fa1c21564ab85cb8d7a9f2c2e6..53b3fd14c964d35f73fbacaf9f7dc6906c3348b3 100644 --- a/tests/st/ops/gpu/test_broadcast_op.py +++ b/tests/st/ops/gpu/test_broadcast_op.py @@ -160,3 +160,45 @@ def test_broadcast_diff_dims(): output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np)) output_np = x1_np - x2_np assert np.allclose(output_ms.asnumpy(), output_np) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_broadcast_fp16(): + context.set_context(mode=context.GRAPH_MODE, device_target='GPU') + + x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16) + x2_np = np.random.rand(1, 4, 1, 6).astype(np.float16) + + output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.minimum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.maximum(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np > x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np < x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np)) + output_np = np.power(x1_np, x2_np) + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np / x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np * x2_np + assert np.allclose(output_ms.asnumpy(), output_np) + + output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np)) + output_np = x1_np - x2_np + assert np.allclose(output_ms.asnumpy(), output_np)