提交 040dc59b 编写于 作者: Y Yang Yu

Correctly handle image operators

上级 2b259bf1
...@@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"Input X must have 2 to 5 dimensions."); "Input X must have 2 to 5 dimensions.");
const int C = const int64_t C =
(data_layout == DataLayout::kNCHW ? x_dims[1] (data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]); : x_dims[x_dims.size() - 1]);
...@@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ...@@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("VarianceOut", {C}); ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C}); ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
} }
}; };
......
...@@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
paddings.size(), strides.size(), paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension should be the same."); "Conv paddings dimension and Conv strides dimension should be the same.");
int input_channels = in_dims[1]; PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
"The number of input channels should be equal to filter " "The number of input channels should be equal to filter "
"channels * groups."); "channels * groups.");
int output_channels = filter_dims[0];
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
output_channels % groups, 0, filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups."); "The number of output channels should be divided by groups.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]}); std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
...@@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { ...@@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
dilations[i], paddings[i], strides[i])); dilations[i], paddings[i], strides[i]));
} }
ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
} }
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
...@@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { ...@@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
} }
ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out");
} }
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册