提交 6ebe132c 编写于 作者: W wilfChen

broadcast refactor

上级 4499d126
......@@ -15,110 +15,216 @@
*/
#include <vector>
#include <iostream>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
template <typename T, typename S>
// Basic function
template <typename T>
struct GreaterFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; }
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs > rhs ? true : false; }
};
template <typename T, typename S>
template <typename T>
struct LessFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; }
__device__ __host__ __forceinline__ bool operator()(const T &lhs, const T &rhs) { return lhs < rhs ? true : false; }
};
template <typename T, typename S>
template <typename T>
struct MinimumFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs < rhs ? lhs : rhs; }
};
template <typename T, typename S>
template <typename T>
struct MaximumFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return lhs > rhs ? lhs : rhs; }
};
template <typename T, typename S>
template <typename T>
struct PowerFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return pow(lhs, rhs); }
};
template <>
struct PowerFunc<half, half> {
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
struct PowerFunc<half> {
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
return __float2half(pow(__half2float(lhs), __half2float(rhs)));
}
};
template <typename T, typename S>
template <>
struct PowerFunc<half2> {
__device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
float2 base = __half22float2(lhs);
float2 index = __half22float2(rhs);
base.x = pow(base.x, index.x);
base.y = pow(base.y, index.y);
return __float22half2_rn(base);
}
};
template <typename T>
struct RealDivFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
};
template <typename T, typename S>
template <typename T>
struct DivFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs / rhs); }
};
template <typename T, typename S>
template <typename T>
struct MulFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs * rhs); }
};
template <typename T, typename S>
template <typename T>
struct SubFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs - rhs); }
};
template <typename T, typename S>
template <typename T>
struct AddFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) { return (lhs + rhs); }
};
template <typename T, typename S>
// convert to float to fix accuracy issue
template <typename T>
struct FloorDivFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) {
return floor(static_cast<float>(lhs) / static_cast<float>(rhs));
__device__ __host__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
return floorf(static_cast<float>(lhs) / static_cast<float>(rhs));
}
};
template <>
struct FloorDivFunc<half, half> {
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
return __float2half(floor(__half2float(lhs) / __half2float(rhs)));
struct FloorDivFunc<half> {
__device__ __host__ __forceinline__ half operator()(const half &lhs, const half &rhs) {
return floorf(__half2float(lhs) / __half2float(rhs));
}
};
template <>
struct FloorDivFunc<half, bool> {
// invalid branch
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; }
struct FloorDivFunc<half2> {
__device__ __host__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
float2 l = __half22float2(lhs);
float2 r = __half22float2(rhs);
l.x = floorf(l.x / r.x);
l.y = floorf(l.y / r.y);
return __float22half2_rn(l);
}
};
template <typename T, typename S>
template <typename T>
struct AbsGradFunc {
__device__ __forceinline__ S operator()(const T &lhs, const T &rhs) {
__device__ __forceinline__ T operator()(const T &lhs, const T &rhs) {
T zero = 0.0;
return lhs < zero ? -rhs : rhs;
}
};
template <>
struct PowerFunc<half, bool> {
// invalid branch
__device__ __forceinline__ half operator()(const half &lhs, const half &rhs) { return false; }
struct AbsGradFunc<half2> {
__device__ __forceinline__ half2 operator()(const half2 &lhs, const half2 &rhs) {
half2 zero(0.0, 0.0);
return lhs < zero ? -rhs : rhs;
}
};
// Element-wise Comparation
template <typename T, typename Func>
__global__ void ElewiseCmpKernel(const int nums, const T *x0, const T *x1, bool *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
y[pos] = Func()(x0[pos], x1[pos]);
}
}
template <typename T>
void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) {
switch (op) {
case BROADCAST_TYPE_GREATER:
return ElewiseCmpKernel<T, GreaterFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_LESS:
return ElewiseCmpKernel<T, LessFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default:
break;
}
}
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, bool *y,
cudaStream_t stream);
template void ElewiseCmp(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, bool *y,
cudaStream_t stream);
// Element-wise ArithMetic
template <typename T, typename Func>
__global__ void ElewiseArithKernel(const int nums, const T *x0, const T *x1, T *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
y[pos] = Func()(x0[pos], x1[pos]);
}
}
template <typename T>
void ElewiseArithKernel(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
switch (op) {
case BROADCAST_TYPE_MINIMUM:
return ElewiseArithKernel<T, MinimumFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_MAXIMUM:
return ElewiseArithKernel<T, MaximumFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_POWER:
return ElewiseArithKernel<T, PowerFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_REALDIV:
return ElewiseArithKernel<T, RealDivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_MUL:
return ElewiseArithKernel<T, MulFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_SUB:
return ElewiseArithKernel<T, SubFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_ADD:
return ElewiseArithKernel<T, AddFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_FLOORDIV:
return ElewiseArithKernel<T, FloorDivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_ABSGRAD:
return ElewiseArithKernel<T, AbsGradFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
case BROADCAST_TYPE_DIV:
return ElewiseArithKernel<T, DivFunc<T>><<<(nums + 255) / 256, 256, 0, stream>>>(nums, x0, x1, y);
default:
break;
}
}
template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
return ElewiseArithKernel(nums, op, x0, x1, y, stream);
}
template <>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y,
cudaStream_t stream) {
if (nums % 2 == 0) {
ElewiseArithKernel<half2>(nums / 2, op, reinterpret_cast<const half2 *>(x0), reinterpret_cast<const half2 *>(x1),
reinterpret_cast<half2 *>(y), stream);
} else {
return ElewiseArithKernel(nums, op, x0, x1, y, stream);
}
}
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const float *x0, const float *x1, float *y,
cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const half *x0, const half *x1, half *y,
cudaStream_t stream);
template void ElewiseArith(const int &nums, enum BroadcastOpType op, const int *x0, const int *x1, int *y,
cudaStream_t stream);
// Broadcast comparation
__device__ __forceinline__ int Index(const int &index, const int &dim) { return dim == 1 ? 0 : index; }
template <typename T, typename S, typename Func>
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
const int &l4, const int &l5, const int &l6, const int &r0,
const int &r1, const int &r2, const int &r3, const int &r4,
const int &r5, const int &r6, const int &d0, const int &d1,
const int &d2, const int &d3, const int &d4, const int &d5,
const int &d6, const T *input0, const T *input1, S *output) {
template <typename T, typename Func>
__global__ void BroadcastCmpKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, const T *x0, const T *x1, bool *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
......@@ -143,115 +249,152 @@ __device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1,
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
output[pos] = Func()(input0[l_index], input1[r_index]);
y[pos] = Func()(x0[l_index], x1[r_index]);
}
}
template <typename T, typename S>
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0,
const T *input1, S *output) {
template <typename T>
void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream) {
int size = 1;
for (auto d : y_dims) {
size *= d;
}
switch (op) {
case BROADCAST_TYPE_GREATER:
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
return BroadcastCmpKernel<T, GreaterFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_LESS:
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MINIMUM:
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MAXIMUM:
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_POWER:
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_REALDIV:
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MUL:
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_SUB:
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_ADD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_FLOORDIV:
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_ABSGRAD:
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_DIV:
return BroadcastOperator<T, S, DivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
return BroadcastCmpKernel<T, LessFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
default:
break;
}
}
template <typename T, typename S>
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
S *output, cudaStream_t stream) {
int size = 1;
for (auto d : output_shape) {
size *= d;
}
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(
lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3], lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0],
rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4], rhs_shape[5], rhs_shape[6], output_shape[0],
output_shape[1], output_shape[2], output_shape[3], output_shape[4], output_shape[5], output_shape[6], op, input0,
input1, output);
}
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
bool *y, cudaStream_t stream);
template void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
bool *y, cudaStream_t stream);
template <typename T, typename S, typename Func>
__device__ __forceinline__ void NoBroadcastOperator(const int &nums, const T *input0, const T *input1, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < nums; pos += blockDim.x * gridDim.x) {
output[pos] = Func()(input0[pos], input1[pos]);
// Broadcast Arithmetic
template <typename T, typename Func>
__global__ void BroadcastArithKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, const T *x0, const T *x1, T *y) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
int k = pos / (d3 * d4 * d5 * d6) % d2;
int l = pos / (d4 * d5 * d6) % d3;
int m = pos / (d5 * d6) % d4;
int n = pos / d6 % d5;
int o = pos % d6;
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
y[pos] = Func()(x0[l_index], x1[r_index]);
}
}
template <typename T, typename S>
__global__ void NoBroadcastKernel(const int nums, enum BroadcastOpType op, const T *input0, const T *input1,
S *output) {
template <typename T>
void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream) {
int size = 1;
for (auto d : y_dims) {
size *= d;
}
switch (op) {
case BROADCAST_TYPE_GREATER:
return NoBroadcastOperator<T, S, GreaterFunc<T, bool>>(nums, input0, input1, output);
case BROADCAST_TYPE_LESS:
return NoBroadcastOperator<T, S, LessFunc<T, bool>>(nums, input0, input1, output);
case BROADCAST_TYPE_MINIMUM:
return NoBroadcastOperator<T, S, MinimumFunc<T, S>>(nums, input0, input1, output);
case BROADCAST_TYPE_MAXIMUM:
return NoBroadcastOperator<T, S, MaximumFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, MaximumFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_MINIMUM:
return BroadcastArithKernel<T, MinimumFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_POWER:
return NoBroadcastOperator<T, S, PowerFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, PowerFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_REALDIV:
return NoBroadcastOperator<T, S, RealDivFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, RealDivFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_MUL:
return NoBroadcastOperator<T, S, MulFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, MulFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_SUB:
return NoBroadcastOperator<T, S, SubFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, SubFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_ADD:
return NoBroadcastOperator<T, S, AddFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, AddFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_FLOORDIV:
return NoBroadcastOperator<T, S, FloorDivFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, FloorDivFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_ABSGRAD:
return NoBroadcastOperator<T, S, AbsGradFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, AbsGradFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
case BROADCAST_TYPE_DIV:
return NoBroadcastOperator<T, S, DivFunc<T, S>>(nums, input0, input1, output);
return BroadcastArithKernel<T, DivFunc<T>><<<(size + 255) / 256, 256, 0, stream>>>(
x0_dims[0], x0_dims[1], x0_dims[2], x0_dims[3], x0_dims[4], x0_dims[5], x0_dims[6], x1_dims[0], x1_dims[1],
x1_dims[2], x1_dims[3], x1_dims[4], x1_dims[5], x1_dims[6], y_dims[0], y_dims[1], y_dims[2], y_dims[3],
y_dims[4], y_dims[5], y_dims[6], x0, x1, y);
default:
break;
}
}
template <typename T, typename S>
void NoBroadcast(const int &nums, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
cudaStream_t stream) {
NoBroadcastKernel<<<GET_BLOCKS(nums), GET_THREADS, 0, stream>>>(nums, op, input0, input1, output);
}
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const float *x0, const float *x1,
float *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const half *x0, const half *x1,
half *y, cudaStream_t stream);
template void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims,
const std::vector<int> &y_dims, enum BroadcastOpType op, const int *x0, const int *x1,
int *y, cudaStream_t stream);
// BroadcastTo
template <typename T>
__global__ void BroadcastToKernel(const int i0, const int i1, const int i2, const int i3, const int o0, const int o1,
const int o2, const int o3, const T *input_addr, T *output_addr) {
......@@ -274,36 +417,6 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
output_addr);
}
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
const float *input1, bool *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
const float *input1, float *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
const half *input1, bool *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
const half *input1, half *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
const int *input1, int *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
const int *input1, bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
float *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1,
bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const half *input0, const half *input1,
half *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, int *output,
cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
cudaStream_t stream);
template void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const float *input_addr, float *output_addr,
cudaStream_t stream);
......
......@@ -36,17 +36,21 @@ enum BroadcastOpType {
BROADCAST_TYPE_INVALID = 0xffffffff,
};
template <typename T, typename S>
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
S *output, cudaStream_t stream);
template <typename T>
void ElewiseCmp(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream);
template <typename T>
void ElewiseArith(const int &nums, enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
template <typename T, typename S>
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
cudaStream_t stream);
template <typename T>
void BroadcastCmp(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, bool *y, cudaStream_t stream);
template <typename T>
void BroadcastArith(const std::vector<int> &x0_dims, const std::vector<int> &x1_dims, const std::vector<int> &y_dims,
enum BroadcastOpType op, const T *x0, const T *x1, T *y, cudaStream_t stream);
template <typename T>
void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, const int &o0, const int &o1,
const int &o2, const int &o3, const T *input_addr, T *output_addr, cudaStream_t stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
......@@ -58,7 +58,7 @@ class AddNGpuFwdKernel : public GpuKernel {
for (size_t i = 0; i < IntToSize(num_input_); i++) {
T *input_addr = GetDeviceAddress<T>(inputs, i);
if (cudnn_data_type_ == CUDNN_DATA_INT32) {
NoBroadcast(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
ElewiseArith(outputs[0]->size / sizeof(T), BROADCAST_TYPE_ADD, input_addr, output_addr, output_addr,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnAddTensor(cudnn_handle_, &alpha, input_descriptor_, input_addr,
......
......@@ -19,119 +19,119 @@
namespace mindspore {
namespace kernel {
// fp32
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
BroadcastOpGpuKernel, float, float)
BroadcastOpGpuKernel, float)
// fp16
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_ONE(
Greater,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, half, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, half, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Maximum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Minimum,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Pow, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
RealDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Sub, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
TensorAdd,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
FloorDiv,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
AbsGrad,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, half)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
BroadcastOpGpuKernel, half, half)
BroadcastOpGpuKernel, half)
// int32
MS_REG_GPU_KERNEL_TWO(
MS_REG_GPU_KERNEL_ONE(
Less, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeBool),
BroadcastOpGpuKernel, int, bool)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
TensorAdd, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Minimum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Maximum, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Mul, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
FloorDiv, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
AbsGrad, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
MS_REG_GPU_KERNEL_TWO(
BroadcastOpGpuKernel, int)
MS_REG_GPU_KERNEL_ONE(
Div, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
BroadcastOpGpuKernel, int, int)
BroadcastOpGpuKernel, int)
} // namespace kernel
} // namespace mindspore
......@@ -28,11 +28,16 @@
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
template <typename T, typename S>
template <typename T>
class BroadcastOpGpuKernel : public GpuKernel {
public:
BroadcastOpGpuKernel()
: op_type_(BROADCAST_TYPE_INVALID), need_broadcast_(false), input1_num_(1), input2_num_(1), output_num_(1) {}
: op_type_(BROADCAST_TYPE_INVALID),
need_broadcast_(false),
is_comp_op_(false),
input1_num_(1),
input2_num_(1),
output_num_(1) {}
~BroadcastOpGpuKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
......@@ -43,13 +48,23 @@ class BroadcastOpGpuKernel : public GpuKernel {
const std::vector<AddressPtr> &outputs, void *stream_ptr) override {
T *lhs = GetDeviceAddress<T>(inputs, 0);
T *rhs = GetDeviceAddress<T>(inputs, 1);
S *output = GetDeviceAddress<S>(outputs, 0);
if (is_comp_op_) {
bool *output = GetDeviceAddress<bool>(outputs, 0);
if (need_broadcast_) {
Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
BroadcastCmp(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
ElewiseCmp(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
} else {
T *output = GetDeviceAddress<T>(outputs, 0);
if (need_broadcast_) {
BroadcastArith(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
ElewiseArith(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
}
return true;
......@@ -91,26 +106,42 @@ class BroadcastOpGpuKernel : public GpuKernel {
void InitSizeLists() override {
input_size_list_.push_back(input1_num_ * sizeof(T));
input_size_list_.push_back(input2_num_ * sizeof(T));
output_size_list_.push_back(output_num_ * sizeof(S));
auto unit_size = is_comp_op_ ? sizeof(bool) : sizeof(T);
output_size_list_.push_back(output_num_ * unit_size);
}
private:
void GetOpType(const CNodePtr &kernel_node) {
std::string kernel_name = AnfAlgo::GetCNodeName(kernel_node);
static std::map<std::string, BroadcastOpType> kBroadcastTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER}, {"Less", BROADCAST_TYPE_LESS}, {"Maximum", BROADCAST_TYPE_MAXIMUM},
{"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER}, {"RealDiv", BROADCAST_TYPE_REALDIV},
{"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB}, {"TensorAdd", BROADCAST_TYPE_ADD},
{"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD}, {"Div", BROADCAST_TYPE_DIV},
static std::map<std::string, BroadcastOpType> kBroadcastCmpTypeMap = {
{"Greater", BROADCAST_TYPE_GREATER},
{"Less", BROADCAST_TYPE_LESS},
};
auto iter = kBroadcastTypeMap.find(kernel_name);
if (iter == kBroadcastTypeMap.end()) {
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
} else {
auto iter = kBroadcastCmpTypeMap.find(kernel_name);
if (iter != kBroadcastCmpTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = true;
return;
}
static std::map<std::string, BroadcastOpType> kBroadcastArithmetricTypeMap = {
{"Maximum", BROADCAST_TYPE_MAXIMUM}, {"Minimum", BROADCAST_TYPE_MINIMUM}, {"Pow", BROADCAST_TYPE_POWER},
{"RealDiv", BROADCAST_TYPE_REALDIV}, {"Mul", BROADCAST_TYPE_MUL}, {"Sub", BROADCAST_TYPE_SUB},
{"TensorAdd", BROADCAST_TYPE_ADD}, {"FloorDiv", BROADCAST_TYPE_FLOORDIV}, {"AbsGrad", BROADCAST_TYPE_ABSGRAD},
{"Div", BROADCAST_TYPE_DIV},
};
iter = kBroadcastArithmetricTypeMap.find(kernel_name);
if (iter != kBroadcastArithmetricTypeMap.end()) {
op_type_ = iter->second;
is_comp_op_ = false;
return;
}
MS_LOG(EXCEPTION) << "operation " << kernel_name << " is not supported.";
}
bool IsBroadcast(const std::vector<size_t> &lhs, const std::vector<size_t> &rhs) {
......@@ -127,6 +158,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
BroadcastOpType op_type_;
bool need_broadcast_;
bool is_comp_op_;
int input1_num_;
int input2_num_;
int output_num_;
......@@ -137,7 +169,7 @@ class BroadcastOpGpuKernel : public GpuKernel {
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
};
}; // namespace kernel
} // namespace kernel
} // namespace mindspore
......
......@@ -160,3 +160,45 @@ def test_broadcast_diff_dims():
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_broadcast_fp16():
context.set_context(mode=context.GRAPH_MODE, device_target='GPU')
x1_np = np.random.rand(3, 1, 5, 1).astype(np.float16)
x2_np = np.random.rand(1, 4, 1, 6).astype(np.float16)
output_ms = P.Minimum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.minimum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Maximum()(Tensor(x1_np), Tensor(x2_np))
output_np = np.maximum(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Greater()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np > x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Less()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np < x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Pow()(Tensor(x1_np), Tensor(x2_np))
output_np = np.power(x1_np, x2_np)
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.RealDiv()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np / x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Mul()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np * x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
output_ms = P.Sub()(Tensor(x1_np), Tensor(x2_np))
output_np = x1_np - x2_np
assert np.allclose(output_ms.asnumpy(), output_np)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册