diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 1ee5ed61ec0ed6be2d6b32b9cb8d031175ac6ae7..f8d87af8ae8273c527feba0392b6ccc7ea0f2b5f 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -35,6 +35,13 @@ using platform::GetMKLDNNFormat; template class DeQuantOpKernel : public framework::OpKernel { public: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output should not be null"); + + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index 06336b025e4cb4ad7886a94597d23aa472041358..d7c831ab5ad50b1be1ba3f2cd9048fe7d3a01beb 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -33,6 +33,14 @@ using platform::GetMKLDNNFormat; template class QuantOpKernel : public framework::OpKernel { public: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output should not be null"); + + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* scale = ctx.Input("Scale"); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 9abc0a4ea45ef8ebcd8a96571823d9dace05fa7f..3b308400d916d2831d61d0955d44bc0a691be16d 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -34,6 +34,14 @@ using platform::GetMKLDNNFormat; template class ReQuantOpKernel : public framework::OpKernel { public: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output should not be null"); + + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* scale = ctx.Input("Scale");