未验证 提交 7d6a4a54 编写于 作者: H Hulek 提交者: GitHub

Delete caching from requantize_mkldnn_op and changed to Acquire API (#48113)

* Delete caching from requantize_mkldnn_op and changed to Acquire API
* Fixed codestyle and implementation
上级 4244fa6e
...@@ -17,7 +17,8 @@ limitations under the License. */ ...@@ -17,7 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/requantize_op.h" #include "paddle/fluid/operators/requantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/phi/backends/onednn/onednn_helper.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -56,101 +57,56 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -56,101 +57,56 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("Scale of output cannot be 0.0")); platform::errors::InvalidArgument("Scale of output cannot be 0.0"));
if (shift_in != 0.0f) { if (shift_in != 0.0f) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(input->dtype()), input->dtype(),
framework::proto::VarType::UINT8, DataType::UINT8,
platform::errors::Unimplemented("Requantize does not support nonzero " platform::errors::Unimplemented("Requantize does not support nonzero "
"shift for signed input.")); "shift for signed input."));
} }
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
auto src_tz = phi::vectorize(input->dims()); auto src_tz = phi::vectorize(input->dims());
float reorder_scale = scale_out / scale_in; auto src_paddle_dt = input->dtype();
auto dst_paddle_dt = with_shift ? DataType::UINT8 : src_paddle_dt;
std::string key = platform::CreateKey(
dev_ctx, src_tz, scale_in, scale_out, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
const T* input_data = input->data<T>();
if (reorder_p == nullptr) {
auto src_dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype()));
auto dst_dt = with_shift ? framework::OneDNNDataType::u8 : src_dt;
src_memory = std::make_shared<dnnl::memory>(
input->mem_desc(), engine, phi::funcs::to_void_cast<T>(input_data));
auto xstrides = input->mem_desc().data.format_desc.blocking.strides; auto xstrides = input->mem_desc().data.format_desc.blocking.strides;
std::vector<dnnl_dim_t> vstrides(xstrides, std::vector<dnnl_dim_t> vstrides(xstrides,
xstrides + input->mem_desc().data.ndims); xstrides + input->mem_desc().data.ndims);
auto dst_md = dnnl::memory::desc({src_tz}, dst_dt, vstrides); dnnl::primitive_attr attrs;
dnnl::primitive_attr attri;
int mask = 0; int mask = 0;
attri.set_output_scales(mask, {reorder_scale}); float reorder_scale = scale_out / scale_in;
attrs.set_output_scales(mask, {reorder_scale});
if (with_shift) { if (with_shift) {
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
uint8_t reorder_shift = uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in); clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel()); attrs.set_zero_points(
dst_memory = std::make_shared<dnnl::memory>( DNNL_ARG_DST, mask, {static_cast<int32_t>(reorder_shift)});
dst_md, engine, phi::funcs::to_void_cast<uint8_t>(output_data));
} else {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, phi::funcs::to_void_cast<T>(output_data));
} }
auto reorder_pd = phi::funcs::ReorderOneDNNHandler reorder_handler(
reorder::primitive_desc(*src_memory, *dst_memory, attri); src_tz,
reorder_p = std::make_shared<reorder>(reorder_pd); src_paddle_dt,
phi::funcs::ToOneDNNDataType(src_paddle_dt),
dev_ctx.SetBlob(key_prim, reorder_p); dst_paddle_dt,
dev_ctx.SetBlob(key_src_mem, src_memory); phi::funcs::ToOneDNNDataType(dst_paddle_dt),
dev_ctx.SetBlob(key_dst_mem, dst_memory); dev_ctx.GetEngine());
} else {
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(phi::funcs::to_void_cast<T>(input_data));
dst_memory = auto src_memory_p = reorder_handler.AcquireSrcMemory(
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem)); input->mem_desc(), phi::funcs::to_void_cast(input->data<T>()));
if (with_shift) { auto dst_memory_p = reorder_handler.AcquireDstMemory(
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace()); output, src_tz, vstrides, dev_ctx.GetPlace());
uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else { auto reorder_p =
T* output_data = output->mutable_data<T>(ctx.GetPlace()); reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs);
dst_memory->set_data_handle(output_data);
}
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory_p, *dst_memory_p);
reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait(); astream.wait();
output->set_mem_desc(dst_memory->get_desc()); output->set_mem_desc(dst_memory_p->get_desc());
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册