提交 fb405ee6 编写于 作者: L linqingke

broadcast, slice, scatter_nd ops optimizer.

上级 645f11fa
......@@ -182,30 +182,59 @@ class ArrayReduceGpuKernel : public GpuKernel {
void InferInAndOutDesc(const std::vector<size_t> &input_shape, const std::vector<size_t> &output_shape) {
std::vector<int> inputA;
std::vector<size_t> outputC_shape = output_shape;
ShapeNdTo4d(input_shape, &inputA);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_, inputA[0],
inputA[1], inputA[2], inputA[3]),
"cudnnSetTensor4dDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &inputA);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(inputA_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
inputA[0], inputA[1], inputA[2], inputA[3]),
"cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, inputA_descriptor_, data_type_);
for (auto dim : input_shape) {
inputA.emplace_back(SizeToInt(dim));
}
}
if (axis_[0] == -1) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
"cudnnSetTensor4dDescriptor failed");
if (inputA[0] == 1 && inputA[1] == 1 && inputA[2] == 1 && inputA[3] == 1) {
all_match_ = true;
outputC_shape.resize(input_shape.size(), 1);
if (outputC_shape.size() <= split_dim) {
CHECK_CUDNN_RET_WITH_EXCEPT(
cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_, 1, 1, 1, 1),
"cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
}
for (auto dim : inputA) {
if (dim != 1) {
return;
}
}
all_match_ = true;
return;
}
std::vector<int> outputC;
if (!keep_dims_) {
for (auto i : axis_) {
(void)(outputC_shape.insert(outputC_shape.begin() + i, 1));
}
}
std::vector<int> outputC;
ShapeNdTo4d(outputC_shape, &outputC);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
outputC[0], outputC[1], outputC[2], outputC[3]),
"cudnnSetTensor4dDescriptor failed");
if (outputC_shape.size() <= split_dim) {
ShapeNdTo4d(outputC_shape, &outputC);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(outputC_descriptor_, CUDNN_TENSOR_NCHW, data_type_,
outputC[0], outputC[1], outputC[2], outputC[3]),
"cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(outputC_shape, outputC_descriptor_, data_type_);
for (auto dim : outputC_shape) {
outputC.emplace_back(SizeToInt(dim));
}
}
if (inputA == outputC) {
all_match_ = true;
}
......
......@@ -69,6 +69,10 @@ class ScatterNdGpuFwdKernel : public GpuKernel {
memcpy_flag_ = true;
}
CHECK_CUDA_RET_WITH_EXCEPT(
cudaMemsetAsync(output, static_cast<T>(0.0), output_size_, reinterpret_cast<cudaStream_t>(stream_ptr)),
"cudaMemSet failed in ScatterNdGpuFwdKernel::Launch.");
const size_t input_size = input_size_ / sizeof(T);
const size_t output_size = output_size_ / sizeof(T);
......
......@@ -14,6 +14,7 @@
* limitations under the License.
*/
#include <vector>
#include "backend/kernel_compiler/gpu/cuda_impl/broadcast_impl.cuh"
#include "runtime/device/gpu/cuda_common.h"
......@@ -107,69 +108,97 @@ __device__ __forceinline__ int Index(const int &index, const int &dim) { return
template <typename T, typename S, typename Func>
__device__ __forceinline__ void BroadcastOperator(const int &l0, const int &l1, const int &l2, const int &l3,
const int &r0, const int &r1, const int &r2, const int &r3,
const int &d0, const int &d1, const int &d2, const int &d3,
const T *input0, const T *input1, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3; pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3) % d0;
int j = pos / (d2 * d3) % d1;
int k = pos / d3 % d2;
int l = pos % d3;
const int &l4, const int &l5, const int &l6, const int &r0,
const int &r1, const int &r2, const int &r3, const int &r4,
const int &r5, const int &r6, const int &d0, const int &d1,
const int &d2, const int &d3, const int &d4, const int &d5,
const int &d6, const T *input0, const T *input1, S *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < d0 * d1 * d2 * d3 * d4 * d5 * d6;
pos += blockDim.x * gridDim.x) {
int i = pos / (d1 * d2 * d3 * d4 * d5 * d6) % d0;
int j = pos / (d2 * d3 * d4 * d5 * d6) % d1;
int k = pos / (d3 * d4 * d5 * d6) % d2;
int l = pos / (d4 * d5 * d6) % d3;
int m = pos / (d5 * d6) % d4;
int n = pos / d6 % d5;
int o = pos % d6;
int l_index = Index(i, l0) * l1 * l2 * l3 + Index(j, l1) * l2 * l3 + Index(k, l2) * l3 + Index(l, l3);
int r_index = Index(i, r0) * r1 * r2 * r3 + Index(j, r1) * r2 * r3 + Index(k, r2) * r3 + Index(l, r3);
int l_index = Index(i, l0) * l1 * l2 * l3 * l4 * l5 * l6;
l_index += Index(j, l1) * l2 * l3 * l4 * l5 * l6;
l_index += Index(k, l2) * l3 * l4 * l5 * l6;
l_index += Index(l, l3) * l4 * l5 * l6;
l_index += Index(m, l4) * l5 * l6;
l_index += Index(n, l5) * l6;
l_index += Index(o, l6);
int r_index = Index(i, r0) * r1 * r2 * r3 * r4 * r5 * r6;
r_index += Index(j, r1) * r2 * r3 * r4 * r5 * r6;
r_index += Index(k, r2) * r3 * r4 * r5 * r6;
r_index += Index(l, r3) * r4 * r5 * r6;
r_index += Index(m, r4) * r5 * r6;
r_index += Index(n, r5) * r6;
r_index += Index(o, r6);
output[pos] = Func()(input0[l_index], input1[r_index]);
}
}
template <typename T, typename S>
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int r0, const int r1,
const int r2, const int r3, const int d0, const int d1, const int d2, const int d3,
enum BroadcastOpType op, const T *input0, const T *input1, S *output) {
__global__ void BroadcastKernel(const int l0, const int l1, const int l2, const int l3, const int l4, const int l5,
const int l6, const int r0, const int r1, const int r2, const int r3, const int r4,
const int r5, const int r6, const int d0, const int d1, const int d2, const int d3,
const int d4, const int d5, const int d6, enum BroadcastOpType op, const T *input0,
const T *input1, S *output) {
switch (op) {
case BROADCAST_TYPE_GREATER:
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, GreaterFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_LESS:
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, LessFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MINIMUM:
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, MinimumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MAXIMUM:
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, MaximumFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_POWER:
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, PowerFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_REALDIV:
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, RealDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_MUL:
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, MulFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_SUB:
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, SubFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_ADD:
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, AddFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1, d2,
d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_FLOORDIV:
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, FloorDivFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
case BROADCAST_TYPE_ABSGRAD:
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, input0, input1,
output);
return BroadcastOperator<T, S, AbsGradFunc<T, S>>(l0, l1, l2, l3, l4, l5, l6, r0, r1, r2, r3, r4, r5, r6, d0, d1,
d2, d3, d4, d5, d6, input0, input1, output);
}
}
template <typename T, typename S>
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
const T *input0, const T *input1, S *output, cudaStream_t stream) {
int size = d0 * d1 * d2 * d3;
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(l0, l1, l2, l3, r0, r1, r2, r3, d0, d1, d2, d3, op,
input0, input1, output);
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
S *output, cudaStream_t stream) {
int size = 1;
for (auto d : output_shape) {
size *= d;
}
BroadcastKernel<<<GET_BLOCKS(size), GET_THREADS, 0, stream>>>(lhs_shape[0], lhs_shape[1], lhs_shape[2], lhs_shape[3],
lhs_shape[4], lhs_shape[5], lhs_shape[6], rhs_shape[0],
rhs_shape[1], rhs_shape[2], rhs_shape[3], rhs_shape[4],
rhs_shape[5], rhs_shape[6], output_shape[0],
output_shape[1], output_shape[2], output_shape[3],
output_shape[4], output_shape[5], output_shape[6],
op, input0, input1, output);
}
template <typename T, typename S, typename Func>
......@@ -236,30 +265,24 @@ void BroadcastTo(const int &i0, const int &i1, const int &i2, const int &i3, con
output_addr);
}
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const float *input0, const float *input1, bool *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const float *input0, const float *input1, float *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const half *input0, const half *input1, bool *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const half *input0, const half *input1, half *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const int *input0, const int *input1, int *output,
cudaStream_t stream);
template void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1,
const int &r2, const int &r3, const int &d0, const int &d1, const int &d2, const int &d3,
enum BroadcastOpType op, const int *input0, const int *input1, bool *output,
cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
const float *input1, bool *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const float *input0,
const float *input1, float *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
const half *input1, bool *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const half *input0,
const half *input1, half *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
const int *input1, int *output, cudaStream_t stream);
template void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const int *input0,
const int *input1, bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
bool *output, cudaStream_t stream);
template void NoBroadcast(const int &nums, enum BroadcastOpType op, const float *input0, const float *input1,
......
......@@ -17,6 +17,7 @@
#ifndef MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
#define MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_BROADCAST_H_
#include <vector>
#include "runtime/device/gpu/cuda_common.h"
enum BroadcastOpType {
......@@ -35,9 +36,9 @@ enum BroadcastOpType {
};
template <typename T, typename S>
void Broadcast(const int &l0, const int &l1, const int &l2, const int &l3, const int &r0, const int &r1, const int &r2,
const int &r3, const int &d0, const int &d1, const int &d2, const int &d3, enum BroadcastOpType op,
const T *input0, const T *input1, S *output, cudaStream_t stream);
void Broadcast(const std::vector<int> &lhs_shape, const std::vector<int> &rhs_shape,
const std::vector<int> &output_shape, enum BroadcastOpType op, const T *input0, const T *input1,
S *output, cudaStream_t stream);
template <typename T, typename S>
void NoBroadcast(const int &size, enum BroadcastOpType op, const T *input0, const T *input1, S *output,
......
......@@ -25,10 +25,10 @@ __global__ void CheckValidKernel(const size_t size, const T *box, const T *img_m
const size_t right_y = i * 4 + 3;
S valid_flag = false;
valid_flag |= !(box[left_x] >= 0.f);
valid_flag |= !(box[left_y] >= 0.f);
valid_flag |= !(img_metas[0] * img_metas[2] - 1.f >= box[right_x]);
valid_flag |= !(img_metas[1] * img_metas[2] - 1.f >= box[right_y]);
valid_flag |= !(box[left_x] >= static_cast<T>(0.0));
valid_flag |= !(box[left_y] >= static_cast<T>(0.0));
valid_flag |= !(img_metas[1] * img_metas[2] - static_cast<T>(1.0) >= box[right_x]);
valid_flag |= !(img_metas[0] * img_metas[2] - static_cast<T>(1.0) >= box[right_y]);
valid[i] = !valid_flag;
}
......@@ -43,3 +43,5 @@ void CheckValid(const size_t &size, const T *box, const T *img_metas, S *valid,
template void CheckValid(const size_t &size, const float *box, const float *img_metas, bool *valid,
cudaStream_t cuda_stream);
template void CheckValid(const size_t &size, const half *box, const half *img_metas, bool *valid,
cudaStream_t cuda_stream);
......@@ -16,27 +16,26 @@
#include "backend/kernel_compiler/gpu/cuda_impl/iou_impl.cuh"
template <typename T>
__device__ T CoordinateMax(const T a, const T b) {
__device__ float CoordinateMax(const float a, const float b) {
return (a > b ? a : b);
}
template <typename T>
__device__ T CoordinateMin(const T a, const T b) {
__device__ float CoordinateMin(const float a, const float b) {
return (a < b ? a : b);
}
template <typename T>
__global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *iou_results, const size_t mode,
const size_t input_len_0) {
T location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
T overlaps_coordinate[IOU_DIMENSION];
const T epsilon = 1e-10;
float location_coordinate[IOU_LOCATION_NUM][IOU_DIMENSION];
float overlaps_coordinate[IOU_DIMENSION];
const float epsilon = 1e-10;
const float offset = 1.0;
for (size_t i = blockIdx.x * blockDim.x + threadIdx.x; i < size; i += gridDim.x * blockDim.x) {
for (size_t j = 0; j < IOU_DIMENSION; j++) {
location_coordinate[0][j] = box1[(i % input_len_0) * IOU_DIMENSION + j];
location_coordinate[1][j] = box2[(i / input_len_0) * IOU_DIMENSION + j];
location_coordinate[0][j] = static_cast<float>(box1[(i % input_len_0) * IOU_DIMENSION + j]);
location_coordinate[1][j] = static_cast<float>(box2[(i / input_len_0) * IOU_DIMENSION + j]);
}
overlaps_coordinate[0] = CoordinateMax(location_coordinate[0][0], location_coordinate[1][0]);
......@@ -44,18 +43,18 @@ __global__ void IOUKernel(const size_t size, const T *box1, const T *box2, T *io
overlaps_coordinate[2] = CoordinateMin(location_coordinate[0][2], location_coordinate[1][2]);
overlaps_coordinate[3] = CoordinateMin(location_coordinate[0][3], location_coordinate[1][3]);
T overlaps_w = CoordinateMax(0.f, overlaps_coordinate[2] - overlaps_coordinate[0] + 1);
T overlaps_h = CoordinateMax(0.f, overlaps_coordinate[3] - overlaps_coordinate[1] + 1);
T overlaps = overlaps_w * overlaps_h;
float overlaps_w = CoordinateMax(0.0, overlaps_coordinate[2] - overlaps_coordinate[0] + offset);
float overlaps_h = CoordinateMax(0.0, overlaps_coordinate[3] - overlaps_coordinate[1] + offset);
float overlaps = overlaps_w * overlaps_h;
T area1 = (location_coordinate[0][2] - location_coordinate[0][0] + 1) * (location_coordinate[0][3] -
location_coordinate[0][1] + 1);
T area2 = (location_coordinate[1][2] - location_coordinate[1][0] + 1) * (location_coordinate[1][3] -
location_coordinate[1][1] + 1);
float area1 = (location_coordinate[0][2] - location_coordinate[0][0] + offset) * (location_coordinate[0][3] -
location_coordinate[0][1] + offset);
float area2 = (location_coordinate[1][2] - location_coordinate[1][0] + offset) * (location_coordinate[1][3] -
location_coordinate[1][1] + offset);
if (mode == 0) {
iou_results[i] = overlaps / (area1 + area2 - overlaps + epsilon);
iou_results[i] = static_cast<T>(overlaps / (area1 + area2 - overlaps + epsilon));
} else {
iou_results[i] = overlaps / (area2 + epsilon);
iou_results[i] = static_cast<T>(overlaps / (area2 + epsilon));
}
}
......@@ -70,3 +69,5 @@ void IOU(const size_t &size, const T *box1, const T *box2, T *iou_results, const
template void IOU(const size_t &size, const float *box1, const float *box2, float *iou_results, const size_t &mode,
const size_t &input_len_0, cudaStream_t cuda_stream);
template void IOU(const size_t &size, const half *box1, const half *box2, half *iou_results, const size_t &mode,
const size_t &input_len_0, cudaStream_t cuda_stream);
......@@ -84,6 +84,40 @@ class GpuKernel : public KernelMod {
}
}
// set the tensor descriptor for cudnn/cublas
void CudnnSetTensorNdDescriptor(const std::vector<size_t> &shape, cudnnTensorDescriptor_t descriptor,
cudnnDataType_t data_type) {
if (shape.size() < 3) {
MS_EXCEPTION(ValueError) << "cudnnSetTensorNdDescriptor don't support" << shape.size() << "D.";
}
const int nbDims = shape.size();
int *dim = new (std::nothrow) int[nbDims];
if (dim == nullptr) {
MS_LOG(EXCEPTION) << "malloc dim failed.";
}
int *stride = new (std::nothrow) int[nbDims];
if (stride == nullptr) {
MS_LOG(EXCEPTION) << "malloc stride failed.";
}
for (int i = 0; i < nbDims; i++) {
dim[i] = SizeToInt(shape[i]);
stride[i] = 1;
}
for (int i = nbDims - 2; i >= 0; i--) {
stride[i] = stride[i + 1] * SizeToInt(shape[i + 1]);
}
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensorNdDescriptor(descriptor, data_type, nbDims, dim, stride),
"cudnnSetTensorNdDescriptor failed");
delete[] dim;
dim = nullptr;
delete[] stride;
stride = nullptr;
}
// choose the suitable datatype for cudnn/cublas
inline cudnnDataType_t GetCudnnDataType(const std::string &Type) {
auto type = kCudnnDtypeMap.find(Type);
......
......@@ -27,6 +27,7 @@
#include "backend/kernel_compiler/gpu/kernel_constants.h"
namespace mindspore {
namespace kernel {
constexpr int MAX_DIMS = 7;
template <typename T, typename S>
class BroadcastOpGpuKernel : public GpuKernel {
public:
......@@ -45,9 +46,8 @@ class BroadcastOpGpuKernel : public GpuKernel {
S *output = GetDeviceAddress<S>(outputs, 0);
if (need_broadcast_) {
Broadcast(lhs_shape_[0], lhs_shape_[1], lhs_shape_[2], lhs_shape_[3], rhs_shape_[0], rhs_shape_[1], rhs_shape_[2],
rhs_shape_[3], output_shape_[0], output_shape_[1], output_shape_[2], output_shape_[3], op_type_, lhs,
rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
Broadcast(lhs_shape_, rhs_shape_, output_shape_, op_type_, lhs, rhs, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
NoBroadcast(output_num_, op_type_, lhs, rhs, output, reinterpret_cast<cudaStream_t>(stream_ptr));
}
......@@ -60,10 +60,13 @@ class BroadcastOpGpuKernel : public GpuKernel {
auto shape2 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
auto shape3 = AnfAlgo::GetOutputInferShape(kernel_node, 0);
need_broadcast_ = IsBroadcast(shape1, shape2);
if (need_broadcast_ && shape1.size() > 4) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 4";
if (need_broadcast_ && shape1.size() > 7) {
MS_LOG(EXCEPTION) << "Broadcast operation not support dim greater than 7";
}
lhs_shape_.resize(MAX_DIMS, 1);
rhs_shape_.resize(MAX_DIMS, 1);
output_shape_.resize(MAX_DIMS, 1);
for (size_t i = 0; i < shape3.size(); i++) {
output_shape_[i] = shape3[i];
output_num_ *= shape3[i];
......@@ -127,9 +130,9 @@ class BroadcastOpGpuKernel : public GpuKernel {
int input1_num_;
int input2_num_;
int output_num_;
int lhs_shape_[4] = {1, 1, 1, 1};
int rhs_shape_[4] = {1, 1, 1, 1};
int output_shape_[4] = {1, 1, 1, 1};
std::vector<int> lhs_shape_;
std::vector<int> rhs_shape_;
std::vector<int> output_shape_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
......
......@@ -83,12 +83,19 @@ class ActivationGpuFwdKernel : public GpuKernel {
return true;
}
std::vector<int> shape;
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_NOT_PROPAGATE_NAN, 0.0),
"cudnnSetActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"cudnnSetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
}
InitSizeLists();
return true;
}
......
......@@ -90,12 +90,18 @@ class ActivationGradGpuKernel : public GpuKernel {
return true;
}
std::vector<int> shape;
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetActivationDescriptor(activation_desc_, mode_, CUDNN_PROPAGATE_NAN, 0.0),
"SetActivationDescriptor failed");
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed");
const int split_dim = 4;
if (input_shape.size() <= split_dim) {
ShapeNdTo4d(input_shape, &shape);
CHECK_CUDNN_RET_WITH_EXCEPT(cudnnSetTensor4dDescriptor(data_descriptor_, CUDNN_TENSOR_NCHW, cudnn_data_type_,
shape[0], shape[1], shape[2], shape[3]),
"SetTensor4dDescriptor failed");
} else {
CudnnSetTensorNdDescriptor(input_shape, data_descriptor_, cudnn_data_type_);
}
InitSizeLists();
return true;
......
......@@ -22,5 +22,9 @@ MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, float, bool)
MS_REG_GPU_KERNEL_TWO(
CheckValid,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeBool),
CheckValidGpuKernel, half, bool)
} // namespace kernel
} // namespace mindspore
......@@ -21,5 +21,8 @@ namespace kernel {
MS_REG_GPU_KERNEL_ONE(
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
IOUGpuKernel, float)
MS_REG_GPU_KERNEL_ONE(
IOU, KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
IOUGpuKernel, half)
} // namespace kernel
} // namespace mindspore
......@@ -37,8 +37,8 @@ def test_floor_div():
y0_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
x1_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
y1_np = np.random.randint(1, 5, (2, 1, 4, 4)).astype(np.float32)
x2_np = np.random.randint(1, 5, (2, 1, 1, 4)).astype(np.float32)
y2_np = np.random.randint(1, 5, (2, 3, 4, 4)).astype(np.float32)
x2_np = np.random.randint(1, 5, (2, 1, 1, 4, 9)).astype(np.float32)
y2_np = np.random.randint(1, 5, (2, 3, 4, 4, 9)).astype(np.float32)
x3_np = np.random.randint(1, 5, 1).astype(np.float32)
y3_np = np.random.randint(1, 5, 1).astype(np.float32)
x4_np = np.array(768).astype(np.float32)
......
......@@ -70,7 +70,7 @@ x11 = np.random.rand(1, 1, 1, 1).astype(np.float32)
axis11 = (0, 1, 2, 3)
keep_dims11 = False
x12 = np.random.rand(2, 3, 4, 4).astype(np.float32)
x12 = np.random.rand(2, 3, 4, 4, 5, 6).astype(np.float32)
axis12 = -2
keep_dims12 = False
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册