提交 f6029709 编写于 作者: M mindspore-ci-bot 提交者: Gitee

!323 Gpu Concat support 4 inputs

Merge pull request !323 from chenweifeng/concat
......@@ -19,15 +19,13 @@
namespace mindspore {
namespace kernel {
MS_REG_GPU_KERNEL_ONE(
Concat,
KernelAttr().AddInputAttr(kNumberTypeFloat32).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat32).AddOutputAttr(kNumberTypeFloat32),
ConcatV2GpuFwdKernel, float)
MS_REG_GPU_KERNEL_ONE(Concat,
KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ConcatV2GpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(
Concat, KernelAttr().AddInputAttr(kNumberTypeInt32).AddInputAttr(kNumberTypeInt32).AddOutputAttr(kNumberTypeInt32),
ConcatV2GpuFwdKernel, int)
MS_REG_GPU_KERNEL_ONE(
Concat,
KernelAttr().AddInputAttr(kNumberTypeFloat16).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
Concat, KernelAttr().AddAllSameAttr(true).AddInputAttr(kNumberTypeFloat16).AddOutputAttr(kNumberTypeFloat16),
ConcatV2GpuFwdKernel, half)
} // namespace kernel
} // namespace mindspore
......@@ -27,7 +27,7 @@ namespace kernel {
template <typename T>
class ConcatV2GpuFwdKernel : public GpuKernel {
public:
ConcatV2GpuFwdKernel() : axis_(0), input0_size_(0), input1_size_(0), output_size_(0), workspace_size_(0) {}
ConcatV2GpuFwdKernel() : axis_(0), output_size_(0) {}
~ConcatV2GpuFwdKernel() override = default;
const std::vector<size_t> &GetInputSizeList() const override { return input_size_list_; }
const std::vector<size_t> &GetOutputSizeList() const override { return output_size_list_; }
......@@ -35,12 +35,32 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
bool Launch(const std::vector<AddressPtr> &inputs, const std::vector<AddressPtr> &,
const std::vector<AddressPtr> &outputs, uintptr_t stream_ptr) override {
T *input_0 = GetDeviceAddress<T>(inputs, 0);
T *input_1 = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
if (inputs.size() == 2) {
T *input_0 = GetDeviceAddress<T>(inputs, 0);
T *input_1 = GetDeviceAddress<T>(inputs, 1);
T *output = GetDeviceAddress<T>(outputs, 0);
ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
if (inputs.size() == 3) {
T *input_0 = GetDeviceAddress<T>(inputs, 0);
T *input_1 = GetDeviceAddress<T>(inputs, 1);
T *input_2 = GetDeviceAddress<T>(inputs, 2);
T *output = GetDeviceAddress<T>(outputs, 0);
ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], input_0, input_1, input_2, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
CalConcatV2(output_size_ / sizeof(T), w_[0], w_[1], input_0, input_1, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
if (inputs.size() == 4) {
T *input_0 = GetDeviceAddress<T>(inputs, 0);
T *input_1 = GetDeviceAddress<T>(inputs, 1);
T *input_2 = GetDeviceAddress<T>(inputs, 2);
T *input_3 = GetDeviceAddress<T>(inputs, 3);
T *output = GetDeviceAddress<T>(outputs, 0);
ConcatKernel(output_size_ / sizeof(T), w_[0], w_[1], w_[2], w_[3], input_0, input_1, input_2, input_3, output,
reinterpret_cast<cudaStream_t>(stream_ptr));
}
return true;
}
bool Init(const CNodePtr &kernel_node) override {
......@@ -48,44 +68,44 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
return false;
}
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
input0_size_ = sizeof(T);
for (size_t i = 0; i < input_shape.size(); i++) {
input0_size_ *= input_shape[i];
}
auto input_shape1 = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 1);
input1_size_ = sizeof(T);
for (size_t i = 0; i < input_shape1.size(); i++) {
input1_size_ *= input_shape1[i];
}
output_size_ = input0_size_ + input1_size_;
axis_ = GetAttr<int>(kernel_node, "axis");
if (axis_ < 0) {
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, 0);
axis_ += SizeToInt(input_shape.size());
}
w_[0] = 1;
w_[1] = 1;
for (size_t i = IntToSize(axis_); i < input_shape.size(); i++) {
w_[0] *= SizeToInt(input_shape[i]);
w_[1] *= SizeToInt(input_shape1[i]);
auto input_num = AnfAlgo::GetInputTensorNum(kernel_node);
for (size_t i = 0; i < input_num; i++) {
auto input_size = sizeof(T);
auto input_shape = AnfAlgo::GetPrevNodeOutputInferShape(kernel_node, i);
for (size_t j = 0; j < input_shape.size(); j++) {
input_size *= SizeToInt(input_shape[j]);
if (j >= IntToSize(axis_)) {
w_[i] *= SizeToInt(input_shape[j]);
}
input_size_list_.push_back(input_size);
}
}
auto output_shape = AnfAlgo::GetOutputInferShape(kernel_node, 0);
output_size_ = sizeof(T);
for (size_t i = 0; i < output_shape.size(); i++) {
output_size_ *= output_shape[i];
}
output_size_list_.push_back(output_size_);
InitSizeLists();
return true;
}
protected:
void InitSizeLists() override {
input_size_list_.push_back(input0_size_);
input_size_list_.push_back(input1_size_);
output_size_list_.push_back(output_size_);
}
void InitSizeLists() override {}
private:
bool CheckParam(const CNodePtr &kernel_node) {
size_t input_num = AnfAlgo::GetInputTensorNum(kernel_node);
if (input_num != 2) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs 2 inputs.";
if (input_num < 2 || input_num > 4) {
MS_LOG(ERROR) << "Input number is " << input_num << ", but ConcatV2GpuFwdKernel needs inputs between 2 and 4.";
return false;
}
size_t output_num = AnfAlgo::GetOutputTensorNum(kernel_node);
......@@ -95,16 +115,12 @@ class ConcatV2GpuFwdKernel : public GpuKernel {
}
return true;
}
int w_[2] = {1};
int w_[4] = {1, 1, 1, 1};
int axis_;
size_t output_size_;
std::vector<size_t> input_size_list_;
std::vector<size_t> output_size_list_;
std::vector<size_t> workspace_size_list_;
size_t input0_size_;
size_t input1_size_;
size_t output_size_;
size_t workspace_size_;
};
} // namespace kernel
} // namespace mindspore
......
......@@ -19,7 +19,7 @@
#include <cuda_runtime.h>
#include "kernel/gpu/cuda_impl/concatv2_impl.cuh"
template <typename T>
__global__ void ConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) {
__global__ void Concat(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int n = pos / (w1 + w2);
int m = pos % (w1 + w2);
......@@ -29,16 +29,80 @@ __global__ void ConcatV2(const size_t size, const int w1, const int w2, const T*
}
template <typename T>
void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
__global__ void Concat(const size_t size, const int w1, const int w2, const int w3,
const T* input_1, const T* input_2, const T* input_3, T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int n = pos / (w1 + w2 + w3);
int m = pos % (w1 + w2 + w3);
output[pos] = m < w1 ? input_1[n * w1 + m] :
m < w1 + w2 ? input_2[n * w2 + m - w1] :
input_3[n * w3 + m - w1 - w2];
}
return;
}
template <typename T>
__global__ void Concat(const size_t size, const int w1, const int w2, const int w3, const int w4,
const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output) {
for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < (size); pos += blockDim.x * gridDim.x) {
int n = pos / (w1 + w2 + w3 + w4);
int m = pos % (w1 + w2 + w3 + w4);
output[pos] = m < w1 ? input_1[n * w1 + m] :
m < w1 + w2 ? input_2[n * w2 + m - w1]:
m < w1 + w2 + w3 ? input_3[n * w3 + m - w1 - w2]:
input_4[n * w4 + m - w1 - w2 - w3];
}
return;
}
template <typename T>
void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
cudaStream_t cuda_stream) {
ConcatV2<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output);
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, input_1, input_2, output);
return;
}
template <typename T>
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
const T* input_1, const T* input_2, const T* input_3, T* output,
cudaStream_t cuda_stream) {
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, input_1, input_2, input_3, output);
return;
}
template void CalConcatV2(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2,
float* output, cudaStream_t cuda_stream);
template void CalConcatV2(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2,
int* output, cudaStream_t cuda_stream);
template void CalConcatV2(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2,
half* output, cudaStream_t cuda_stream);
template <typename T>
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output,
cudaStream_t cuda_stream) {
Concat<<<GET_BLOCKS(size), GET_THREADS, 0, cuda_stream>>>(size, w1, w2, w3, w4, input_1,
input_2, input_3, input_4, output);
return;
}
template void ConcatKernel(const size_t size, const int w1, const int w2, const float* input_1, const float* input_2,
float* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int* input_1, const int* input_2,
int* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const half* input_1, const half* input_2,
half* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
const float* input_1, const float* input_2, const float* input_3,
float* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
const int* input_1, const int* input_2, const int* input_3,
int* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
const half* input_1, const half* input_2, const half* input_3,
half* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
const float* input_1, const float* input_2, const float* input_3, const float* input_4,
float* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
const int* input_1, const int* input_2, const int* input_3, const int* input_4,
int* output, cudaStream_t cuda_stream);
template void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
const half* input_1, const half* input_2, const half* input_3, const half* input_4,
half* output, cudaStream_t cuda_stream);
......@@ -19,7 +19,13 @@
#include "device/gpu/cuda_common.h"
template <typename T>
void CalConcatV2(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
cudaStream_t cuda_stream);
void ConcatKernel(const size_t size, const int w1, const int w2, const T* input_1, const T* input_2, T* output,
cudaStream_t cuda_stream);
template <typename T>
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3,
const T* input_1, const T* input_2, const T* input_3, T* output, cudaStream_t cuda_stream);
template <typename T>
void ConcatKernel(const size_t size, const int w1, const int w2, const int w3, const int w4,
const T* input_1, const T* input_2, const T* input_3, const T* input_4, T* output,
cudaStream_t cuda_stream);
#endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_CONCATV2IMPL_H_
......@@ -113,3 +113,62 @@ def test_axis21():
[2., 3., 3., 4., 5.]]
assert (output.asnumpy() == expect).all()
print(output)
class Concat3INet(nn.Cell):
def __init__(self):
super(Concat3INet, self).__init__()
self.cat = P.Concat(axis=1)
def construct(self, x1, x2, x3):
return self.cat((x1, x2, x3))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_3i():
cat = Concat3INet()
x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32)
x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32)
x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32)
output_np = np.concatenate((x1_np, x2_np, x3_np), axis=1)
x1_ms = Tensor(x1_np)
x2_ms = Tensor(x2_np)
x3_ms = Tensor(x3_np)
output_ms = cat(x1_ms, x2_ms, x3_ms)
error = np.ones(shape=output_np.shape) * 10e-6
diff = output_ms.asnumpy() - output_np
assert np.all(diff < error)
class Concat4INet(nn.Cell):
def __init__(self):
super(Concat4INet, self).__init__()
self.cat = P.Concat(axis=1)
def construct(self, x1, x2, x3, x4):
return self.cat((x1, x2, x3, x4))
@pytest.mark.level0
@pytest.mark.platform_x86_gpu_training
@pytest.mark.env_onecard
def test_concat_4i():
cat = Concat4INet()
x1_np = np.random.randn(32, 4, 224, 224).astype(np.float32)
x2_np = np.random.randn(32, 8, 224, 224).astype(np.float32)
x3_np = np.random.randn(32, 10, 224, 224).astype(np.float32)
x4_np = np.random.randn(32, 5, 224, 224).astype(np.float32)
output_np = np.concatenate((x1_np, x2_np, x3_np, x4_np), axis=1)
x1_ms = Tensor(x1_np)
x2_ms = Tensor(x2_np)
x3_ms = Tensor(x3_np)
x4_ms = Tensor(x4_np)
output_ms = cat(x1_ms, x2_ms, x3_ms, x4_ms)
error = np.ones(shape=output_np.shape) * 10e-6
diff = output_ms.asnumpy() - output_np
assert np.all(diff < error)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册