未验证 提交 7d6a4a54 编写于 作者: H Hulek 提交者: GitHub

Delete caching from requantize_mkldnn_op and changed to Acquire API (#48113)

* Delete caching from requantize_mkldnn_op and changed to Acquire API
* Fixed codestyle and implementation
上级 4244fa6e
......@@ -17,7 +17,8 @@ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/requantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/phi/backends/onednn/onednn_helper.h"
#include "paddle/phi/backends/onednn/onednn_reuse.h"
namespace paddle {
namespace operators {
......@@ -56,101 +57,56 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("Scale of output cannot be 0.0"));
if (shift_in != 0.0f) {
PADDLE_ENFORCE_EQ(
framework::TransToProtoVarType(input->dtype()),
framework::proto::VarType::UINT8,
input->dtype(),
DataType::UINT8,
platform::errors::Unimplemented("Requantize does not support nonzero "
"shift for signed input."));
}
auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
auto src_tz = phi::vectorize(input->dims());
float reorder_scale = scale_out / scale_in;
auto src_paddle_dt = input->dtype();
auto dst_paddle_dt = with_shift ? DataType::UINT8 : src_paddle_dt;
auto xstrides = input->mem_desc().data.format_desc.blocking.strides;
std::vector<dnnl_dim_t> vstrides(xstrides,
xstrides + input->mem_desc().data.ndims);
std::string key = platform::CreateKey(
dev_ctx, src_tz, scale_in, scale_out, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
const T* input_data = input->data<T>();
if (reorder_p == nullptr) {
auto src_dt = framework::ToMKLDNNDataType(
framework::TransToProtoVarType(input->dtype()));
auto dst_dt = with_shift ? framework::OneDNNDataType::u8 : src_dt;
src_memory = std::make_shared<dnnl::memory>(
input->mem_desc(), engine, phi::funcs::to_void_cast<T>(input_data));
auto xstrides = input->mem_desc().data.format_desc.blocking.strides;
std::vector<dnnl_dim_t> vstrides(xstrides,
xstrides + input->mem_desc().data.ndims);
auto dst_md = dnnl::memory::desc({src_tz}, dst_dt, vstrides);
dnnl::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {reorder_scale});
if (with_shift) {
dnnl::post_ops post_operations;
post_operations.append_sum();
attri.set_post_ops(post_operations);
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, phi::funcs::to_void_cast<uint8_t>(output_data));
} else {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory = std::make_shared<dnnl::memory>(
dst_md, engine, phi::funcs::to_void_cast<T>(output_data));
}
auto reorder_pd =
reorder::primitive_desc(*src_memory, *dst_memory, attri);
reorder_p = std::make_shared<reorder>(reorder_pd);
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(phi::funcs::to_void_cast<T>(input_data));
dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
if (with_shift) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in);
std::memset(output_data, reorder_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
T* output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory->set_data_handle(output_data);
}
dnnl::primitive_attr attrs;
int mask = 0;
float reorder_scale = scale_out / scale_in;
attrs.set_output_scales(mask, {reorder_scale});
if (with_shift) {
uint8_t reorder_shift =
clip_to_uint8(shift_out - reorder_scale * shift_in);
attrs.set_zero_points(
DNNL_ARG_DST, mask, {static_cast<int32_t>(reorder_shift)});
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
phi::funcs::ReorderOneDNNHandler reorder_handler(
src_tz,
src_paddle_dt,
phi::funcs::ToOneDNNDataType(src_paddle_dt),
dst_paddle_dt,
phi::funcs::ToOneDNNDataType(dst_paddle_dt),
dev_ctx.GetEngine());
reorder_p->execute(astream, *src_memory, *dst_memory);
auto src_memory_p = reorder_handler.AcquireSrcMemory(
input->mem_desc(), phi::funcs::to_void_cast(input->data<T>()));
auto dst_memory_p = reorder_handler.AcquireDstMemory(
output, src_tz, vstrides, dev_ctx.GetPlace());
auto reorder_p =
reorder_handler.AcquireReorder(dst_memory_p, src_memory_p, attrs);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
reorder_p->execute(astream, *src_memory_p, *dst_memory_p);
astream.wait();
output->set_mem_desc(dst_memory->get_desc());
output->set_mem_desc(dst_memory_p->get_desc());
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册