未验证 提交 5c2b9258 编写于 作者: W Wojciech Uss 提交者: GitHub

Fix (de/re)quantize cache keys (#26549)

上级 eeda90d6
...@@ -51,11 +51,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -51,11 +51,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
std::string key = std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt,
platform::CreateKey(src_dt, src_tz, ctx.OutputName("Output")); src_tz, ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p"; const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@src_mem"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@dst_mem"; const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<mkldnn::memory> dst_memory;
......
...@@ -48,11 +48,12 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -48,11 +48,12 @@ class QuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
bool is_negative = ctx.Attr<bool>("is_negative_input"); bool is_negative = ctx.Attr<bool>("is_negative_input");
std::string key = platform::CreateKey(src_tz, scale_data, is_negative, std::string key =
ctx.OutputName("Output")); platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data,
const std::string key_prim = key + "@reorder_p"; is_negative, ctx.OutputName("Output"));
const std::string key_src_mem = key + "@src_mem"; const std::string key_prim = key + "@r";
const std::string key_dst_mem = key + "@dst_mem"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<mkldnn::memory> dst_memory;
......
...@@ -40,11 +40,12 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -40,11 +40,12 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize(input->dims()); auto src_tz = paddle::framework::vectorize(input->dims());
std::string key = platform::CreateKey(src_tz, scale_in, scale_out, std::string key =
ctx.OutputName("Output")); platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in,
const std::string key_prim = key + "@reorder_p"; scale_out, ctx.OutputName("Output"));
const std::string key_src_mem = key + "@src_mem"; const std::string key_prim = key + "@r";
const std::string key_dst_mem = key + "@dst_mem"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory; std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory; std::shared_ptr<dnnl::memory> dst_memory;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册