未验证 提交 a8b39968 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #7219 from reyoung/feature/correctly_handle_lod_information_for_image_operators

Correctly handle lod information of image operators
......@@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5,
"Input X must have 2 to 5 dimensions.");
const int C =
const int64_t C =
(data_layout == DataLayout::kNCHW ? x_dims[1]
: x_dims[x_dims.size() - 1]);
......@@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("VarianceOut", {C});
ctx->SetOutputDim("SavedMean", {C});
ctx->SetOutputDim("SavedVariance", {C});
ctx->ShareLoD("X", "Y");
}
};
......
......@@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
paddings.size(), strides.size(),
"Conv paddings dimension and Conv strides dimension should be the same.");
int input_channels = in_dims[1];
PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups,
PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups,
"The number of input channels should be equal to filter "
"channels * groups.");
int output_channels = filter_dims[0];
PADDLE_ENFORCE_EQ(
output_channels % groups, 0,
filter_dims[0] % groups, 0,
"The number of output channels should be divided by groups.");
std::vector<int64_t> output_shape({in_dims[0], filter_dims[0]});
......@@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const {
dilations[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Output", framework::make_ddim(output_shape));
ctx->ShareLoD("Input", "Output");
}
Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......
......@@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const {
OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i]));
}
ctx->SetOutputDim("Out", framework::make_ddim(output_shape));
ctx->ShareLoD("X", "Out");
}
void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册