diff --git a/paddle/fluid/operators/dequantize_op.h b/paddle/fluid/operators/dequantize_op.h index 350b0376c2fcd19f2a1e59a33a7702b3e904be7e..b0d42311af6a3a5b3f87263aab6dd32cf9e6836f 100644 --- a/paddle/fluid/operators/dequantize_op.h +++ b/paddle/fluid/operators/dequantize_op.h @@ -28,7 +28,10 @@ class DeQuantOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override{ + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } protected: framework::OpKernelType GetExpectedKernelType( diff --git a/paddle/fluid/operators/quantize_op.h b/paddle/fluid/operators/quantize_op.h index ffdbe8400c0be8eeb3361c5822204bfc1fb1ed95..4763f523ecb24706d1f2e0117b6a1d80ab050151 100644 --- a/paddle/fluid/operators/quantize_op.h +++ b/paddle/fluid/operators/quantize_op.h @@ -29,7 +29,7 @@ class QuantOp : public framework::OperatorWithKernel { using framework::OperatorWithKernel::OperatorWithKernel; void InferShape(framework::InferShapeContext* ctx) const override{ - ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); ctx->ShareLoD("Input", /*->*/ "Output"); } diff --git a/paddle/fluid/operators/requantize_op.h b/paddle/fluid/operators/requantize_op.h index f96df360691a4f6f1b8c66f3402f3c5c9a62bcf3..75401f4d516b3fd64d641c4b5f031ed5f8212812 100644 --- a/paddle/fluid/operators/requantize_op.h +++ b/paddle/fluid/operators/requantize_op.h @@ -28,7 +28,10 @@ class ReQuantOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; - void InferShape(framework::InferShapeContext* ctx) const override {} + void InferShape(framework::InferShapeContext* ctx) const override{ + ctx->SetOutputDim("Output", ctx->GetInputDim("Input")); + ctx->ShareLoD("Input", /*->*/ "Output"); + } protected: framework::OpKernelType GetExpectedKernelType(