未验证 提交 5c2b9258 编写于 作者: W Wojciech Uss 提交者: GitHub

Fix (de/re)quantize cache keys (#26549)

上级 eeda90d6
......@@ -51,11 +51,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format();
std::string key =
platform::CreateKey(src_dt, src_tz, ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt,
src_tz, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
......
......@@ -48,11 +48,12 @@ class QuantOpKernel : public framework::OpKernel<T> {
const T* input_data = input->data<T>();
bool is_negative = ctx.Attr<bool>("is_negative_input");
std::string key = platform::CreateKey(src_tz, scale_data, is_negative,
ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_data,
is_negative, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
......
......@@ -40,11 +40,12 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
auto src_tz = paddle::framework::vectorize(input->dims());
std::string key = platform::CreateKey(src_tz, scale_in, scale_out,
ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in,
scale_out, ctx.OutputName("Output"));
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册