提交 b750e3e1 编写于 作者: Z zhaoting

fix gpu Split and Concat memory allocation bug

上级 5c0962ac
...@@ -74,12 +74,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel { ...@@ -74,12 +74,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
inputs_host_ = std::make_unique<T *[]>(input_num_); inputs_host_ = std::make_unique<T *[]>(input_num_);
len_axis_ = std::make_unique<int[]>(input_num_); len_axis_ = std::make_unique<int[]>(input_num_);
for (int i = 0; i < input_num_; i++) { for (int i = 0; i < input_num_; i++) {
int input_size = 1; size_t input_size = 1;
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i); auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) { for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= SizeToInt(input_shape[j]); input_size *= input_shape[j];
} }
input_size_list_.push_back(IntToSize(input_size * sizeof(T))); input_size_list_.push_back(input_size * sizeof(T));
len_axis_[i] = SizeToInt(input_shape[axis_]); len_axis_[i] = SizeToInt(input_shape[axis_]);
} }
workspace_size_list_.push_back(sizeof(T *) * input_num_); workspace_size_list_.push_back(sizeof(T *) * input_num_);
...@@ -97,7 +97,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel { ...@@ -97,7 +97,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
all_size_before_axis_ *= output_shape[i]; all_size_before_axis_ *= output_shape[i];
} }
} }
output_size_list_.push_back(IntToSize(output_size_ * sizeof(T))); output_size_list_.push_back(output_size_ * sizeof(T));
InitSizeLists(); InitSizeLists();
return true; return true;
...@@ -117,7 +117,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel { ...@@ -117,7 +117,7 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
} }
int axis_; int axis_;
int input_num_; int input_num_;
int output_size_; size_t output_size_;
int all_size_before_axis_; int all_size_before_axis_;
int all_size_axis_; int all_size_axis_;
std::unique_ptr<T *[]> inputs_host_; std::unique_ptr<T *[]> inputs_host_;
......
...@@ -83,7 +83,7 @@ class SplitGpuFwdKernel : public GpuKernel { ...@@ -83,7 +83,7 @@ class SplitGpuFwdKernel : public GpuKernel {
all_size_before_axis_ *= input_shape[i]; all_size_before_axis_ *= input_shape[i];
} }
} }
input_size_list_.push_back(IntToSize(input_size_ * sizeof(T))); input_size_list_.push_back(input_size_ * sizeof(T));
axis_step_ = input_shape[axis_] / output_num_; axis_step_ = input_shape[axis_] / output_num_;
for (int i = 0; i < output_num_; i++) { for (int i = 0; i < output_num_; i++) {
...@@ -138,7 +138,7 @@ class SplitGpuFwdKernel : public GpuKernel { ...@@ -138,7 +138,7 @@ class SplitGpuFwdKernel : public GpuKernel {
} }
int axis_; int axis_;
int output_num_; int output_num_;
int input_size_; size_t input_size_;
int axis_step_; int axis_step_;
int all_size_before_axis_; int all_size_before_axis_;
int all_size_axis_; int all_size_axis_;
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/concatv2_impl.cuh"
template <typename T> template <typename T>
__global__ void Concat(const int size, const int input_num, __global__ void Concat(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis, const int all_size_before_axis, const int all_size_axis,
int* len_axis, T** inputs, T* output) { int* len_axis, T** inputs, T* output) {
for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) { for (int pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
...@@ -45,7 +45,7 @@ __global__ void Concat(const int size, const int input_num, ...@@ -45,7 +45,7 @@ __global__ void Concat(const int size, const int input_num,
} }
template <typename T> template <typename T>
void ConcatKernel(const int size, const int input_num, void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis, const int all_size_before_axis, const int all_size_axis,
int* len_axis, T** inputs, T* output, int* len_axis, T** inputs, T* output,
cudaStream_t cuda_stream) { cudaStream_t cuda_stream) {
...@@ -55,15 +55,15 @@ void ConcatKernel(const int size, const int input_num, ...@@ -55,15 +55,15 @@ void ConcatKernel(const int size, const int input_num,
return; return;
} }
template void ConcatKernel(const int size, const int input_num, template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis, const int all_size_before_axis, const int all_size_axis,
int* len_axis, float** inputs, float* output, int* len_axis, float** inputs, float* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const int size, const int input_num, template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis, const int all_size_before_axis, const int all_size_axis,
int* len_axis, int** inputs, int* output, int* len_axis, int** inputs, int* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void ConcatKernel(const int size, const int input_num, template void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis, const int all_size_before_axis, const int all_size_axis,
int* len_axis, half** inputs, half* output, int* len_axis, half** inputs, half* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void ConcatKernel(const int size, const int input_num, void ConcatKernel(const size_t size, const int input_num,
const int all_size_before_axis, const int all_size_axis, const int all_size_before_axis, const int all_size_axis,
int* len_axis, T** inputs, T* output, int* len_axis, T** inputs, T* output,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
......
...@@ -19,7 +19,7 @@ ...@@ -19,7 +19,7 @@
#include <cuda_runtime.h> #include <cuda_runtime.h>
#include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh" #include "backend/kernel_compiler/gpu/cuda_impl/split_impl.cuh"
template <typename T> template <typename T>
__global__ void Split(const int size, const int axis_step, const int all_size_before_axis, __global__ void Split(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const T* input, T** outputs) { const int all_size_axis, const T* input, T** outputs) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) {
int num = pos % all_size_before_axis / all_size_axis; int num = pos % all_size_before_axis / all_size_axis;
...@@ -32,19 +32,19 @@ __global__ void Split(const int size, const int axis_step, const int all_size_be ...@@ -32,19 +32,19 @@ __global__ void Split(const int size, const int axis_step, const int all_size_be
} }
template <typename T> template <typename T>
void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) { const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream) {
Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis, Split<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, axis_step, all_size_before_axis,
all_size_axis, input, outputs); all_size_axis, input, outputs);
return; return;
} }
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const float* input, float** outputs, const int all_size_axis, const float* input, float** outputs,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const int* input, int** outputs, const int all_size_axis, const int* input, int** outputs,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
template void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, template void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const half* input, half** outputs, const int all_size_axis, const half* input, half** outputs,
cudaStream_t cuda_stream); cudaStream_t cuda_stream);
...@@ -19,6 +19,6 @@ ...@@ -19,6 +19,6 @@
#include "runtime/device/gpu/cuda_common.h" #include "runtime/device/gpu/cuda_common.h"
template <typename T> template <typename T>
void SplitKernel(const int size, const int axis_step, const int all_size_before_axis, void SplitKernel(const size_t size, const int axis_step, const int all_size_before_axis,
const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream); const int all_size_axis, const T* input, T** outputs, cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_ #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_SPLIT_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册