提交 a19acd20 编写于 作者: Z Zhen Wang

change the output into inplace. test=develop

上级 696cf626
2 合并请求!17258Add quant scale pass,!17215Quant output scale
......@@ -458,6 +458,18 @@ $$Out = X$$
}
};
class MovingAverageAbsMaxScaleOpInplaceInToOut
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"X", "Out"},
};
return inplace_in_to_out;
}
};
} // namespace operators
} // namespace paddle
......@@ -482,6 +494,7 @@ REGISTER_OPERATOR(fake_quantize_moving_average_abs_max,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
......@@ -491,6 +504,7 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker,
paddle::framework::EmptyGradOpMaker);
paddle::framework::EmptyGradOpMaker,
ops::MovingAverageAbsMaxScaleOpInplaceInToOut);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
......@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h"
namespace paddle {
......@@ -204,9 +204,8 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>();
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out);
out->ShareDataWith(*in);
bool is_test = context.Attr<bool>("is_test");
// testing
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册
反馈
建议
客服 返回
顶部