未验证 提交 871af28d 编写于 作者: C cjt222 提交者: GitHub

add deformable psroi pooling (#17827)

* add deformable psroi pooling

* test=develop

* test=develop

* test=develop
modify format

* fix bug

* test=develop run ci

* test=develop
add API.spec

* add test_layers.py

* run ci again

* test=develop
run ci again

* run ci again

* test=develop
run ci again

* test=develop
run ci again

* test=develop
run ci again

* add space between two lines

* test=develop
add space between two lines

* test=develop
add space between lines

* test=develop
modify comment in nn.py

* test=develop
add space between two lines

* test=develop
add space between two lines

* update API.spec

* run ci again

* test=develop
run ci again

* rerun ci

* test=develop
rerun ci

* change input shape

* run ci

* test=develop
run ci

* modify format of nn.py

* test=develop

* test=develop

* test=develop
update API.spec

* test=develop
fix API doc

* modify API comment

* modift API comment

* test=develop
update API.spec

* test=develop
modify comment

* test=develop
modift comment

* test=develop
modift comment

* test=develop
update API.spec

* test=develop
modify comment

* test=develop
add inference in nn.py

* test=develop
update API.spec

* test=develop
resolve confict

* test=develop
update API.spec
上级 40885c22
......@@ -239,6 +239,7 @@ paddle.fluid.layers.where (ArgSpec(args=['condition'], varargs=None, keywords=No
paddle.fluid.layers.sign (ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None), ('document', 'ccf6bb7912afd2818d24bc45461e807a'))
paddle.fluid.layers.deformable_conv (ArgSpec(args=['input', 'offset', 'mask', 'num_filters', 'filter_size', 'stride', 'padding', 'dilation', 'groups', 'deformable_groups', 'im2col_step', 'param_attr', 'bias_attr', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None, None, None, None, None, None)), ('document', 'c896b66265a60bd3c5510f66e6e02919'))
paddle.fluid.layers.unfold (ArgSpec(args=['x', 'kernel_sizes', 'strides', 'paddings', 'dilations', 'name'], varargs=None, keywords=None, defaults=(1, 0, 1, None)), ('document', '3f884662ad443d9ecc2b3734b4f61ad6'))
paddle.fluid.layers.deformable_roi_pooling (ArgSpec(args=['input', 'rois', 'trans', 'no_trans', 'spatial_scale', 'group_size', 'pooled_height', 'pooled_width', 'part_size', 'sample_per_part', 'trans_std', 'position_sensitive', 'name'], varargs=None, keywords=None, defaults=(False, 1.0, [1, 1], 1, 1, None, 1, 0.1, False, None)), ('document', '65b8dbe13e00c4dc8224652f6ff89540'))
paddle.fluid.layers.data (ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)), ('document', '6e19128b46936edf9f3fad77860a1da8'))
paddle.fluid.layers.open_files (ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)), ('document', 'dce69a78638da8f7ad80b1fc00ed2029'))
paddle.fluid.layers.read_file (ArgSpec(args=['reader'], varargs=None, keywords=None, defaults=None), ('document', '32181f6037e387fb6e68a5beaafe33b6'))
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/deformable_psroi_pooling_op.h"
#include <iostream>
#include <memory>
#include <vector>
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
namespace operators {
class DeformablePSROIPoolOpMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("Input",
"(Tensor), "
"the input of Deformable PSROIPooling. "
"The shape of input tensor is [N,C,H,W]. Where N is batch size, "
"C is number of input channels, "
"H is height of the feature, and "
"W is the width of the feature.");
AddInput("ROIs",
"(LoDTensor), "
"ROIs (Regions of Interest) to pool over. "
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4) "
"given as [[x1, y1, x2, y2], ...]. "
"(x1, y1) is the top left coordinates, and "
"(x2, y2) is the bottom right coordinates.");
AddInput("Trans",
"(Tensor),"
"offset of features on ROIs while pooling. "
"The format is NCHW, where N is number of ROIs, "
"C is number of channels, which indicate the offset distance "
"in the x and y directions, "
"H is pooled height, and "
"W is pooled width.");
AddAttr<bool>("no_trans",
"(bool), "
"whether add offset to get new value or not while roi "
"pooling, which value is True or False");
AddAttr<float>("spatial_scale",
"(float), "
"ratio of input feature map height (or width) to "
"raw image height (or width). Equals the reciprocal "
"of total stride in convolutional layers.");
AddAttr<int>("output_dim",
"(int), "
"the number of output channels, which should be less than "
"input channels. Deformable roi_pooling requires "
"output_channels = input_channels, while deformable "
"psroi_pooling requires output_channels = input_channels "
"* pooled_height * pooled_width");
AddAttr<std::vector<int>>(
"group_size",
"(vector<int>), "
"the number of groups which input channels are divided."
"(eg.number of input channels is k1*k2*(C+1), which k1 and k2 "
"are group width and height and C+1 is number of output "
"chanels. eg.(4, 6), which 4 is height of group and 6 is "
"width of group");
AddAttr<int>("pooled_height",
"(int), "
"the pooled output height.");
AddAttr<int>("pooled_width",
"(int), "
"the pooled output width.");
AddAttr<std::vector<int>>(
"part_size",
"(vector<int>), "
"the height and width of offset, eg.(4, 6), which height is 4 "
" and width is 6");
AddAttr<int>("sample_per_part",
"(int), "
"the number of samples in each bin");
AddAttr<float>("trans_std",
"(float), "
"Coefficient of offset");
AddOutput("TopCount",
"(Tensor), "
"record the number of pixel in average pooling to in each bin. "
"The format is NCHW, where N is the number of ROIs, "
"C is the number of output channels, "
"H is the height of output, and "
"W is the width of output.");
AddOutput("Output",
"(Tensor), "
"the output of Deformable PSROIPooling. "
"The format is NCHW, where N is the number of ROIs, "
"C is the number of output channels, "
"H is the height of output, and "
"W is thewidth of output. ");
AddComment(R"DOC(
**DeformablePSROIPooling Operator**
DeformablePSROIPooling is a new method based Region of interest pooling
(also known as RoI pooling).
The operator has four steps:
1. Dividing each region proposal into equal-sized sections with
the pooled_width and pooled_height.
2. Add offset to pixel in ROI to get new location and the new value which are
computed directly through bilinear interpolation with four nearest pixel.
3. Sample several points to get average values in each bin.
4. Copying these average values to the output buffer.
DeformablePSROIPooling is part of Deformable Convolutional Networks,
please refer to https://arxiv.org/abs/1703.06211 for more details.
)DOC");
}
};
class DeformablePSROIPoolOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("Input"),
"Input(Input) of DeformablePSROIPoolOp"
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("ROIs"),
"Input(ROIs) of DeformablePSROIPoolOp "
"should not be null.");
PADDLE_ENFORCE(ctx->HasInput("Trans"),
"Input(Trans) of DeformablePSROIPoolOp "
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Output"),
"Output(Output) of DeformablePSROIPoolOp "
"should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("TopCount"),
"Output(TopCount) of DeformablePSROIPoolOp "
"should not be null.");
auto input_dims = ctx->GetInputDim("Input");
auto rois_dims = ctx->GetInputDim("ROIs");
auto trans_dims = ctx->GetInputDim("Trans");
PADDLE_ENFORCE(rois_dims.size() == 2,
"ROIs should be a 2-D LoDTensor of shape (num_rois, 4)"
"given as [[ x1, y1, x2, y2], ...].");
PADDLE_ENFORCE(trans_dims.size() == 4,
"The format of Input Trans is (N, 2, H, W).");
auto pooled_height = ctx->Attrs().Get<int>("pooled_height");
auto pooled_width = ctx->Attrs().Get<int>("pooled_width");
auto spatial_scale = ctx->Attrs().Get<float>("spatial_scale");
auto output_channels = ctx->Attrs().Get<int>("output_dim");
auto group_size = ctx->Attrs().Get<std::vector<int>>("group_size");
auto group_height = group_size[0];
auto group_width = group_size[1];
auto part_size = ctx->Attrs().Get<std::vector<int>>("part_size");
auto part_height = part_size[0];
auto part_width = part_size[1];
auto sample_per_part = ctx->Attrs().Get<int>("sample_per_part");
auto trans_std = ctx->Attrs().Get<float>("trans_std");
PADDLE_ENFORCE(trans_std >= 0.0f, "trans_std must greater than 0.0");
PADDLE_ENFORCE(input_dims[1] >= output_channels,
"input channels must greater than out_channels");
PADDLE_ENFORCE_GT(pooled_height, 0,
"The pooled height must greater than 0");
PADDLE_ENFORCE_GT(pooled_width, 0, "The pooled width must greater than 0");
PADDLE_ENFORCE_GT(spatial_scale, 0.0f,
"The spatial scale must greater than 0");
PADDLE_ENFORCE_EQ(group_size.size(), 2,
"The size of group_size should be 2.");
PADDLE_ENFORCE_GT(group_height, 0,
"The group_height in group_size must greater than 0");
PADDLE_ENFORCE_GT(group_width, 0,
"The group_width in group_size must greater than 0");
PADDLE_ENFORCE_EQ(part_size.size(), 2,
"The size of part_size should be 2.");
PADDLE_ENFORCE_GT(part_height, 0,
"The part_height in part_size must greater than 0");
PADDLE_ENFORCE_GT(part_width, 0,
"The part_width in part_size must greater than 0");
PADDLE_ENFORCE(part_height <= trans_dims[2],
"The height of trans must greater than part_height");
PADDLE_ENFORCE(part_width <= trans_dims[3],
"The width of trans must greater than part_width");
PADDLE_ENFORCE_GT(sample_per_part, 0,
"The sample_per_part must greater than 0");
auto out_dims = input_dims;
out_dims[0] = rois_dims[0];
out_dims[1] = output_channels;
out_dims[2] = pooled_height;
out_dims[3] = pooled_width;
ctx->SetOutputDim("Output", out_dims);
ctx->SetOutputDim("TopCount", out_dims);
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Input")->type(),
ctx.device_context());
}
};
class DeformablePSROIPoolGradOpDescMaker
: public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("deformable_psroi_pooling_grad");
op->SetInput("Input", Input("Input"));
op->SetInput("Trans", Input("Trans"));
op->SetInput("ROIs", Input("ROIs"));
op->SetInput("TopCount", Output("TopCount"));
op->SetInput(framework::GradVarName("Output"), OutputGrad("Output"));
op->SetOutput(framework::GradVarName("Input"), InputGrad("Input"));
op->SetOutput(framework::GradVarName("Trans"), InputGrad("Trans"));
op->SetAttrMap(Attrs());
return op;
}
};
class DeformablePSROIPoolGradOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Output")),
"The gradient of Output should not be null.");
if (ctx->HasOutput(framework::GradVarName("Input"))) {
ctx->SetOutputDim(framework::GradVarName("Input"),
ctx->GetInputDim("Input"));
}
if (ctx->HasOutput(framework::GradVarName("Trans"))) {
ctx->SetOutputDim(framework::GradVarName("Trans"),
ctx->GetInputDim("Trans"));
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("Trans")->type(),
ctx.device_context());
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
using CPU = paddle::platform::CPUDeviceContext;
REGISTER_OPERATOR(deformable_psroi_pooling, ops::DeformablePSROIPoolOp,
ops::DeformablePSROIPoolOpMaker,
ops::DeformablePSROIPoolGradOpDescMaker);
REGISTER_OPERATOR(deformable_psroi_pooling_grad,
ops::DeformablePSROIPoolGradOp);
REGISTER_OP_CPU_KERNEL(deformable_psroi_pooling,
ops::DeformablePSROIPoolCPUKernel<CPU, float>,
ops::DeformablePSROIPoolCPUKernel<CPU, double>);
REGISTER_OP_CPU_KERNEL(deformable_psroi_pooling_grad,
ops::DeformablePSROIPoolGradCPUKernel<CPU, float>,
ops::DeformablePSROIPoolGradCPUKernel<CPU, double>);
此差异已折叠。
此差异已折叠。
......@@ -204,6 +204,7 @@ __all__ = [
'sign',
'deformable_conv',
'unfold',
'deformable_roi_pooling',
]
kIgnoreIndex = -100
......@@ -12168,3 +12169,117 @@ def unfold(x, kernel_sizes, strides=1, paddings=0, dilations=1, name=None):
"dilations": dilations
})
return out
def deformable_roi_pooling(input,
rois,
trans,
no_trans=False,
spatial_scale=1.0,
group_size=[1, 1],
pooled_height=1,
pooled_width=1,
part_size=None,
sample_per_part=1,
trans_std=0.1,
position_sensitive=False,
name=None):
"""
Deformable PSROI Pooling Layer
Args:
input (Variable):The input of Deformable PSROIPooling.The shape of input tensor is
[N,C,H,W]. Where N is batch size,C is number of input channels,H
is height of the feature, and W is the width of the feature.
rois (Variable): ROIs (Regions of Interest) to pool over.It should be
a 2-D LoDTensor of shape (num_rois, 4), the lod level
is 1. Given as [[x1, y1, x2, y2], ...], (x1, y1) is
the top left coordinates, and (x2, y2) is the bottom
right coordinates.
trans (Variable): Offset of features on ROIs while pooling.The format is NCHW, where
N is number of ROIs, C is number of channels, which indicate the offset distance
in the x and y directions, H is pooled height, and W is pooled width.
no_trans (bool): Whether to add offset to get new value or not while roi pooling, which
value is True or False. Default: False.
spatial_scale (float): Ratio of input feature map height (or width) to raw image height (or width).
Equals the reciprocal of total stride in convolutional layers, Default: 1.0.
group_size (list|tuple): The number of groups which input channels are divided.(eg.number of input channels
is k1*k2*(C+1), which k1 and k2 are group width and height and C+1 is number of output
chanels. eg.(4, 6), which 4 is height of group and 6 is width of group. Default: [1, 1].
pooled_height (integer): The pooled output height. Default: 1.
pooled_width (integer): The pooled output width. Default: 1.
part_size (list|tuple): The height and width of offset, eg.(4, 6), which height is 4 and width is 6, Default:
if None, default value is [pooled_height, pooled_width].
sample_per_part (integer): The number of samples in each bin. Default: 1.
trans_std (float): Coefficient of offset. Default: 0.1.
position_sensitive (bool): Whether to choose deformable psroi pooling mode or not. Default: False.
name (str): Name of layer. Default: None.
Returns:
Variable: The tensor variable storing the deformable psroi pooling \
result.
Examples:
.. code-block:: python
input = fluid.layers.data(name="input",
shape=[2, 192, 64, 64],
dtype='float32',
append_batch_size=False)
rois = fluid.layers.data(name="rois",
shape=[4],
dtype='float32',
lod_level=1)
trans = fluid.layers.data(name="trans",
shape=[2, 384, 64, 64],
dtype='float32',
append_batch_size=False)
x = fluid.layers.nn.deformable_roi_pooling(input=input,
rois=rois,
trans=trans,
no_trans=False,
spatial_scale=1.0,
group_size=(1, 1),
pooled_height=8,
pooled_width=8,
part_size=(8, 8),
sample_per_part=4,
trans_std=0.1,
position_sensitive=False)
"""
input_channels = input.shape[1]
if position_sensitive == False:
output_channels = input_channels
else:
output_channels = input_channels / pooled_height / pooled_width
if part_size is None:
part_height = pooled_height
part_width = pooled_width
part_size = [part_height, part_width]
part_size = utils.convert_to_list(part_size, 2, 'part_size')
group_size = utils.convert_to_list(group_size, 2, 'group_size')
helper = LayerHelper('deformable_psroi_pooling', **locals())
dtype = helper.input_dtype()
output = helper.create_variable_for_type_inference(dtype)
top_count = helper.create_variable_for_type_inference(dtype='int32')
helper.append_op(
type="deformable_psroi_pooling",
inputs={"Input": input,
"ROIs": rois,
"Trans": trans},
outputs={"Output": output,
"TopCount": top_count},
attrs={
"no_trans": no_trans,
"spatial_scale": spatial_scale,
"output_dim": output_channels,
"group_size": group_size,
"pooled_height": pooled_height,
"pooled_width": pooled_width,
"part_size": part_size,
"sample_per_part": sample_per_part,
"trans_std": trans_std
})
return output
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
from op_test import OpTest
def set_input(input, rois, trans):
inputs = {'Input': input, "ROIs": rois, "Trans": trans}
return inputs
def set_attrs(no_trans, spatial_scale, output_channels, group_size,
pooled_height, pooled_width, part_size, sample_per_part,
trans_std):
attrs = {
'no_trans': no_trans,
'spatial_scale': spatial_scale,
'output_dim': output_channels,
'group_size': group_size,
'pooled_height': pooled_height,
'pooled_width': pooled_width,
'part_size': part_size,
'sample_per_part': sample_per_part,
'trans_std': trans_std
}
return attrs
def set_outputs(output, top_count):
outputs = {
'Output': output.astype('float32'),
'TopCount': top_count.astype('float32')
}
return outputs
class TestDeformablePSROIPoolOp(OpTest):
def set_data(self):
self.start_test1()
self.start_test2()
self.start_test3()
self.start_test4()
def start_test1(self):
self.init_test_case1()
self.make_rois()
self.calc_deformable_psroi_pooling()
inputs = self.input
rois = (self.rois[:, 1:5], self.rois_lod)
trans = self.trans
self.inputs = set_input(inputs, rois, trans)
no_trans = self.no_trans
spatial_scale = self.spatial_scale
output_channels = self.output_channels
group_size = self.group_size
pooled_height = self.pooled_height
pooled_width = self.pooled_width
part_size = self.part_size
sample_per_part = self.sample_per_part
trans_std = self.trans_std
self.attrs = set_attrs(no_trans, spatial_scale, output_channels,
group_size, pooled_height, pooled_width,
part_size, sample_per_part, trans_std)
output = self.out.astype('float32')
top_count = self.top_count.astype('float32')
self.outputs = set_outputs(output, top_count)
def start_test2(self):
self.init_test_case2()
self.make_rois()
self.calc_deformable_psroi_pooling()
inputs = self.input
rois = (self.rois[:, 1:5], self.rois_lod)
trans = self.trans
self.inputs = set_input(inputs, rois, trans)
no_trans = self.no_trans
spatial_scale = self.spatial_scale
output_channels = self.output_channels
group_size = self.group_size
pooled_height = self.pooled_height
pooled_width = self.pooled_width
part_size = self.part_size
sample_per_part = self.sample_per_part
trans_std = self.trans_std
self.attrs = set_attrs(no_trans, spatial_scale, output_channels,
group_size, pooled_height, pooled_width,
part_size, sample_per_part, trans_std)
output = self.out.astype('float32')
top_count = self.top_count.astype('float32')
self.outputs = set_outputs(output, top_count)
def start_test3(self):
self.init_test_case3()
self.make_rois()
self.calc_deformable_psroi_pooling()
inputs = self.input
rois = (self.rois[:, 1:5], self.rois_lod)
trans = self.trans
self.inputs = set_input(inputs, rois, trans)
no_trans = self.no_trans
spatial_scale = self.spatial_scale
output_channels = self.output_channels
group_size = self.group_size
pooled_height = self.pooled_height
pooled_width = self.pooled_width
part_size = self.part_size
sample_per_part = self.sample_per_part
trans_std = self.trans_std
self.attrs = set_attrs(no_trans, spatial_scale, output_channels,
group_size, pooled_height, pooled_width,
part_size, sample_per_part, trans_std)
output = self.out.astype('float32')
top_count = self.top_count.astype('float32')
self.outputs = set_outputs(output, top_count)
def start_test4(self):
self.init_test_case4()
self.make_rois()
self.calc_deformable_psroi_pooling()
inputs = self.input
rois = (self.rois[:, 1:5], self.rois_lod)
trans = self.trans
self.inputs = set_input(inputs, rois, trans)
no_trans = self.no_trans
spatial_scale = self.spatial_scale
output_channels = self.output_channels
group_size = self.group_size
pooled_height = self.pooled_height
pooled_width = self.pooled_width
part_size = self.part_size
sample_per_part = self.sample_per_part
trans_std = self.trans_std
self.attrs = set_attrs(no_trans, spatial_scale, output_channels,
group_size, pooled_height, pooled_width,
part_size, sample_per_part, trans_std)
output = self.out.astype('float32')
top_count = self.top_count.astype('float32')
self.outputs = set_outputs(output, top_count)
def init_test_case1(self):
self.batch_size = 3
self.channels = 3 * 2 * 2
self.height = 12
self.width = 12
self.input_dim = [
self.batch_size, self.channels, self.height, self.width
]
self.no_trans = False
self.spatial_scale = 1.0 / 4.0
self.output_channels = 12
self.group_size = [1, 1]
self.pooled_height = 4
self.pooled_width = 4
self.part_size = [4, 4]
self.sample_per_part = 2
self.trans_std = 0.1
self.input = np.random.random(self.input_dim).astype('float32')
def init_test_case2(self):
self.batch_size = 2
self.channels = 3 * 2 * 2
self.height = 12
self.width = 12
self.input_dim = [
self.batch_size, self.channels, self.height, self.width
]
self.no_trans = True
self.spatial_scale = 1.0 / 2.0
self.output_channels = 12
self.group_size = [1, 1]
self.pooled_height = 7
self.pooled_width = 7
self.part_size = [7, 7]
self.sample_per_part = 4
self.trans_std = 0.1
self.input = np.random.random(self.input_dim).astype('float32')
def init_test_case3(self):
self.batch_size = 2
self.channels = 3 * 2 * 2
self.height = 12
self.width = 12
self.input_dim = [
self.batch_size, self.channels, self.height, self.width
]
self.no_trans = False
self.spatial_scale = 1.0 / 4.0
self.output_channels = 12
self.group_size = [1, 1]
self.pooled_height = 3
self.pooled_width = 3
self.part_size = [3, 3]
self.sample_per_part = 3
self.trans_std = 0.2
self.input = np.random.random(self.input_dim).astype('float32')
def init_test_case4(self):
self.batch_size = 2
self.channels = 3 * 2 * 2
self.height = 12
self.width = 12
self.input_dim = [
self.batch_size, self.channels, self.height, self.width
]
self.no_trans = True
self.spatial_scale = 1.0 / 2.0
self.output_channels = 12
self.group_size = [1, 1]
self.pooled_height = 6
self.pooled_width = 2
self.part_size = [6, 6]
self.sample_per_part = 6
self.trans_std = 0.4
self.input = np.random.random(self.input_dim).astype('float32')
def make_rois(self):
rois = []
self.rois_lod = [[]]
for bno in range(self.batch_size):
self.rois_lod[0].append(bno + 1)
for i in range(bno + 1):
x_1 = np.random.random_integers(
0, self.width // self.spatial_scale - self.pooled_width)
y_1 = np.random.random_integers(
0, self.height // self.spatial_scale - self.pooled_height)
x_2 = np.random.random_integers(
x_1 + self.pooled_width, self.width // self.spatial_scale)
y_2 = np.random.random_integers(
y_1 + self.pooled_height, self.height // self.spatial_scale)
roi = [bno, x_1, y_1, x_2, y_2]
rois.append(roi)
self.rois_num = len(rois)
self.rois = np.array(rois).astype("float32")
def dmc_bilinear(self, data_im, p_h, p_w):
h_low = int(np.floor(p_h))
w_low = int(np.floor(p_w))
h_high = h_low + 1
w_high = w_low + 1
l_h = p_h - h_low
l_w = p_w - w_low
h_h = 1 - l_h
h_w = 1 - l_w
v_1 = 0
if h_low >= 0 and w_low >= 0:
v_1 = data_im[h_low, w_low]
v_2 = 0
if h_low >= 0 and w_high <= self.width - 1:
v_2 = data_im[h_low, w_high]
v_3 = 0
if h_high <= self.height - 1 and w_low >= 0:
v_3 = data_im[h_high, w_low]
v_4 = 0
if h_high <= self.height - 1 and w_high <= self.width - 1:
v_4 = data_im[h_high, w_high]
w_1, w_2, w_3, w_4 = h_h * h_w, h_h * l_w, l_h * h_w, l_h * l_w
val = w_1 * v_1 + w_2 * v_2 + w_3 * v_3 + w_4 * v_4
return val
def calc_deformable_psroi_pooling(self):
output_shape = (self.rois_num, self.output_channels, self.pooled_height,
self.pooled_width)
self.out = np.zeros(output_shape)
self.trans = np.random.rand(self.rois_num, 2, self.part_size[0],
self.part_size[1]).astype('float32')
self.top_count = np.random.random((output_shape)).astype('float32')
count = self.rois_num * self.output_channels * self.pooled_height * self.pooled_width
for index in range(count):
p_w = int(index % self.pooled_width)
p_h = int(index / self.pooled_width % self.pooled_height)
ctop = int(index / self.pooled_width / self.pooled_height %
self.output_channels)
n_out = int(index / self.pooled_width / self.pooled_height /
self.output_channels)
roi = self.rois[n_out]
roi_batch_id = int(roi[0])
roi_start_w = int(np.round(roi[1])) * self.spatial_scale - 0.5
roi_start_h = int(np.round(roi[2])) * self.spatial_scale - 0.5
roi_end_w = int(np.round(roi[3] + 1)) * self.spatial_scale - 0.5
roi_end_h = int(np.round(roi[4] + 1)) * self.spatial_scale - 0.5
roi_width = max(roi_end_w - roi_start_w, 0.1)
roi_height = max(roi_end_h - roi_start_h, 0.1)
bin_size_h = float(roi_height) / float(self.pooled_height)
bin_size_w = float(roi_width) / float(self.pooled_width)
sub_bin_size_h = bin_size_h / self.sample_per_part
sub_bin_size_w = bin_size_w / self.sample_per_part
part_h = int(np.floor(p_h) / self.pooled_height * self.part_size[0])
part_w = int(np.floor(p_w) / self.pooled_width * self.part_size[1])
if self.no_trans:
trans_x = 0
trans_y = 0
else:
trans_x = self.trans[n_out][0][part_h][part_w] * self.trans_std
trans_y = self.trans[n_out][1][part_h][part_w] * self.trans_std
wstart = p_w * bin_size_w + roi_start_w
wstart = wstart + trans_x * roi_width
hstart = p_h * bin_size_h + roi_start_h
hstart = hstart + trans_y * roi_height
sum = 0
num_sample = 0
g_w = np.floor(p_w * self.group_size[0] / self.pooled_height)
g_h = np.floor(p_h * self.group_size[1] / self.pooled_width)
g_w = min(max(g_w, 0), self.group_size[0] - 1)
g_h = min(max(g_h, 0), self.group_size[1] - 1)
input_i = self.input[roi_batch_id]
for i_w in range(self.sample_per_part):
for i_h in range(self.sample_per_part):
w_sample = wstart + i_w * sub_bin_size_w
h_sample = hstart + i_h * sub_bin_size_h
if w_sample < -0.5 or w_sample > self.width - 0.5 or \
h_sample < -0.5 or h_sample > self.height - 0.5:
continue
w_sample = min(max(w_sample, 0.), self.width - 1.)
h_sample = min(max(h_sample, 0.), self.height - 1.)
c_sample = int((ctop * self.group_size[0] + g_h) *
self.group_size[1] + g_w)
val = self.dmc_bilinear(input_i[c_sample], h_sample,
w_sample)
sum = sum + val
num_sample = num_sample + 1
if num_sample == 0:
self.out[n_out][ctop][p_h][p_w] = 0
else:
self.out[n_out][ctop][p_h][p_w] = sum / num_sample
self.top_count[n_out][ctop][p_h][p_w] = num_sample
def setUp(self):
self.op_type = "deformable_psroi_pooling"
self.set_data()
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['Input'], 'Output')
if __name__ == '__main__':
unittest.main()
......@@ -1995,6 +1995,35 @@ class TestBook(LayerTest):
out = layers.unfold(x, [3, 3], 1, 1, 1)
return (out)
def test_deform_roi_pooling(self):
with program_guard(fluid.default_main_program(),
fluid.default_startup_program()):
input = layers.data(
name='input',
shape=[2, 3, 32, 32],
dtype='float32',
append_batch_size=False)
rois = layers.data(
name="rois", shape=[4], dtype='float32', lod_level=1)
trans = layers.data(
name="trans",
shape=[2, 3, 32, 32],
dtype='float32',
append_batch_size=False)
out = layers.deformable_roi_pooling(
input=input,
rois=rois,
trans=trans,
no_trans=False,
spatial_scale=1.0,
group_size=(1, 1),
pooled_height=8,
pooled_width=8,
part_size=(8, 8),
sample_per_part=4,
trans_std=0.1)
return (out)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册