提交 29145e1e 编写于 作者: L lemon34 提交者: whs

change im2sequence for ctc batch inference (#11696)

* change im2sequence for ctc batch inference

* Update im2sequence_op.cc

* change im2sequence for ctc batch inference

* update

* change PR by comment

* fix ocr test error

* fix test_im2sequence

* modify the old name to standard name

* fix test_layers failed
上级 74fa603c
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/im2sequence_op.h"
#include <string>
#include <vector>
namespace paddle {
......@@ -28,20 +29,19 @@ class Im2SequenceOp : public framework::OperatorWithKernel {
"Input(X) of Im2SequenceOp should not be null.");
PADDLE_ENFORCE(ctx->HasOutput("Out"),
"Output(Out) of Im2SequenceOp op should not be null.");
auto in_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(in_dim.size(), 4,
"Input(X) format must be 4D tensor, eg., NCHW.");
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int batch_size = in_dim[0];
int img_channels = in_dim[1];
int img_height = in_dim[2];
int img_width = in_dim[3];
auto kernels = ctx->Attrs().Get<std::vector<int>>("kernels");
auto strides = ctx->Attrs().Get<std::vector<int>>("strides");
auto paddings = ctx->Attrs().Get<std::vector<int>>("paddings");
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
......@@ -61,6 +61,10 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
"C: channels"
"H: height"
"W: width");
AddInput("Y",
"(Tensor) The input tensor of image real size(H, W)."
"2-D with shape [batchsize, 2]")
.AsDispensable();
AddOutput("Out", "(LodTensor) The output data of im2sequence op,");
AddAttr<std::vector<int>>("kernels",
"(vector<int>), the "
......@@ -73,6 +77,13 @@ class Im2SequenceOpMaker : public framework::OpProtoAndCheckerMaker {
"(vector<int> default:{0, 0, 0, 0}), the "
"paddings(up_pad, left_pad, down_pad, right_pad)")
.SetDefault({0, 0, 0, 0});
AddAttr<std::vector<int>>("out_stride",
"the attribute is valid only when input(Y)"
"is not NULL.this attribute represents the"
"scaling of the pic through the CNN"
"(vector<int> dedault:{1,1}),the out_stride"
" (out_stride_height, out_stride_width)")
.SetDefault({1, 1});
AddComment(R"DOC(
This op uses kernels to scan images and converts these images to sequences.
After expanding, The number of time steps are output_height * output_width
......@@ -123,7 +134,7 @@ output.data = [[ 6. 2. 8. 3. 2. 4. 6. 3.]
[ 7. 1. 7. 9. 2. 1. 3. 5.]
[ 5. 7. 2. 4. 1. 3. 9. 0.]
[ 7. 9. 4. 8. 3. 5. 0. 8.]]
output.dims = {8, 9}
output.dims = {8, 8}
output.lod = [[0, 4, 8]]
)DOC");
......
......@@ -13,6 +13,7 @@
limitations under the License. */
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/eigen.h"
......@@ -39,50 +40,106 @@ class Im2SequenceKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
const Tensor* in = ctx.Input<Tensor>("X");
LoDTensor* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
// TODO(wanghaoshuang): Add layout checker after 'set_layout'
// being available for python API
// PADDLE_ENFORCE_EQ(in->layout(), framework::DataLayout::kNCHW,
// "Input(X) layout must be NCHW");
auto in_dim = in->dims();
int batch_size = in_dim[0];
int img_channels = in_dim[1];
int img_height = in_dim[2];
int img_width = in_dim[3];
auto kernels = ctx.Attr<std::vector<int>>("kernels");
auto strides = ctx.Attr<std::vector<int>>("strides");
auto paddings = ctx.Attr<std::vector<int>>("paddings");
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
const std::vector<int> dilations({1, 1});
auto out_dims = out->dims();
out->Resize({batch_size, out->numel() / batch_size});
for (int i = 0; i < batch_size; i++) {
const Tensor src =
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
Tensor dst = out->Slice(i, i + 1).Resize(
{output_height, output_width, img_channels, kernels[0], kernels[1]});
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
out->Resize(out_dims);
// set lod information
// TODO(wanghaoshuang): Move this to InferShape
framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
if (ctx.HasInput("Y") && batch_size > 1) {
const Tensor* imgrealsize = ctx.Input<Tensor>("Y");
auto out_stride = ctx.Attr<std::vector<int>>("out_stride");
Tensor cpu_shape_tensor;
TensorCopySync(*imgrealsize, platform::CPUPlace(), &cpu_shape_tensor);
std::vector<int> imgreal_h;
std::vector<int> imgreal_w;
std::vector<int> output_height;
std::vector<int> output_width;
int result = 0;
for (int i = 0; i < batch_size; i++) {
int tmp_real_h = static_cast<int>((cpu_shape_tensor.data<T>())[2 * i]);
int tmp_real_w =
static_cast<int>((cpu_shape_tensor.data<T>())[2 * i + 1]);
if (tmp_real_h % out_stride[0] == 0) {
tmp_real_h = tmp_real_h / out_stride[0];
} else {
tmp_real_h = tmp_real_h / out_stride[0] + 1;
}
if (tmp_real_w % out_stride[1] == 0) {
tmp_real_w = tmp_real_w / out_stride[1];
} else {
tmp_real_w = tmp_real_w / out_stride[1] + 1;
}
imgreal_h.push_back(tmp_real_h);
imgreal_w.push_back(tmp_real_w);
output_height.push_back(Im2SeqOutputSize(
imgreal_h[i], kernels[0], paddings[0], paddings[2], strides[0]));
output_width.push_back(Im2SeqOutputSize(
imgreal_w[i], kernels[1], paddings[1], paddings[3], strides[1]));
result += output_height[i] * output_width[i];
}
out->mutable_data<T>({result, img_channels * kernels[0] * kernels[1]},
ctx.GetPlace());
const std::vector<int> dilations({1, 1});
int offset_out = 0;
for (int i = 0; i < batch_size; i++) {
const Tensor src =
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
Tensor dst = out->Slice(offset_out,
offset_out + output_height[i] * output_width[i])
.Resize({output_height[i], output_width[i],
img_channels, kernels[0], kernels[1]});
offset_out += output_height[i] * output_width[i];
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
int offset = 0;
lod[0].push_back(offset);
for (int i = 0; i < batch_size; ++i) {
offset += output_height[i] * output_width[i];
lod[0].push_back(offset);
}
out->set_lod(lod);
} else {
out->mutable_data<T>(ctx.GetPlace());
int output_height = Im2SeqOutputSize(img_height, kernels[0], paddings[0],
paddings[2], strides[0]);
int output_width = Im2SeqOutputSize(img_width, kernels[1], paddings[1],
paddings[3], strides[1]);
const std::vector<int> dilations({1, 1});
auto out_dims = out->dims();
out->Resize({batch_size, out->numel() / batch_size});
for (int i = 0; i < batch_size; i++) {
const Tensor src =
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
Tensor dst =
out->Slice(i, i + 1).Resize({output_height, output_width,
img_channels, kernels[0], kernels[1]});
math::Im2ColFunctor<math::ColFormat::kOCF, DeviceContext, T> f;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
f(dev_ctx, src, dilations, strides, paddings, &dst);
}
out->Resize(out_dims);
framework::LoD lod(1);
lod[0].reserve(batch_size + 1);
int offset = 0;
lod[0].push_back(offset);
offset += output_height * output_width;
for (int i = 0; i < batch_size; ++i) {
offset += output_height * output_width;
lod[0].push_back(offset);
}
out->set_lod(lod);
}
out->set_lod(lod);
}
};
......
......@@ -43,21 +43,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int col_height = col->dims()[3];
int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
((dilation[0] * (filter_height - 1) + 1))) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
((dilation[1] * (filter_width - 1) + 1))) /
stride[1] +
1,
col_width,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
int channels_col = im_channels * filter_height * filter_width;
const T* im_data = im.data<T>();
......@@ -178,17 +163,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int col_height = col->dims()[0];
int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ(
(im_height + padding[0] + padding[2] - filter_height) / stride[0] + 1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ(
(im_width + padding[1] + padding[3] - filter_width) / stride[1] + 1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
const T* im_data = im.data<T>();
T* col_data = col->data<T>();
......
......@@ -77,21 +77,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kCFO,
int col_height = col->dims()[3];
int col_width = col->dims()[4];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int num_outputs = im_channels * col_height * col_width;
int blocks = (num_outputs + 1024 - 1) / 1024;
int block_x = 512;
......@@ -274,21 +259,6 @@ class Im2ColFunctor<paddle::operators::math::ColFormat::kOCF,
int col_height = col->dims()[0];
int col_width = col->dims()[1];
PADDLE_ENFORCE_EQ((im_height + padding[0] + padding[2] -
(dilation[0] * (filter_height - 1) + 1)) /
stride[0] +
1,
col_height,
"Output_height and padding(padding_up, padding_down) are "
"inconsistent.");
PADDLE_ENFORCE_EQ((im_width + padding[1] + padding[3] -
(dilation[1] * (filter_width - 1) + 1)) /
stride[1] +
1,
col_width,
"col_width and padding(padding_left, padding_right) are "
"inconsistent.");
int block_dim_x = 0;
int block_dim_y = 0;
if (filter_height <= 4 && filter_width <= 4) {
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copyright (c ) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
......@@ -3900,7 +3914,13 @@ def transpose(x, perm, name=None):
return out
def im2sequence(input, filter_size=1, stride=1, padding=0, name=None):
def im2sequence(input,
filter_size=1,
stride=1,
padding=0,
input_image_size=None,
out_stride=1,
name=None):
"""
Extracts image patches from the input tensor to form a tensor of shape
{input.batch_size * output_height * output_width, filter_size_H *
......@@ -3937,6 +3957,15 @@ def im2sequence(input, filter_size=1, stride=1, padding=0, name=None):
padding_up = padding_down = padding_left = padding_right = padding
Default: padding = 0.
input_image_size(Variable): the input contains image real size.It's dim
is [batchsize, 2]. It is dispensable.It is just for batch inference.
out_stride(int|tuple): The scaling of image through CNN. It is
dispensable. It is valid only when input_image_size is not null.
If out_stride is tuple, it must contain two intergers,
(out_stride_H, out_stride_W). Otherwise,
the out_stride_H = out_stride_W = out_stride.
name (int): The name of this layer. It is optional.
Returns:
......@@ -3987,7 +4016,7 @@ def im2sequence(input, filter_size=1, stride=1, padding=0, name=None):
[ 5. 7. 2. 4. 1. 3. 9. 0.]
[ 7. 9. 4. 8. 3. 5. 0. 8.]]
output.dims = {8, 9}
output.dims = {8, 8}
output.lod = [[4, 4]]
......@@ -4009,18 +4038,17 @@ def im2sequence(input, filter_size=1, stride=1, padding=0, name=None):
if len(padding) == 2:
padding.append(padding[0])
padding.append(padding[1])
inputs = {"X": input}
attrs = {"kernels": filter_size, "strides": stride, "padding": padding}
if input_image_size:
if isinstance(out_stride, int):
out_stride = [out_stride, out_stride]
inputs["Y"] = input_image_size
attrs["out_stride"] = out_stride
helper = LayerHelper('im2sequence', **locals())
out = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='im2sequence',
inputs={'X': input},
outputs={'Out': out},
attrs={
'kernels': filter_size,
'strides': stride,
'paddings': padding,
})
type='im2sequence', inputs=inputs, outputs={'Out': out}, attrs=attrs)
return out
......
......@@ -16,23 +16,48 @@ import numpy as np
from op_test import OpTest
def get_output_shape(attrs, in_shape):
def get_output_shape(attrs, in_shape, img_real_size):
batchsize = in_shape[0]
img_height = in_shape[2]
img_width = in_shape[3]
paddings = np.array(attrs['paddings']).astype("int32")
kernels = np.array(attrs['kernels']).astype("int32")
strides = np.array(attrs['strides']).astype("int32")
output_height = np.zeros((1, batchsize)).astype("int32")
output_width = np.zeros((1, batchsize)).astype("int32")
if len(img_real_size):
out_stride = np.array(attrs['out_stride']).astype("int32")
imgreal_h = 0
imgreal_w = 0
for index in range(batchsize):
if img_real_size[index, 0] % out_stride[0] == 0:
imgreal_h = img_real_size[index, 0] / out_stride[0]
else:
imgreal_h = img_real_size[index, 0] / out_stride[0] + 1
if img_real_size[index, 0] % out_stride[1] == 0:
imgreal_w = img_real_size[index, 1] / out_stride[1]
else:
imgreal_w = img_real_size[index, 0] / out_stride[1] + 1
output_height[0,index] = \
1 + \
(imgreal_h + paddings[0] + paddings[2] - kernels[0] + strides[0] - 1) / \
strides[0]
paddings = attrs['paddings']
kernels = attrs['kernels']
strides = attrs['strides']
output_width[0,index] = \
1 + \
(imgreal_w + paddings[1] + paddings[3] - kernels[1] + strides[1] - 1) / \
strides[1]
else:
for index in range(batchsize):
output_height[0,index] = \
1 + \
(img_height + paddings[0] + paddings[2] - kernels[0] + strides[0] - 1) / \
strides[0]
output_height = \
1 + \
(img_height + paddings[0] + paddings[2] - kernels[0] + strides[0] - 1) / \
strides[0]
output_width = \
1 + \
(img_width + paddings[1] + paddings[3] - kernels[1] + strides[1] - 1) / \
strides[1]
output_width[0,index] = \
1 + \
(img_width + paddings[1] + paddings[3] - kernels[1] + strides[1] - 1) / \
strides[1]
return output_height, output_width
......@@ -75,22 +100,25 @@ def im2col(attrs, im, col):
im_row_offset][im_col_offset]
def Im2Sequence(inputs, attrs):
output_height, output_width = get_output_shape(attrs, inputs.shape)
def Im2Sequence(inputs, img_real_size, attrs):
output_height, output_width = get_output_shape(attrs, inputs.shape,
img_real_size)
img_channels = inputs.shape[1]
batch_size = inputs.shape[0]
out = np.zeros([
batch_size, output_height, output_width, img_channels,
attrs['kernels'][0], attrs['kernels'][1]
]).astype("float32")
for i in range(len(inputs)):
im2col(attrs, inputs[i], out[i])
out = out.reshape([
batch_size * output_height * output_width,
img_channels * attrs['kernels'][0] * attrs['kernels'][1]
])
out = []
for index in range(batch_size):
tmp = np.zeros([
output_height[0, index], output_width[0, index], img_channels,
attrs['kernels'][0], attrs['kernels'][1]
]).astype("float32")
out.append(tmp)
for index in range(len(inputs)):
im2col(attrs, inputs[index], out[index])
out[index] = out[index].reshape([
output_height[0, index] * output_width[0, index],
img_channels * attrs['kernels'][0] * attrs['kernels'][1]
])
out = np.concatenate(out, axis=0)
return out
......@@ -103,7 +131,7 @@ class TestBlockExpandOp(OpTest):
self.attrs = {
'kernels': [2, 2],
'strides': [1, 1],
'paddings': [1, 1, 1, 1]
'paddings': [1, 1, 1, 1],
}
def setUp(self):
......@@ -113,7 +141,8 @@ class TestBlockExpandOp(OpTest):
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
out = Im2Sequence(x, self.attrs)
real_size = np.array([]).astype("float32")
out = Im2Sequence(x, real_size, self.attrs)
self.inputs = {'X': x}
self.outputs = {'Out': out}
......@@ -133,20 +162,20 @@ class TestBlockExpandOpCase2(TestBlockExpandOp):
self.attrs = {
'kernels': [2, 1],
'strides': [2, 1],
'paddings': [2, 1, 2, 1]
'paddings': [2, 1, 2, 1],
}
class TestBlockExpandOpCase3(TestBlockExpandOp):
def config(self):
self.batch_size = 3
self.batch_size = 2
self.img_channels = 1
self.img_height = 4
self.img_width = 5
self.attrs = {
'kernels': [2, 1],
'strides': [2, 1],
'paddings': [2, 0, 2, 0]
'paddings': [2, 0, 2, 0],
}
......@@ -159,9 +188,94 @@ class TestBlockExpandOpCase4(TestBlockExpandOp):
self.attrs = {
'kernels': [2, 2],
'strides': [1, 1],
'paddings': [0, 0, 0, 0]
'paddings': [0, 0, 0, 0],
}
class TestBlockExpandOpCase5(OpTest):
def config(self):
self.batch_size = 1
self.img_channels = 3
self.img_height = 4
self.img_width = 5
self.attrs = {
'kernels': [2, 1],
'strides': [2, 1],
'paddings': [2, 1, 2, 1],
'out_stride': [2, 2],
}
def setUp(self):
self.config()
self.op_type = "im2sequence"
x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
real_size = np.array([[8, 10], [5, 8]]).astype("float32")
out = np.array(Im2Sequence(x, real_size, self.attrs))
self.inputs = {'X': x, 'Y': real_size} #l ??
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestBlockExpandOpCase6(OpTest):
def config(self):
self.batch_size = 3
self.img_channels = 1
self.img_height = 4
self.img_width = 5
self.attrs = {
'kernels': [2, 1],
'strides': [1, 1],
'paddings': [0, 0, 0, 0],
'out_stride': [1, 1],
}
def setUp(self):
self.config()
self.op_type = "im2sequence"
x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
real_size = np.array([[8, 10], [5, 8], [5, 8]]).astype("float32")
out = np.array(Im2Sequence(x, real_size, self.attrs))
self.inputs = {'X': x, 'Y': real_size} #l ??
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
class TestBlockExpandOpCase7(OpTest):
def config(self):
self.batch_size = 2
self.img_channels = 2
self.img_height = 3
self.img_width = 3
self.attrs = {
'kernels': [2, 2],
'strides': [1, 1],
'paddings': [1, 0, 1, 0],
'out_stride': [2, 2],
}
def setUp(self):
self.config()
self.op_type = "im2sequence"
x = np.random.uniform(0.1, 1, [
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
real_size = np.array([[6, 6], [4, 4]]).astype("float32")
out = np.array(Im2Sequence(x, real_size, self.attrs))
self.inputs = {'X': x, 'Y': real_size}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
if __name__ == '__main__':
unittest.main()
#set shiftwidth=4 set expandtab set tabstop=4
......@@ -251,12 +251,16 @@ class TestBook(unittest.TestCase):
print(str(program))
def test_im2sequence(self):
print("test_im2sequence")
program = Program()
with program_guard(program):
x = layers.data(name='x', shape=[3, 128, 128], dtype='float32')
y = layers.data(name='y', shape=[], dtype='float32')
output = layers.im2sequence(
input=x, stride=[1, 1], filter_size=[2, 2])
input=x,
input_image_size=y,
stride=[1, 1],
filter_size=[2, 2],
out_stride=[1, 1])
self.assertIsNotNone(output)
print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册