提交 e82f1008 编写于 作者: W wanghaoshuang

Finish block expand op

1. Add lod to output
2. Fix im2col arguments list
3. Refine code and doc
4. Fix output shape
上级 25a3d2d7
......@@ -32,37 +32,27 @@ class BlockExpandOp : public framework::OperatorWithKernel {
auto in_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(in_dim.size(), 4,
"Input(X) format must be 4D tensor, eg., NCHW.");
PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1.");
int block_height = ctx->Attrs().Get<int>("blockHeight");
int block_width = ctx->Attrs().Get<int>("blockWidth");
int stride_height = ctx->Attrs().Get<int>("strideHeight");
int stride_width = ctx->Attrs().Get<int>("strideWidth");
int padding_height = ctx->Attrs().Get<int>("paddingHeight");
int padding_width = ctx->Attrs().Get<int>("paddingWidth");
int block_height = ctx->Attrs().Get<int>("block_height");
int block_width = ctx->Attrs().Get<int>("block_width");
int stride_height = ctx->Attrs().Get<int>("stride_height");
int stride_width = ctx->Attrs().Get<int>("stride_width");
int padding_height = ctx->Attrs().Get<int>("padding_height");
int padding_width = ctx->Attrs().Get<int>("padding_width");
int N = in_dim[0];
int C = in_dim[1];
int batch_size = in_dim[0];
int img_channels = in_dim[1];
int img_height = in_dim[2];
int img_width = in_dim[3];
int output_height = 0;
int output_width = 0;
int output_height = get_output_size(img_height, block_height, stride_height,
padding_height);
int output_width =
get_output_size(img_width, block_width, stride_width, padding_width);
get_blockexpand_output_shape(img_height, img_width, block_height,
block_width, stride_height, stride_width,
padding_height, padding_width, output_height,
output_width);
// The result of im2col is [output_height, output_width,
// inputChannels, filterHeight, filterWidth], and it is easy to
// reshape into [seqLength, stepSize], where seqLength is equal
// output_height * output_width, stepSize is equal
// input_channels * blockHeight * blockWidth
ctx->SetOutputDim(
"Out", {N, output_height, output_width, C, block_height, block_width});
// ctx->ShareLoD("X", /*->*/ "Out");
ctx->SetOutputDim("Out", {batch_size * output_height * output_width,
img_channels * block_height * block_width});
// TODO(wanghaoshuang): cal lod in complie time
}
};
......@@ -79,28 +69,69 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
W: width
)DOC");
AddOutput("Out", "(LodTensor)The output data of block_expand op,");
AddAttr<int>("blockHeight", "(int)height of block.");
AddAttr<int>("blockWidth", "(int)width of block.");
AddAttr<int>("strideHeight", "(int)height of stride.");
AddAttr<int>("strideWidth", "(int)width of stride.");
AddAttr<int>("paddingHeight", "(int)height of padding.");
AddAttr<int>("paddingWidth", "(int)width of padding.");
AddAttr<int>("block_height", "(int)height of block.");
AddAttr<int>("block_width", "(int)width of block.");
AddAttr<int>("stride_height", "(int)height of stride.");
AddAttr<int>("stride_width", "(int)width of stride.");
AddAttr<int>("padding_height", "(int)height of padding.");
AddAttr<int>("padding_width", "(int)width of padding.");
AddComment(R"DOC(
Expand feature map to minibatch matrix.
- matirx height is: output_height * output_width
- matrix width is: blockHeight * blockWidth * channels
output_height =
1 + (2 * paddingHeight + img_height - blockHeight + strideHeight - 1) /
strideHeight;
output_width =
1 + (2 * paddingWidth + img_width - blockWidth + strideWidth - 1) /
strideWidth;
The expand method is the same with ExpandConvLayer, but saved the transposed
value. After expanding, The number of time steps are output_height * output_width
and the dimension of each time step is blockHeight * blockWidth * channels.
This layer can be used after convolution neural network, and before recurrent neural network.
- matrix width is: block_height * block_width * channels
output_height =
1 + (2 * padding_height + img_height - block_height + stride_height - 1) /
stride_height;
output_width =
1 + (2 * padding_width + img_width - block_width + stride_width - 1) /
stride_width;
After expanding, The number of time steps are output_height * output_width
and the dimension of each time step is block_height * block_width * channels.
This op can be used after convolution neural network, and before recurrent neural network.
Given:
x = [[[[ 6. 2. 1.]
[ 8. 3. 5.]
[ 0. 2. 6.]]
[[ 2. 4. 4.]
[ 6. 3. 0.]
[ 6. 4. 7.]]]
[[[ 6. 7. 1.]
[ 5. 7. 9.]
[ 2. 4. 8.]]
[[ 1. 2. 1.]
[ 1. 3. 5.]
[ 9. 0. 8.]]]]
x.dims = {2, 2, 3, 3}
And:
block_height = 2
block_width = 2
stride_height = 1
stride_width = 1
padding_height = 0
padding_width = 0
Then:
output.data = [[ 6. 2. 8. 3. 2. 4. 6. 3.]
[ 2. 1. 3. 5. 4. 4. 3. 0.]
[ 8. 3. 0. 2. 6. 3. 6. 4.]
[ 3. 5. 2. 6. 3. 0. 4. 7.]
[ 6. 7. 5. 7. 1. 2. 1. 3.]
[ 7. 1. 7. 9. 2. 1. 3. 5.]
[ 5. 7. 2. 4. 1. 3. 9. 0.]
[ 7. 9. 4. 8. 3. 5. 0. 8.]]
output.dims = {8, 9}
output.lod = [[0, 4, 8]]
)DOC");
}
};
......
......@@ -23,20 +23,9 @@
namespace paddle {
namespace operators {
inline void get_blockexpand_output_shape(int img_height, int img_width,
int block_height, int block_width,
int stride_height, int stride_width,
int padding_height, int padding_width,
int& outputHeight, int& outputWidth) {
outputHeight =
1 +
(img_height + 2 * padding_height - block_height + stride_height - 1) /
stride_height;
outputWidth =
1 +
(img_width + 2 * padding_width - block_width + stride_width - 1) /
stride_width;
inline int get_output_size(int img_size, int block_size, int stride,
int padding) {
return (1 + (img_size + 2 * padding - block_size + stride - 1) / stride);
}
template <typename Place, typename T>
......@@ -45,40 +34,54 @@ class BlockExpandKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework;
const Tensor* in = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out");
LoDTensor* out = ctx.Output<LoDTensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto in_dim = in->dims();
int N = in_dim[0];
int C = in_dim[1];
int batch_size = in_dim[0];
int img_channels = in_dim[1];
int img_height = in_dim[2];
int img_width = in_dim[3];
int block_height = ctx.Attr<int>("blockHeight");
int block_width = ctx.Attr<int>("blockWidth");
int stride_height = ctx.Attr<int>("strideHeight");
int stride_width = ctx.Attr<int>("strideWidth");
int padding_height = ctx.Attr<int>("paddingHeight");
int padding_width = ctx.Attr<int>("paddingWidth");
int outputHeight = 0;
int outputWidth = 0;
get_blockexpand_output_shape(
img_height, img_width, block_height, block_width, stride_height,
stride_width, padding_height, padding_width, outputHeight, outputWidth);
std::vector<int> stride({stride_height, stride_width});
std::vector<int> padding({padding_height, padding_width});
for (int i = 0; i < N; i++) {
Tensor src = in->Slice(i, i + 1).Resize({C, img_height, img_width});
Tensor dst = out->Slice(i, i + 1).Resize(
{outputHeight, outputWidth, C, block_height, block_width});
int block_height = ctx.Attr<int>("block_height");
int block_width = ctx.Attr<int>("block_width");
int stride_height = ctx.Attr<int>("stride_height");
int stride_width = ctx.Attr<int>("stride_width");
int padding_height = ctx.Attr<int>("padding_height");
int padding_width = ctx.Attr<int>("padding_width");
int output_height = get_output_size(img_height, block_height, stride_height,
padding_height);
int output_width =
get_output_size(img_width, block_width, stride_width, padding_width);
const std::vector<int> dilations({1, 1});
const std::vector<int> strides(
{stride_height, stride_width, stride_height, stride_width});
const std::vector<int> paddings(
{padding_height, padding_width, padding_height, padding_width});
auto out_dims = out->dims();
out->Resize({batch_size, out->numel() / batch_size});
for (int i = 0; i < batch_size; i++) {
const Tensor src =
in->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
Tensor dst = out->Slice(i, i + 1).Resize({output_height, output_width,
img_channels, block_height,
block_width});
math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f;
f(ctx.device_context(), src, stride, padding, &dst);
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
}
out->Resize(out_dims);
// set lod information
// TODO(wanghaoshuang): Move this to InferShape
framework::LoD lod(1);
for (int i = 0, offset = 0; i < batch_size + 1; ++i) {
lod[0].push_back(offset);
offset += output_height * output_width;
}
out->set_lod(lod);
}
};
......@@ -88,7 +91,8 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework;
auto* in = ctx.Input<Tensor>("X");
auto* d_out = ctx.Input<Tensor>(framework::GradVarName("Out"));
Tensor* d_out =
const_cast<Tensor*>(ctx.Input<Tensor>(framework::GradVarName("Out")));
auto* d_x = ctx.Output<Tensor>(GradVarName("X"));
d_x->mutable_data<T>(ctx.GetPlace());
......@@ -96,36 +100,40 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
x_v.device(ctx.GetEigenDevice<Place>()) = x_v.constant(0.0);
auto in_dim = in->dims();
int N = in_dim[0];
int C = in_dim[1];
int batch_size = in_dim[0];
int img_channels = in_dim[1];
int img_height = in_dim[2];
int img_width = in_dim[3];
int block_height = ctx.Attr<int>("blockHeight");
int block_width = ctx.Attr<int>("blockWidth");
int stride_height = ctx.Attr<int>("strideHeight");
int stride_width = ctx.Attr<int>("strideWidth");
int padding_height = ctx.Attr<int>("paddingHeight");
int padding_width = ctx.Attr<int>("paddingWidth");
int outputHeight = 0;
int outputWidth = 0;
get_blockexpand_output_shape(
img_height, img_width, block_height, block_width, stride_height,
stride_width, padding_height, padding_width, outputHeight, outputWidth);
std::vector<int> stride({stride_height, stride_width});
std::vector<int> padding({padding_height, padding_width});
// std::vector<int> stride({stride_height, stride_width});
for (int i = 0; i < N; i++) {
Tensor dst = d_x->Slice(i, i + 1).Resize({C, img_height, img_width});
Tensor src = d_out->Slice(i, i + 1).Resize(
{outputHeight, outputWidth, C, block_height, block_width});
int block_height = ctx.Attr<int>("block_height");
int block_width = ctx.Attr<int>("block_width");
int stride_height = ctx.Attr<int>("stride_height");
int stride_width = ctx.Attr<int>("stride_width");
int padding_height = ctx.Attr<int>("padding_height");
int padding_width = ctx.Attr<int>("padding_width");
int output_height = get_output_size(img_height, block_height, stride_height,
padding_height);
int output_width =
get_output_size(img_width, block_width, stride_width, padding_width);
const std::vector<int> dilations({1, 1});
const std::vector<int> strides(
{stride_height, stride_width, stride_height, stride_width});
const std::vector<int> paddings(
{padding_height, padding_width, padding_height, padding_width});
auto d_out_dims = d_out->dims();
d_out->Resize({batch_size, d_out->numel() / batch_size});
for (int i = 0; i < batch_size; i++) {
Tensor dst =
d_x->Slice(i, i + 1).Resize({img_channels, img_height, img_width});
const Tensor src = d_out->Slice(i, i + 1).Resize(
{output_height, output_width, img_channels, block_height,
block_width});
math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f;
f(ctx.device_context(), dst, stride, padding, &src);
f(ctx.device_context(), src, dilations, strides, paddings, &dst);
}
d_out->Resize(d_out_dims);
}
};
......
......@@ -4,20 +4,20 @@ from op_test import OpTest
def get_output_shape(attrs, x):
img_height = x.shape[1]
img_width = x.shape[2]
img_height = x.shape[2]
img_width = x.shape[3]
padding_height = attrs['paddingHeight']
padding_width = attrs['paddingWidth']
block_height = attrs['blockHeight']
block_width = attrs['blockWidth']
stride_height = attrs['strideHeight']
stride_width = attrs['strideWidth']
padding_height = attrs['padding_height']
padding_width = attrs['padding_width']
block_height = attrs['block_height']
block_width = attrs['block_width']
stride_height = attrs['stride_height']
stride_width = attrs['stride_width']
output_height = \
1 + \
(img_height + 2 * padding_height - block_height + stride_height - 1) / \
strideHeight
stride_height
output_width = \
1 + \
......@@ -42,10 +42,10 @@ def im2col(attrs, im, col):
filter_height = col.shape[3]
filter_width = col.shape[4]
stride_height = attrs['strideHeight']
stride_width = attrs['strideWidth']
padding_height = attrs['paddingHeight']
padding_width = attrs['paddingWidth']
stride_height = attrs['stride_height']
stride_width = attrs['stride_width']
padding_height = attrs['padding_height']
padding_width = attrs['padding_width']
for col_row_idx in range(0, output_height):
for col_col_idx in range(0, output_width):
......@@ -73,83 +73,51 @@ def im2col(attrs, im, col):
im_row_offset][im_col_offset]
def col2img(attrs, col, img):
"""
img: {CHW}
col:
{output_height, outputWidth, inputChannels, filterHeight, filterWidth}
"""
input_channels = im.shape[0]
input_height = im.shape[1]
input_width = im.shape[2]
output_height = col.shape[0]
output_width = col.shape[1]
filter_height = col.shape[3]
filter_width = col.shape[4]
def block_expand(inputs, attrs):
output_height, output_width = get_output_shape(attrs, inputs)
img_channels = inputs.shape[1]
batch_size = inputs.shape[0]
out = np.zeros([
batch_size, output_height, output_width, img_channels,
attrs['block_height'], attrs['block_width']
]).astype("float32")
stride_height = attrs['strideHeight']
stride_width = attrs['strideWidth']
padding_height = attrs['paddingHeight']
padding_width = attrs['paddingWidth']
for i in range(len(inputs)):
im2col(attrs, inputs[i], out[i])
for col_row_idx in range(0, output_height):
for col_col_idx in range(0, output_width):
for channel in range(0, input_channels):
for filter_row_idx in range(0, filter_height):
for filter_col_idx in range(0, filter_width):
im_row_offset = \
col_row_idx * stride_height + filter_row_idx - padding_height
im_col_offset = \
col_col_idx * stride_width + filter_col_idx - padding_width
if (im_row_offset >= 0 and
im_row_offset < input_height and
im_col_offset >= 0 and
im_col_offset < input_width):
im[channel][im_row_offset][im_col_offset] = \
col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx]
def get_input_data(C, H, W):
x = np.random.uniform(0.1, 1, [C, H, W]).astype("float32")
for c in range(0, C):
for h in range(0, H):
for w in range(0, W):
#x[c][h][w] = c * H * W + h *W + w
x[c][h][w] = 0.2 + 0.01 * (c * H * W + h * W + w)
return x
out = out.reshape([
batch_size * output_height * output_width,
img_channels * attrs['block_height'] * attrs['block_width']
])
return out
class TestBlockExpandOp(OpTest):
def setUp(self):
C = 3
H = 4
W = 4
x = get_input_data(C, H, W)
attrs = {
'blockHeight': 2,
'blockWidth': 2,
'strideHeight': 1,
'strideWidth': 1,
'paddingHeight': 1,
'paddingWidth': 1,
def config(self):
self.batch_size = 1
self.img_channels = 3
self.img_height = 4
self.img_width = 4
self.attrs = {
'block_height': 2,
'block_width': 2,
'stride_height': 1,
'stride_width': 1,
'padding_height': 1,
'padding_width': 1,
}
output_height, output_width = get_output_shape(attrs, x)
out = np.random.uniform(0.1, 1,\
[output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth']]).astype("float32")
def setUp(self):
self.config()
self.op_type = "block_expand"
self.inputs = {'X': x.reshape(1, C, H, W)}
self.attrs = attrs
#x = np.random.uniform(0.1, 1,
x = np.random.randint(0, 10, [
self.batch_size, self.img_channels, self.img_height, self.img_width
]).astype("float32")
im2col(attrs, x, out)
self.outputs = {
'Out':out.reshape(1, output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth'])
}
out = block_expand(x, self.attrs)
self.inputs = {'X': x}
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
......@@ -158,42 +126,52 @@ class TestBlockExpandOp(OpTest):
self.check_grad(['X'], 'Out')
class TestBlockExpandOp2(OpTest):
def setUp(self):
C = 3
H = 4
W = 5
x = get_input_data(C, H, W)
attrs = {
'blockHeight': 2,
'blockWidth': 1,
'strideHeight': 2,
'strideWidth': 1,
'paddingHeight': 2,
'paddingWidth': 1,
class TestBlockExpandOpCase2(TestBlockExpandOp):
def config(self):
self.batch_size = 2
self.img_channels = 3
self.img_height = 4
self.img_width = 5
self.attrs = {
'block_height': 2,
'block_width': 1,
'stride_height': 2,
'stride_width': 1,
'padding_height': 2,
'padding_width': 1,
}
output_height, output_width = get_output_shape(attrs, x)
out = np.random.uniform(0.1, 1,\
[output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth']]).astype("float32")
self.op_type = "block_expand"
self.inputs = {'X': x.reshape(1, C, H, W)}
self.attrs = attrs
im2col(attrs, x, out)
self.outputs = {
'Out':out.reshape(1, output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth'])
}
class TestBlockExpandOpCase3(TestBlockExpandOp):
def config(self):
self.batch_size = 3
self.img_channels = 1
self.img_height = 4
self.img_width = 5
self.attrs = {
'block_height': 2,
'block_width': 1,
'stride_height': 2,
'stride_width': 1,
'padding_height': 2,
'padding_width': 0,
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestBlockExpandOpCase4(TestBlockExpandOp):
def config(self):
self.batch_size = 2
self.img_channels = 2
self.img_height = 3
self.img_width = 3
self.attrs = {
'block_height': 2,
'block_width': 2,
'stride_height': 1,
'stride_width': 1,
'padding_height': 0,
'padding_width': 0,
}
if __name__ == '__main__':
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册