提交 9a74c448 编写于 作者: J JiabinYang

test=develop

上级 6e361542
...@@ -31,31 +31,31 @@ class SpaceToDepthOp : public framework::OperatorWithKernel { ...@@ -31,31 +31,31 @@ class SpaceToDepthOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor"); PADDLE_ENFORCE_EQ(x_dims.size(), 4, "input should be a 4D tensor");
auto stride = ctx->Attrs().Get<int64_t>("stride"); auto blocksize = ctx->Attrs().Get<int64_t>("blocksize");
PADDLE_ENFORCE_GT(stride, 1, "The stride should be Greater than 1"); PADDLE_ENFORCE_GT(blocksize, 1, "The blocksize should be Greater than 1");
PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[1], 0, "input channel should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[2], 0, "input Height should be Greater than 0");
PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0"); PADDLE_ENFORCE_GT(x_dims[3], 0, "input Width should be Greater than 0");
PADDLE_ENFORCE_EQ(x_dims[1] % (stride * stride), 0, PADDLE_ENFORCE_EQ(x_dims[1] % (blocksize * blocksize), 0,
"input channel should be divisible of the square of " "input channel should be divisible of the square of "
"SpaceToDepthOp stride"); "SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[2] % (stride), 0, PADDLE_ENFORCE_EQ(x_dims[2] % (blocksize), 0,
"input Height should be divisible of the square of " "input Height should be divisible of the square of "
"SpaceToDepthOp stride"); "SpaceToDepthOp blocksize");
PADDLE_ENFORCE_EQ(x_dims[3] % (stride), 0, PADDLE_ENFORCE_EQ(x_dims[3] % (blocksize), 0,
"input Width should be divisible of the square of " "input Width should be divisible of the square of "
"SpaceToDepthOp stride"); "SpaceToDepthOp blocksize");
VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims VLOG(3) << "SpaceToDepthOp operator x.shape=" << x_dims
<< "Attribute stride" << stride << std::endl; << "Attribute blocksize" << blocksize << std::endl;
std::vector<int64_t> output_shape(4, 0); // [B,C,H,W] std::vector<int64_t> output_shape(4, 0); // [B,C,H,W]
output_shape[0] = x_dims[0]; output_shape[0] = x_dims[0];
output_shape[1] = x_dims[1] * stride * stride; output_shape[1] = x_dims[1] * blocksize * blocksize;
output_shape[2] = x_dims[2] / stride; output_shape[2] = x_dims[2] / blocksize;
output_shape[3] = x_dims[3] / stride; output_shape[3] = x_dims[3] / blocksize;
auto out_dims = framework::make_ddim(output_shape); auto out_dims = framework::make_ddim(output_shape);
...@@ -80,20 +80,20 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -80,20 +80,20 @@ class SpaceToDepthOpMaker : public framework::OpProtoAndCheckerMaker {
"(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of " "(Tensor), The output should be a 4D tensor B * C2 * W2 * H2 of "
"SpaceToDepthOp operator."); "SpaceToDepthOp operator.");
AddAttr<int64_t>( AddAttr<int64_t>(
"stride", "blocksize",
"(int64_t, default 2) stride used to do change Space To Depth.") "(int64_t, default 2) blocksize used to do change Space To Depth.")
.SetDefault(2) .SetDefault(2)
.GreaterThan(1); .GreaterThan(1);
AddComment(R"DOC( AddComment(R"DOC(
reorg operator used in Yolo v2. reorg operator used in Yolo v2.
The equation is: C2 = C1/stride * stride, W2 = W1 ∗ stride + offset % stride, H2 = H1 ∗ stride + offset / stride, The equation is: C2 = C1/blocksize * blocksize, W2 = W1 ∗ blocksize + offset % blocksize, H2 = H1 ∗ blocksize + offset / blocksize,
Reshape Input(X) into the shape according to Attr(stride). The Reshape Input(X) into the shape according to Attr(blocksize). The
data in Input(X) are unchanged. data in Input(X) are unchanged.
Examples: Examples:
1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the stride is 2, the reorg operator will transform Input(X) 1. Given a 4-D tensor Input(X) with a shape [128, 2048, 26, 26], and the blocksize is 2, the reorg operator will transform Input(X)
into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged. into a 4-D tensor with shape [128, 2048, 13, 13] and leaving Input(X)'s data unchanged.
)DOC"); )DOC");
......
...@@ -25,19 +25,19 @@ template <typename T> ...@@ -25,19 +25,19 @@ template <typename T>
class space_to_depth_compute { class space_to_depth_compute {
public: public:
HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c, HOSTDEVICE space_to_depth_compute(const T *x, int64_t w, int64_t h, int64_t c,
int64_t batch, int64_t stride, int64_t batch, int64_t blocksize,
int64_t forward, T *out) int64_t forward, T *out)
: x_(x), : x_(x),
w_(w), w_(w),
h_(h), h_(h),
c_(c), c_(c),
batch_(batch), batch_(batch),
stride_(stride), blocksize_(blocksize),
forward_(forward), forward_(forward),
out_(out) {} out_(out) {}
HOSTDEVICE void operator()(int64_t in_index) { HOSTDEVICE void operator()(int64_t in_index) {
int64_t out_c = c_ / (stride_ * stride_); int64_t out_c = c_ / (blocksize_ * blocksize_);
// calculate each dim position with index of tensor // calculate each dim position with index of tensor
int64_t b = in_index / (c_ * h_ * w_); int64_t b = in_index / (c_ * h_ * w_);
int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_); int64_t k = (in_index % (c_ * h_ * w_)) / (h_ * w_);
...@@ -46,10 +46,10 @@ class space_to_depth_compute { ...@@ -46,10 +46,10 @@ class space_to_depth_compute {
int64_t c2 = k % out_c; int64_t c2 = k % out_c;
int64_t offset = k / out_c; int64_t offset = k / out_c;
int64_t w2 = i * stride_ + offset % stride_; int64_t w2 = i * blocksize_ + offset % blocksize_;
int64_t h2 = j * stride_ + offset / stride_; int64_t h2 = j * blocksize_ + offset / blocksize_;
int64_t out_index = int64_t out_index =
w2 + w_ * stride_ * (h2 + h_ * stride_ * (c2 + out_c * b)); w2 + w_ * blocksize_ * (h2 + h_ * blocksize_ * (c2 + out_c * b));
if (forward_) if (forward_)
out_[out_index] = x_[in_index]; out_[out_index] = x_[in_index];
else else
...@@ -58,7 +58,7 @@ class space_to_depth_compute { ...@@ -58,7 +58,7 @@ class space_to_depth_compute {
private: private:
const T *x_; const T *x_;
int64_t w_, h_, c_, batch_, stride_, forward_; int64_t w_, h_, c_, batch_, blocksize_, forward_;
T *out_; T *out_;
}; };
...@@ -68,7 +68,7 @@ class SpaceToDepthKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,7 @@ class SpaceToDepthKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
auto *out = context.Output<framework::LoDTensor>("Out"); auto *out = context.Output<framework::LoDTensor>("Out");
auto *x = context.Input<framework::LoDTensor>("X"); auto *x = context.Input<framework::LoDTensor>("X");
auto stride = context.Attr<int64_t>("stride"); auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = x->dims(); auto in_dims = x->dims();
out->mutable_data(context.GetPlace(), x->type()); out->mutable_data(context.GetPlace(), x->type());
...@@ -83,8 +83,8 @@ class SpaceToDepthKernel : public framework::OpKernel<T> { ...@@ -83,8 +83,8 @@ class SpaceToDepthKernel : public framework::OpKernel<T> {
auto *x_data = x->data<T>(); auto *x_data = x->data<T>();
auto *out_data = out->data<T>(); auto *out_data = out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(x_data, W, H, C, B, paddle::operators::space_to_depth_compute<T> computer(
stride, 1, out_data); x_data, W, H, C, B, blocksize, 1, out_data);
for_range(computer); for_range(computer);
out->Resize(out_dims); out->Resize(out_dims);
...@@ -99,7 +99,7 @@ class SpaceToDepthGradKernel : public framework::OpKernel<T> { ...@@ -99,7 +99,7 @@ class SpaceToDepthGradKernel : public framework::OpKernel<T> {
context.Input<framework::LoDTensor>(framework::GradVarName("Out")); context.Input<framework::LoDTensor>(framework::GradVarName("Out"));
auto *d_x = auto *d_x =
context.Output<framework::LoDTensor>(framework::GradVarName("X")); context.Output<framework::LoDTensor>(framework::GradVarName("X"));
auto stride = context.Attr<int64_t>("stride"); auto blocksize = context.Attr<int64_t>("blocksize");
auto in_dims = d_x->dims(); auto in_dims = d_x->dims();
d_x->mutable_data(context.GetPlace(), d_out->type()); d_x->mutable_data(context.GetPlace(), d_out->type());
...@@ -115,8 +115,8 @@ class SpaceToDepthGradKernel : public framework::OpKernel<T> { ...@@ -115,8 +115,8 @@ class SpaceToDepthGradKernel : public framework::OpKernel<T> {
auto *dx_data = d_x->data<T>(); auto *dx_data = d_x->data<T>();
auto *dout_data = d_out->data<T>(); auto *dout_data = d_out->data<T>();
paddle::operators::space_to_depth_compute<T> computer(dout_data, W, H, C, B, paddle::operators::space_to_depth_compute<T> computer(
stride, 0, dx_data); dout_data, W, H, C, B, blocksize, 0, dx_data);
for_range(computer); for_range(computer);
d_x->Resize(in_dims); d_x->Resize(in_dims);
......
...@@ -7485,29 +7485,29 @@ def maxout(x, groups, name=None): ...@@ -7485,29 +7485,29 @@ def maxout(x, groups, name=None):
return out return out
def space_to_depth(x, stride, name=None): def space_to_depth(x, blocksize, name=None):
""" """
Gives a stride to space_to_depth the input LoDtensor Gives a blocksize to space_to_depth the input LoDtensor with Layout: [batch, channel, height, width]
Rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the This op rearranges blocks of spatial data, into depth. More specifically, this op outputs a copy of the
input LoDtensor where values from the height and width dimensions are moved to the channel dimension. input LoDtensor where values from the height and width dimensions are moved to the channel dimension.
The attr stride indicates the input block size. The attr blocksize indicates the input block size.
space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according space_to_depth will reorgnize the elements of input with shape[batch, channel, height, width] according
to stride to construct output with shape [batch, channel * stride * stride, height/stride, width/stride]: to blocksize to construct output with shape [batch, channel * blocksize * blocksize, height/blocksize, width/blocksize]:
space_to_depth is used to This operation is useful for resizing the activations between convolutions space_to_depth is used to This operation is useful for resizing the activations between convolutions
(but keeping all data) (but keeping all data)
Args: Args:
x(variable): The input LoDtensor. x(variable): The input LoDtensor.
stride(variable): The stride to select the element on each feature map blocksize(variable): The blocksize to select the element on each feature map
Returns: Returns:
Variable: The output LoDtensor. Variable: The output LoDtensor.
Raises: Raises:
TypeError: stride type must be a long. TypeError: blocksize type must be a long.
Examples: Examples:
.. code-block:: python .. code-block:: python
...@@ -7515,13 +7515,13 @@ def space_to_depth(x, stride, name=None): ...@@ -7515,13 +7515,13 @@ def space_to_depth(x, stride, name=None):
data = fluid.layers.data( data = fluid.layers.data(
name='data', shape=[1, 4, 2, 2], dtype='float32') name='data', shape=[1, 4, 2, 2], dtype='float32')
space_to_depthed = fluid.layers.space_to_depth( space_to_depthed = fluid.layers.space_to_depth(
x=data, stride=2) x=data, blocksize=2)
""" """
helper = LayerHelper("space_to_depth", **locals()) helper = LayerHelper("space_to_depth", **locals())
if not (isinstance(stride, int)): if not (isinstance(blocksize, int)):
raise ValueError("stride must be a python Int") raise ValueError("blocksize must be a python Int")
if name is None: if name is None:
out = helper.create_variable_for_type_inference( out = helper.create_variable_for_type_inference(
...@@ -7533,7 +7533,7 @@ def space_to_depth(x, stride, name=None): ...@@ -7533,7 +7533,7 @@ def space_to_depth(x, stride, name=None):
helper.append_op( helper.append_op(
type="space_to_depth", type="space_to_depth",
inputs={"X": x}, inputs={"X": x},
attrs={"stride": stride}, attrs={"blocksize": blocksize},
outputs={"Out": out}) outputs={"Out": out})
return out return out
......
...@@ -21,8 +21,8 @@ from op_test import OpTest ...@@ -21,8 +21,8 @@ from op_test import OpTest
class TestSpaceToDepthOp(OpTest): class TestSpaceToDepthOp(OpTest):
@staticmethod @staticmethod
def helper(in_, width, height, channel, batch, stride, forward, out_): def helper(in_, width, height, channel, batch, blocksize, forward, out_):
channel_out = channel // (stride * stride) channel_out = channel // (blocksize * blocksize)
for b in range(batch): for b in range(batch):
for k in range(channel): for k in range(channel):
for j in range(height): for j in range(height):
...@@ -30,10 +30,10 @@ class TestSpaceToDepthOp(OpTest): ...@@ -30,10 +30,10 @@ class TestSpaceToDepthOp(OpTest):
in_index = i + width * (j + height * (k + channel * b)) in_index = i + width * (j + height * (k + channel * b))
channel2 = k % channel_out channel2 = k % channel_out
offset = k // channel_out offset = k // channel_out
width2 = i * stride + offset % stride width2 = i * blocksize + offset % blocksize
height2 = j * stride + offset // stride height2 = j * blocksize + offset // blocksize
out_index = width2 + width * stride * ( out_index = width2 + width * blocksize * (
height2 + height * stride * height2 + height * blocksize *
(channel2 + channel_out * b)) (channel2 + channel_out * b))
if forward: if forward:
out_[out_index] = in_[in_index] out_[out_index] = in_[in_index]
...@@ -46,10 +46,10 @@ class TestSpaceToDepthOp(OpTest): ...@@ -46,10 +46,10 @@ class TestSpaceToDepthOp(OpTest):
self.op_type = "space_to_depth" self.op_type = "space_to_depth"
self.inputs = {"X": self.x} self.inputs = {"X": self.x}
self.helper(self.x_1d, self.x.shape[3], self.x.shape[2], self.helper(self.x_1d, self.x.shape[3], self.x.shape[2],
self.x.shape[1], self.x.shape[0], self.stride, self.forward, self.x.shape[1], self.x.shape[0], self.blocksize,
self.out_1d) self.forward, self.out_1d)
self.out = np.reshape(self.out_1d, self.infered_shape) self.out = np.reshape(self.out_1d, self.infered_shape)
self.attrs = {"stride": self.stride} self.attrs = {"blocksize": self.blocksize}
self.outputs = {"Out": self.out} self.outputs = {"Out": self.out}
def init_data(self): def init_data(self):
...@@ -57,7 +57,7 @@ class TestSpaceToDepthOp(OpTest): ...@@ -57,7 +57,7 @@ class TestSpaceToDepthOp(OpTest):
self.infered_shape = (32, 48, 3, 3) self.infered_shape = (32, 48, 3, 3)
self.one_d_len = 32 * 48 * 3 * 3 self.one_d_len = 32 * 48 * 3 * 3
self.stride = 2 self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float32') self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len) self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32') self.out = np.zeros(self.infered_shape).astype('float32')
...@@ -81,7 +81,7 @@ class TestSpaceToDepthOpBasic(TestSpaceToDepthOp): ...@@ -81,7 +81,7 @@ class TestSpaceToDepthOpBasic(TestSpaceToDepthOp):
self.infered_shape = (32, 32, 3, 3) self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3 self.one_d_len = 32 * 32 * 3 * 3
self.stride = 2 self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float32') self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len) self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32') self.out = np.zeros(self.infered_shape).astype('float32')
...@@ -95,7 +95,7 @@ class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp): ...@@ -95,7 +95,7 @@ class TestSpaceToDepthOpDoubleBasic(TestSpaceToDepthOp):
self.infered_shape = (32, 32, 3, 3) self.infered_shape = (32, 32, 3, 3)
self.one_d_len = 32 * 32 * 3 * 3 self.one_d_len = 32 * 32 * 3 * 3
self.stride = 2 self.blocksize = 2
self.x = np.random.random(self.ori_shape).astype('float64') self.x = np.random.random(self.ori_shape).astype('float64')
self.x_1d = np.reshape(self.x, self.one_d_len) self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float64') self.out = np.zeros(self.infered_shape).astype('float64')
...@@ -109,7 +109,7 @@ class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp): ...@@ -109,7 +109,7 @@ class TestSpaceToDepthOpWithStride3(TestSpaceToDepthOp):
self.infered_shape = (32, 81, 2, 2) self.infered_shape = (32, 81, 2, 2)
self.one_d_len = 32 * 81 * 2 * 2 self.one_d_len = 32 * 81 * 2 * 2
self.stride = 3 self.blocksize = 3
self.x = np.random.random(self.ori_shape).astype('float32') self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len) self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32') self.out = np.zeros(self.infered_shape).astype('float32')
...@@ -123,7 +123,7 @@ class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp): ...@@ -123,7 +123,7 @@ class TestSpaceToDepthOpWithNotSquare(TestSpaceToDepthOp):
self.infered_shape = (32, 81, 3, 2) self.infered_shape = (32, 81, 3, 2)
self.one_d_len = 32 * 81 * 3 * 2 self.one_d_len = 32 * 81 * 3 * 2
self.stride = 3 self.blocksize = 3
self.x = np.random.random(self.ori_shape).astype('float32') self.x = np.random.random(self.ori_shape).astype('float32')
self.x_1d = np.reshape(self.x, self.one_d_len) self.x_1d = np.reshape(self.x, self.one_d_len)
self.out = np.zeros(self.infered_shape).astype('float32') self.out = np.zeros(self.infered_shape).astype('float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册