From b59c1063f522d23e9d19e64103dd170c6c31d823 Mon Sep 17 00:00:00 2001 From: Haihao Shen Date: Wed, 10 Oct 2018 17:35:40 +0800 Subject: [PATCH] Add missing reshape found by GM --- paddle/fluid/operators/dequantize_op.cc | 7 +++++++ paddle/fluid/operators/quantize_op.cc | 8 ++++++++ paddle/fluid/operators/requantize_op.cc | 8 ++++++++ 3 files changed, 23 insertions(+) diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 1ee5ed61ec0..f8d87af8ae8 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -35,6 +35,13 @@ using platform::GetMKLDNNFormat; template class DeQuantOpKernel : public framework::OpKernel { public: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output should not be null"); + + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index 06336b025e4..d7c831ab5ad 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -33,6 +33,14 @@ using platform::GetMKLDNNFormat; template class QuantOpKernel : public framework::OpKernel { public: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output should not be null"); + + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* scale = ctx.Input("Scale"); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index 9abc0a4ea45..3b308400d91 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -34,6 +34,14 @@ using platform::GetMKLDNNFormat; template class ReQuantOpKernel : public framework::OpKernel { public: + void InferShape(framework::InferShapeContext *ctx) const override { + PADDLE_ENFORCE(ctx->HasInput("Input"), "Input should not be null"); + PADDLE_ENFORCE(ctx->HasOutput("Output"), "Output should not be null"); + + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } + void Compute(const framework::ExecutionContext& ctx) const override { auto* input = ctx.Input("Input"); auto* scale = ctx.Input("Scale"); -- GitLab