提交 57e68e57 编写于 作者: S sweetsky0901

modify for code review by qingqing 2nd

上级 f9c2a5c3
......@@ -29,19 +29,19 @@ __global__ void KernelUnpool2dMax(const int nthreads,
T* output_data,
const int output_height,
const int output_width) {
int bsize = input_height * input_width * channels;
int csize = input_height * input_width;
int out_bsize = output_height * output_width * channels;
int out_csize = output_height * output_width;
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / bsize;
int boffset = i % bsize;
int cidx = boffset / csize;
int out_offset = bidx * out_bsize + cidx * out_csize;
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ASSERT(out_index < (output_height * output_width));
PADDLE_ASSERT(out_index < out_c_stride);
output_data[out_offset + out_index] = input_data[i];
}
}
......@@ -57,19 +57,19 @@ __global__ void KernelUnpool2dMaxGrad(const int nthreads,
const int output_height,
const int output_width,
T* input_grad) {
int bsize = input_height * input_width * channels;
int csize = input_height * input_width;
int out_bsize = output_height * output_width * channels;
int out_csize = output_height * output_width;
int in_n_stride = input_height * input_width * channels;
int in_c_stride = input_height * input_width;
int out_n_stride = output_height * output_width * channels;
int out_c_stride = output_height * output_width;
int index = blockIdx.x * blockDim.x + threadIdx.x;
int offset = blockDim.x * gridDim.x;
for (int i = index; i < nthreads; i += offset) {
int bidx = i / bsize;
int boffset = i % bsize;
int cidx = boffset / csize;
int out_offset = bidx * out_bsize + cidx * out_csize;
int bidx = i / in_n_stride;
int boffset = i % in_n_stride;
int cidx = boffset / in_c_stride;
int out_offset = bidx * out_n_stride + cidx * out_c_stride;
int out_index = indices_data[i];
PADDLE_ASSERT(out_index < (output_height * output_width));
PADDLE_ASSERT(out_index < out_c_stride);
input_grad[i] = output_grad[out_offset + out_index];
}
}
......@@ -93,10 +93,8 @@ class Unpool2dMaxFunctor<platform::GPUPlace, T, T2> {
const T2 * indices_data = indices.data<T2>();
T* output_data = output->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
int threads = 1024;
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMax<
T, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
......@@ -129,10 +127,8 @@ class Unpool2dMaxGradFunctor<platform::GPUPlace, T, T2> {
const T* output_grad_data = output_grad.data<T>();
T* input_grad_data = input_grad->mutable_data<T>(context.GetPlace());
int nthreads = batch_size * output_channels * input_height * input_width;
int blocks = (nthreads + 1024 - 1) / 1024;
dim3 threads(1024, 1);
dim3 grid(blocks, 1);
int threads = 1024;
int grid = (input.numel() + threads - 1) / threads;
KernelUnpool2dMaxGrad<
T, T2><<<grid, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/unpool_op.h"
namespace paddle {
......@@ -25,7 +25,7 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor) The input tensor of unpool operator. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
AddInput("Y",
AddInput("Indices",
"(Tensor) The input tensor of the indices given out by MaxPool2d. "
"The format of input tensor is NCHW. Where N is batch size, C is the "
"number of channels, H and W is the height and width of feature.");
......@@ -50,12 +50,10 @@ class Unpool2dOpMaker : public framework::OpProtoAndCheckerMaker {
"(string), unpooling type, can be \"max\" for max-unpooling ")
.InEnum({"max"});
AddComment(R"DOC(
"input: the input Tensor to invert
indices: the indices given out by MaxPool2d
ksize – Size of the max pooling window.
stride – Stride of the max pooling window.
"It is set to kernel_size by default.
padding – Padding that was added to the input"
"Paper: http://www.matthewzeiler.com/wp-content/uploads/2017
/07/iccv2011.pdf
PyTorch: http://pytorch.org/docs/master/nn.html?highlight=unpool#
torch.nn.MaxUnpool2d"
)DOC");
}
};
......@@ -79,27 +77,20 @@ public:
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Y"), "Input(Y) of UnpoolOp"
PADDLE_ENFORCE(ctx->HasInput("Indices"), "Input(Indices) of UnpoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of UnpoolOp should not be null.");
auto in_x_dims = ctx->GetInputDim("X");
auto in_y_dims = ctx->GetInputDim("Y");
auto in_y_dims = ctx->GetInputDim("Indices");
std::string unpooling_type =
ctx->Attrs().Get<std::string>("unpooling_type");
std::vector<int> ksize = ctx->Attrs().Get<std::vector<int>>("ksize");
std::vector<int> strides = ctx->Attrs().Get<std::vector<int>>("strides");
std::vector<int> paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
PADDLE_ENFORCE(in_x_dims.size() == 4,
"Unpooling intput must be of 4-dimensional.");
for (int i = 0; i < 4; ++i) {
PADDLE_ENFORCE(in_x_dims[i] == in_y_dims[i],
"X size must be eq Y size!");
}
PADDLE_ENFORCE_EQ(in_x_dims, in_y_dims);
std::vector<int64_t> output_shape({in_x_dims[0], in_x_dims[1]});
for (size_t i = 0; i < ksize.size(); ++i) {
output_shape.push_back(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/operators/unpool_op.h"
......
......@@ -2,7 +2,7 @@
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
Indicesou may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
......@@ -26,7 +26,7 @@ class UnpoolKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
auto * out = context.Output<framework::Tensor>("Out");
std::string unpooling_type = context.Attr<std::string>("unpooling_type");
std::vector<int> ksize = context.Attr<std::vector<int>>("ksize");
......@@ -47,7 +47,7 @@ class UnpoolGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Y");
const framework::Tensor* in_y = context.Input<framework::Tensor>("Indices");
const framework::Tensor* out = context.Input<framework::Tensor>("Out");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
......
......@@ -5,16 +5,16 @@ from op_test import OpTest
def unpool2dmax_forward_naive(input, indices, ksize, strides, paddings):
s0, s1, s2, s3 = input.shape
out_H=(s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0]
out_W=(s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1]
out = np.zeros((s0, s1, out_H, out_W))
out_hsize = (s2 - 1) * strides[0] - 2 * paddings[0] + ksize[0]
out_wsize = (s2 - 1) * strides[1] - 2 * paddings[1] + ksize[1]
out = np.zeros((s0, s1, out_hsize, out_wsize))
for nidx in xrange(s0):
for cidx in xrange(s1):
for h in xrange(s2):
for w in xrange(s3):
index = indices[nidx, cidx, h, w]
hidx = (index - index % out_W) / out_W
widx = index % out_W
hidx = (index - index % out_wsize) / out_wsize
widx = index % out_wsize
out[nidx, cidx, int(hidx), int(widx)] = \
input[nidx, cidx, h, w]
......@@ -26,34 +26,34 @@ class TestUnpoolOp(OpTest):
self.op_type = "unpool"
self.init_test_case()
pre_input = np.random.random(self.shape).astype("float32")
N, C, H, W = pre_input.shape
H_out = (H - self.ksize[0] + 2 * self.paddings[0]) / \
nsize, csize, hsize, wsize = pre_input.shape
hsize_out = (hsize - self.ksize[0] + 2 * self.paddings[0]) / \
self.strides[0] + 1
W_out = (W - self.ksize[1] + 2 * self.paddings[1]) / \
wsize_out = (wsize - self.ksize[1] + 2 * self.paddings[1]) / \
self.strides[1] + 1
input = np.zeros((N, C, H_out, W_out))
indices = np.zeros((N, C, H_out, W_out))
for i in xrange(H_out):
for j in xrange(W_out):
input = np.zeros((nsize, csize, hsize_out, wsize_out))
indices = np.zeros((nsize, csize, hsize_out, wsize_out))
for i in xrange(hsize_out):
for j in xrange(wsize_out):
r_start = np.max((i * self.strides[0] - self.paddings[0], 0))
r_end = np.min((i * self.strides[0] + self.ksize[0] - \
self.paddings[0], H))
self.paddings[0], hsize))
c_start = np.max((j * self.strides[1] - self.paddings[1], 0))
c_end = np.min((j * self.strides[1] + self.ksize[1] - \
self.paddings[1], W))
for nidx in xrange(N):
for cidx in xrange(C):
self.paddings[1], wsize))
for nidx in xrange(nsize):
for cidx in xrange(csize):
x_masked = pre_input[nidx, cidx, r_start:r_end, \
c_start:c_end]
input[nidx, cidx, i, j] = x_masked.max()
arg = x_masked.argmax()
indices[nidx, cidx, i, j] = \
(r_start + arg / self.ksize[1]) * W + \
(r_start + arg / self.ksize[1]) * wsize + \
c_start + arg % self.ksize[1]
output = self.Unpool2d_forward_naive(input, indices, self.ksize, \
self.strides, self.paddings).astype("float32")
self.inputs = {'X': input.astype('float32'),
'Y': indices.astype('int32')}
'Indices': indices.astype('int32')}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册