未验证 提交 b03a44e0 编写于 作者: X Xin Pan 提交者: GitHub

Merge pull request #14026 from JiabinYang/add_reorg_op

Add reorg op
......@@ -174,6 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'blocksize', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_grid ArgSpec(args=['theta', 'out_shape', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.sequence_reverse ArgSpec(args=['x', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/space_to_depth_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class SpaceToDepthOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of SpaceToDepthOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of SpaceToDepthOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp blocksize");
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
<< "Attribute blocksize" << blocksize << std::endl;
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
output_shape[0] = x_dims[0];
output_shape[1] = x_dims[1] * blocksize * blocksize;
output_shape[2] = x_dims[2] / blocksize;
output_shape[3] = x_dims[3] / blocksize;
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor). The input should be a 4D tensor B * C * W * H of "
"SpaceToDepthOp "
"operator.");
AddOutput("Out",
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"SpaceToDepthOp operator.");
AddAttr<int64_t>(
"blocksize",
"(int64_t, default 2) blocksize used to do change Space To Depth.")
.SetDefault(2)
.GreaterThan(1);
AddComment(R"DOC(
reorg operator used in Yolo v2.
The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged.
Examples:
1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X)
into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged.
)DOC");
}
};
class SpaceToDepthGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp);
REGISTER_OP_CPU_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/space_to_depth_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class space_to_depth_compute {
public:
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t blocksize,
int64_t forward, T *out)
: x_(x),
w_(w),
h_(h),
c_(c),
batch_(batch),
blocksize_(blocksize),
forward_(forward),
out_(out) {}
HOSTDEVICE void operator()(int64_t in_index) {
int64_t out_c = c_ / (blocksize_ * blocksize_);
// calculate each dim position with index of tensor
int64_t b = in_index / (c_ * h_ * w_);
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_;
int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_;
int64_t c2 = k % out_c;
int64_t offset = k / out_c;
int64_t w2 = i * blocksize_ + offset % blocksize_;
int64_t h2 = j * blocksize_ + offset / blocksize_;
int64_t out_index =
w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b));
if (forward_)
out_[out_index] = x_[in_index];
else
out_[in_index] = x_[out_index];
}
private:
const T *x_;
int64_t w_, h_, c_, batch_, blocksize_, forward_;
T *out_;
};
template <typename DeviceContext, typename T>
class SpaceToDepthKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *x = context.Input<framework::LoDTensor>("X");
auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = x->dims();
out->mutable_data(context.GetPlace(), x->type());
auto out_dims = out->dims();
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(x->numel()));
auto *x_data = x->data<T>();
auto *out_data = out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(
x_data, W, H, C, B, blocksize, 1, out_data);
for_range(computer);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class SpaceToDepthGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = d_x->dims();
d_x->mutable_data(context.GetPlace(), d_out->type());
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(d_x->numel()));
auto *dx_data = d_x->data<T>();
auto *dout_data = d_out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(
dout_data, W, H, C, B, blocksize, 0, dx_data);
for_range(computer);
d_x->Resize(in_dims);
}
};
} // namespace operators
} // namespace paddle
......@@ -154,6 +154,7 @@ __all__ = [
'mul',
'sigmoid_cross_entropy_with_logits',
'maxout',
'space_to_depth',
'affine_grid',
'sequence_reverse',
'affine_channel',
......@@ -7674,6 +7675,66 @@ def maxout(x, groups, name=None):
return out
def space_to_depth(x, blocksize, name=None):
"""
Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr blocksize indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data)
- Non-overlapping blocks of size block_size x block size are rearranged into depth at each location.
- The depth of the output tensor is block_size * block_size * input channel
- The Y, X coordinates within each block of the input become the high order component of the output channel index
- channel should be divisible by square of blocksize
- height, width should be divsible by blocksize
Args:
x(variable): The input LoDtensor.
blocksize(variable): The blocksize to select the element on each feature map should be > 2
Returns:
Variable: The output LoDtensor.
Raises:
TypeError: blocksize type must be a long.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32')
space_to_depthed = fluid.layers.space_to_depth(
x=data, blocksize=2)
"""
helper = LayerHelper("space_to_depth", **locals())
if not (isinstance(blocksize, int)):
raise ValueError("blocksize must be a python Int")
if name is None:
out = helper.create_variable_for_type_inference(
dtype=x.dtype) #fix create
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="space_to_depth",
inputs={"X": x},
attrs={"blocksize": blocksize},
outputs={"Out": out})
return out
@templatedoc()
def sequence_reverse(x, name=None):
"""
......
......@@ -108,6 +108,8 @@ class OpDescCreationMethod(object):
new_attr.i = user_defined_attr
elif attr.type == framework_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr.type == framework_pb2.LONG:
new_attr.l = user_defined_attr
elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOLEAN:
......
......@@ -248,6 +248,17 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.softmax(hid))
print(str(program))
def test_space_to_depth(self):
program = Program()
with program_guard(program):
data = layers.data(
name='data',
shape=[32, 9, 6, 6],
append_batch_size=False,
dtype='float32')
self.assertIsNotNone(layers.space_to_depth(data, 3))
print(str(program))
def test_sequence_unsqueeze(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest
class TestSpaceToDepthOp(OpTest):
@staticmethod
def helper(in_, width, height, channel, batch, blocksize, forward, out_):
channel_out = channel // (blocksize * blocksize)
for b in range(batch):
for k in range(channel):
for j in range(height):
for i in range(width):
in_index = i + width * (j + height * (k + channel * b))
channel2 = k % channel_out
offset = k // channel_out
width2 = i * blocksize + offset % blocksize
height2 = j * blocksize + offset // blocksize
out_index = width2 + width * blocksize * (
height2 + height * blocksize *
(channel2 + channel_out * b))
if forward:
out_[out_index] = in_[in_index]
else:
out_[in_index] = in_[out_index]
def setUp(self):
self.init_data()
self.op_type = "space_to_depth"
self.inputs = {"X": self.x}
self.helper(self.x_1d, self.x.shape[3], self.x.shape[2],
self.x.shape[1], self.x.shape[0], self.blocksize,
self.forward, self.out_1d)
self.out = np.reshape(self.out_1d, self.infered_shape)
self.attrs = {"blocksize": self.blocksize}
self.outputs = {"Out": self.out}
def init_data(self):
self.ori_shape = (32, 12, 6, 6)
self.infered_shape = (32, 48, 3, 3)
self.one_d_len = 32 * 48 * 3 * 3
self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
def test_check_output(self):
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.core.CPUPlace()
self.check_output_with_place(place, 1e-5, None, False)
def test_check_grad(self):
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.core.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')
class TestSpaceToDepthOpBasic(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 8, 6, 6)
self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3
self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 8, 6, 6)
self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3
self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float64')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float64')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 9, 6, 6)
self.infered_shape = (32, 81, 2, 2)
self.one_d_len = 32 * 81 * 2 * 2
self.blocksize = 3
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 9, 9, 6)
self.infered_shape = (32, 81, 3, 2)
self.one_d_len = 32 * 81 * 3 * 2
self.blocksize = 3
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册