未验证 提交 bf748f24 编写于 作者: J Jacek Czaja 提交者: GitHub

Implemented LRU based cache clearing (#36290)

- Lint

- Merge with develop

- lint
上级 59e425cd
...@@ -78,7 +78,8 @@ class ConvMKLDNNHandlerT ...@@ -78,7 +78,8 @@ class ConvMKLDNNHandlerT
mkldnn::convolution_backward_weights>( mkldnn::convolution_backward_weights>(
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) { unique_name)),
is_test_(ctx.Attr<bool>("is_test")) {
if (!this->isCached()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->layout(), framework::DataLayout::kMKLDNN, input->layout(), framework::DataLayout::kMKLDNN,
...@@ -159,7 +160,6 @@ class ConvMKLDNNHandlerT ...@@ -159,7 +160,6 @@ class ConvMKLDNNHandlerT
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
const auto ksize = framework::vectorize(filter_data_dims); const auto ksize = framework::vectorize(filter_data_dims);
const bool is_test = ctx.Attr<bool>("is_test");
auto strides_temp = ctx.Attr<std::vector<int>>("strides"); auto strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp)); std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
...@@ -214,9 +214,8 @@ class ConvMKLDNNHandlerT ...@@ -214,9 +214,8 @@ class ConvMKLDNNHandlerT
const auto dst_md = platform::MKLDNNMemDesc( const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference const auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
float sum_scale = 1.0f; float sum_scale = 1.0f;
std::vector<float> output_shift_scale; std::vector<float> output_shift_scale;
if (platform::is_int8<T>()) if (platform::is_int8<T>())
...@@ -261,7 +260,8 @@ class ConvMKLDNNHandlerT ...@@ -261,7 +260,8 @@ class ConvMKLDNNHandlerT
mkldnn::convolution_backward_weights>( mkldnn::convolution_backward_weights>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in->dims()), platform::CreateKey(dev_ctx, framework::vectorize(in->dims()),
unique_name)) { unique_name)),
is_test_(false) {
if (!this->isBwdCached()) { if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in->layout(), framework::DataLayout::kMKLDNN, in->layout(), framework::DataLayout::kMKLDNN,
...@@ -291,7 +291,7 @@ class ConvMKLDNNHandlerT ...@@ -291,7 +291,7 @@ class ConvMKLDNNHandlerT
"Wrong format set for output_grad tensor")); "Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false, is_test_, false,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"is_test attribute should be set to False in training phase.")); "is_test attribute should be set to False in training phase."));
...@@ -557,13 +557,14 @@ class ConvMKLDNNHandlerT ...@@ -557,13 +557,14 @@ class ConvMKLDNNHandlerT
framework::vectorize(in_mem->dims()), framework::vectorize(in_mem->dims()),
platform::MKLDNNGetDataType<T>(), in_mem->format()); platform::MKLDNNGetDataType<T>(), in_mem->format());
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem); user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem,
is_test_);
} else { } else {
const std::string target_key_suffix{key_mem_target}; const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix); const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data)); user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) { if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem); this->AcquireReorder(user_mem_p, target_mem_p);
} }
return target_mem_p; return target_mem_p;
} }
...@@ -571,12 +572,11 @@ class ConvMKLDNNHandlerT ...@@ -571,12 +572,11 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int groups, const bool is_conv3d, const framework::Tensor* filter, const int groups, const bool is_conv3d,
const bool is_test, const std::vector<float>& scale_data = {1.0f}, const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
int mask = 0) {
// This is workaround to make execution faster, delete // This is workaround to make execution faster, delete
// if statement after including md inside Tensor // if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) { if (is_test_ && weights_mem_p) {
return weights_mem_p; return weights_mem_p;
} else { } else {
const K* filter_data = filter->data<K>(); const K* filter_data = filter->data<K>();
...@@ -589,16 +589,16 @@ class ConvMKLDNNHandlerT ...@@ -589,16 +589,16 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->weights_desc(), user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {}, platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test_,
scale_data, mask); {}, scale_data, mask);
} }
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test, const framework::Tensor* bias,
const std::vector<float>& scale_data = {1.0f}, int mask = 0) { const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) { if (is_test_ && bias_mem_p) {
return bias_mem_p; return bias_mem_p;
} else { } else {
const K* bias_data = bias->data<K>(); const K* bias_data = bias->data<K>();
...@@ -608,7 +608,7 @@ class ConvMKLDNNHandlerT ...@@ -608,7 +608,7 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), user_bias_md, this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test, {}, platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test_, {},
scale_data, mask); scale_data, mask);
} }
} }
...@@ -641,7 +641,7 @@ class ConvMKLDNNHandlerT ...@@ -641,7 +641,7 @@ class ConvMKLDNNHandlerT
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) { platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param); auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->template AcquireDstMemory<T_out>(output); dst_memory_p = this->template AcquireDstMemory<T_out>(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst"); this->AcquireReorder(residual_memory_p, dst_memory_p);
} else { } else {
// Changing ShareDataWith to TensorCopy results in performance drop // Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures // on ResNet architectures
...@@ -651,6 +651,9 @@ class ConvMKLDNNHandlerT ...@@ -651,6 +651,9 @@ class ConvMKLDNNHandlerT
} }
return dst_memory_p; return dst_memory_p;
} }
private:
const bool is_test_;
}; };
} // anonymous namespace } // anonymous namespace
...@@ -695,7 +698,6 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -695,7 +698,6 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test");
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U; const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
...@@ -712,7 +714,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -712,7 +714,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_conv3d, is_test); filter, ctx.Attr<int>("groups"), is_conv3d);
std::shared_ptr<dnnl::memory> dst_memory_p; std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
...@@ -731,7 +733,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -731,7 +733,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory_p}}; {MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) { if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
} }
...@@ -783,11 +785,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -783,11 +785,10 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.Attr<std::vector<float>>("Scale_weights"); ctx.Attr<std::vector<float>>("Scale_weights");
const bool is_multi_channel = scale_weights_data.size() > 1; const bool is_multi_channel = scale_weights_data.size() > 1;
const int& groups = ctx.Attr<int>("groups"); const int& groups = ctx.Attr<int>("groups");
const bool& is_test = ctx.Attr<bool>("is_test");
int mask_reorder = int mask_reorder =
is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0; is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, groups, false, is_test, scale_weights_data, mask_reorder); filter, groups, false, scale_weights_data, mask_reorder);
std::shared_ptr<dnnl::memory> dst_memory_p; std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
...@@ -822,7 +823,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -822,7 +823,7 @@ class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
handler.get_int8_bias_scales(ctx); handler.get_int8_bias_scales(ctx);
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder( auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
bias, is_test, scale_bias_data, mask_reorder); bias, scale_bias_data, mask_reorder);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
} }
......
...@@ -51,10 +51,10 @@ class ConvTransposeMKLDNNHandlerT ...@@ -51,10 +51,10 @@ class ConvTransposeMKLDNNHandlerT
: platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward>( : platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward>(
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) { unique_name)),
is_test_(ctx.Attr<bool>("is_test")) {
if (!this->isCached()) { if (!this->isCached()) {
const bool is_test = ctx.Attr<bool>("is_test"); PADDLE_ENFORCE_EQ(is_test_, true,
PADDLE_ENFORCE_EQ(is_test, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"ConvTransposeMKLDNN works only for inference. " "ConvTransposeMKLDNN works only for inference. "
"The attribute \'is_test\' value should be set to " "The attribute \'is_test\' value should be set to "
...@@ -169,7 +169,7 @@ class ConvTransposeMKLDNNHandlerT ...@@ -169,7 +169,7 @@ class ConvTransposeMKLDNNHandlerT
const mkldnn::primitive_attr conv_trans_attr = const mkldnn::primitive_attr conv_trans_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta); CreatePostOps(fuse_activation, fuse_alpha, fuse_beta);
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference auto fwd_prop_kind = is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
std::vector<int64_t> bias_tz = framework::vectorize(bias->dims()); std::vector<int64_t> bias_tz = framework::vectorize(bias->dims());
...@@ -231,18 +231,18 @@ class ConvTransposeMKLDNNHandlerT ...@@ -231,18 +231,18 @@ class ConvTransposeMKLDNNHandlerT
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix); const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(platform::to_void_cast<T>(input_data)); user_src_mem_p->set_data_handle(platform::to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) { if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p"); this->AcquireReorder(user_src_mem_p, target_src_mem_p);
} }
return target_src_mem_p; return target_src_mem_p;
} }
} }
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int& groups, const bool& is_test) { const framework::Tensor* filter, const int& groups) {
// This is workaround to make execution faster, delete // This is workaround to make execution faster, delete
// if statement after including md inside Tensor // if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) { if (is_test_ && weights_mem_p) {
return weights_mem_p; return weights_mem_p;
} else { } else {
const K* filter_data = filter->data<K>(); const K* filter_data = filter->data<K>();
...@@ -277,15 +277,15 @@ class ConvTransposeMKLDNNHandlerT ...@@ -277,15 +277,15 @@ class ConvTransposeMKLDNNHandlerT
return this->template AcquireMemoryWithReorder<K>( return this->template AcquireMemoryWithReorder<K>(
user_src_md, this->fwd_pd_->weights_desc(), user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test_,
iohw2oihw_reorder); iohw2oihw_reorder);
} }
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool& is_test) { const framework::Tensor* bias) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) { if (is_test_ && bias_mem_p) {
return bias_mem_p; return bias_mem_p;
} else { } else {
const K* bias_data = bias->data<K>(); const K* bias_data = bias->data<K>();
...@@ -294,9 +294,12 @@ class ConvTransposeMKLDNNHandlerT ...@@ -294,9 +294,12 @@ class ConvTransposeMKLDNNHandlerT
MKLDNNMemoryFormat::x); MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), user_bias_md, this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test); platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test_);
} }
} }
private:
const bool is_test_;
}; };
template <typename T, typename K> template <typename T, typename K>
...@@ -325,8 +328,6 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -325,8 +328,6 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test");
const auto* input = ctx.Input<Tensor>("Input"); const auto* input = ctx.Input<Tensor>("Input");
const auto* filter = ctx.Input<Tensor>("Filter"); const auto* filter = ctx.Input<Tensor>("Filter");
const auto* bias = const auto* bias =
...@@ -340,7 +341,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -340,7 +341,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
output, unique_name); output, unique_name);
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input); auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder( auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_test); filter, ctx.Attr<int>("groups"));
std::shared_ptr<dnnl::memory> dst_memory_p = std::shared_ptr<dnnl::memory> dst_memory_p =
handler.template AcquireDstMemory<T_out>(output); handler.template AcquireDstMemory<T_out>(output);
...@@ -352,7 +353,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> { ...@@ -352,7 +353,7 @@ class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
{MKLDNN_ARG_DST, *dst_memory_p}}; {MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) { if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p}); args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
} }
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
......
...@@ -64,21 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -64,21 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input"); bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16"); bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key = // TODO(jczaja): Refactor with Acquire API
platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
std::shared_ptr<mkldnn::memory> src_memory; std::shared_ptr<mkldnn::memory> src_memory;
std::shared_ptr<mkldnn::memory> dst_memory; std::shared_ptr<mkldnn::memory> dst_memory;
std::shared_ptr<reorder> reorder_p; std::shared_ptr<reorder> reorder_p;
reorder_p = std::static_pointer_cast<reorder>(dev_ctx.GetBlob(key_prim));
if (reorder_p == nullptr) {
std::string out_layout = ctx.Attr<std::string>("output_format"); std::string out_layout = ctx.Attr<std::string>("output_format");
MKLDNNMemoryFormat out_format = MKLDNNMemoryFormat out_format =
platform::data_format_to_memory_format(out_layout); platform::data_format_to_memory_format(out_layout);
...@@ -97,8 +87,8 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -97,8 +87,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32, auto src_md = platform::MKLDNNMemDesc({src_tz}, memory::data_type::f32,
input->format()); input->format());
src_memory = std::make_shared<mkldnn::memory>( src_memory = std::make_shared<mkldnn::memory>(src_md, engine,
src_md, engine, to_void_cast<T>(input_data)); to_void_cast<T>(input_data));
std::shared_ptr<mkldnn::memory::desc> dst_md; std::shared_ptr<mkldnn::memory::desc> dst_md;
if (bfloat16) { if (bfloat16) {
...@@ -108,38 +98,13 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -108,38 +98,13 @@ class QuantOpKernel : public framework::OpKernel<T> {
platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine, platform::SetDstMemoryQuantized<int8_t>(ctx, output, dst_tz, engine,
dst_md, dst_memory, out_format); dst_md, dst_memory, out_format);
} else { } else {
platform::SetDstMemoryQuantized<uint8_t>( platform::SetDstMemoryQuantized<uint8_t>(ctx, output, dst_tz, engine,
ctx, output, dst_tz, engine, dst_md, dst_memory, out_format); dst_md, dst_memory, out_format);
} }
auto reorder_pd = std::shared_ptr<reorder::primitive_desc>( auto reorder_pd = std::shared_ptr<reorder::primitive_desc>(
new reorder::primitive_desc(*src_memory, *dst_memory, attri)); new reorder::primitive_desc(*src_memory, *dst_memory, attri));
reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd)); reorder_p = std::shared_ptr<reorder>(new reorder(*reorder_pd));
dev_ctx.SetBlob(key_prim, reorder_p);
dev_ctx.SetBlob(key_src_mem, src_memory);
dev_ctx.SetBlob(key_dst_mem, dst_memory);
} else {
src_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_src_mem));
src_memory->set_data_handle(to_void_cast<T>(input_data));
dst_memory = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(key_dst_mem));
auto place = ctx.GetPlace();
if (bfloat16) {
dst_memory->set_data_handle(
output->mutable_data<paddle::platform::bfloat16>(place));
} else if (with_shift || !is_negative_input) {
uint8_t* output_data = output->mutable_data<uint8_t>(ctx.GetPlace());
if (with_shift) std::memset(output_data, scale_shift, output->numel());
dst_memory->set_data_handle(output_data);
} else {
dst_memory->set_data_handle(
output->mutable_data<int8_t>(ctx.GetPlace()));
}
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
{ {
platform::RecordEvent record_reorder("int_reorder", platform::RecordEvent record_reorder("int_reorder",
......
...@@ -11,6 +11,12 @@ See the License for the specific language governing permissions and ...@@ -11,6 +11,12 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include <set> #include <set>
#include <utility>
#ifdef _WIN32
#include <intrin.h>
#else
#include <x86intrin.h>
#endif
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
#include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h" #include "paddle/fluid/memory/allocation/cuda_device_context_allocator.h"
...@@ -666,7 +672,7 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { ...@@ -666,7 +672,7 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
// of this executor // of this executor
for (auto& s : *p_exec_items_) { for (auto& s : *p_exec_items_) {
for (auto& v : (*s.second)[ptr]) { for (auto& v : (*s.second)[ptr]) {
(v.first)->erase(v.second); (v.first)->second.erase(v.second);
} }
s.second->erase(ptr); s.second->erase(ptr);
} }
...@@ -677,11 +683,26 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) { ...@@ -677,11 +683,26 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
} }
} }
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(void) const { std::string MKLDNNDeviceContext::PickLeastUsedShape(
p_exec_items_->erase(p_exec_items_->begin()); BlobPtr_t<ShapeBlob> sb) const {
auto ancient_one = sb->begin();
for (auto v = std::next(sb->begin()); v != sb->end(); ++v) {
if (v->second->first < ancient_one->second->first) {
ancient_one = v;
}
}
VLOG(2) << "num_shapes: " << sb->size()
<< ", remove all blobs of shape: " << ancient_one->first;
return ancient_one->first;
}
void MKLDNNDeviceContext::RemoveShapeEntriesWithExecutor(
std::string shape_to_be_removed) const {
p_exec_items_->erase(shape_to_be_removed);
} }
void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob, void MKLDNNDeviceContext::LinkEntryWithExecutor(
BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pblob,
KeyBlob::iterator it) const { KeyBlob::iterator it) const {
// Take current input shape from TLS // Take current input shape from TLS
// Take current executor addess from TLS // Take current executor addess from TLS
...@@ -719,7 +740,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -719,7 +740,7 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
BlobPtr_t<void> data) const { BlobPtr_t<void> data) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr; BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr; BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pBlob = nullptr;
int sid = tls().get_cur_mkldnn_session_id(); int sid = tls().get_cur_mkldnn_session_id();
...@@ -748,22 +769,24 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name, ...@@ -748,22 +769,24 @@ void MKLDNNDeviceContext::SetBlob(const std::string& name,
sBlob->size() && sBlob->size() &&
(sBlob->size() >= (sBlob->size() >=
static_cast<size_t>(tls().cur_input_shape_cache_capacity))) { static_cast<size_t>(tls().cur_input_shape_cache_capacity))) {
VLOG(2) << "sid=" << sid auto shape_to_be_erased = PickLeastUsedShape(sBlob);
<< ", remove all blobs of shape: " << sBlob->begin()->first; sBlob->erase(shape_to_be_erased);
sBlob->erase(sBlob->begin()->first); RemoveShapeEntriesWithExecutor(shape_to_be_erased);
RemoveShapeEntriesWithExecutor();
} }
pBlob = std::make_shared<KeyBlob>(); pBlob = std::make_shared<std::pair<unsigned long long, KeyBlob>>();
pBlob->first = __rdtsc();
(*sBlob)[tls().cur_input_shape_str] = pBlob; (*sBlob)[tls().cur_input_shape_str] = pBlob;
} else { } else {
pBlob = key_it->second; pBlob = key_it->second;
// Update time stamp
pBlob->first = __rdtsc();
} }
// Find Blob via name // Find Blob via name
auto blob_it = pBlob->find(name); auto blob_it = pBlob->second.find(name);
if (blob_it == pBlob->end()) { if (blob_it == pBlob->second.end()) {
auto el = auto el = pBlob->second.insert(
pBlob->insert(std::make_pair(name, data)); // (*pBlob)[name] = data; std::make_pair(name, data)); // (*pBlob)[name] = data;
// Register new element in per executor map // Register new element in per executor map
// to have easily erased when executor terminated // to have easily erased when executor terminated
LinkEntryWithExecutor(pBlob, el.first); LinkEntryWithExecutor(pBlob, el.first);
...@@ -779,7 +802,7 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const { ...@@ -779,7 +802,7 @@ unsigned int MKLDNNDeviceContext::GetCachedObjectsNumber(void) const {
unsigned int num_entries = 0; unsigned int num_entries = 0;
for (auto const& l3 : *p_blobmap_) { for (auto const& l3 : *p_blobmap_) {
for (auto const& l2 : *(l3.second)) { for (auto const& l2 : *(l3.second)) {
num_entries += (l2.second)->size(); num_entries += (l2.second->second).size();
} }
} }
return num_entries; return num_entries;
...@@ -789,7 +812,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( ...@@ -789,7 +812,7 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
const std::string& name) const { const std::string& name) const {
BlobMap* pMap = p_blobmap_.get(); BlobMap* pMap = p_blobmap_.get();
BlobPtr_t<ShapeBlob> sBlob = nullptr; BlobPtr_t<ShapeBlob> sBlob = nullptr;
BlobPtr_t<KeyBlob> pBlob = nullptr; BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pBlob = nullptr;
int sid = tls().get_cur_mkldnn_session_id(); int sid = tls().get_cur_mkldnn_session_id();
...@@ -813,12 +836,14 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob( ...@@ -813,12 +836,14 @@ MKLDNNDeviceContext::BlobPtr_t<void> MKLDNNDeviceContext::GetBlob(
pBlob = sBlob_it->second; pBlob = sBlob_it->second;
// Find Blob via name // Find Blob via name
auto key_it = pBlob->find(name); auto key_it = pBlob->second.find(name);
if (key_it == pBlob->end()) { if (key_it == pBlob->second.end()) {
VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n"; VLOG(2) << "GetBlob sid=" << sid << ", miss blob=" << name << "\n";
return nullptr; return nullptr;
} }
// Update timestamp
sBlob_it->second->first = __rdtsc(); // TODO(windows)
VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n"; VLOG(2) << "GetBlob sid=" << sid << ", get blob=" << name << "\n";
// lock will be automatically released when out of scope // lock will be automatically released when out of scope
......
...@@ -757,18 +757,20 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -757,18 +757,20 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// Following three maps are used to cache MKLDNN primitives. // Following three maps are used to cache MKLDNN primitives.
// There relations are: // There relations are:
// - BlobMap = Map<cur_thread_id, ShapeBlob> // - BlobMap = Map<cur_thread_id, ShapeBlob>
// - ShapeBlob = Map<cur_input_shape_str, KeyBlob> // - ShapeBlob = Map<cur_input_shape_str,<unsigned long long, KeyBlob>>
// - KeyBlob = Map<blob_name, blob> // - KeyBlob = Map<blob_name, blob>
using KeyBlob = umap_key_string_t<void>; using KeyBlob = umap_key_string_t<void>;
using ShapeBlob = umap_key_string_t<KeyBlob>; using ShapeBlob = umap_key_string_t<std::pair<unsigned long long, KeyBlob>>;
using BlobMap = umap_value_smart_t<int, ShapeBlob>; using BlobMap = umap_value_smart_t<int, ShapeBlob>;
// Auxillary two-level structure (shape, executor) to easier control // Auxillary two-level structure (shape, executor) to easier control
// clearing cache objects related to specific executor // clearing cache objects related to specific executor
using ExecKey = void*; using ExecKey = void*;
using ExecMapCacheIterPair = std::pair<BlobPtr_t<KeyBlob>, KeyBlob::iterator>; using ExecMapCacheIterPair =
std::pair<BlobPtr_t<std::pair<unsigned long long, KeyBlob>>,
KeyBlob::iterator>;
using ExecMap = using ExecMap =
std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>; std::unordered_map<ExecKey, std::vector<ExecMapCacheIterPair>>;
using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>; using ExecShape = std::unordered_map<std::string, std::shared_ptr<ExecMap>>;
...@@ -779,8 +781,11 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -779,8 +781,11 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
const mkldnn::engine& GetEngine() const { return tls().get_engine(); } const mkldnn::engine& GetEngine() const { return tls().get_engine(); }
// Register object to currently used executor's map // Register object to currently used executor's map
void LinkEntryWithExecutor(BlobPtr_t<KeyBlob>, KeyBlob::iterator) const; void LinkEntryWithExecutor(
void RemoveShapeEntriesWithExecutor(void) const; BlobPtr_t<std::pair<unsigned long long, KeyBlob>> pblob,
KeyBlob::iterator it) const;
void RemoveShapeEntriesWithExecutor(std::string) const;
std::string PickLeastUsedShape(BlobPtr_t<ShapeBlob> sb) const;
// Remove all entries from the blob map // Remove all entries from the blob map
void ResetBlobMap(void* ptr); void ResetBlobMap(void* ptr);
......
...@@ -500,18 +500,9 @@ class MKLDNNHandlerT { ...@@ -500,18 +500,9 @@ class MKLDNNHandlerT {
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p, void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p, const std::shared_ptr<mkldnn::memory>& target_memory_p) {
const std::string& suffix) { auto reorder_p =
const auto key_reorder_p = key_ + suffix + "reorder_p";
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p); std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -578,6 +569,8 @@ class MKLDNNHandlerT { ...@@ -578,6 +569,8 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr); user_memory_p->set_data_handle(ptr);
// TODO(jczaja): Here we detect if reorder is cached it means it is needed
// need to change this to get rid of keys
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p)); dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) { if (reorder_p != nullptr) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册