未验证 提交 9333a627 编写于 作者: B Bai Yifan 提交者: GitHub

Add flatten op interface and enhance APIs about detection to support...

Add flatten op interface and enhance APIs about detection to support variable-length image. (#12422)

* add flatten api&enhance detection api

* unify shape_op data type

* update API.spec
上级 f276006f
...@@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul ...@@ -159,6 +159,7 @@ paddle.fluid.layers.relu ArgSpec(args=['x'], varargs=None, keywords=None, defaul
paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None) paddle.fluid.layers.log ArgSpec(args=['x'], varargs=None, keywords=None, defaults=None)
paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None)) paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs=None, keywords=None, defaults=(None, None, None))
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,)) paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True)) paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True)) paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None)) paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
......
...@@ -38,7 +38,7 @@ class ShapeOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -38,7 +38,7 @@ class ShapeOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Input", "(Tensor), The input tensor."); AddInput("Input", "(Tensor), The input tensor.");
AddOutput("Out", AddOutput("Out",
"(Tensor), The shape of input tensor, the data type of the shape" "(Tensor), The shape of input tensor, the data type of the shape"
" is int64_t, will be on the same device with the input Tensor."); " is int32_t, will be on the same device with the input Tensor.");
AddComment(R"DOC( AddComment(R"DOC(
Shape Operator Shape Operator
...@@ -53,5 +53,5 @@ Get the shape of input tensor. Only support CPU input Tensor now. ...@@ -53,5 +53,5 @@ Get the shape of input tensor. Only support CPU input Tensor now.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker, REGISTER_OPERATOR(shape, ops::ShapeOp, ops::ShapeOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int64_t>, REGISTER_OP_CPU_KERNEL(shape, ops::ShapeKernel<int>, ops::ShapeKernel<int32_t>,
ops::ShapeKernel<float>, ops::ShapeKernel<double>); ops::ShapeKernel<float>, ops::ShapeKernel<double>);
...@@ -15,6 +15,6 @@ limitations under the License. */ ...@@ -15,6 +15,6 @@ limitations under the License. */
#include "paddle/fluid/operators/shape_op.h" #include "paddle/fluid/operators/shape_op.h"
REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel<int>, REGISTER_OP_CUDA_KERNEL(shape, paddle::operators::ShapeKernel<int>,
paddle::operators::ShapeKernel<int64_t>, paddle::operators::ShapeKernel<int32_t>,
paddle::operators::ShapeKernel<float>, paddle::operators::ShapeKernel<float>,
paddle::operators::ShapeKernel<double>); paddle::operators::ShapeKernel<double>);
...@@ -27,7 +27,7 @@ class ShapeKernel : public framework::OpKernel<T> { ...@@ -27,7 +27,7 @@ class ShapeKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_t = ctx.Input<Tensor>("Input"); auto* in_t = ctx.Input<Tensor>("Input");
auto* out_t = ctx.Output<Tensor>("Out"); auto* out_t = ctx.Output<Tensor>("Out");
auto out_data = out_t->mutable_data<int64_t>(platform::CPUPlace()); auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
auto in_dims = in_t->dims(); auto in_dims = in_t->dims();
for (int i = 0; i < in_dims.size(); ++i) { for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i]; out_data[i] = in_dims[i];
......
...@@ -20,7 +20,9 @@ from .layer_function_generator import autodoc, templatedoc ...@@ -20,7 +20,9 @@ from .layer_function_generator import autodoc, templatedoc
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from . import tensor from . import tensor
from . import nn from . import nn
from . import ops
import math import math
import numpy
from functools import reduce from functools import reduce
__all__ = [ __all__ = [
...@@ -264,10 +266,11 @@ def detection_output(loc, ...@@ -264,10 +266,11 @@ def detection_output(loc,
prior_box_var=prior_box_var, prior_box_var=prior_box_var,
target_box=loc, target_box=loc,
code_type='decode_center_size') code_type='decode_center_size')
old_shape = scores.shape compile_shape = scores.shape
scores = nn.reshape(x=scores, shape=(-1, old_shape[-1])) run_shape = ops.shape(scores)
scores = nn.flatten(x=scores, axis=2)
scores = nn.softmax(input=scores) scores = nn.softmax(input=scores)
scores = nn.reshape(x=scores, shape=old_shape) scores = nn.reshape(x=scores, shape=compile_shape, actual_shape=run_shape)
scores = nn.transpose(scores, perm=[0, 2, 1]) scores = nn.transpose(scores, perm=[0, 2, 1])
scores.stop_gradient = True scores.stop_gradient = True
nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype) nmsed_outs = helper.create_tmp_variable(dtype=decoded_box.dtype)
...@@ -677,9 +680,10 @@ def ssd_loss(location, ...@@ -677,9 +680,10 @@ def ssd_loss(location,
raise ValueError("Only support mining_type == max_negative now.") raise ValueError("Only support mining_type == max_negative now.")
num, num_prior, num_class = confidence.shape num, num_prior, num_class = confidence.shape
conf_shape = ops.shape(confidence)
def __reshape_to_2d(var): def __reshape_to_2d(var):
return nn.reshape(x=var, shape=[-1, var.shape[-1]]) return nn.flatten(x=var, axis=2)
# 1. Find matched boundding box by prior box. # 1. Find matched boundding box by prior box.
# 1.1 Compute IOU similarity between ground-truth boxes and prior boxes. # 1.1 Compute IOU similarity between ground-truth boxes and prior boxes.
...@@ -690,7 +694,8 @@ def ssd_loss(location, ...@@ -690,7 +694,8 @@ def ssd_loss(location,
# 2. Compute confidence for mining hard examples # 2. Compute confidence for mining hard examples
# 2.1. Get the target label based on matched indices # 2.1. Get the target label based on matched indices
gt_label = nn.reshape(x=gt_label, shape=gt_label.shape + (1, )) gt_label = nn.reshape(
x=gt_label, shape=(len(gt_label.shape) - 1) * (0, ) + (-1, 1))
gt_label.stop_gradient = True gt_label.stop_gradient = True
target_label, _ = target_assign( target_label, _ = target_assign(
gt_label, matched_indices, mismatch_value=background_label) gt_label, matched_indices, mismatch_value=background_label)
...@@ -701,9 +706,12 @@ def ssd_loss(location, ...@@ -701,9 +706,12 @@ def ssd_loss(location,
target_label = __reshape_to_2d(target_label) target_label = __reshape_to_2d(target_label)
target_label.stop_gradient = True target_label.stop_gradient = True
conf_loss = nn.softmax_with_cross_entropy(confidence, target_label) conf_loss = nn.softmax_with_cross_entropy(confidence, target_label)
# 3. Mining hard examples # 3. Mining hard examples
conf_loss = nn.reshape(x=conf_loss, shape=(num, num_prior)) conf_loss = nn.reshape(
x=conf_loss,
shape=(num, num_prior),
actual_shape=ops.slice(
conf_shape, axes=[0], starts=[0], ends=[2]))
conf_loss.stop_gradient = True conf_loss.stop_gradient = True
neg_indices = helper.create_tmp_variable(dtype='int32') neg_indices = helper.create_tmp_variable(dtype='int32')
dtype = matched_indices.dtype dtype = matched_indices.dtype
...@@ -772,7 +780,11 @@ def ssd_loss(location, ...@@ -772,7 +780,11 @@ def ssd_loss(location,
# 5.3 Compute overall weighted loss. # 5.3 Compute overall weighted loss.
loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss loss = conf_loss_weight * conf_loss + loc_loss_weight * loc_loss
# reshape to [N, Np], N is the batch size and Np is the prior box number. # reshape to [N, Np], N is the batch size and Np is the prior box number.
loss = nn.reshape(x=loss, shape=[-1, num_prior]) loss = nn.reshape(
x=loss,
shape=(num, num_prior),
actual_shape=ops.slice(
conf_shape, axes=[0], starts=[0], ends=[2]))
loss = nn.reduce_sum(loss, dim=1, keep_dim=True) loss = nn.reduce_sum(loss, dim=1, keep_dim=True)
if normalize: if normalize:
normalizer = nn.reduce_sum(target_loc_weight) normalizer = nn.reduce_sum(target_loc_weight)
...@@ -1005,13 +1017,7 @@ def multi_box_head(inputs, ...@@ -1005,13 +1017,7 @@ def multi_box_head(inputs,
""" """
def _reshape_with_axis_(input, axis=1): def _reshape_with_axis_(input, axis=1):
if not (axis > 0 and axis < len(input.shape)): out = nn.flatten(x=input, axis=axis)
raise ValueError("The axis should be smaller than "
"the arity of input and bigger than 0.")
new_shape = [
-1, reduce(lambda x, y: x * y, input.shape[axis:len(input.shape)])
]
out = nn.reshape(x=input, shape=new_shape)
return out return out
def _is_list_or_tuple_(data): def _is_list_or_tuple_(data):
...@@ -1101,11 +1107,13 @@ def multi_box_head(inputs, ...@@ -1101,11 +1107,13 @@ def multi_box_head(inputs,
stride=stride) stride=stride)
mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1]) mbox_loc = nn.transpose(mbox_loc, perm=[0, 2, 3, 1])
new_shape = [ compile_shape = [
mbox_loc.shape[0], mbox_loc.shape[0],
mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4 mbox_loc.shape[1] * mbox_loc.shape[2] * mbox_loc.shape[3] / 4, 4
] ]
mbox_loc_flatten = nn.reshape(mbox_loc, shape=new_shape) run_shape = tensor.assign(numpy.array([0, -1, 4]).astype("int32"))
mbox_loc_flatten = nn.reshape(
mbox_loc, shape=compile_shape, actual_shape=run_shape)
mbox_locs.append(mbox_loc_flatten) mbox_locs.append(mbox_loc_flatten)
# get conf # get conf
...@@ -1117,11 +1125,15 @@ def multi_box_head(inputs, ...@@ -1117,11 +1125,15 @@ def multi_box_head(inputs,
padding=pad, padding=pad,
stride=stride) stride=stride)
conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1]) conf_loc = nn.transpose(conf_loc, perm=[0, 2, 3, 1])
new_shape = [ new_shape = [0, -1, num_classes]
compile_shape = [
conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] * conf_loc.shape[0], conf_loc.shape[1] * conf_loc.shape[2] *
conf_loc.shape[3] / num_classes, num_classes conf_loc.shape[3] / num_classes, num_classes
] ]
conf_loc_flatten = nn.reshape(conf_loc, shape=new_shape) run_shape = tensor.assign(
numpy.array([0, -1, num_classes]).astype("int32"))
conf_loc_flatten = nn.reshape(
conf_loc, shape=compile_shape, actual_shape=run_shape)
mbox_confs.append(conf_loc_flatten) mbox_confs.append(conf_loc_flatten)
if len(box_results) == 1: if len(box_results) == 1:
......
...@@ -112,6 +112,7 @@ __all__ = [ ...@@ -112,6 +112,7 @@ __all__ = [
'log', 'log',
'crop', 'crop',
'rank_loss', 'rank_loss',
'flatten',
] ]
...@@ -5361,3 +5362,70 @@ def rank_loss(label, left, right, name=None): ...@@ -5361,3 +5362,70 @@ def rank_loss(label, left, right, name=None):
"Right": right}, "Right": right},
outputs={'Out': out}) outputs={'Out': out})
return out return out
def flatten(x, axis=1, name=None):
"""
**Flatten layer**
Flattens the input tensor into a 2D matrix.
Examples:
Case 1:
Given
X.shape = (3, 100, 100, 4)
and
axis = 2
We get:
Out.shape = (3 * 100, 4 * 100)
Case 2:
Given
X.shape = (3, 100, 100, 4)
and
axis = 0
We get:
Out.shape = (1, 3 * 100 * 100 * 4)
Args:
x (Variable): A tensor of rank >= axis.
axis (int): Indicate up to which input dimensions (exclusive) should
be flattened to the outer dimension of the output.
The value for axis must be in the range [0, R], where R
is the rank of the input tensor. When axis = 0, the shape
of the output tensor is (1, (d_0 X d_1 ... d_n), where the
shape of the input tensor is (d_0, d_1, ... d_n).
name(str|None): A name for this layer(optional). If set None, the layer
will be named automatically.
Returns:
Variable: A 2D tensor with the contents of the input tensor, with input
dimensions up to axis flattened to the outer dimension of
the output and remaining input dimensions flattened into the
inner dimension of the output.
Raises:
ValueError: If x is not a variable.
ValueError: If axis is not in range [0, rank(x)].
Examples:
.. code-block:: python
x = fluid.layers.data(name="x", shape=[4, 4, 3], dtype="float32")
out = fluid.layers.flatten(x=x, axis=2)
"""
helper = LayerHelper('flatten', **locals())
if not (isinstance(x, Variable)):
raise ValueError("The input x should be a Variable")
if not (isinstance(axis, int)) or axis > len(x.shape) or axis < 0:
raise ValueError("The axis should be a int, and in range [0, rank(x)]")
out = helper.create_tmp_variable(x.dtype)
helper.append_op(
type='flatten',
inputs={"X": x},
outputs={'Out': out},
attrs={"axis": axis})
return out
...@@ -465,6 +465,17 @@ class TestBook(unittest.TestCase): ...@@ -465,6 +465,17 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(out) self.assertIsNotNone(out)
print(str(program)) print(str(program))
def test_flatten(self):
program = Program()
with program_guard(program):
x = layers.data(
name='x',
append_batch_size=False,
shape=[4, 4, 3],
dtype="float32")
out = layers.flatten(x, axis=1, name="flatten")
self.assertIsNotNone(out)
def test_shape(self): def test_shape(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册