提交 f8cbc4f3 编写于 作者: X xiaoli.liu@intel.com 提交者: ceci3

Optimize INT8 DeQuantize Op with primitive reuse.

test=develop
上级 701af439
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/dequantize_op.h" #include "paddle/fluid/operators/dequantize_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -30,6 +31,18 @@ using framework::DataLayout; ...@@ -30,6 +31,18 @@ using framework::DataLayout;
using mkldnn::stream; using mkldnn::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const mkldnn::memory::data_type& src_dt,
const std::vector<int>& src_tz, const float scale_data) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(src_dt));
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output"));
return key;
}
template <typename T> template <typename T>
class DeQuantOpKernel : public framework::OpKernel<T> { class DeQuantOpKernel : public framework::OpKernel<T> {
public: public:
...@@ -51,14 +64,24 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -51,14 +64,24 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::format src_fmt = input->format(); mkldnn::memory::format src_fmt = input->format();
std::string key = CreateKey(ctx, src_dt, src_tz, reorder_scale[0]);
const std::string key_prim = key + "@reorder_p";
const std::string key_src_mem = key + "@src_mem";
const std::string key_dst_mem = key + "@dst_mem";
std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
mkldnn::primitive_attr attri; mkldnn::primitive_attr attri;
int mask = 0; int mask = 0;
attri.set_output_scales(mask, reorder_scale); attri.set_output_scales(mask, reorder_scale);
auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt); auto src_md = platform::MKLDNNMemDesc({src_tz}, src_dt, src_fmt);
auto src_pd = mkldnn::memory::primitive_desc(src_md, engine); auto src_pd = mkldnn::memory::primitive_desc(src_md, engine);
auto src_memory = src_memory =
std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data)); std::make_shared<mkldnn::memory>(src_pd, to_void_cast<T>(input_data));
std::shared_ptr<primitive::at> src_memory_p = std::shared_ptr<primitive::at> src_memory_p =
std::shared_ptr<primitive::at>(new primitive::at(*src_memory)); std::shared_ptr<primitive::at>(new primitive::at(*src_memory));
...@@ -66,16 +89,30 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -66,16 +89,30 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32, auto dst_md = platform::MKLDNNMemDesc({dst_tz}, memory::data_type::f32,
memory::format::nchw); memory::format::nchw);
auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine); auto dst_pd = mkldnn::memory::primitive_desc(dst_md, engine);
auto dst_memory = mkldnn::memory(dst_pd, to_void_cast<float>(output_data)); dst_memory = std::make_shared<mkldnn::memory>(
dst_pd, to_void_cast<float>(output_data));
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(src_pd, dst_pd, attri)); new reorder::primitive_desc(src_pd, dst_pd, attri));
auto reorder_p = std::shared_ptr<reorder>( reorder_p = std::shared_ptr<reorder>(
new reorder(*reorder_pd, *src_memory_p, dst_memory)); new reorder(*reorder_pd, *src_memory_p, *dst_memory));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
dst_memory->set_data_handle(output->mutable_data<float>(ctx.GetPlace()));
}
pipeline.push_back(*reorder_p); pipeline.push_back(*reorder_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_format(GetMKLDNNFormat(dst_memory)); output->set_format(GetMKLDNNFormat(*dst_memory));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册