提交 32db8db5 编写于 作者: G gongweibao

fix bugs

上级 45f16c90
...@@ -23,6 +23,7 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -23,6 +23,7 @@ class BlockExpandOp : public framework::OperatorWithKernel {
protected: protected:
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
printf("op infershape\n");
using namespace framework; using namespace framework;
PADDLE_ENFORCE(ctx->HasInput("X"), PADDLE_ENFORCE(ctx->HasInput("X"),
"Input of BlockExpandOp should not be null."); "Input of BlockExpandOp should not be null.");
...@@ -33,6 +34,7 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -33,6 +34,7 @@ class BlockExpandOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format must be NCHW."); PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format must be NCHW.");
PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1."); PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1.");
printf("op infershape2\n");
int block_height = ctx->Attrs().Get<int>("blockHeight"); int block_height = ctx->Attrs().Get<int>("blockHeight");
int block_width = ctx->Attrs().Get<int>("blockWidth"); int block_width = ctx->Attrs().Get<int>("blockWidth");
int stride_height = ctx->Attrs().Get<int>("strideHeight"); int stride_height = ctx->Attrs().Get<int>("strideHeight");
...@@ -42,8 +44,8 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -42,8 +44,8 @@ class BlockExpandOp : public framework::OperatorWithKernel {
int N = in_dim[0]; int N = in_dim[0];
int C = in_dim[1]; int C = in_dim[1];
int img_height = in_dim[3]; int img_height = in_dim[2];
int img_width = in_dim[4]; int img_width = in_dim[3];
int output_height = 0; int output_height = 0;
int output_width = 0; int output_width = 0;
...@@ -58,6 +60,8 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -58,6 +60,8 @@ class BlockExpandOp : public framework::OperatorWithKernel {
// reshape into [seqLength, stepSize], where seqLength is equal // reshape into [seqLength, stepSize], where seqLength is equal
// output_height * output_width, stepSize is equal // output_height * output_width, stepSize is equal
// input_channels * blockHeight * blockWidth // input_channels * blockHeight * blockWidth
printf("N:%d, o_h:%d o_w:%d C:%d b_h:%d b_w:%d\n", N, output_height,
output_width, C, block_height, block_width);
ctx->SetOutputDim( ctx->SetOutputDim(
"Out", {N, output_height, output_width, C, block_height, block_width}); "Out", {N, output_height, output_width, C, block_height, block_width});
...@@ -77,6 +81,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -77,6 +81,7 @@ class BlockExpandOpMaker : public framework::OpProtoAndCheckerMaker {
H: height H: height
W: width W: width
)DOC"); )DOC");
printf("opmakeer\n");
AddOutput("Out", "(LodTensor)The output data of block_expand op,"); AddOutput("Out", "(LodTensor)The output data of block_expand op,");
AddAttr<int>("blockHeight", "(int)height of block."); AddAttr<int>("blockHeight", "(int)height of block.");
AddAttr<int>("blockWidth", "(int)width of block."); AddAttr<int>("blockWidth", "(int)width of block.");
......
...@@ -44,7 +44,7 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -44,7 +44,7 @@ class BlockExpandKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
using namespace framework; using namespace framework;
const Tensor* in = ctx.Input<Tensor>("input"); const Tensor* in = ctx.Input<Tensor>("X");
Tensor* out = ctx.Output<Tensor>("Out"); Tensor* out = ctx.Output<Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
...@@ -68,7 +68,11 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -68,7 +68,11 @@ class BlockExpandKernel : public framework::OpKernel<T> {
img_height, img_width, block_height, block_width, stride_height, img_height, img_width, block_height, block_width, stride_height,
stride_width, padding_height, padding_width, outputHeight, outputWidth); stride_width, padding_height, padding_width, outputHeight, outputWidth);
printf("N:%d, o_h:%d o_w:%d C:%d b_h:%d b_w:%d\n", N, outputHeight,
outputWidth, C, block_height, block_width);
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
printf("i:%d\n", i);
Tensor src = in->Slice<T>(i, i + 1).Resize({C, img_height, img_width}); Tensor src = in->Slice<T>(i, i + 1).Resize({C, img_height, img_width});
Tensor dst = out->Slice<T>(i, i + 1).Resize( Tensor dst = out->Slice<T>(i, i + 1).Resize(
{outputHeight, outputWidth, C, block_height, block_width}); {outputHeight, outputWidth, C, block_height, block_width});
...@@ -109,6 +113,9 @@ class BlockExpandGradKernel : public framework::OpKernel<T> { ...@@ -109,6 +113,9 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
img_height, img_width, block_height, block_width, stride_height, img_height, img_width, block_height, block_width, stride_height,
stride_width, padding_height, padding_width, outputHeight, outputWidth); stride_width, padding_height, padding_width, outputHeight, outputWidth);
printf("N:%d, o_h:%d o_w:%d C:%d b_h:%d b_w:%d\n", N, outputHeight,
outputWidth, C, block_height, block_width);
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
Tensor dst = Tensor dst =
out_grad->Slice<T>(i, i + 1).Resize({C, img_height, img_width}); out_grad->Slice<T>(i, i + 1).Resize({C, img_height, img_width});
......
...@@ -3,119 +3,153 @@ import numpy as np ...@@ -3,119 +3,153 @@ import numpy as np
from op_test import OpTest from op_test import OpTest
def get_output_shape(attrs, X): def get_output_shape(attrs, x):
img_height = X.shape[2] imgHeight = x.shape[1]
img_width = X.shpe[3] imgWidth = x.shape[2]
padding_height = attrs['padding_height']
padding_width = attrs['padding_width'] paddingHeight = attrs['paddingHeight']
block_height = attrs['block_height'] paddingWidth = attrs['paddingWidth']
block_width = attrs['block_width'] blockHeight = attrs['blockHeight']
stride_height = attrs['stride_height'] blockWidth = attrs['blockWidth']
stride_width = attrs['stride_width'] strideHeight = attrs['strideHeight']
output_height = \ strideWidth = attrs['strideWidth']
outputHeight = \
1 + \ 1 + \
(img_height + 2 * padding_height - block_height + stride_height - 1) / \ (imgHeight + 2 * paddingHeight - blockHeight + strideHeight - 1) / \
stride_height strideHeight
output_width = \ outputWidth = \
1 + \ 1 + \
(img_width + 2 * padding_width - block_width + stride_width - 1) / \ (imgWidth + 2 * paddingWidth - blockWidth + strideWidth - 1) / \
stride_width strideWidth
return output_height, output_width return outputHeight, outputWidth
""" """
img: {CHW} im: {CHW}
col: col:
{output_height, output_width, inputChannels, filterHeight, filterWidth} {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
""" """
def img2col(attrs, im, col): def im2col(attrs, im, col):
input_channels = im.shape.dims[0] input_channels = im.shape[0]
input_height = im.shape.dims[1] inputHeight = im.shape[1]
input_width = im.shape.dims[2] inputWidth = im.shape[2]
filter_height = col.shape.dims[3]
filter_width = col.shape.dims[4] outputHeight = col.shape[0]
output_height = col.shape.dims[0] outputWidth = col.shape[1]
output_width = col.shape.dims[1] filterHeight = col.shape[3]
filterWidth = col.shape[4]
for col_row_idx in range(0, output_height): strideHeight = attrs['strideHeight']
for col_col_idx in range(0, output_width): strideWidth = attrs['strideWidth']
paddingHeight = attrs['paddingHeight']
paddingWidth = attrs['paddingWidth']
for col_row_idx in range(0, outputHeight):
for col_col_idx in range(0, outputWidth):
for channel in range(0, input_channels): for channel in range(0, input_channels):
for filter_row_idx in range(0, filter_height): for filter_row_idx in range(0, filterHeight):
for filter_col_idx in range(0, filter_width): for filter_col_idx in range(0, filterWidth):
im_row_offset = col_row_idx * stride_height \ im_row_offset = col_row_idx * strideHeight \
+ filter_row_idx - padding_height + filter_row_idx - paddingHeight
im_col_offset = col_col_idx * stride_width \
+ filter_col_idx - padding_width im_col_offset = col_col_idx * strideWidth \
if (im_row_offset < 0 or + filter_col_idx - paddingWidth
im_row_offset >= input_height or
if (im_row_offset < 0 or im_row_offset >= inputHeight or
im_col_offset < 0 or im_col_offset < 0 or
im_col_offset >= input_width): im_col_offset >= inputWidth):
col[col_row_idx][col_col_idx][channel][ col[col_row_idx][col_col_idx][channel][\
filter_row_idx][filter_col_idx] = 0.0 filter_row_idx][filter_col_idx] = 0.0
else: else:
im_offset = (channel * input_height + im_row_offset im_offset = (channel * inputHeight + im_row_offset \
) * input_width + im_col_offset ) * inputWidth + im_col_offset
col[col_row_idx][col_col_idx][channel][
filter_row_idx][filter_col_idx] = im[channel][ col[col_row_idx][col_col_idx][channel][\
filter_row_idx][filter_col_idx] = im[channel][ \
im_row_offset][im_col_offset] im_row_offset][im_col_offset]
""" """
img: {CHW} img: {CHW}
col: col:
{output_height, output_width, inputChannels, filterHeight, filterWidth} {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
""" """
def col2img(attrs, col, img): def col2img(attrs, col, img):
input_channels = im.shape.dims[0] input_channels = im.shape[0]
input_height = im.shape.dims[1] inputHeight = im.shape[1]
input_width = im.shape.dims[2] inputWidth = im.shape[2]
filter_height = col.shape.dims[3]
filter_width = col.shape.dims[4] outputHeight = col.shape[0]
output_height = col.shape.dims[0] outputWidth = col.shape[1]
output_width = col.shape.dims[1] filterHeight = col.shape[3]
filterWidth = col.shape[4]
for col_row_idx in range(0, output_height):
for col_col_idx in range(0, output_width): strideHeight = attrs['strideHeight']
strideWidth = attrs['strideWidth']
paddingHeight = attrs['paddingHeight']
paddingWidth = attrs['paddingWidth']
for col_row_idx in range(0, outputHeight):
for col_col_idx in range(0, outputWidth):
for channel in range(0, input_channels): for channel in range(0, input_channels):
for filter_row_idx in range(0, filter_height): for filter_row_idx in range(0, filterHeight):
for filter_col_idx in range(0, filter_width): for filter_col_idx in range(0, filterWidth):
im_row_offset = \ im_row_offset = \
col_row_idx * stride_height + filter_row_idx - padding_height col_row_idx * strideHeight + filter_row_idx - paddingHeight
im_col_offset = \ im_col_offset = \
col_col_idx * stride_width + filter_col_idx - padding_width col_col_idx * strideWidth + filter_col_idx - paddingWidth
if (im_row_offset >= 0 and if (im_row_offset >= 0 and
im_row_offset < input_height and im_row_offset < inputHeight and
im_col_offset >= 0 and im_col_offset >= 0 and
im_col_offset < input_width): im_col_offset < inputWidth):
im[channel][im_row_offset][im_col_offset] = \ im[channel][im_row_offset][im_col_offset] = \
col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx] col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx]
class TestBlockExpandMulOp(OpTest): class TestBlockExpandMulOp(OpTest):
def setUp(self): def setUp(self):
self.op_type = "block_expand" x = np.random.uniform(0.1, 1, [3, 9, 9]).astype("float32")
self.inputs = { attrs = {
'X': np.random.uniform(0.1, 1, [2, 3, 9, 9]).astype("float64"), 'blockHeight': 3,
} 'blockWidth': 3,
self.attrs = { 'strideHeight': 2,
'block_height': 3, 'strideWidth': 2,
'block_width': 3, 'paddingHeight': 3,
'stride_height': 2, 'paddingWidth': 3,
'stride_width': 2,
'padding_height': 3,
'padding_width': 3,
} }
self.outputs = {'Out': np.multiply(self.inputs['X'], self.inputs['Y'])} outputHeight, outputWidth = get_output_shape(attrs, x)
out = np.random.uniform(0.1, 1,\
[outputHeight, outputWidth, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth']]).astype("float32")
self.op_type = "block_expand"
self.inputs = {'X': x.reshape(1, 3, 9, 9)}
self.attrs = attrs
im2col(attrs, x, out)
self.outputs = {
'Out':out.reshape(1, outputHeight, outputWidth, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth'])
}
#print out
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
print 1
"""
def test_check_grad_normal(self): def test_check_grad_normal(self):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
"""
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册