diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu index 99f776a2c81275d2432be0bd1b89380b74720724..7b8f10be02bc7779b1cad0c2e758a43946b9c5f5 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cu @@ -18,6 +18,7 @@ #include #include "backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh" +// For internal OP use, not user facing template __global__ void Pad(const size_t size, const T* input, const int num, const int channels, const int old_height, const int old_width, const int padded_height, const int padded_width, const int pad_top, @@ -37,6 +38,7 @@ __global__ void Pad(const size_t size, const T* input, const int num, const int return; } +// For internal OP use, not user facing template __global__ void PadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, const int channels, const int padded_height, const int padded_width, const int pad_top, @@ -57,6 +59,37 @@ __global__ void PadNHWC(const size_t size, const T* input, const int num, const return; } +// Used by user facing 'Pad' API +template +__global__ void PadGeneral(const size_t size, const T *input, const int num, const int channels_orig, + const int pad_channel_before, const int pad_channel_after, const int old_height, + const int old_width, const int padded_height, const int padded_width, const int pad_top, + const int pad_left, float pad_value, T *output) { + T pad_value_template = static_cast(pad_value); + for (size_t pos = blockIdx.x * blockDim.x + threadIdx.x; pos < size; pos += blockDim.x * gridDim.x) { + int block_num = (pos / padded_width) / padded_height; // total blocks = (batch * channels) + const int padded_w = pos % padded_width; // x coordinate refered to by cur 'pos' + const int padded_h = (pos / padded_width) % padded_height; // y coordinate refered to by cur 'pos' + + int channels_new = channels_orig + pad_channel_after + pad_channel_before; // new number of channels from padding + int channel_num = block_num % channels_new; // current channel + int batch_item = block_num / channels_new; // current item in batch + int equiv_block_num = 0; // init variable to select equivalent block to copy data from from input + + if (padded_h - pad_top < 0 || padded_w - pad_left < 0 || padded_h - pad_top >= old_height || + padded_w - pad_left >= old_width || channel_num <= pad_channel_before - 1 || + channel_num > channels_orig + pad_channel_before - 1) { + output[pos] = pad_value_template; + } else { + // on a block/x,y positon that isn't padding, copy data from the correct block/x,y pos the input + // calculate from number of blocks of padding (due to channel padding) inserted prior + equiv_block_num = block_num - (batch_item * (pad_channel_before + pad_channel_after)) - pad_channel_before; + output[pos] = input[(equiv_block_num * old_height + padded_h - pad_top) * old_width + padded_w - pad_left]; + } + } + return; +} + template __global__ void PadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, const int channels, const int padded_height, const int padded_width, const int pad_top, @@ -102,6 +135,17 @@ void CalPadNHWC(const size_t size, const T* input, const int num, const int old_ return; } +template +void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, + const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, + const int padded_height, const int padded_width, const int pad_top, const int pad_left, + float pad_value, T *output, cudaStream_t cuda_stream) { + PadGeneral<<>>(size, input, num, channels_orig, pad_channel_before, + pad_channel_after, old_height, old_width, padded_height, + padded_width, pad_top, pad_left, pad_value, output); + return; +} + template void CalPadGradNHWC(const size_t size, const T* dy, const int num, const int old_height, const int old_width, const int channels, const int padded_height, const int padded_width, const int pad_top, @@ -152,3 +196,13 @@ template void CalPadGradNHWC(const size_t size, const half* dy, const int const int old_width, const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, half* dx, cudaStream_t cuda_stream); +template void CalPadGeneral(const size_t size, const float *input, const int num, const int channels_orig, + const int pad_channel_before, const int pad_channel_after, const int old_height, + const int old_width, const int padded_height, const int padded_width, + const int pad_top, const int pad_left, float pad_value, float *output, + cudaStream_t cuda_stream); +template void CalPadGeneral(const size_t size, const half *input, const int num, const int channels_orig, + const int pad_channel_before, const int pad_channel_after, const int old_height, + const int old_width, const int padded_height, const int padded_width, + const int pad_top, const int pad_left, float pad_value, half *output, + cudaStream_t cuda_stream); diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh index b1c94b8dab68dacad08b11b04137b79302e0ef29..3c12ab537611a31494652d07f663405db15d20af 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/cuda_impl/pad_impl.cuh @@ -31,9 +31,13 @@ template void CalPadNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, float pad_value, T* output, cudaStream_t cuda_stream); - template void CalPadGradNHWC(const size_t size, const T* input, const int num, const int old_height, const int old_width, const int channels, const int padded_height, const int padded_width, const int pad_top, const int pad_left, T* output, cudaStream_t cuda_stream); +template +void CalPadGeneral(const size_t size, const T *input, const int num, const int channels_orig, + const int pad_channel_before, const int pad_channel_after, const int old_height, const int old_width, + const int padded_height, const int padded_width, const int pad_top, const int pad_left, + float pad_value, T *output, cudaStream_t cuda_stream); #endif // MINDSPORE_CCSRC_KERNEL_GPU_CUDA_IMPL_PADIMPL_H_ diff --git a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h index 64595d38eab0208eed992be1e844033eab6dbb49..baf341feabf3d23ae336a76f15434f8a99fe7f6c 100644 --- a/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h +++ b/mindspore/ccsrc/backend/kernel_compiler/gpu/nn/pad_gpu_kernel.h @@ -42,9 +42,12 @@ class PadGpuFwdKernel : public GpuKernel { size_t size = output_size_ / sizeof(T); int pad_left = paddings[3][0]; int pad_top = paddings[2][0]; + int pad_channel_before = paddings[1][0]; + int pad_channel_after = paddings[1][1]; T pad_value = 0.0; - CalPad(size, input, input_shape_[0], input_shape_[1], input_shape_[2], input_shape_[3], output_shape_[2], - output_shape_[3], pad_top, pad_left, pad_value, output, reinterpret_cast(stream_ptr)); + CalPadGeneral(size, input, input_shape_[0], input_shape_[1], pad_channel_before, pad_channel_after, input_shape_[2], + input_shape_[3], output_shape_[2], output_shape_[3], pad_top, pad_left, pad_value, output, + reinterpret_cast(stream_ptr)); return true; } diff --git a/mindspore/nn/layer/basic.py b/mindspore/nn/layer/basic.py index 0309fc5552b337a040bef826bd30ff6179b67008..7be8b77eebeefd6ca1340923972538991da71fda 100644 --- a/mindspore/nn/layer/basic.py +++ b/mindspore/nn/layer/basic.py @@ -470,6 +470,8 @@ class Pad(Cell): for item in paddings: if len(item) != 2: raise ValueError('The shape of paddings must be (n, 2).') + if len(paddings) > 4: + raise ValueError('Only padding up to 4 dims is supported') if mode == "CONSTANT": self.pad = P.Pad(self.paddings) else: diff --git a/tests/st/ops/gpu/test_pad.py b/tests/st/ops/gpu/test_pad.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d9cf6d9c4777bcd7b8cadc3a569ec673324bba --- /dev/null +++ b/tests/st/ops/gpu/test_pad.py @@ -0,0 +1,204 @@ +# Copyright 2020 Huawei Technologies Co., Ltd +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================ + +import pytest +import numpy as np + +import mindspore +import mindspore.nn as nn +import mindspore.context as context + +from mindspore import Tensor +from mindspore.ops.composite import GradOperation + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pad_basic(): + # confirm array is being padded with 0's + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + test_arr = np.array([[1, 2], [3, 4]]).astype(np.float32) + test_arr_expected = np.array( + [[0, 0, 0, 0], [0, 1, 2, 0], [0, 3, 4, 0], [0, 0, 0, 0]]).astype(np.float32) + x_test = Tensor(test_arr, dtype=mindspore.float32) + + pad_op = nn.Pad(mode='CONSTANT', paddings=((1, 1), (1, 1))) + y_test = pad_op(x_test).asnumpy() + + np.testing.assert_array_equal(y_test, test_arr_expected) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pad_row(): + # Confirm correct row padding + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + test_arr_1 = np.random.rand(40, 40).astype(np.float32) + test_paddings_1 = ((2, 3), (0, 0)) + + test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) + test_paddings_2 = ((0, 0), (0, 0), (3, 0), (0, 0)) + + pad_op_row_1 = nn.Pad(mode='CONSTANT', paddings=test_paddings_1) + pad_op_row_2 = nn.Pad(mode='CONSTANT', paddings=test_paddings_2) + + x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) + x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) + + y_test_1 = pad_op_row_1(x_test_1).asnumpy() + y_test_2 = pad_op_row_2(x_test_2).asnumpy() + + # check size + assert y_test_1.shape == (45, 40) + assert y_test_2.shape == (3, 10, 33, 30) + + # check values - select correct sections + np.testing.assert_equal(y_test_1[2:-3, :], test_arr_1) + np.testing.assert_equal(y_test_2[:, :, 3:, :], test_arr_2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pad_column(): + # Confirm correct column padding + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + test_arr_1 = np.random.randn(40, 40).astype(np.float32) + test_paddings_1 = ((0, 0), (3, 3)) + + test_arr_2 = np.random.randn(3, 10, 30, 30).astype(np.float32) + test_paddings_2 = ((0, 0), (0, 0), (0, 0), (6, 1)) + + pad_op_col_1 = nn.Pad(mode='CONSTANT', paddings=test_paddings_1) + pad_op_col_2 = nn.Pad(mode='CONSTANT', paddings=test_paddings_2) + + x_test_1 = Tensor(np.array(test_arr_1), dtype=mindspore.float32) + x_test_2 = Tensor(np.array(test_arr_2), dtype=mindspore.float32) + + y_test_1 = pad_op_col_1(x_test_1).asnumpy() + y_test_2 = pad_op_col_2(x_test_2).asnumpy() + + # check size + assert y_test_1.shape == (40, 46) + assert y_test_2.shape == (3, 10, 30, 37) + + # check values - select correct sections - should match + np.testing.assert_equal(y_test_1[:, 3:-3], test_arr_1) + np.testing.assert_equal(y_test_2[:, :, :, 6:-1], test_arr_2) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pad_3d_pad(): + # Confirm correct 3d padding - row, column, channel + context.set_context(mode=context.PYNATIVE_MODE, device_target="GPU") + + test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) + test_paddings = ((0, 0), (2, 1), (0, 1), (0, 2)) # padding 3 dims now + + pad_op_3d = nn.Pad(mode='CONSTANT', paddings=test_paddings) + x_test = Tensor(np.array(test_arr), dtype=mindspore.float32) + + y_test = pad_op_3d(x_test).asnumpy() + assert y_test.shape == (5, 6, 31, 32) + np.testing.assert_equal(test_arr, y_test[:, 2:-1, :-1, :-2]) + + +# For testing backprop +class Grad(nn.Cell): + def __init__(self, network): + super(Grad, self).__init__() + self.grad = GradOperation(get_all=True, sens_param=True) + self.network = network + + def construct(self, input_, output_grad): + return self.grad(self.network)(input_, output_grad) + + +class Net(nn.Cell): + def __init__(self): + super(Net, self).__init__() + self.pad = nn.Pad(mode="CONSTANT", paddings=( + (0, 0), (4, 3), (1, 1), (0, 2))) + + def construct(self, x): + return self.pad(x) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pad_3d_backprop(): + # Confirm correct 3d padding backprop + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + test_arr = np.random.randn(5, 3, 30, 30).astype(np.float32) + x_test = Tensor(test_arr, dtype=mindspore.float32) + + padded_shape = (5, 10, 32, 32) + dy = np.random.randn(*padded_shape).astype(np.float32) + expected_dx = dy[:, 4:-3, 1:-1, :-2] + + net = Grad(Net()) + dx = net(x_test, Tensor(dy)) + dx = dx[0].asnumpy() + np.testing.assert_array_equal(dx, expected_dx) + + +@pytest.mark.level0 +@pytest.mark.platform_x86_gpu_training +@pytest.mark.env_onecard +def test_pad_error_cases(): + # Test against common errorneous inputs to catch correctly + context.set_context(mode=context.GRAPH_MODE, device_target="GPU") + + # TEST 1 - Neg padding values + test_op = nn.Pad(paddings=((0, 0), (-1, -1)), mode="CONSTANT") + test_arr = np.random.randn(3, 3) + test_arr_ms = Tensor(test_arr, dtype=mindspore.float32) + + with pytest.raises(ValueError): + test_op(test_arr_ms) + + # TEST 2 - Mismatched input size and paddings - 1D tensor + test_op = nn.Pad(paddings=((0, 0), (1, 0)), mode="CONSTANT") + test_arr = np.random.randn(3) # 1D Tensor + test_arr_ms = Tensor(test_arr, dtype=mindspore.float32) + + with pytest.raises(ValueError): + test_op(test_arr_ms) + + # TEST 3 - Mismatched input size and paddings - 2D tensor, 3D padding + test_op = nn.Pad(paddings=((0, 0), (1, 0)), mode="CONSTANT") # 2D Padding + test_arr = np.random.randn(1, 3, 3) # 3D Tensor + test_arr_ms = Tensor(test_arr, dtype=mindspore.float32) + + with pytest.raises(ValueError): + test_op(test_arr_ms) + + # TEST 4 - 1D Paddings should not work + with pytest.raises(TypeError): + test_op = nn.Pad(paddings=((0, 2)), mode="CONSTANT") + + # TEST 5 - Padding beyond 4d - (added check in nn file in PR) + with pytest.raises(ValueError): + _ = nn.Pad(paddings=((0, 0), (0, 0,), (0, 0), (0, 0), + (1, 0)), mode="CONSTANT") # 2D Padding