diff --git a/paddle/fluid/API.spec b/paddle/fluid/API.spec index 5c4aa6158eed9a5e967e65a6590342ce884bb10f..3ac9fe31b4f3a4e23b181a6005ec00077f562e38 100644 --- a/paddle/fluid/API.spec +++ b/paddle/fluid/API.spec @@ -174,7 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None)) paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,)) -paddle.fluid.layers.reorg ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,)) +paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) diff --git a/paddle/fluid/operators/reorg_op.cc b/paddle/fluid/operators/space_to_depth_op.cc similarity index 62% rename from paddle/fluid/operators/reorg_op.cc rename to paddle/fluid/operators/space_to_depth_op.cc index 757761ab51fbef6471faed85c548355bfe2e83ad..a9a266a3f77be26541e76e55c47eeff40949f00a 100644 --- a/paddle/fluid/operators/reorg_op.cc +++ b/paddle/fluid/operators/space_to_depth_op.cc @@ -12,44 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/reorg_op.h" +#include "paddle/fluid/operators/space_to_depth_op.h" #include #include namespace paddle { namespace operators { -class ReorgOp : public framework::OperatorWithKernel { +class SpaceToDepthOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override { PADDLE_ENFORCE(ctx->HasInput("X"), - "Input(X) of reorgOp should not be null."); + "Input(X) of SpaceToDepthOp should not be null."); PADDLE_ENFORCE(ctx->HasOutput("Out"), - "Output(Out) of reorgOp should not be null."); + "Output(Out) of SpaceToDepthOp should not be null."); auto x_dims = ctx->GetInputDim("X"); PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor"); auto stride = ctx->Attrs().Get("stride"); - PADDLE_ENFORCE_GT(stride, 0, "The stride should be Greater than 0"); + PADDLE_ENFORCE_GT(stride, 1, "The stride should be Greater than 1"); PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0"); - PADDLE_ENFORCE_EQ( - x_dims[1] % (stride * stride), 0, - "input channel should be dvisible of the square of reorg stride"); - PADDLE_ENFORCE_EQ( - x_dims[2] % (stride), 0, - "input Height should be dvisible of the square of reorg stride"); - PADDLE_ENFORCE_EQ( - x_dims[3] % (stride), 0, - "input Width should be dvisible of the square of reorg stride"); + PADDLE_ENFORCE_EQ(x_dims[1] % (stride * stride), 0, + "input channel should be divisible of the square of " + "SpaceToDepthOp stride"); + PADDLE_ENFORCE_EQ(x_dims[2] % (stride), 0, + "input Height should be divisible of the square of " + "SpaceToDepthOp stride"); + PADDLE_ENFORCE_EQ(x_dims[3] % (stride), 0, + "input Width should be divisible of the square of " + "SpaceToDepthOp stride"); - VLOG(3) << "reorg operator x.shape=" << x_dims << "Attribute stride" - << stride << std::endl; + VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims + << "Attribute stride" << stride << std::endl; std::vector output_shape(4, 0); // [B,C,H,W] output_shape[0] = x_dims[0]; @@ -69,19 +69,21 @@ class ReorgOp : public framework::OperatorWithKernel { } }; -class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { +class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { public: void Make() override { AddInput("X", - "(Tensor). The input should be a 4D tensor B * C * W * H of reorg " + "(Tensor). The input should be a 4D tensor B * C * W * H of " + "SpaceToDepthOp " "operator."); AddOutput("Out", "(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of " - "reorg operator."); - AddAttr("stride", - "(int64_t, default 1) stride used to do reorgnization.") - .SetDefault(1) - .EqualGreaterThan(1); + "SpaceToDepthOp operator."); + AddAttr( + "stride", + "(int64_t, default 2) stride used to do change Space To Depth.") + .SetDefault(2) + .GreaterThan(1); AddComment(R"DOC( reorg operator used in Yolo v2. The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride, @@ -98,7 +100,7 @@ class ReorgOpMaker : public framework::OpProtoAndCheckerMaker { } }; -class ReorgGradOp : public framework::OperatorWithKernel { +class SpaceToDepthGradOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -114,14 +116,16 @@ class ReorgGradOp : public framework::OperatorWithKernel { namespace ops = paddle::operators; -REGISTER_OPERATOR(reorg, ops::ReorgOp, ops::ReorgOpMaker, +REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(reorg_grad, ops::ReorgGradOp); +REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp); REGISTER_OP_CPU_KERNEL( - reorg, ops::ReorgKernel, - ops::ReorgKernel, - ops::ReorgKernel); + space_to_depth, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel); REGISTER_OP_CPU_KERNEL( - reorg_grad, ops::ReorgGradKernel, - ops::ReorgGradKernel, - ops::ReorgGradKernel); + space_to_depth_grad, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel); diff --git a/paddle/fluid/operators/reorg_op.cu b/paddle/fluid/operators/space_to_depth_op.cu similarity index 57% rename from paddle/fluid/operators/reorg_op.cu rename to paddle/fluid/operators/space_to_depth_op.cu index de1c7d7468e105afc4f350e036a3e0eabc37a72c..38d0a662733222386b8ecd68d064f3d1abe56c3b 100644 --- a/paddle/fluid/operators/reorg_op.cu +++ b/paddle/fluid/operators/space_to_depth_op.cu @@ -12,18 +12,19 @@ // See the License for the specific language governing permissions and // limitations under the License. -#include "paddle/fluid/operators/reorg_op.h" +#include "paddle/fluid/operators/space_to_depth_op.h" namespace plat = paddle::platform; namespace ops = paddle::operators; REGISTER_OP_CUDA_KERNEL( - reorg, ops::ReorgKernel, - ops::ReorgKernel, - ops::ReorgKernel); + space_to_depth, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel, + ops::SpaceToDepthKernel); REGISTER_OP_CUDA_KERNEL( - reorg_grad, - ops::ReorgGradKernel, - ops::ReorgGradKernel, - ops::ReorgGradKernel); + space_to_depth_grad, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel, + ops::SpaceToDepthGradKernel); diff --git a/paddle/fluid/operators/reorg_op.h b/paddle/fluid/operators/space_to_depth_op.h similarity index 79% rename from paddle/fluid/operators/reorg_op.h rename to paddle/fluid/operators/space_to_depth_op.h index 108437b4d8f895f5951b4d560307548672f7e9d1..a236c1d5b7a8702654c1f823b25846a8999c1fe5 100644 --- a/paddle/fluid/operators/reorg_op.h +++ b/paddle/fluid/operators/space_to_depth_op.h @@ -11,9 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#ifndef PADDLE_FLUID_OPERATORS_REORG_OP_H_ -#define PADDLE_FLUID_OPERATORS_REORG_OP_H_ -#endif // PADDLE_FLUID_OPERATORS_REORG_OP_H_ +#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_ +#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_ +#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/platform/for_range.h" @@ -22,10 +22,11 @@ namespace paddle { namespace operators { template -class reorg_cpu { +class space_to_depth_compute { public: - HOSTDEVICE reorg_cpu(const T *x, int64_t w, int64_t h, int64_t c, - int64_t batch, int64_t stride, int64_t forward, T *out) + HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c, + int64_t batch, int64_t stride, + int64_t forward, T *out) : x_(x), w_(w), h_(h), @@ -62,7 +63,7 @@ class reorg_cpu { }; template -class ReorgKernel : public framework::OpKernel { +class SpaceToDepthKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *out = context.Output("Out"); @@ -82,16 +83,16 @@ class ReorgKernel : public framework::OpKernel { auto *x_data = x->data(); auto *out_data = out->data(); - paddle::operators::reorg_cpu reorg(x_data, W, H, C, B, stride, 1, - out_data); - for_range(reorg); + paddle::operators::space_to_depth_compute computer(x_data, W, H, C, B, + stride, 1, out_data); + for_range(computer); out->Resize(out_dims); } }; template -class ReorgGradKernel : public framework::OpKernel { +class SpaceToDepthGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &context) const override { auto *d_out = @@ -114,9 +115,9 @@ class ReorgGradKernel : public framework::OpKernel { auto *dx_data = d_x->data(); auto *dout_data = d_out->data(); - paddle::operators::reorg_cpu reorg(dout_data, W, H, C, B, stride, 0, - dx_data); - for_range(reorg); + paddle::operators::space_to_depth_compute computer(dout_data, W, H, C, B, + stride, 0, dx_data); + for_range(computer); d_x->Resize(in_dims); } diff --git a/python/paddle/fluid/layers/nn.py b/python/paddle/fluid/layers/nn.py index e7f343508a736635a48367060b78e4ff25f34db2..6688c0e99fb12f6a1867d6345a72233d68e8d0d4 100644 --- a/python/paddle/fluid/layers/nn.py +++ b/python/paddle/fluid/layers/nn.py @@ -154,7 +154,7 @@ __all__ = [ 'mul', 'sigmoid_cross_entropy_with_logits', 'maxout', - 'reorg', + 'space_to_depth', 'affine_channel', ] @@ -7456,25 +7456,26 @@ def maxout(x, groups, name=None): return out -def reorg(x, stride, name=None): +def space_to_depth(x, stride, name=None): """ - Gives a stride to reorg the input tensor - - Here are some example: - - input is 4D LoDtensor with shape [batch, channel, height, width] and has an attrs stride = 2 - - reorg will do some math work to reorder the elements of input according to stride to construt - put with shape [batch, channel * stride * stride, height/stride, width/stride] - - reorg is used to reorgnization the output of pre_layer and change the tensor to fit the shape + Gives a stride to space_to_depth the input LoDtensor + + Rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the + input LoDtensor where values from the height and width dimensions are moved to the channel dimension. + The attr stride indicates the input block size. + + space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according + to stride to construct output with shape [batch, channel * stride * stride, height/stride, width/stride]: + + space_to_depth is used to This operation is useful for resizing the activations between convolutions + (but keeping all data) Args: - x(variable): The input tensor. - stride(variable): The stride to reorg + x(variable): The input LoDtensor. + stride(variable): The stride to space_to_depth Returns: - Variable: The output tensor. + Variable: The output LoDtensor. Raises: TypeError: stride type must be a long. @@ -7484,11 +7485,11 @@ def reorg(x, stride, name=None): data = fluid.layers.data( name='data', shape=[1, 4, 2, 2], dtype='float32') - reorged = fluid.layers.reorged( + space_to_depthed = fluid.layers.space_to_depth( x=data, stride=2) """ - helper = LayerHelper("reorg", **locals()) + helper = LayerHelper("space_to_depth", **locals()) if not (isinstance(stride, int)): raise ValueError("stride must be a python Int") @@ -7501,7 +7502,7 @@ def reorg(x, stride, name=None): name=name, dtype=x.dtype, persistable=False) helper.append_op( - type="reorg", + type="space_to_depth", inputs={"X": x}, attrs={"stride": stride}, outputs={"Out": out}) diff --git a/python/paddle/fluid/tests/unittests/test_layers.py b/python/paddle/fluid/tests/unittests/test_layers.py index 92c60da71546d10155eede83ed19afab9c678beb..9dd733a54d7868e6d22308166b04066670c9fade 100644 --- a/python/paddle/fluid/tests/unittests/test_layers.py +++ b/python/paddle/fluid/tests/unittests/test_layers.py @@ -248,7 +248,7 @@ class TestBook(unittest.TestCase): self.assertIsNotNone(layers.softmax(hid)) print(str(program)) - def test_reorg(self): + def test_space_to_depth(self): program = Program() with program_guard(program): data = layers.data( @@ -256,7 +256,7 @@ class TestBook(unittest.TestCase): shape=[32, 9, 6, 6], append_batch_size=False, dtype='float32') - self.assertIsNotNone(layers.reorg(data, 3)) + self.assertIsNotNone(layers.space_to_depth(data, 3)) print(str(program)) def test_sequence_unsqueeze(self): diff --git a/python/paddle/fluid/tests/unittests/test_reorg_op.py b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py similarity index 67% rename from python/paddle/fluid/tests/unittests/test_reorg_op.py rename to python/paddle/fluid/tests/unittests/test_space_to_depth_op.py index a3afabe7afec50e94b3a23b38969aa27fabcf14c..36c8cd11199cde734cadb00643a95312e827140f 100644 --- a/python/paddle/fluid/tests/unittests/test_reorg_op.py +++ b/python/paddle/fluid/tests/unittests/test_space_to_depth_op.py @@ -19,7 +19,7 @@ import paddle.fluid as fluid from op_test import OpTest -class TestReorgOp(OpTest): +class TestSpaceToDepthOp(OpTest): @staticmethod def helper(in_, width, height, channel, batch, stride, forward, out_): channel_out = channel // (stride * stride) @@ -43,7 +43,7 @@ class TestReorgOp(OpTest): def setUp(self): self.init_data() - self.op_type = "reorg" + self.op_type = "space_to_depth" self.inputs = {"X": self.x} self.helper(self.x_1d, self.x.shape[3], self.x.shape[2], self.x.shape[1], self.x.shape[0], self.stride, self.forward, @@ -75,7 +75,35 @@ class TestReorgOp(OpTest): self.check_grad_with_place(place, ['X'], 'Out') -class TestReorgOp2(TestReorgOp): +class TestSpaceToDepthOpBasic(TestSpaceToDepthOp): + def init_data(self): + self.ori_shape = (32, 8, 6, 6) + self.infered_shape = (32, 32, 3, 3) + self.one_d_len = 32 * 32 * 3 * 3 + + self.stride = 2 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + +class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp): + def init_data(self): + self.ori_shape = (32, 8, 6, 6) + self.infered_shape = (32, 32, 3, 3) + self.one_d_len = 32 * 32 * 3 * 3 + + self.stride = 2 + self.x = np.random.random(self.ori_shape).astype('float64') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float64') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + +class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp): def init_data(self): self.ori_shape = (32, 9, 6, 6) self.infered_shape = (32, 81, 2, 2) @@ -89,5 +117,19 @@ class TestReorgOp2(TestReorgOp): self.forward = 1 +class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp): + def init_data(self): + self.ori_shape = (32, 9, 9, 6) + self.infered_shape = (32, 81, 3, 2) + self.one_d_len = 32 * 81 * 3 * 2 + + self.stride = 3 + self.x = np.random.random(self.ori_shape).astype('float32') + self.x_1d = np.reshape(self.x, self.one_d_len) + self.out = np.zeros(self.infered_shape).astype('float32') + self.out_1d = np.reshape(self.out, self.one_d_len) + self.forward = 1 + + if __name__ == '__main__': unittest.main()