提交 9942d9ed 编写于 作者: A Adam 提交者: Tao Luo

Add caching mechanizm to requantize_mkldnn_op (#22223)

上级 1230c110
......@@ -21,14 +21,10 @@ limitations under the License. */
namespace paddle {
namespace operators {
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using dnnl::memory;
using dnnl::reorder;
using platform::to_void_cast;
using Tensor = framework::Tensor;
using framework::DataLayout;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
template <typename T>
class ReQuantOpKernel : public framework::OpKernel<T> {
......@@ -42,42 +38,66 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline;
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims());
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = src_dt;
MKLDNNMemoryFormat src_fmt = MKLDNNMemoryFormat::nhwc;
MKLDNNMemoryFormat dst_fmt = MKLDNNMemoryFormat::nhwc;
auto src_tz = paddle::framework::vectorize(input->dims());
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
float scale_shift = scale_out / scale_in;
mkldnn::primitive_attr attri;
int mask = 0;
attri.set_output_scales(mask, {scale_shift});
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
auto src_memory = std::make_shared<mkldnn::memory>(
src_md, engine, to_void_cast<T>(input_data));
std::string key = platform::CreateKey(src_tz, scale_in, scale_out,
ctx.OutputName("Output"));
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, dst_dt, dst_fmt);
auto dst_memory =
mkldnn::memory(dst_md, engine, to_void_cast<T>(output_data));
std::shared_ptr<dnnl::memory> src_memory;
std::shared_ptr<dnnl::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, dst_memory, attri));
auto reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
mkldnn::stream astream(engine);
reorder_p->execute(astream, *src_memory, dst_memory);
if (reorder_p == nullptr) {
dnnl::primitive_attr attri;
int mask = 0;
float scale_shift = scale_out / scale_in;
attri.set_output_scales(mask, {scale_shift});
auto dst_tz = paddle::framework::vectorize(output->dims());
dnnl::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
dnnl::memory::data_type dst_dt = src_dt;
auto src_md =
platform::MKLDNNMemDesc({src_tz}, src_dt, MKLDNNMemoryFormat::nhwc);
src_memory = std::make_shared<dnnl::memory>(src_md, engine,
to_void_cast<T>(input_data));
auto dst_md =
platform::MKLDNNMemDesc({dst_tz}, dst_dt, MKLDNNMemoryFormat::nhwc);
dst_memory = std::make_shared<dnnl::memory>(dst_md, engine,
to_void_cast<T>(output_data));
auto reorder_pd =
reorder::primitive_desc(*src_memory, *dst_memory, attri);
reorder_p = std::make_shared<reorder>(reorder_pd);
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory =
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(key_dst_mem));
dst_memory->set_data_handle(output_data);
}
dnnl::stream astream(engine);
reorder_p->execute(astream, *src_memory, *dst_memory);
astream.wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(dst_memory));
output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory));
}
};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册