未验证 提交 e10aa80f 编写于 作者: W whs 提交者: GitHub

Add pad2d op. (#12950)

* Add pad2d op.

* Add unitest and python api.

* Fix cuda op kernel.

* Fix python api.

* Fix python api.

* Update API.spec.

* Fix python api
上级 c709a04a
......@@ -170,6 +170,7 @@ paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], vara
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'maxlen', 'dtype', 'name'], varargs=None, keywords=None, defaults=(None, 'int64', None))
paddle.fluid.layers.stack ArgSpec(args=['x', 'axis'], varargs=None, keywords=None, defaults=(0,))
paddle.fluid.layers.pad2d ArgSpec(args=['input', 'paddings', 'mode', 'pad_value', 'data_format', 'name'], varargs=None, keywords=None, defaults=([0, 0, 0, 0], 'constant', 0.0, 'NCHW', None))
paddle.fluid.layers.unstack ArgSpec(args=['x', 'axis', 'num'], varargs=None, keywords=None, defaults=(0, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
......
此差异已折叠。
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/cuda_primitives.h"
#include "paddle/fluid/platform/gpu_info.h"
namespace paddle {
namespace operators {
using platform::PADDLE_CUDA_NUM_THREADS;
#define CUDA_1D_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
using framework::Tensor;
template <typename T>
__global__ void Pad2DConstNCHW(const int nthreads, const T* in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left, T value,
T* out_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int nc = index / out_width;
const int out_w = index % out_width;
const int out_h = nc % out_height;
nc /= out_height;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
out_data[index] =
(in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width)
? value
: in_data[(nc * in_height + in_h) * in_width + in_w];
}
}
template <typename T>
__global__ void Pad2DConstNHWC(const int nthreads, const T* in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left, T value,
T* out_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / channels;
const int c = index % channels;
const int out_w = n % out_width;
n /= out_width;
const int out_h = n % out_height;
n /= out_height;
const int in_h = out_h - pad_top;
const int in_w = out_w - pad_left;
out_data[index] =
(in_h < 0 || in_w < 0 || in_h >= in_height || in_w >= in_width)
? value
: in_data[((n * in_height + in_h) * in_width + in_w) * channels +
c];
}
}
template <typename T>
__global__ void Pad2DReflectNCHW(const int nthreads, const T* in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
T* out_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int nc = index / out_width;
const int out_w = index % out_width;
const int out_h = nc % out_height;
nc /= out_height;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
in_h = max(in_h, -in_h); // reflect by 0
in_h = min(in_h, 2 * in_height - in_h - 2); // reflect by in_height
in_w = max(in_w, -in_w); // reflect by 0
in_w = min(in_w, 2 * in_width - in_w - 2); // reflect by in_width
out_data[index] = in_data[(nc * in_height + in_h) * in_width + in_w];
}
}
template <typename T>
__global__ void Pad2DReflectNHWC(const int nthreads, const T* in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
T* out_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / channels;
const int c = index % channels;
const int out_w = n % out_width;
n /= out_width;
const int out_h = n % out_height;
n /= out_height;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
in_h = max(in_h, -in_h);
in_h = min(in_h, 2 * in_height - in_h - 2);
in_w = max(in_w, -in_w);
in_w = min(in_w, 2 * in_width - in_w - 2);
out_data[index] =
in_data[((n * in_height + in_h) * in_width + in_w) * channels + c];
}
}
template <typename T>
__global__ void Pad2DEdgeNCHW(const int nthreads, const T* in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
T* out_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int nc = index / out_width;
const int out_w = index % out_width;
const int out_h = nc % out_height;
nc /= out_height;
int in_h = min(in_height - 1, max(out_h - pad_top, 0));
int in_w = min(in_width - 1, max(out_w - pad_left, 0));
out_data[index] = in_data[(nc * in_height + in_h) * in_width + in_w];
}
}
template <typename T>
__global__ void Pad2DEdgeNHWC(const int nthreads, const T* in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
T* out_data) {
CUDA_1D_KERNEL_LOOP(index, nthreads) {
int n = index / channels;
const int c = index % channels;
const int out_w = n % out_width;
n /= out_width;
const int out_h = n % out_height;
n /= out_height;
int in_h = min(in_height - 1, max(out_h - pad_top, 0));
int in_w = min(in_width - 1, max(out_w - pad_left, 0));
out_data[index] =
in_data[((n * in_height + in_h) * in_width + in_w) * channels + c];
}
}
template <typename T>
__global__ void Pad2DGradConstNCHW(const int in_size, T* d_in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
const T* d_out_data) {
CUDA_1D_KERNEL_LOOP(in_index, in_size) {
int nc = in_index / in_width;
const int out_w = in_index % in_width + pad_left;
const int out_h = nc % in_height + pad_top;
nc /= in_height;
d_in_data[in_index] =
d_out_data[(nc * out_height + out_h) * out_width + out_w];
}
}
template <typename T>
__global__ void Pad2DGradConstNHWC(const int in_size, T* d_in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
const T* d_out_data) {
CUDA_1D_KERNEL_LOOP(in_index, in_size) {
int n = in_index / channels;
const int c = in_index % channels;
const int out_w = n % in_width + pad_left;
n /= in_width;
const int out_h = n % in_height + pad_top;
n /= in_height;
d_in_data[in_index] =
d_out_data[((n * out_height + out_h) * out_width + out_w) * channels +
c];
}
}
template <typename T>
__global__ void Pad2DGradReflectNCHW(const int out_size, T* d_in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
const T* d_out_data) {
CUDA_1D_KERNEL_LOOP(out_index, out_size) {
int nc = out_index / out_width;
const int out_w = out_index % out_width;
const int out_h = nc % out_height;
nc /= out_height;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
in_h = max(in_h, -in_h);
in_w = max(in_w, -in_w);
in_h = min(in_h, 2 * in_height - in_h - 2);
in_w = min(in_w, 2 * in_width - in_w - 2);
atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w],
d_out_data[out_index]);
}
}
template <typename T>
__global__ void Pad2DGradReflectNHWC(const int out_size, T* d_in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
const T* d_out_data) {
CUDA_1D_KERNEL_LOOP(out_index, out_size) {
const int c = out_index % channels;
int n = out_index / channels;
const int out_w = n % out_width;
n /= out_width;
const int out_h = n % out_height;
n /= out_height;
int in_h = out_h - pad_top;
int in_w = out_w - pad_left;
in_h = max(in_h, -in_h);
in_w = max(in_w, -in_w);
in_h = min(in_h, in_height * 2 - in_h - 2);
in_w = min(in_w, in_width * 2 - in_w - 2);
atomicAdd(
&d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c],
d_out_data[out_index]);
}
}
template <typename T>
__global__ void Pad2DGradEdgeNCHW(const int out_size, T* d_in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
const T* d_out_data) {
CUDA_1D_KERNEL_LOOP(out_index, out_size) {
int nc = out_index / out_width;
const int out_w = out_index % out_width;
const int out_h = nc % out_height;
nc /= out_height;
const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
const int in_w = min(in_width - 1, max(out_w - pad_left, 0));
atomicAdd(&d_in_data[(nc * in_height + in_h) * in_width + in_w],
d_out_data[out_index]);
}
}
template <typename T>
__global__ void Pad2DGradEdgeNHWC(const int out_size, T* d_in_data,
const int num, const int channels,
const int in_height, const int in_width,
const int out_height, const int out_width,
const int pad_top, const int pad_left,
const T* d_out_data) {
CUDA_1D_KERNEL_LOOP(out_index, out_size) {
const int c = out_index % channels;
int n = out_index / channels;
const int out_w = n % out_width;
n /= out_width;
const int out_h = n % out_height;
n /= out_height;
const int in_h = min(in_height - 1, max(out_h - pad_top, 0));
const int in_w = min(in_width - 1, max(out_w - pad_left, 0));
atomicAdd(
&d_in_data[((n * in_height + in_h) * in_width + in_w) * channels + c],
d_out_data[out_index]);
}
}
template <typename T>
class Pad2dCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
auto mode = context.Attr<std::string>("mode");
auto data_format = context.Attr<std::string>("data_format");
T value = context.Attr<T>("pad_value");
auto* x = context.Input<Tensor>("X");
auto* out = context.Output<Tensor>("Out");
auto in_dims = x->dims();
auto out_dims = out->dims();
const T* in_data = x->data<T>();
T* out_data = out->mutable_data<T>(context.GetPlace());
const int pad_top = pads[0];
const int pad_left = pads[2];
const int num = in_dims[0];
auto stream = context.cuda_device_context().stream();
int block = PADDLE_CUDA_NUM_THREADS;
const int out_size = out->numel();
int grid = (out_size + block - 1) / block;
if (data_format == "NCHW") {
const int channels = in_dims[1];
const int in_height = in_dims[2];
const int in_width = in_dims[3];
const int out_height = out_dims[2];
const int out_width = out_dims[3];
if (mode == "reflect") {
Pad2DReflectNCHW<T><<<grid, block, 0, stream>>>(
out_size, in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, out_data);
} else if (mode == "edge") {
Pad2DEdgeNCHW<T><<<grid, block, 0, stream>>>(
out_size, in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, out_data);
} else {
Pad2DConstNCHW<T><<<grid, block, 0, stream>>>(
out_size, in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, value, out_data);
}
} else {
const int channels = in_dims[3];
const int in_height = in_dims[1];
const int in_width = in_dims[2];
const int out_height = out_dims[1];
const int out_width = out_dims[2];
if (mode == "reflect") {
Pad2DReflectNHWC<T><<<grid, block, 0, stream>>>(
out_size, in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, out_data);
} else if (mode == "edge") {
Pad2DEdgeNHWC<T><<<grid, block, 0, stream>>>(
out_size, in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, out_data);
} else {
Pad2DConstNHWC<T><<<grid, block, 0, stream>>>(
out_size, in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, value, out_data);
}
}
}
};
template <typename T>
class Pad2dGradCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto pads = context.Attr<std::vector<int>>("paddings");
auto mode = context.Attr<std::string>("mode");
auto data_format = context.Attr<std::string>("data_format");
auto* d_out = context.Input<Tensor>(framework::GradVarName("Out"));
auto* d_in = context.Output<Tensor>(framework::GradVarName("X"));
auto d_in_dims = d_in->dims();
auto d_out_dims = d_out->dims();
const T* d_out_data = d_out->data<T>();
T* d_in_data = d_in->mutable_data<T>(context.GetPlace());
math::SetConstant<platform::CUDADeviceContext, T> set_zero;
set_zero(context.template device_context<platform::CUDADeviceContext>(),
d_in, static_cast<T>(0));
const int pad_top = pads[0];
const int pad_left = pads[2];
const int num = d_in_dims[0];
auto stream = context.cuda_device_context().stream();
int block = PADDLE_CUDA_NUM_THREADS;
const int out_size = d_out->numel();
const int in_size = d_in->numel();
int grid = (out_size + block - 1) / block;
if (data_format == "NCHW") {
const int channels = d_in_dims[1];
const int in_height = d_in_dims[2];
const int in_width = d_in_dims[3];
const int out_height = d_out_dims[2];
const int out_width = d_out_dims[3];
if (mode == "reflect") {
Pad2DGradReflectNCHW<T><<<grid, block, 0, stream>>>(
out_size, d_in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, d_out_data);
} else if (mode == "edge") {
Pad2DGradEdgeNCHW<T><<<grid, block, 0, stream>>>(
out_size, d_in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, d_out_data);
} else {
grid = (in_size + block - 1) / block;
Pad2DGradConstNCHW<T><<<grid, block, 0, stream>>>(
in_size, d_in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, d_out_data);
}
} else {
const int channels = d_in_dims[3];
const int in_height = d_in_dims[1];
const int in_width = d_in_dims[2];
const int out_height = d_out_dims[1];
const int out_width = d_out_dims[2];
if (mode == "reflect") {
Pad2DGradReflectNHWC<T><<<grid, block, 0, stream>>>(
out_size, d_in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, d_out_data);
} else if (mode == "edge") {
Pad2DGradEdgeNHWC<T><<<grid, block, 0, stream>>>(
out_size, d_in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, d_out_data);
} else {
grid = (in_size + block - 1) / block;
Pad2DGradConstNHWC<T><<<grid, block, 0, stream>>>(
in_size, d_in_data, num, channels, in_height, in_width, out_height,
out_width, pad_top, pad_left, d_out_data);
}
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(pad2d, ops::Pad2dCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(pad2d_grad, ops::Pad2dGradCUDAKernel<float>);
......@@ -109,6 +109,7 @@ __all__ = [
'flatten',
'sequence_mask',
'stack',
'pad2d',
'unstack',
]
......@@ -5614,6 +5615,94 @@ def rank_loss(label, left, right, name=None):
return out
def pad2d(input,
paddings=[0, 0, 0, 0],
mode='constant',
pad_value=0.0,
data_format="NCHW",
name=None):
"""
Pad 2-d images accordding to 'paddings' and 'mode'.
If mode is 'reflect', paddings[0] and paddings[1] must be no greater
than height-1. And the width dimension has the same condition.
Example:
Given that X is a channel of image from input:
X = [[1, 2, 3],
[4, 5, 6]]
Case 0:
paddings = [0, 1, 2, 3],
mode = 'constant'
pad_value = 0
Out = [[0, 0, 1, 2, 3, 0, 0, 0]
[0, 0, 4, 5, 6, 0, 0, 0]
[0, 0, 0, 0, 0, 0, 0, 0]]
Case 1:
paddings = [0, 1, 2, 1],
mode = 'reflect'
Out = [[3, 2, 1, 2, 3, 2]
[6, 5, 4, 5, 6, 5]
[3, 2, 1, 2, 3, 2]]
Case 2:
paddings = [0, 1, 2, 1],
mode = 'edge'
Out = [[1, 1, 1, 2, 3, 3]
[4, 4, 4, 5, 6, 6]
[4, 4, 4, 5, 6, 6]]
Args:
input (Variable): The input image with [N, C, H, W] format or [N, H, W, C] format.
paddings (tuple|list): The padding size. If padding is a tuple, it must
contain four integers, (padding_top, padding_bottom, padding_left, padding_right).
Default: padding = [0, 0, 0, 0].
mode (str): Three modes: constant(default), reflect, edge. Default: constant
pad_value (float32): The value to fill the padded areas in constant mode. Default: 0
data_format (str): An optional string from: "NHWC", "NCHW". Specify the data format of
the input data.
Default: "NCHW"
name (str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The tensor variable padded accordding to paddings and mode.
Examples:
.. code-block:: python
data = fluid.layers.data(name='data', shape=[3, 32, 32], dtype='float32')
result = fluid.layers.pad2d(input=data, padding=[1,2,3,4], mode='reflect')
"""
helper = LayerHelper('pad2d', **locals())
dtype = helper.input_dtype(input_param_name='input')
out = helper.create_tmp_variable(dtype)
helper.append_op(
type='pad2d',
inputs={'X': input},
outputs={"Out": out},
attrs={
'paddings': paddings,
'mode': mode,
'pad_value': pad_value,
'data_frmat': data_format
})
return out
def prelu(x, mode, param_attr=None, name=None):
"""
Equation:
......@@ -5628,8 +5717,8 @@ def prelu(x, mode, param_attr=None, name=None):
all: all elements share same weight
channel:elements in a channel share same weight
element:each element has a weight
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: The output tensor with the same shape as input.
......
......@@ -521,6 +521,20 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out)
print(str(program))
def test_pad2d(self):
program = Program()
with program_guard(program):
input = layers.data(
name="input", shape=[3, 100, 100], dtype="float32")
out = layers.pad2d(
input,
paddings=[1, 2, 3, 4],
mode='reflect',
data_format='NCHW',
name="shape")
self.assertIsNotNone(out)
print(str(program))
def test_prelu(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from op_test import OpTest
class TestPad2dOp(OpTest):
def setUp(self):
self.pad_value = 0.0
self.initTestCase()
self.op_type = "pad2d"
self.inputs = {'X': np.random.random(self.shape).astype("float32"), }
self.attrs = {}
self.attrs['paddings'] = np.array(self.paddings).flatten()
self.attrs['pad_value'] = self.pad_value
self.attrs['mode'] = self.mode
self.attrs['data_format'] = self.data_format
if self.data_format == "NCHW":
paddings = [(0, 0), (0, 0), (self.paddings[0], self.paddings[1]),
(self.paddings[2], self.paddings[3])]
else:
paddings = [(0, 0), (self.paddings[0], self.paddings[1]),
(self.paddings[2], self.paddings[3]), (0, 0)]
if self.mode == "constant":
out = np.pad(self.inputs['X'],
paddings,
mode=self.mode,
constant_values=self.pad_value)
else:
out = np.pad(self.inputs['X'], paddings, mode=self.mode)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out', max_relative_error=0.006)
def initTestCase(self):
self.shape = (2, 3, 4, 4)
self.paddings = [0, 1, 2, 3]
self.mode = "constant"
self.data_format = "NCHW"
self.pad_value = 0.0
class TestCase1(TestPad2dOp):
def initTestCase(self):
self.shape = (2, 3, 4, 4)
self.paddings = [0, 1, 2, 3]
self.mode = "reflect"
self.data_format = "NCHW"
class TestCase2(TestPad2dOp):
def initTestCase(self):
self.shape = (2, 3, 4, 4)
self.paddings = [0, 1, 2, 3]
self.mode = "edge"
self.data_format = "NCHW"
class TestCase3(TestPad2dOp):
def initTestCase(self):
self.shape = (2, 4, 4, 2)
self.paddings = [0, 1, 2, 3]
self.mode = "reflect"
self.data_format = "NHWC"
class TestCase4(TestPad2dOp):
def initTestCase(self):
self.shape = (2, 4, 4, 2)
self.paddings = [0, 1, 2, 3]
self.mode = "edge"
self.data_format = "NHWC"
class TestCase5(TestPad2dOp):
def initTestCase(self):
self.shape = (2, 4, 4, 2)
self.paddings = [0, 1, 2, 3]
self.mode = "constant"
self.pad_value = 1.2
self.data_format = "NHWC"
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册