diff --git a/paddle/operators/batch_norm_op.cc b/paddle/operators/batch_norm_op.cc index 98db28ddee7c6cdb37fe7732649d4fc38de7b873..dd7b038b00813b192177c05dc06aa165a60b5156 100644 --- a/paddle/operators/batch_norm_op.cc +++ b/paddle/operators/batch_norm_op.cc @@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, "Input X must have 2 to 5 dimensions."); - const int C = + const int64_t C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); @@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ctx->SetOutputDim("VarianceOut", {C}); ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); + ctx->ShareLoD("X", "Y"); } }; diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index e65a5dce52c3c51d3d6bee1684c1e97230203d38..ad84524e1785e3b6b4586c83001852b6dba7afe8 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { paddings.size(), strides.size(), "Conv paddings dimension and Conv strides dimension should be the same."); - int input_channels = in_dims[1]; - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); - int output_channels = filter_dims[0]; PADDLE_ENFORCE_EQ( - output_channels % groups, 0, + filter_dims[0] % groups, 0, "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); @@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { dilations[i], paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Output"); } Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 50057eb6483e9c9e745bc07dee26a0bbbbb5a48c..d3cf5fa638c53dfdfacec153211f447a1e2fa3bf 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("X", "Out"); } void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const {