提交 5ece5c96 编写于 作者: Q qijun

add python wrap for sppLayer

上级 b282caf4
...@@ -46,6 +46,12 @@ conv_operator ...@@ -46,6 +46,12 @@ conv_operator
:members: conv_operator :members: conv_operator
:noindex: :noindex:
conv_projection
-------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: conv_projection
:noindex:
conv_shift_layer conv_shift_layer
------------------ ------------------
.. automodule:: paddle.trainer_config_helpers.layers .. automodule:: paddle.trainer_config_helpers.layers
...@@ -73,6 +79,12 @@ img_pool_layer ...@@ -73,6 +79,12 @@ img_pool_layer
:members: img_pool_layer :members: img_pool_layer
:noindex: :noindex:
spp_layer
--------------
.. automodule:: paddle.trainer_config_helpers.layers
:members: spp_layer
:noindex:
maxout_layer maxout_layer
------------ ------------
.. automodule:: paddle.trainer_config_helpers.layers .. automodule:: paddle.trainer_config_helpers.layers
......
...@@ -16,7 +16,7 @@ limitations under the License. */ ...@@ -16,7 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
REGISTER_PROJECTION_CREATE_FUNC(pool2, &PoolProjection::create); REGISTER_PROJECTION_CREATE_FUNC(pool, &PoolProjection::create);
PoolProjection* PoolProjection::create(const ProjectionConfig& config, PoolProjection* PoolProjection::create(const ProjectionConfig& config,
ParameterPtr parameter, bool useGpu) { ParameterPtr parameter, bool useGpu) {
......
...@@ -24,7 +24,7 @@ ProjectionConfig SpatialPyramidPoolLayer::getConfig(size_t imgSizeW, ...@@ -24,7 +24,7 @@ ProjectionConfig SpatialPyramidPoolLayer::getConfig(size_t imgSizeW,
size_t pyramidLevel, size_t pyramidLevel,
std::string& poolType) { std::string& poolType) {
ProjectionConfig config; ProjectionConfig config;
config.set_type("pool2"); config.set_type("pool");
PoolConfig* conf = config.mutable_pool_conf(); PoolConfig* conf = config.mutable_pool_conf();
conf->set_channels(channels); conf->set_channels(channels);
conf->set_img_size(imgSizeW); conf->set_img_size(imgSizeW);
...@@ -93,7 +93,7 @@ bool SpatialPyramidPoolLayer::init(const LayerMap& layerMap, ...@@ -93,7 +93,7 @@ bool SpatialPyramidPoolLayer::init(const LayerMap& layerMap,
startCol = endCol; startCol = endCol;
projInput_.emplace_back(Argument()); projInput_.emplace_back(Argument());
} }
outputSize_ = endCol; CHECK_EQ(endCol, getSize());
return true; return true;
} }
...@@ -101,7 +101,7 @@ void SpatialPyramidPoolLayer::forward(PassType passType) { ...@@ -101,7 +101,7 @@ void SpatialPyramidPoolLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
int batchSize = getInput(0).getBatchSize(); int batchSize = getInput(0).getBatchSize();
resetOutput(batchSize, outputSize_); resetOutput(batchSize, getSize());
for (size_t i = 0; i < pyramidHeight_; i++) { for (size_t i = 0; i < pyramidHeight_; i++) {
size_t startCol = projCol_[i].first; size_t startCol = projCol_[i].first;
size_t endCol = projCol_[i].second; size_t endCol = projCol_[i].second;
......
...@@ -27,7 +27,6 @@ protected: ...@@ -27,7 +27,6 @@ protected:
size_t imgSizeW_; size_t imgSizeW_;
size_t imgSizeH_; size_t imgSizeH_;
size_t pyramidHeight_; size_t pyramidHeight_;
size_t outputSize_;
std::string poolType_; std::string poolType_;
std::vector<std::unique_ptr<PoolProjection>> poolProjections_; std::vector<std::unique_ptr<PoolProjection>> poolProjections_;
......
...@@ -931,6 +931,8 @@ void testSppLayer(const string& poolType, const int pyramidHeight, bool trans, ...@@ -931,6 +931,8 @@ void testSppLayer(const string& poolType, const int pyramidHeight, bool trans,
sppConfig->set_channels(16); sppConfig->set_channels(16);
sppConfig->set_img_size(10); sppConfig->set_img_size(10);
sppConfig->set_img_size_y(20); sppConfig->set_img_size_y(20);
int outputSize = (std::pow(4, sppConfig->pyramid_height()) - 1) / (4 - 1);
config.layerConfig.set_size(outputSize * sppConfig->channels());
testLayerGrad(config, "spp", 100, trans, useGpu); testLayerGrad(config, "spp", 100, trans, useGpu);
} }
......
...@@ -1510,18 +1510,19 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH, ...@@ -1510,18 +1510,19 @@ void CpuMatrix::maxPoolForward(Matrix& inputMat, size_t imgSizeH,
CHECK(inHeight * inWidth == inputMat.getWidth() / channels); CHECK(inHeight * inWidth == inputMat.getWidth() / channels);
CHECK_EQ(num, this->getHeight()); CHECK_EQ(num, this->getHeight());
CHECK_EQ(channels * outputH * outputW, this->getWidth()); CHECK_EQ(channels * outputH * outputW, this->getWidth());
size_t outStride = getStride();
/* initialize the data_ */ /* initialize the data_ */
for (size_t i = 0; i < height_; i++) { for (size_t i = 0; i < height_; i++) {
for (size_t j = 0; j < width_; j++) { for (size_t j = 0; j < width_; j++) {
outData[i * getStride() + j] = -(real)FLT_MAX; outData[i * outStride + j] = -(real)FLT_MAX;
} }
} }
/* pool max one by one */ /* pool max one by one */
for (size_t n = 0; n < num; ++n) { // frame by frame for (size_t n = 0; n < num; ++n) { // frame by frame
if (!isContiguous()) { if (!isContiguous()) {
outData = data_ + n * getStride(); outData = data_ + n * outStride;
} }
for (size_t c = 0; c < channels; ++c) { // channel by channel for (size_t c = 0; c < channels; ++c) { // channel by channel
for (size_t ph = 0; ph < outputH; ++ph) { for (size_t ph = 0; ph < outputH; ++ph) {
...@@ -1564,10 +1565,15 @@ void CpuMatrix::maxPoolBackward(Matrix& image, size_t imgSizeH, size_t imgSizeW, ...@@ -1564,10 +1565,15 @@ void CpuMatrix::maxPoolBackward(Matrix& image, size_t imgSizeH, size_t imgSizeW,
real* inData = image.getData(); real* inData = image.getData();
real* otData = outV.getData(); real* otData = outV.getData();
real* otGrad = outGrad.getData(); real* otGrad = outGrad.getData();
size_t outStride = outV.getStride();
real* origOutData = otData;
real* origOutGrad = otGrad;
for (size_t n = 0; n < num; ++n) { for (size_t n = 0; n < num; ++n) {
if (!outV.isContiguous()) { if (!outV.isContiguous()) {
otData = outV.getData() + n * outV.getStride(); otData = origOutData + n * outStride;
otGrad = outGrad.getData() + n * outGrad.getStride(); otGrad = origOutGrad + n * outStride;
} }
for (size_t c = 0; c < channels; ++c) { for (size_t c = 0; c < channels; ++c) {
for (size_t ph = 0; ph < outputH; ++ph) { for (size_t ph = 0; ph < outputH; ++ph) {
......
...@@ -202,11 +202,11 @@ message ProjectionConfig { ...@@ -202,11 +202,11 @@ message ProjectionConfig {
optional ConvConfig conv_conf = 8; optional ConvConfig conv_conf = 8;
optional int32 num_filters = 9; optional int32 num_filters = 9;
// For pool
optional PoolConfig pool_conf = 10;
// For IdentityOffsetProjection // For IdentityOffsetProjection
optional uint64 offset = 11 [default = 0]; optional uint64 offset = 11 [default = 0];
// For pool
optional PoolConfig pool_conf = 12;
} }
message OperatorConfig { message OperatorConfig {
......
...@@ -470,6 +470,7 @@ class Input(Cfg): ...@@ -470,6 +470,7 @@ class Input(Cfg):
image=None, image=None,
block_expand=None, block_expand=None,
maxout=None, maxout=None,
spp=None,
format=None, format=None,
nnz=None, nnz=None,
is_static=None, is_static=None,
...@@ -669,7 +670,6 @@ class ConvProjection(Projection): ...@@ -669,7 +670,6 @@ class ConvProjection(Projection):
def calc_parameter_dims(self, input_size, output_size): def calc_parameter_dims(self, input_size, output_size):
return None return None
# Define a operator for mixed layer # Define a operator for mixed layer
@config_class @config_class
class Operator(Cfg): class Operator(Cfg):
...@@ -783,6 +783,15 @@ class Pool(Cfg): ...@@ -783,6 +783,15 @@ class Pool(Cfg):
padding_y = None): padding_y = None):
self.add_keys(locals()) self.add_keys(locals())
class SpatialPyramidPool(Cfg):
def __init__(
self,
pool_type,
pyramid_height,
channels,
img_width = None):
self.add_keys(locals())
# please refer to the comments in proto/ModelConfig.proto # please refer to the comments in proto/ModelConfig.proto
@config_class @config_class
class Norm(Cfg): class Norm(Cfg):
...@@ -1043,6 +1052,22 @@ def parse_pool(pool, input_layer_name, pool_conf): ...@@ -1043,6 +1052,22 @@ def parse_pool(pool, input_layer_name, pool_conf):
2*pool_conf.padding_y - pool_conf.size_y) / \ 2*pool_conf.padding_y - pool_conf.size_y) / \
float(pool_conf.stride_y))) + 1 float(pool_conf.stride_y))) + 1
def parse_spp(spp, input_layer_name, spp_conf):
spp_conf.pool_type = spp.pool_type
config_assert(spp.pool_type in ['max-projection', 'avg-projection'],
"pool-type %s is not in " "['max-projection', 'avg-projection']"
% spp.pool_type)
spp_conf.pyramid_height = spp.pyramid_height
spp_conf.channels = spp.channels
img_pixels = g_layer_map[input_layer_name].size / spp_conf.channels
spp_conf.img_size = default(spp.img_width, int(img_pixels ** 0.5))
spp_conf.img_size_y = img_pixels / spp_conf.img_size
config_assert(spp_conf.img_size * spp_conf.img_size_y == img_pixels,
"Incorrect input image size %d for input image pixels %d"
% (spp_conf.img_size, img_pixels))
def parse_image(image, input_layer_name, image_conf): def parse_image(image, input_layer_name, image_conf):
image_conf.channels = image.channels image_conf.channels = image.channels
image_pixels = g_layer_map[input_layer_name].size / image_conf.channels image_pixels = g_layer_map[input_layer_name].size / image_conf.channels
...@@ -1649,6 +1674,25 @@ class PoolLayer(LayerBase): ...@@ -1649,6 +1674,25 @@ class PoolLayer(LayerBase):
name, pool_conf.output_y, pool_conf.output_x)) name, pool_conf.output_y, pool_conf.output_x))
self.set_layer_size((pool_conf.output_x * pool_conf.output_y) * pool_conf.channels) self.set_layer_size((pool_conf.output_x * pool_conf.output_y) * pool_conf.channels)
@config_layer('spp')
class SpatialPyramidPoolLayer(LayerBase):
def __init__(
self,
name,
inputs,
device=None):
super(SpatialPyramidPoolLayer, self).__init__(name, 'spp', 0, inputs=inputs, device=device)
for input_index in xrange(len(self.inputs)):
input_layer = self.get_input_layer(input_index)
parse_spp(
self.inputs[input_index].spp,
input_layer.name,
self.config.inputs[input_index].spp_conf)
spp_conf = self.config.inputs[input_index].spp_conf
output_size = (pow(4, spp_conf.pyramid_height) - 1) / (4 - 1)
print("output size for %s is %d " % (name, output_size))
self.set_layer_size(output_size * spp_conf.channels)
@config_layer('batch_norm') @config_layer('batch_norm')
class BatchNormLayer(LayerBase): class BatchNormLayer(LayerBase):
layer_type = 'batch_norm' layer_type = 'batch_norm'
......
...@@ -55,7 +55,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel", ...@@ -55,7 +55,8 @@ __all__ = ["full_matrix_projection", "AggregateLevel", "ExpandLevel",
'multi_binary_label_cross_entropy', 'multi_binary_label_cross_entropy',
'rank_cost', 'lambda_cost', 'huber_cost', 'rank_cost', 'lambda_cost', 'huber_cost',
'block_expand_layer', 'block_expand_layer',
'maxout_layer', 'out_prod_layer', 'print_layer' 'maxout_layer', 'out_prod_layer', 'print_layer',
'spp_layer',
] ]
...@@ -111,6 +112,7 @@ class LayerType(object): ...@@ -111,6 +112,7 @@ class LayerType(object):
LINEAR_COMBINATION_LAYER = "convex_comb" LINEAR_COMBINATION_LAYER = "convex_comb"
BLOCK_EXPAND = "blockexpand" BLOCK_EXPAND = "blockexpand"
MAXOUT = "maxout" MAXOUT = "maxout"
SPP_LAYER = "spp"
PRINT_LAYER = "print" PRINT_LAYER = "print"
...@@ -868,6 +870,7 @@ def pooling_layer(input, pooling_type=None, name=None, bias_attr=None, ...@@ -868,6 +870,7 @@ def pooling_layer(input, pooling_type=None, name=None, bias_attr=None,
size=input.size) size=input.size)
@wrap_bias_attr_default() @wrap_bias_attr_default()
@wrap_param_attr_default() @wrap_param_attr_default()
@wrap_act_default(param_names=['gate_act'], @wrap_act_default(param_names=['gate_act'],
...@@ -1708,6 +1711,62 @@ def img_pool_layer(input, pool_size, name=None, ...@@ -1708,6 +1711,62 @@ def img_pool_layer(input, pool_size, name=None,
num_filters=num_channels) num_filters=num_channels)
@wrap_name_default("spp")
@layer_support()
def spp_layer(input, name=None, num_channels=None, pool_type=None,
pyramid_height=None, img_width=None, layer_attr=None):
pass
"""
Spatial Pyramid Pooling in Deep Convolutional Networks for Visual Recognition.
The details please refer to
`Kaiming He's paper <https://arxiv.org/abs/1406.4729>`_.
:param name: layer name.
:type name: basestring
:param input: layer's input.
:type input: LayerOutput
:param num_channels: number of input channel.
:type num_channels: int
:param pool_type: Pooling type. MaxPooling or AveragePooling. Default is MaxPooling.
:type scale: BasePoolingType
:param pyramid_height: pyramid height.
:type pyramid_height: int
:param img_width: the width of input feature map. If it is None, the input feature
map should be square.
:type img_width: int|None
:param layer_attr: Extra Layer Attribute.
:type layer_attr: ExtraLayerAttribute
:return: LayerOutput object.
:rtype: LayerOutput
"""
if num_channels is None:
assert input.num_filters is not None
num_channels = input.num_filters
if pool_type is None:
pool_type = MaxPooling()
elif isinstance(pool_type, AvgPooling):
pool_type.name = 'avg'
type_name = pool_type.name
if (isinstance(pool_type, AvgPooling) or isinstance(pool_type, MaxPooling)):
type_name += '-projection'
Layer(
name=name,
type=LayerType.SPP_LAYER,
inputs=Input(input.name,
spp=SpatialPyramidPool(pool_type=type_name,
channels=num_channels,
pyramid_height=pyramid_height,
img_width=img_width)
),
**ExtraLayerAttribute.to_kwargs(layer_attr)
)
return LayerOutput(name, LayerType.SPP_LAYER, parents=[input],
num_filters=num_channels)
def __img_norm_layer__(name, input, size, norm_type, scale, power, def __img_norm_layer__(name, input, size, norm_type, scale, power,
num_channels, blocked, layer_attr): num_channels, blocked, layer_attr):
if num_channels is None: if num_channels is None:
......
...@@ -20,3 +20,4 @@ fded24727338fb8ce44d9951ed8aea08 test_rnn_group.protostr ...@@ -20,3 +20,4 @@ fded24727338fb8ce44d9951ed8aea08 test_rnn_group.protostr
67d6fde3afb54f389d0ce4ff14726fe1 test_sequence_pooling.protostr 67d6fde3afb54f389d0ce4ff14726fe1 test_sequence_pooling.protostr
f586a548ef4350ba1ed47a81859a64cb unused_layers.protostr f586a548ef4350ba1ed47a81859a64cb unused_layers.protostr
f937a5a6e7e8864b4d8cf56b0f7c7f44 util_layers.protostr f937a5a6e7e8864b4d8cf56b0f7c7f44 util_layers.protostr
60c9a71e19bd4b2a1253712799d0ae70 test_spp_layer.protostr
...@@ -9,7 +9,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer ...@@ -9,7 +9,7 @@ test_sequence_pooling test_lstmemory_layer test_grumemory_layer
last_first_seq test_expand_layer test_ntm_layers test_hsigmoid last_first_seq test_expand_layer test_ntm_layers test_hsigmoid
img_layers util_layers simple_rnn_layers unused_layers test_cost_layers img_layers util_layers simple_rnn_layers unused_layers test_cost_layers
test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight test_rnn_group shared_fc shared_lstm test_cost_layers_with_weight
test_maxout test_bi_grumemory) test_maxout test_bi_grumemory test_spp_layer)
for conf in ${configs[*]} for conf in ${configs[*]}
......
from paddle.trainer_config_helpers import *
settings(
batch_size=100,
learning_rate=1e-5
)
data = data_layer(name='data', size=3200)
spp = spp_layer(input=data,
pyramid_height=2,
num_channels=16,
pool_type=MaxPooling(),
img_width=10)
outputs(spp)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册