提交 4e25fec7 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!324 Gpu Slice kernel performance improve

Merge pull request !324 from chenweifeng/slice
...@@ -41,8 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel { ...@@ -41,8 +41,9 @@ class SliceGpuFwdKernel : public GpuKernel {
CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output, CalStridedSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, strides_, output,
reinterpret_cast<cudaStream_t>(stream_ptr)); reinterpret_cast<cudaStream_t>(stream_ptr));
} else { } else {
CalSlice(output_size_ / sizeof(T), input, input_shape_, begin_, size_, output, Slice4DKernel(begin_[0], begin_[1], begin_[2], begin_[3], size_[0], size_[1], size_[2], size_[3], input_shape_[0],
reinterpret_cast<cudaStream_t>(stream_ptr)); input_shape_[1], input_shape_[2], input_shape_[3], input, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
} }
return true; return true;
} }
......
...@@ -21,11 +21,22 @@ ...@@ -21,11 +21,22 @@
#include "kernel/gpu/cuda_impl/slice_impl.cuh" #include "kernel/gpu/cuda_impl/slice_impl.cuh"
template <typename T> template <typename T>
__global__ void Slice(const T* input, int p, int start, int length, T* output) { __global__ void Slice4D(const int s1, const int s2, const int s3, const int s4,
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (length); pos += blockDim.x * gridDim.x) { const int l1, const int l2, const int l3, const int l4,
output[p + pos] = input[start + pos]; const int d1, const int d2, const int d3, const int d4,
const T *input, T *output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (l1 * l2 * l3 * l4); pos += blockDim.x * gridDim.x) {
int i = pos / (l2 * l3 * l4) % l1;
int j = pos / (l3 * l4) % l2;
int k = pos / l4 % l3;
int o = pos % l4;
int offset = (i + s1) * (d2 * d3 * d4) +
(j + s2) * (d3 * d4) +
(k + s3) * d4 +
(o + s4);
output[pos] = input[offset];
} }
return;
} }
template <typename T> template <typename T>
__global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) { __global__ void SliceGrad(const T* dy, int p, int start, int length, T* output) {
...@@ -64,22 +75,12 @@ void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaSt ...@@ -64,22 +75,12 @@ void FillDeviceArray(const size_t input_size, T* addr, const float value, cudaSt
return; return;
} }
template <typename T> template <typename T>
void CalSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, const std::vector<int> begin, void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const std::vector<int> size, T* output, cudaStream_t cuda_stream) { const int l1, const int l2, const int l3, const int l4,
int block = in_shape[1] * in_shape[2] * in_shape[3]; const int d1, const int d2, const int d3, const int d4,
int map = in_shape[2] * in_shape[3]; const T *input, T *output, cudaStream_t stream) {
int w = in_shape[3]; Slice4D<<<GET_BLOCKS(l1 * l2 * l3 * l4), GET_THREADS, 0, stream>>>(s1, s2, s3, s4, l1, l2, l3, l4,
int length = size[3]; d1, d2, d3, d4, input, output);
int p = 0;
for (int i = begin[0]; i < size[0] + begin[0]; i++) {
for (int j = begin[1]; j < size[1] + begin[1]; j++) {
for (int k = begin[2]; k < size[2] + begin[2]; k++) {
Slice<<<GET_BLOCKS(input_size), GET_THREADS, 0, cuda_stream>>>(input, p, i * block + j * map + k * w + begin[3],
length, output);
p = p + size[3];
}
}
}
} }
template <typename T> template <typename T>
void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, const std::vector<int> begin, void CalSliceGrad(const size_t input_size, const T* dy, const std::vector<int> in_shape, const std::vector<int> begin,
...@@ -147,9 +148,10 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector ...@@ -147,9 +148,10 @@ void CalStridedSliceGrad(const size_t input_size, const T* dy, const std::vector
} }
template void FillDeviceArray<float>(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray<float>(const size_t input_size, float* addr, const float value, cudaStream_t cuda_stream);
template void CalSlice<float>(const size_t input_size, const float* input, const std::vector<int> in_shape, template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const std::vector<int> begin, const std::vector<int> size, float* output, const int l1, const int l2, const int l3, const int l4,
cudaStream_t cuda_stream); const int d1, const int d2, const int d3, const int d4,
const float *input, float *output, cudaStream_t stream);
template void CalSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape, template void CalSliceGrad<float>(const size_t input_size, const float* dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, float* output, const std::vector<int> begin, const std::vector<int> size, float* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
...@@ -160,9 +162,10 @@ template void CalStridedSliceGrad<float>(const size_t input_size, const float* d ...@@ -160,9 +162,10 @@ template void CalStridedSliceGrad<float>(const size_t input_size, const float* d
const std::vector<int> begin, const std::vector<int> end, const std::vector<int> begin, const std::vector<int> end,
const std::vector<int> strides, float* dx, cudaStream_t cuda_stream); const std::vector<int> strides, float* dx, cudaStream_t cuda_stream);
template void FillDeviceArray<half>(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray<half>(const size_t input_size, half* addr, const float value, cudaStream_t cuda_stream);
template void CalSlice<half>(const size_t input_size, const half* input, const std::vector<int> in_shape, template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const std::vector<int> begin, const std::vector<int> size, half* output, const int l1, const int l2, const int l3, const int l4,
cudaStream_t cuda_stream); const int d1, const int d2, const int d3, const int d4,
const half *input, half *output, cudaStream_t stream);
template void CalSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape, template void CalSliceGrad<half>(const size_t input_size, const half* dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, half* output, const std::vector<int> begin, const std::vector<int> size, half* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
...@@ -173,9 +176,10 @@ template void CalStridedSliceGrad<half>(const size_t input_size, const half* dy, ...@@ -173,9 +176,10 @@ template void CalStridedSliceGrad<half>(const size_t input_size, const half* dy,
const std::vector<int> begin, const std::vector<int> end, const std::vector<int> begin, const std::vector<int> end,
const std::vector<int> strides, half* dx, cudaStream_t cuda_stream); const std::vector<int> strides, half* dx, cudaStream_t cuda_stream);
template void FillDeviceArray<int>(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream); template void FillDeviceArray<int>(const size_t input_size, int* addr, const float value, cudaStream_t cuda_stream);
template void CalSlice<int>(const size_t input_size, const int* input, const std::vector<int> in_shape, template void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const std::vector<int> begin, const std::vector<int> size, int* output, const int l1, const int l2, const int l3, const int l4,
cudaStream_t cuda_stream); const int d1, const int d2, const int d3, const int d4,
const int *input, int *output, cudaStream_t stream);
template void CalSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape, template void CalSliceGrad<int>(const size_t input_size, const int* dy, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, int* output, const std::vector<int> begin, const std::vector<int> size, int* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
......
...@@ -21,9 +21,12 @@ ...@@ -21,9 +21,12 @@
#include <vector> #include <vector>
#include "device/gpu/cuda_common.h" #include "device/gpu/cuda_common.h"
template <typename T> template <typename T>
void CalSlice(const size_t input_size, const T* input, const std::vector<int> in_shape, const std::vector<int> begin, void Slice4DKernel(const int s1, const int s2, const int s3, const int s4,
const std::vector<int> size, T* output, cudaStream_t cuda_stream); const int l1, const int l2, const int l3, const int l4,
const int d1, const int d2, const int d3, const int d4,
const T *input, T *output, cudaStream_t stream);
template <typename T> template <typename T>
void CalSliceGrad(const size_t input_size, const T* input, const std::vector<int> in_shape, void CalSliceGrad(const size_t input_size, const T* input, const std::vector<int> in_shape,
const std::vector<int> begin, const std::vector<int> size, T* output, cudaStream_t cuda_stream); const std::vector<int> begin, const std::vector<int> size, T* output, cudaStream_t cuda_stream);
......
...@@ -43,3 +43,22 @@ def test_slice(): ...@@ -43,3 +43,22 @@ def test_slice():
slice = Slice() slice = Slice()
output = slice(x) output = slice(x)
assert (output.asnumpy() == expect).all() assert (output.asnumpy() == expect).all()
class SliceNet(nn.Cell):
def __init__(self):
super(SliceNet, self).__init__()
self.slice = P.Slice()
def construct(self, x):
return self.slice(x, (0, 11, 0, 0), (32, 7, 224, 224))
def test_slice_4d():
x_np = np.random.randn(32, 24, 224, 224).astype(np.float32)
output_np = x_np[:, 11:18, :, :]
x_ms = Tensor(x_np)
net = SliceNet()
output_ms = net(x_ms)
assert (output_ms.asnumpy() == output_np).all()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册