提交 9cad409f 编写于 作者: J JiabinYang

test=develop

上级 bd064c0f
......@@ -174,7 +174,7 @@ paddle.fluid.layers.mean ArgSpec(args=['x', 'name'], varargs=None, keywords=None
paddle.fluid.layers.mul ArgSpec(args=['x', 'y', 'x_num_col_dims', 'y_num_col_dims', 'name'], varargs=None, keywords=None, defaults=(1, 1, None))
paddle.fluid.layers.sigmoid_cross_entropy_with_logits ArgSpec(args=['x', 'label', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.maxout ArgSpec(args=['x', 'groups', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.reorg ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.space_to_depth ArgSpec(args=['x', 'stride', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.affine_channel ArgSpec(args=['x', 'scale', 'bias', 'data_layout', 'name'], varargs=None, keywords=None, defaults=(None, None, 'NCHW', None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
......@@ -12,44 +12,44 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/reorg_op.h"
#include "paddle/fluid/operators/space_to_depth_op.h"
#include <string>
#include <vector>
namespace paddle {
namespace operators {
class ReorgOp : public framework::OperatorWithKernel {
class SpaceToDepthOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"),
"Input(X) of reorgOp should not be null.");
"Input(X) of SpaceToDepthOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of reorgOp should not be null.");
"Output(Out) of SpaceToDepthOp should not be null.");
auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
auto stride = ctx->Attrs().Get<int64_t>("stride");
PADDLE_ENFORCE_GT(stride, 0, "The stride should be Greater than 0");
PADDLE_ENFORCE_GT(stride, 1, "The stride should be Greater than 1");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(
x_dims[1] % (stride * stride), 0,
"input channel should be dvisible of the square of reorg stride");
PADDLE_ENFORCE_EQ(
x_dims[2] % (stride), 0,
"input Height should be dvisible of the square of reorg stride");
PADDLE_ENFORCE_EQ(
x_dims[3] % (stride), 0,
"input Width should be dvisible of the square of reorg stride");
PADDLE_ENFORCE_EQ(x_dims[1] % (stride * stride), 0,
"input channel should be divisible of the square of "
"SpaceToDepthOp stride");
PADDLE_ENFORCE_EQ(x_dims[2] % (stride), 0,
"input Height should be divisible of the square of "
"SpaceToDepthOp stride");
PADDLE_ENFORCE_EQ(x_dims[3] % (stride), 0,
"input Width should be divisible of the square of "
"SpaceToDepthOp stride");
VLOG(3) << "reorg operator x.shape=" << x_dims << "Attribute stride"
<< stride << std::endl;
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
<< "Attribute stride" << stride << std::endl;
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
output_shape[0] = x_dims[0];
......@@ -69,19 +69,21 @@ class ReorgOp : public framework::OperatorWithKernel {
}
};
class ReorgOpMaker : public framework::OpProtoAndCheckerMaker {
class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X",
"(Tensor). The input should be a 4D tensor B * C * W * H of reorg "
"(Tensor). The input should be a 4D tensor B * C * W * H of "
"SpaceToDepthOp "
"operator.");
AddOutput("Out",
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"reorg operator.");
AddAttr<int64_t>("stride",
"(int64_t, default 1) stride used to do reorgnization.")
.SetDefault(1)
.EqualGreaterThan(1);
"SpaceToDepthOp operator.");
AddAttr<int64_t>(
"stride",
"(int64_t, default 2) stride used to do change Space To Depth.")
.SetDefault(2)
.GreaterThan(1);
AddComment(R"DOC(
reorg operator used in Yolo v2.
The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride,
......@@ -98,7 +100,7 @@ class ReorgOpMaker : public framework::OpProtoAndCheckerMaker {
}
};
class ReorgGradOp : public framework::OperatorWithKernel {
class SpaceToDepthGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -114,14 +116,16 @@ class ReorgGradOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
REGISTER_OPERATOR(reorg, ops::ReorgOp, ops::ReorgOpMaker,
REGISTER_OPERATOR(space_to_depth, ops::SpaceToDepthOp, ops::SpaceToDepthOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(reorg_grad, ops::ReorgGradOp);
REGISTER_OPERATOR(space_to_depth_grad, ops::SpaceToDepthGradOp);
REGISTER_OP_CPU_KERNEL(
reorg, ops::ReorgKernel<paddle::platform::CPUDeviceContext, float>,
ops::ReorgKernel<paddle::platform::CPUDeviceContext, double>,
ops::ReorgKernel<paddle::platform::CPUDeviceContext, int64_t>);
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CPUDeviceContext, int64_t>);
REGISTER_OP_CPU_KERNEL(
reorg_grad, ops::ReorgGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::ReorgGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::ReorgGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CPUDeviceContext, int64_t>);
......@@ -12,18 +12,19 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/reorg_op.h"
#include "paddle/fluid/operators/space_to_depth_op.h"
namespace plat = paddle::platform;
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(
reorg, ops::ReorgKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReorgKernel<paddle::platform::CUDADeviceContext, double>,
ops::ReorgKernel<paddle::platform::CUDADeviceContext, int64_t>);
space_to_depth,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthKernel<paddle::platform::CUDADeviceContext, int64_t>);
REGISTER_OP_CUDA_KERNEL(
reorg_grad,
ops::ReorgGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::ReorgGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::ReorgGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
space_to_depth_grad,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, float>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, double>,
ops::SpaceToDepthGradKernel<paddle::platform::CUDADeviceContext, int64_t>);
......@@ -11,9 +11,9 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef PADDLE_FLUID_OPERATORS_REORG_OP_H_
#define PADDLE_FLUID_OPERATORS_REORG_OP_H_
#endif // PADDLE_FLUID_OPERATORS_REORG_OP_H_
#ifndef PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#define PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#endif // PADDLE_FLUID_OPERATORS_SPACE_TO_DEPTH_OP_H_
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
......@@ -22,10 +22,11 @@ namespace paddle {
namespace operators {
template <typename T>
class reorg_cpu {
class space_to_depth_compute {
public:
HOSTDEVICE reorg_cpu(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t stride, int64_t forward, T *out)
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t stride,
int64_t forward, T *out)
: x_(x),
w_(w),
h_(h),
......@@ -62,7 +63,7 @@ class reorg_cpu {
};
template <typename DeviceContext, typename T>
class ReorgKernel : public framework::OpKernel<T> {
class SpaceToDepthKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out");
......@@ -82,16 +83,16 @@ class ReorgKernel : public framework::OpKernel<T> {
auto *x_data = x->data<T>();
auto *out_data = out->data<T>();
paddle::operators::reorg_cpu<T> reorg(x_data, W, H, C, B, stride, 1,
out_data);
for_range(reorg);
paddle::operators::space_to_depth_compute<T> computer(x_data, W, H, C, B,
stride, 1, out_data);
for_range(computer);
out->Resize(out_dims);
}
};
template <typename DeviceContext, typename T>
class ReorgGradKernel : public framework::OpKernel<T> {
class SpaceToDepthGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *d_out =
......@@ -114,9 +115,9 @@ class ReorgGradKernel : public framework::OpKernel<T> {
auto *dx_data = d_x->data<T>();
auto *dout_data = d_out->data<T>();
paddle::operators::reorg_cpu<T> reorg(dout_data, W, H, C, B, stride, 0,
dx_data);
for_range(reorg);
paddle::operators::space_to_depth_compute<T> computer(dout_data, W, H, C, B,
stride, 0, dx_data);
for_range(computer);
d_x->Resize(in_dims);
}
......
......@@ -154,7 +154,7 @@ __all__ = [
'mul',
'sigmoid_cross_entropy_with_logits',
'maxout',
'reorg',
'space_to_depth',
'affine_channel',
]
......@@ -7456,25 +7456,26 @@ def maxout(x, groups, name=None):
return out
def reorg(x, stride, name=None):
def space_to_depth(x, stride, name=None):
"""
Gives a stride to reorg the input tensor
Here are some example:
input is 4D LoDtensor with shape [batch, channel, height, width] and has an attrs stride = 2
reorg will do some math work to reorder the elements of input according to stride to construt
put with shape [batch, channel * stride * stride, height/stride, width/stride]
reorg is used to reorgnization the output of pre_layer and change the tensor to fit the shape
Gives a stride to space_to_depth the input LoDtensor
Rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr stride indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to stride to construct output with shape [batch, channel * stride * stride, height/stride, width/stride]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data)
Args:
x(variable): The input tensor.
stride(variable): The stride to reorg
x(variable): The input LoDtensor.
stride(variable): The stride to space_to_depth
Returns:
Variable: The output tensor.
Variable: The output LoDtensor.
Raises:
TypeError: stride type must be a long.
......@@ -7484,11 +7485,11 @@ def reorg(x, stride, name=None):
data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32')
reorged = fluid.layers.reorged(
space_to_depthed = fluid.layers.space_to_depth(
x=data, stride=2)
"""
helper = LayerHelper("reorg", **locals())
helper = LayerHelper("space_to_depth", **locals())
if not (isinstance(stride, int)):
raise ValueError("stride must be a python Int")
......@@ -7501,7 +7502,7 @@ def reorg(x, stride, name=None):
name=name, dtype=x.dtype, persistable=False)
helper.append_op(
type="reorg",
type="space_to_depth",
inputs={"X": x},
attrs={"stride": stride},
outputs={"Out": out})
......
......@@ -248,7 +248,7 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.softmax(hid))
print(str(program))
def test_reorg(self):
def test_space_to_depth(self):
program = Program()
with program_guard(program):
data = layers.data(
......@@ -256,7 +256,7 @@ class TestBook(unittest.TestCase):
shape=[32, 9, 6, 6],
append_batch_size=False,
dtype='float32')
self.assertIsNotNone(layers.reorg(data, 3))
self.assertIsNotNone(layers.space_to_depth(data, 3))
print(str(program))
def test_sequence_unsqueeze(self):
......
......@@ -19,7 +19,7 @@ import paddle.fluid as fluid
from op_test import OpTest
class TestReorgOp(OpTest):
class TestSpaceToDepthOp(OpTest):
@staticmethod
def helper(in_, width, height, channel, batch, stride, forward, out_):
channel_out = channel // (stride * stride)
......@@ -43,7 +43,7 @@ class TestReorgOp(OpTest):
def setUp(self):
self.init_data()
self.op_type = "reorg"
self.op_type = "space_to_depth"
self.inputs = {"X": self.x}
self.helper(self.x_1d, self.x.shape[3], self.x.shape[2],
self.x.shape[1], self.x.shape[0], self.stride, self.forward,
......@@ -75,7 +75,35 @@ class TestReorgOp(OpTest):
self.check_grad_with_place(place, ['X'], 'Out')
class TestReorgOp2(TestReorgOp):
class TestSpaceToDepthOpBasic(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 8, 6, 6)
self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3
self.stride = 2
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 8, 6, 6)
self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3
self.stride = 2
self.x = np.random.random(self.ori_shape).astype('float64')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float64')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 9, 6, 6)
self.infered_shape = (32, 81, 2, 2)
......@@ -89,5 +117,19 @@ class TestReorgOp2(TestReorgOp):
self.forward = 1
class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp):
def init_data(self):
self.ori_shape = (32, 9, 9, 6)
self.infered_shape = (32, 81, 3, 2)
self.one_d_len = 32 * 81 * 3 * 2
self.stride = 3
self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32')
self.out_1d = np.reshape(self.out, self.one_d_len)
self.forward = 1
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册