提交 4e25fec7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!324 Gpu Slice kernel performance improve

Merge pull request !324 from chenweifeng/slice
......@@ -41,8 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel {
CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} else {
CalSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0],
input_shape_[1], input_shape_[2], input_shape_[3], input, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
......
......@@ -21,11 +21,22 @@
#include "kernel/gpu/cuda_impl/slice_impl.cuh"
template <typename T>
__global__ void Slice(const T* input, int p, int start, int length, T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) {
output[p + pos] = input[start + pos];
__global__ void Slice4D(const int s1, const int s2, const int s3, const int s4,
const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const T *input, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
int i = pos / (l2 * l3 * l4) % l1;
int j = pos / (l3 * l4) % l2;
int k = pos / l4 % l3;
int o = pos % l4;
int offset = (i + s1) * (d2 * d3 * d4) +
(j + s2) * (d3 * d4) +
(k + s3) * d4 +
(o + s4);
output[pos] = input[offset];
}
return;
}
template <typename T>
__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) {
......@@ -64,22 +75,12 @@ void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaSt
return;
}
template <typename T>
void CalSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, const std::vector<int> begin,
const std::vector<int> size, T* output, cudaStream_t cuda_stream) {
int block = in_shape[1] * in_shape[2] * in_shape[3];
int map = in_shape[2] * in_shape[3];
int w = in_shape[3];
int length = size[3];
int p = 0;
for (int i = begin[0]; i < size[0] + begin[0]; i++) {
for (int j = begin[1]; j < size[1] + begin[1]; j++) {
for (int k = begin[2]; k < size[2] + begin[2]; k++) {
Slice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, i * block + j * map + k * w + begin[3],
length, output);
p = p + size[3];
}
}
}
void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const T *input, T *output, cudaStream_t stream) {
Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4,
d1, d2, d3, d4, input, output);
}
template <typename T>
void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, const std::vector<int> begin,
......@@ -147,9 +148,10 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
}
template void FillDeviceArray<float>(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream);
template void CalSlice<float>(const size_t input_size, const float* input, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, float* output,
cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const float *input, float *output, cudaStream_t stream);
template void CalSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, float* output,
cudaStream_t cuda_stream);
......@@ -160,9 +162,10 @@ template void CalStridedSliceGrad<float>(const size_t input_size, const float* d
const std::vector<int> begin, const std::vector<int> end,
const std::vector<int> strides, float* dx, cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream);
template void CalSlice<half>(const size_t input_size, const half* input, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, half* output,
cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const half *input, half *output, cudaStream_t stream);
template void CalSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, half* output,
cudaStream_t cuda_stream);
......@@ -173,9 +176,10 @@ template void CalStridedSliceGrad<half>(const size_t input_size, const half* dy,
const std::vector<int> begin, const std::vector<int> end,
const std::vector<int> strides, half* dx, cudaStream_t cuda_stream);
template void FillDeviceArray<int>(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream);
template void CalSlice<int>(const size_t input_size, const int* input, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, int* output,
cudaStream_t cuda_stream);
template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const int *input, int *output, cudaStream_t stream);
template void CalSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, int* output,
cudaStream_t cuda_stream);
......
......@@ -21,9 +21,12 @@
#include <vector>
#include "device/gpu/cuda_common.h"
template <typename T>
void CalSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, const std::vector<int> begin,
const std::vector<int> size, T* output, cudaStream_t cuda_stream);
void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const T *input, T *output, cudaStream_t stream);
template <typename T>
void CalSliceGrad(const size_t input_size, const T* input, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, T* output, cudaStream_t cuda_stream);
......
......@@ -43,3 +43,22 @@ def test_slice():
slice = Slice()
output = slice(x)
assert (output.asnumpy() == expect).all()
class SliceNet(nn.Cell):
def __init__(self):
super(SliceNet, self).__init__()
self.slice = P.Slice()
def construct(self, x):
return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224))
def test_slice_4d():
x_np = np.random.randn(32, 24, 224, 224).astype(np.float32)
output_np = x_np[:, 11:18, :, :]
x_ms = Tensor(x_np)
net = SliceNet()
output_ms = net(x_ms)
assert (output_ms.asnumpy() == output_np).all()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册