From 040dc59b0f3558baee90627936716823569700d5 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Thu, 4 Jan 2018 21:57:16 +0800 Subject: [PATCH] Correctly handle image operators --- paddle/operators/batch_norm_op.cc | 3 ++- paddle/operators/conv_op.cc | 7 +++---- paddle/operators/pool_op.cc | 1 + 3 files changed, 6 insertions(+), 5 deletions(-) diff --git a/paddle/operators/batch_norm_op.cc b/paddle/operators/batch_norm_op.cc index 98db28ddee..dd7b038b00 100644 --- a/paddle/operators/batch_norm_op.cc +++ b/paddle/operators/batch_norm_op.cc @@ -64,7 +64,7 @@ class BatchNormOp : public framework::OperatorWithKernel { PADDLE_ENFORCE(x_dims.size() >= 2 && x_dims.size() <= 5, "Input X must have 2 to 5 dimensions."); - const int C = + const int64_t C = (data_layout == DataLayout::kNCHW ? x_dims[1] : x_dims[x_dims.size() - 1]); @@ -78,6 +78,7 @@ class BatchNormOp : public framework::OperatorWithKernel { ctx->SetOutputDim("VarianceOut", {C}); ctx->SetOutputDim("SavedMean", {C}); ctx->SetOutputDim("SavedVariance", {C}); + ctx->ShareLoD("X", "Y"); } }; diff --git a/paddle/operators/conv_op.cc b/paddle/operators/conv_op.cc index e65a5dce52..ad84524e17 100644 --- a/paddle/operators/conv_op.cc +++ b/paddle/operators/conv_op.cc @@ -44,14 +44,12 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { paddings.size(), strides.size(), "Conv paddings dimension and Conv strides dimension should be the same."); - int input_channels = in_dims[1]; - PADDLE_ENFORCE_EQ(input_channels, filter_dims[1] * groups, + PADDLE_ENFORCE_EQ(in_dims[1], filter_dims[1] * groups, "The number of input channels should be equal to filter " "channels * groups."); - int output_channels = filter_dims[0]; PADDLE_ENFORCE_EQ( - output_channels % groups, 0, + filter_dims[0] % groups, 0, "The number of output channels should be divided by groups."); std::vector output_shape({in_dims[0], filter_dims[0]}); @@ -66,6 +64,7 @@ void ConvOp::InferShape(framework::InferShapeContext* ctx) const { dilations[i], paddings[i], strides[i])); } ctx->SetOutputDim("Output", framework::make_ddim(output_shape)); + ctx->ShareLoD("Input", "Output"); } Conv2DOpMaker::Conv2DOpMaker(OpProto* proto, OpAttrChecker* op_checker) diff --git a/paddle/operators/pool_op.cc b/paddle/operators/pool_op.cc index 50057eb648..d3cf5fa638 100644 --- a/paddle/operators/pool_op.cc +++ b/paddle/operators/pool_op.cc @@ -58,6 +58,7 @@ void PoolOp::InferShape(framework::InferShapeContext *ctx) const { OutputSizePool(in_x_dims[i + 2], ksize[i], paddings[i], strides[i])); } ctx->SetOutputDim("Out", framework::make_ddim(output_shape)); + ctx->ShareLoD("X", "Out"); } void PoolOpGrad::InferShape(framework::InferShapeContext *ctx) const { -- GitLab