提交 70351de1 编写于 作者: J JiabinYang

test=develop

上级 5428cb99
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reorg_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class ReorgOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of reorgOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of reorgOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
auto stride = ctx->Attrs().Get<int64_t>("stride");
PADDLE_ENFORCE_GT(stride, 0, "The stride should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(
x_dims[1] % (stride * stride), 0,
"input channel should be dvisible of the square of reorg stride");
PADDLE_ENFORCE_EQ(
x_dims[2] % (stride), 0,
"input Height should be dvisible of the square of reorg stride");
PADDLE_ENFORCE_EQ(
x_dims[3] % (stride), 0,
"input Width should be dvisible of the square of reorg stride");
VLOG(3) << "reorg operator x.shape=" << x_dims << "Attribute stride"
<< stride << std::endl;
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
output_shape[0] = x_dims[0];
output_shape[1] = x_dims[1] * stride * stride;
output_shape[2] = x_dims[2] / stride;
output_shape[3] = x_dims[3] / stride;
auto out_dims = framework::make_ddim(output_shape);
ctx->SetOutputDim("Out", out_dims);
if (x_dims[0] == out_dims[0]) {
// Only pass LoD when the first dimension of output and Input(X)
// are the same.
ctx->ShareLoD("X", /*->*/ "Out");
}
}
};
class ReorgOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor). The input should be a 4D tensor B * C * W * H of reorg "
"operator.");
AddOutput("Out",
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"reorg operator.");
AddAttr<int64_t>("stride",
"(int64_t, default 1) stride used to do reorgnization.")
.SetDefault(1)
.EqualGreaterThan(1);
AddComment(R"DOC(
reorg operator used in Yolo v2.
The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride,
Reshape Input(X) into the shape according to Attr(stride). The
data in Input(X) are unchanged.
Examples:
1. Given a 3-D tensor Input(X) with a shape [2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X)
into a 3-D tensor with shape [2048, 13, 13] and leaving Input(X)'s data unchanged.
)DOC");
}
};
class ReorgGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) shouldn't be null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Out")),
"Input(Out@GRAD) shouldn't be null.");
ctx->SetOutputDim(framework::GradVarName("X"), ctx->GetInputDim("X"));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(reorg, ops::ReorgOp, ops::ReorgOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reorg_grad, ops::ReorgGradOp);
REGISTER_OP_CPU_KERNEL(
reorg, ops::ReorgKernel<paddle::platform::CPUDeviceContext, float>,
ops::ReorgKernel<paddle::platform::CPUDeviceContext, double>,
ops::ReorgKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
reorg_grad, ops::ReorgGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ReorgGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ReorgGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reorg_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
reorg, ops::ReorgKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReorgKernel<paddle::platform::CUDADeviceContext, double>,
ops::ReorgKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
reorg_grad,
ops::ReorgGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReorgGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ReorgGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_REORG_OP_H_
#define PADDLE_FLUID_OPERATORS_REORG_OP_H_
#endif // PADDLE_FLUID_OPERATORS_REORG_OP_H_
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
class reorg_cpu {
public:
HOSTDEVICE reorg_cpu(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t stride, int64_t forward, T *out)
: x_(x),
w_(w),
h_(h),
c_(c),
batch_(batch),
stride_(stride),
forward_(forward),
out_(out) {}
HOSTDEVICE void operator()(int64_t in_index) {
int64_t out_c = c_ / (stride_ * stride_);
// calculate each dim position with index of tensor
int64_t b = in_index / (c_ * h_ * w_);
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
int64_t j = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) / w_;
int64_t i = ((in_index % (c_ * h_ * w_)) % (h_ * w_)) % w_;
int64_t c2 = k % out_c;
int64_t offset = k / out_c;
int64_t w2 = i * stride_ + offset % stride_;
int64_t h2 = j * stride_ + offset / stride_;
int64_t out_index =
w2 + w_ * stride_ * (h2 + h_ * stride_ * (c2 + out_c * b));
if (forward_)
out_[out_index] = x_[in_index];
else
out_[in_index] = x_[out_index];
}
private:
const T *x_;
int64_t w_, h_, c_, batch_, stride_, forward_;
T *out_;
};
template <typename DeviceContext, typename T>
class ReorgKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
auto *x = context.Input<framework::LoDTensor>("X");
auto stride = context.Attr<int64_t>("stride");
auto in_dims = x->dims();
out->mutable_data(context.GetPlace(), x->type());
auto out_dims = out->dims();
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(x->numel()));
auto *x_data = x->data<T>();
auto *out_data = out->data<T>();
paddle::operators::reorg_cpu<T> reorg(x_data, W, H, C, B, stride, 1,
out_data);
for_range(reorg);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class ReorgGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *d_out =
context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto stride = context.Attr<int64_t>("stride");
auto in_dims = d_x->dims();
d_x->mutable_data(context.GetPlace(), d_out->type());
auto B = in_dims[0];
auto C = in_dims[1];
auto H = in_dims[2];
auto W = in_dims[3];
platform::ForRange<DeviceContext> for_range(
context.template device_context<DeviceContext>(),
static_cast<size_t>(d_x->numel()));
auto *dx_data = d_x->data<T>();
auto *dout_data = d_out->data<T>();
paddle::operators::reorg_cpu<T> reorg(dout_data, W, H, C, B, stride, 0,
dx_data);
for_range(reorg);
d_x->Resize(in_dims);
}
};
} // namespace operators
} // namespace paddle
......@@ -150,6 +150,7 @@ __all__ = [
'mul',
'sigmoid_cross_entropy_with_logits',
'maxout',
'reorg',
]
......@@ -7084,3 +7085,54 @@ def maxout(x, groups, name=None):
attrs={"groups": groups},
outputs={"Out": out})
return out
def reorg(x, stride, name=None):
"""
Gives a stride to reorg the input tensor
Here are some example:
input is 4D LoDtensor with shape [batch, channel, height, width] and has an attrs stride = 2
reorg will do some math work to reorder the elements of input according to stride to construt
put with shape [batch, channel * stride * stride, height/stride, width/stride]
reorg is used to reorgnization the output of pre_layer and change the tensor to fit the shape
Args:
x(variable): The input tensor.
stride(variable): The stride to reorg
Returns:
Variable: The output tensor.
Raises:
TypeError: stride type must be a long.
Examples:
.. code-block:: python
data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32')
reorged = fluid.layers.reorged(
x=data, stride=2)
"""
if not (isinstance(stride, long)):
raise ValueError("stride must be a python long")
helper = LayerHelper("reorg", **locals())
if name is None:
out = helper.create_tmp_variable(dtype=x.dtype)
else:
out = helper.create_variable(
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="reorg",
inputs={"X": x},
attrs={"stride": stride},
outputs={"Out": out})
return out
......@@ -108,6 +108,8 @@ class OpDescCreationMethod(object):
new_attr.i = user_defined_attr
elif attr.type == framework_pb2.FLOAT:
new_attr.f = user_defined_attr
elif attr.type == framework_pb2.LONG:
new_attr.l = user_defined_attr
elif attr.type == framework_pb2.STRING:
new_attr.s = user_defined_attr
elif attr.type == framework_pb2.BOOLEAN:
......
......@@ -240,6 +240,17 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.softmax(hid))
print(str(program))
def test_reorg(self):
program = Program()
with program_guard(program):
data = layers.data(
name="data",
shape=[32, 9, 6, 6],
append_batch_size=False,
dtype='float32')
self.assertIsNotNone(layers.reorg(data, long(3)))
print(str(program))
def test_sequence_unsqueeze(self):
program = Program()
with program_guard(program):
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid as fluid
from op_test import OpTest
class TestReorgOp(OpTest):
@staticmethod
def helper(in_, width, height, channel, batch, stride, forward, out_):
channel_out = channel / (stride * stride)
for b in range(batch):
for k in range(channel):
for j in range(height):
for i in range(width):
in_index = i + width * (j + height * (k + channel * b))
channel2 = k % channel_out
offset = k / channel_out
width2 = i * stride + offset % stride
height2 = j * stride + offset / stride
out_index = width2 + width * stride * (
height2 + height * stride *
(channel2 + channel_out * b))
if forward:
out_[out_index] = in_[in_index]
else:
out_[in_index] = in_[out_index]
def setUp(self):
self.init_data()
self.op_type = "reorg"
self.inputs = {"X": self.x}
self.helper(self.x_1d, self.x.shape[3], self.x.shape[2],
self.x.shape[1], self.x.shape[0], self.stride, self.forward,
self.out_1d)
self.out = np.reshape(self.out_1d, self.infered_shape)
self.attrs = {"stride": long(self.stride)}
self.outputs = {"Out": self.out}
def init_data(self):
self.ori_shape = (32, 12, 6, 6)
self.infered_shape = (32, 48, 3, 3)
self.one_d_len = 32 * 48 * 3 * 3
self.stride = 2
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
def test_check_output(self):
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.core.CPUPlace()
self.check_output_with_place(place, 1e-5, None, False)
def test_check_grad(self):
place = fluid.core.CUDAPlace(0) if fluid.core.is_compiled_with_cuda(
) else fluid.core.CPUPlace()
self.check_grad_with_place(place, ['X'], 'Out')
class TestReorgOp2(TestReorgOp):
def init_data(self):
self.ori_shape = (32, 9, 6, 6)
self.infered_shape = (32, 81, 2, 2)
self.one_d_len = 32 * 81 * 2 * 2
self.stride = 3
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册