diff --git a/paddle/fluid/operators/fake_quantize_op.cc b/paddle/fluid/operators/fake_quantize_op.cc index 25ca1f7e0a0086b803d48aa892b0888e0d5635b1..beaa4392f204103449f9ad167573c3746f1e0abd 100644 --- a/paddle/fluid/operators/fake_quantize_op.cc +++ b/paddle/fluid/operators/fake_quantize_op.cc @@ -458,6 +458,18 @@ $$Out = X$$ } }; +class MovingAverageAbsMaxScaleOpInplaceInToOut + : public framework::InplaceOpInference { + public: + std::unordered_map operator()( + const framework::OpDesc& op_desc) const override { + std::unordered_map inplace_in_to_out = { + {"X", "Out"}, + }; + return inplace_in_to_out; + } +}; + } // namespace operators } // namespace paddle @@ -482,6 +494,7 @@ REGISTER_OPERATOR(fake_quantize_moving_average_abs_max, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, ops::FakeQuantizeMovingAverageAbsMaxKernel); + REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, ops::FakeChannelWiseQuantizeAbsMaxOp, ops::FakeChannelWiseQuantizeAbsMaxOpMaker, @@ -491,6 +504,7 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, ops::MovingAverageAbsMaxScaleOpMaker, - paddle::framework::EmptyGradOpMaker); + paddle::framework::EmptyGradOpMaker, + ops::MovingAverageAbsMaxScaleOpInplaceInToOut); REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleKernel); diff --git a/paddle/fluid/operators/fake_quantize_op.h b/paddle/fluid/operators/fake_quantize_op.h index 87bcece582442e7336049d65bcabc87eadd52342..b0fd459095134cc21fbdcb1214647e1af73e5e6b 100644 --- a/paddle/fluid/operators/fake_quantize_op.h +++ b/paddle/fluid/operators/fake_quantize_op.h @@ -15,9 +15,9 @@ limitations under the License. */ #pragma once #include +#include #include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/math/blas.h" namespace paddle { @@ -204,9 +204,8 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext& context) const override { auto* in = context.Input("X"); auto* out = context.Output("Out"); - out->mutable_data(context.GetPlace()); auto& dev_ctx = context.template device_context(); - framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); + out->ShareDataWith(*in); bool is_test = context.Attr("is_test"); // testing