提交 a19acd20 编写于 作者: Z Zhen Wang

change the output into inplace. test=develop

上级 696cf626
...@@ -458,6 +458,18 @@ $$Out = X$$ ...@@ -458,6 +458,18 @@ $$Out = X$$
} }
}; };
class MovingAverageAbsMaxScaleOpInplaceInToOut
: public framework::InplaceOpInference {
public:
std::unordered_map<std::string, std::string> operator()(
const framework::OpDesc& op_desc) const override {
std::unordered_map<std::string, std::string> inplace_in_to_out = {
{"X", "Out"},
};
return inplace_in_to_out;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -482,6 +494,7 @@ REGISTER_OPERATOR(fake_quantize_moving_average_abs_max, ...@@ -482,6 +494,7 @@ REGISTER_OPERATOR(fake_quantize_moving_average_abs_max,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max, REGISTER_OP_CPU_KERNEL(fake_quantize_moving_average_abs_max,
ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>); ops::FakeQuantizeMovingAverageAbsMaxKernel<CPU, float>);
REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max, REGISTER_OPERATOR(fake_channel_wise_quantize_abs_max,
ops::FakeChannelWiseQuantizeAbsMaxOp, ops::FakeChannelWiseQuantizeAbsMaxOp,
ops::FakeChannelWiseQuantizeAbsMaxOpMaker, ops::FakeChannelWiseQuantizeAbsMaxOpMaker,
...@@ -491,6 +504,7 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max, ...@@ -491,6 +504,7 @@ REGISTER_OP_CPU_KERNEL(fake_channel_wise_quantize_abs_max,
REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp, REGISTER_OPERATOR(moving_average_abs_max_scale, ops::MovingAverageAbsMaxScaleOp,
ops::MovingAverageAbsMaxScaleOpMaker, ops::MovingAverageAbsMaxScaleOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker,
ops::MovingAverageAbsMaxScaleOpInplaceInToOut);
REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale, REGISTER_OP_CPU_KERNEL(moving_average_abs_max_scale,
ops::MovingAverageAbsMaxScaleKernel<CPU, float>); ops::MovingAverageAbsMaxScaleKernel<CPU, float>);
...@@ -15,9 +15,9 @@ limitations under the License. */ ...@@ -15,9 +15,9 @@ limitations under the License. */
#pragma once #pragma once
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/eigen.h" #include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/operators/math/blas.h"
namespace paddle { namespace paddle {
...@@ -204,9 +204,8 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> { ...@@ -204,9 +204,8 @@ class MovingAverageAbsMaxScaleKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<framework::Tensor>("X"); auto* in = context.Input<framework::Tensor>("X");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
framework::TensorCopy(*in, context.GetPlace(), dev_ctx, out); out->ShareDataWith(*in);
bool is_test = context.Attr<bool>("is_test"); bool is_test = context.Attr<bool>("is_test");
// testing // testing
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册