提交 25a3d2d7 编写于 作者: G gongweibao

fix by comments

上级 e11d4424
...@@ -30,7 +30,8 @@ class BlockExpandOp : public framework::OperatorWithKernel { ...@@ -30,7 +30,8 @@ class BlockExpandOp : public framework::OperatorWithKernel {
"Output of BlockExpandOp op should not be null."); "Output of BlockExpandOp op should not be null.");
auto in_dim = ctx->GetInputDim("X"); auto in_dim = ctx->GetInputDim("X");
PADDLE_ENFORCE_EQ(in_dim.size(), 4, "Input format must be NCHW."); PADDLE_ENFORCE_EQ(in_dim.size(), 4,
"Input(X) format must be 4D tensor, eg., NCHW.");
PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1."); PADDLE_ENFORCE_GE(in_dim[0], 1, "Input batchsize must >= 1.");
int block_height = ctx->Attrs().Get<int>("blockHeight"); int block_height = ctx->Attrs().Get<int>("blockHeight");
......
...@@ -68,13 +68,16 @@ class BlockExpandKernel : public framework::OpKernel<T> { ...@@ -68,13 +68,16 @@ class BlockExpandKernel : public framework::OpKernel<T> {
img_height, img_width, block_height, block_width, stride_height, img_height, img_width, block_height, block_width, stride_height,
stride_width, padding_height, padding_width, outputHeight, outputWidth); stride_width, padding_height, padding_width, outputHeight, outputWidth);
std::vector<int> stride({stride_height, stride_width});
std::vector<int> padding({padding_height, padding_width});
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
Tensor src = in->Slice<T>(i, i + 1).Resize({C, img_height, img_width}); Tensor src = in->Slice(i, i + 1).Resize({C, img_height, img_width});
Tensor dst = out->Slice<T>(i, i + 1).Resize( Tensor dst = out->Slice(i, i + 1).Resize(
{outputHeight, outputWidth, C, block_height, block_width}); {outputHeight, outputWidth, C, block_height, block_width});
math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f; math::Im2ColFunctor<math::ColFormat::kOCF, Place, T> f;
f(ctx.device_context(), src, dst, stride_height, stride_width, f(ctx.device_context(), src, stride, padding, &dst);
padding_height, padding_width);
} }
} }
}; };
...@@ -112,13 +115,16 @@ class BlockExpandGradKernel : public framework::OpKernel<T> { ...@@ -112,13 +115,16 @@ class BlockExpandGradKernel : public framework::OpKernel<T> {
img_height, img_width, block_height, block_width, stride_height, img_height, img_width, block_height, block_width, stride_height,
stride_width, padding_height, padding_width, outputHeight, outputWidth); stride_width, padding_height, padding_width, outputHeight, outputWidth);
std::vector<int> stride({stride_height, stride_width});
std::vector<int> padding({padding_height, padding_width});
// std::vector<int> stride({stride_height, stride_width});
for (int i = 0; i < N; i++) { for (int i = 0; i < N; i++) {
Tensor dst = d_x->Slice<T>(i, i + 1).Resize({C, img_height, img_width}); Tensor dst = d_x->Slice(i, i + 1).Resize({C, img_height, img_width});
Tensor src = d_out->Slice<T>(i, i + 1).Resize( Tensor src = d_out->Slice(i, i + 1).Resize(
{outputHeight, outputWidth, C, block_height, block_width}); {outputHeight, outputWidth, C, block_height, block_width});
math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f; math::Col2ImFunctor<math::ColFormat::kOCF, Place, T> f;
f(ctx.device_context(), dst, src, stride_height, stride_width, f(ctx.device_context(), dst, stride, padding, &src);
padding_height, padding_width);
} }
} }
}; };
......
...@@ -4,27 +4,27 @@ from op_test import OpTest ...@@ -4,27 +4,27 @@ from op_test import OpTest
def get_output_shape(attrs, x): def get_output_shape(attrs, x):
imgHeight = x.shape[1] img_height = x.shape[1]
imgWidth = x.shape[2] img_width = x.shape[2]
paddingHeight = attrs['paddingHeight'] padding_height = attrs['paddingHeight']
paddingWidth = attrs['paddingWidth'] padding_width = attrs['paddingWidth']
blockHeight = attrs['blockHeight'] block_height = attrs['blockHeight']
blockWidth = attrs['blockWidth'] block_width = attrs['blockWidth']
strideHeight = attrs['strideHeight'] stride_height = attrs['strideHeight']
strideWidth = attrs['strideWidth'] stride_width = attrs['strideWidth']
outputHeight = \ output_height = \
1 + \ 1 + \
(imgHeight + 2 * paddingHeight - blockHeight + strideHeight - 1) / \ (img_height + 2 * padding_height - block_height + stride_height - 1) / \
strideHeight strideHeight
outputWidth = \ output_width = \
1 + \ 1 + \
(imgWidth + 2 * paddingWidth - blockWidth + strideWidth - 1) / \ (img_width + 2 * padding_width - block_width + stride_width - 1) / \
strideWidth stride_width
return outputHeight, outputWidth return output_height, output_width
def im2col(attrs, im, col): def im2col(attrs, im, col):
...@@ -34,38 +34,39 @@ def im2col(attrs, im, col): ...@@ -34,38 +34,39 @@ def im2col(attrs, im, col):
{outputHeight, outputWidth, inputChannels, filterHeight, filterWidth} {outputHeight, outputWidth, inputChannels, filterHeight, filterWidth}
""" """
input_channels = im.shape[0] input_channels = im.shape[0]
inputHeight = im.shape[1] input_height = im.shape[1]
inputWidth = im.shape[2] input_width = im.shape[2]
outputHeight = col.shape[0] output_height = col.shape[0]
outputWidth = col.shape[1] output_width = col.shape[1]
filterHeight = col.shape[3] filter_height = col.shape[3]
filterWidth = col.shape[4] filter_width = col.shape[4]
strideHeight = attrs['strideHeight'] stride_height = attrs['strideHeight']
strideWidth = attrs['strideWidth'] stride_width = attrs['strideWidth']
paddingHeight = attrs['paddingHeight'] padding_height = attrs['paddingHeight']
paddingWidth = attrs['paddingWidth'] padding_width = attrs['paddingWidth']
for col_row_idx in range(0, outputHeight): for col_row_idx in range(0, output_height):
for col_col_idx in range(0, outputWidth): for col_col_idx in range(0, output_width):
for channel in range(0, input_channels): for channel in range(0, input_channels):
for filter_row_idx in range(0, filterHeight): for filter_row_idx in range(0, filter_height):
for filter_col_idx in range(0, filterWidth): for filter_col_idx in range(0, filter_width):
im_row_offset = col_row_idx * strideHeight \ im_row_offset = col_row_idx * stride_height \
+ filter_row_idx - paddingHeight + filter_row_idx - padding_height
im_col_offset = col_col_idx * strideWidth \ im_col_offset = col_col_idx * stride_width \
+ filter_col_idx - paddingWidth + filter_col_idx - padding_width
if (im_row_offset < 0 or im_row_offset >= inputHeight or if (im_row_offset < 0 or
im_row_offset >= input_height or
im_col_offset < 0 or im_col_offset < 0 or
im_col_offset >= inputWidth): im_col_offset >= input_width):
col[col_row_idx][col_col_idx][channel][\ col[col_row_idx][col_col_idx][channel][\
filter_row_idx][filter_col_idx] = 0.0 filter_row_idx][filter_col_idx] = 0.0
else: else:
im_offset = (channel * inputHeight + im_row_offset \ im_offset = (channel * input_height + im_row_offset \
) * inputWidth + im_col_offset ) * input_width + im_col_offset
col[col_row_idx][col_col_idx][channel][\ col[col_row_idx][col_col_idx][channel][\
filter_row_idx][filter_col_idx] = im[channel][ \ filter_row_idx][filter_col_idx] = im[channel][ \
...@@ -76,55 +77,55 @@ def col2img(attrs, col, img): ...@@ -76,55 +77,55 @@ def col2img(attrs, col, img):
""" """
img: {CHW} img: {CHW}
col: col:
{outputHeight, outputWidth, inputChannels, filterHeight, filterWidth} {output_height, outputWidth, inputChannels, filterHeight, filterWidth}
""" """
input_channels = im.shape[0] input_channels = im.shape[0]
inputHeight = im.shape[1] input_height = im.shape[1]
inputWidth = im.shape[2] input_width = im.shape[2]
outputHeight = col.shape[0] output_height = col.shape[0]
outputWidth = col.shape[1] output_width = col.shape[1]
filterHeight = col.shape[3] filter_height = col.shape[3]
filterWidth = col.shape[4] filter_width = col.shape[4]
strideHeight = attrs['strideHeight'] stride_height = attrs['strideHeight']
strideWidth = attrs['strideWidth'] stride_width = attrs['strideWidth']
paddingHeight = attrs['paddingHeight'] padding_height = attrs['paddingHeight']
paddingWidth = attrs['paddingWidth'] padding_width = attrs['paddingWidth']
for col_row_idx in range(0, outputHeight): for col_row_idx in range(0, output_height):
for col_col_idx in range(0, outputWidth): for col_col_idx in range(0, output_width):
for channel in range(0, input_channels): for channel in range(0, input_channels):
for filter_row_idx in range(0, filterHeight): for filter_row_idx in range(0, filter_height):
for filter_col_idx in range(0, filterWidth): for filter_col_idx in range(0, filter_width):
im_row_offset = \ im_row_offset = \
col_row_idx * strideHeight + filter_row_idx - paddingHeight col_row_idx * stride_height + filter_row_idx - padding_height
im_col_offset = \ im_col_offset = \
col_col_idx * strideWidth + filter_col_idx - paddingWidth col_col_idx * stride_width + filter_col_idx - padding_width
if (im_row_offset >= 0 and if (im_row_offset >= 0 and
im_row_offset < inputHeight and im_row_offset < input_height and
im_col_offset >= 0 and im_col_offset >= 0 and
im_col_offset < inputWidth): im_col_offset < input_width):
im[channel][im_row_offset][im_col_offset] = \ im[channel][im_row_offset][im_col_offset] = \
col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx] col[col_row_idx][col_col_idx][channel][filter_row_idx][filter_col_idx]
class TestBlockExpandOp(OpTest): def get_input_data(C, H, W):
def get_input_data(self, C, H, W): x = np.random.uniform(0.1, 1, [C, H, W]).astype("float32")
x = np.random.uniform(0.1, 1, [C, H, W]).astype("float32") for c in range(0, C):
for c in range(0, C): for h in range(0, H):
for h in range(0, H): for w in range(0, W):
for w in range(0, W): #x[c][h][w] = c * H * W + h *W + w
#x[c][h][w] = c * H * W + h *W + w x[c][h][w] = 0.2 + 0.01 * (c * H * W + h * W + w)
x[c][h][w] = 0.2 + 0.01 * (c * H * W + h * W + w)
return x return x
class TestBlockExpandOp(OpTest):
def setUp(self): def setUp(self):
C = 3 C = 3
H = 4 H = 4
W = 4 W = 4
x = self.get_input_data(C, H, W) x = get_input_data(C, H, W)
#print x
attrs = { attrs = {
'blockHeight': 2, 'blockHeight': 2,
...@@ -135,9 +136,47 @@ class TestBlockExpandOp(OpTest): ...@@ -135,9 +136,47 @@ class TestBlockExpandOp(OpTest):
'paddingWidth': 1, 'paddingWidth': 1,
} }
outputHeight, outputWidth = get_output_shape(attrs, x) output_height, output_width = get_output_shape(attrs, x)
out = np.random.uniform(0.1, 1,\
[output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth']]).astype("float32")
self.op_type = "block_expand"
self.inputs = {'X': x.reshape(1, C, H, W)}
self.attrs = attrs
im2col(attrs, x, out)
self.outputs = {
'Out':out.reshape(1, output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth'])
}
def test_check_output(self):
self.check_output()
def test_check_grad_normal(self):
self.check_grad(['X'], 'Out')
class TestBlockExpandOp2(OpTest):
def setUp(self):
C = 3
H = 4
W = 5
x = get_input_data(C, H, W)
attrs = {
'blockHeight': 2,
'blockWidth': 1,
'strideHeight': 2,
'strideWidth': 1,
'paddingHeight': 2,
'paddingWidth': 1,
}
output_height, output_width = get_output_shape(attrs, x)
out = np.random.uniform(0.1, 1,\ out = np.random.uniform(0.1, 1,\
[outputHeight, outputWidth, x.shape[0], \ [output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth']]).astype("float32") attrs['blockHeight'], attrs['blockWidth']]).astype("float32")
self.op_type = "block_expand" self.op_type = "block_expand"
...@@ -146,7 +185,7 @@ class TestBlockExpandOp(OpTest): ...@@ -146,7 +185,7 @@ class TestBlockExpandOp(OpTest):
im2col(attrs, x, out) im2col(attrs, x, out)
self.outputs = { self.outputs = {
'Out':out.reshape(1, outputHeight, outputWidth, x.shape[0], \ 'Out':out.reshape(1, output_height, output_width, x.shape[0], \
attrs['blockHeight'], attrs['blockWidth']) attrs['blockHeight'], attrs['blockWidth'])
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册