未验证 提交 f6cca625 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Making ThreadID info in caching key optional (#29272)

上级 08f24a31
...@@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
if (in_format != out_format) { if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type); void* in_data = GetDataFromTensor(in, in_type);
const std::string key = std::string key =
platform::CreateKey(in_tz, in_format, out_format, in_type); platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx, platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key); cpu_engine, key);
......
...@@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> { ...@@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
const std::string& unique_name) const std::string& unique_name)
: platform::MKLDNNHandlerT<T, dnnl::gru_forward>( : platform::MKLDNNHandlerT<T, dnnl::gru_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
CreateKey(unique_name, MKLDNNGetDataType<T>(), Ti)), CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
N(N), N(N),
Ti(Ti), Ti(Ti),
IC(IC), IC(IC),
OC(OC) { OC(OC) {
// Create memory key without Ti because weights, bias and h0 memories // Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do // do not depend on Ti size but primitive and input/output memory do
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>());
} else {
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>(), "-t:",
platform::ThreadIDasStr());
}
// Is it int8 kernel // Is it int8 kernel
const bool is_INT8 = std::is_same<T, uint8_t>::value; const bool is_INT8 = std::is_same<T, uint8_t>::value;
......
...@@ -109,13 +109,8 @@ class MultiGRUHandler { ...@@ -109,13 +109,8 @@ class MultiGRUHandler {
const std::string unique_name = ctx.OutputName("Hidden"); const std::string unique_name = ctx.OutputName("Hidden");
// Create memory key without Ti because weights, bias and h0 memories // Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do // do not depend on Ti size but primitive and input/output memory do
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>());
} else {
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>(), "-t:",
platform::ThreadIDasStr());
}
key_ = memory_key_; key_ = memory_key_;
key_.append("T").append(std::to_string(Ti_)); key_.append("T").append(std::to_string(Ti_));
......
...@@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler ...@@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, : platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(x->dims()), unique_name)) { platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
unique_name)) {
if (!this->isCached()) { if (!this->isCached()) {
const float epsilon = ctx.Attr<float>("epsilon"); const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu"); const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
...@@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler ...@@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward, : platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>( mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) { platform::CreateKey(dev_ctx, dims, uniq_name)) {
auto diff_dst_md = auto diff_dst_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = auto src_md =
......
...@@ -158,9 +158,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -158,9 +158,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// If one of the multiple inputs of concat has an input size of 0, the // If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change // actual size of the multi_input will change
std::string key = platform::CreateKey( std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()), dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt, multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); platform::ThreadIDasStr());
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@concat_p"; const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd"; const std::string key_concat_pd = key + "@concat_pd";
......
...@@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT ...@@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT
const std::string& unique_name) const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::convolution_forward>( : platform::MKLDNNHandlerT<T, mkldnn::convolution_forward>(
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) { unique_name)) {
if (!this->isCached()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
std::string key = platform::CreateKey( std::string key =
src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter")); platform::CreateKey(dev_ctx, src_tz, src_dt,
ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
...@@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// This is workaround for hacky implementation // This is workaround for hacky implementation
// of conv int8 mkl-dnn. Once conv fp32 and conv int8 // of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear // are merged/unified, this will disappear
std::string key_tid = ""; auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { auto prim_key = key_tid + "@conv_p";
key_tid = "-t:" + platform::ThreadIDasStr(); auto dst_key = key_tid + "@dst_mem_p";
} auto src_key = key_tid + "@src_mem_p";
auto weights_key = key_tid + "@weights_mem_p";
auto prim_key = key + key_tid + "@conv_p"; auto bias_key = key_tid + "@bias_mem_p";
auto dst_key = key + key_tid + "@dst_mem_p"; auto user_src_key = key_tid + "@user_src_mem_p";
auto src_key = key + key_tid + "@src_mem_p"; auto user_residual_key = key_tid + "@user_residual_data_mem_p";
auto weights_key = key + key_tid + "@weights_mem_p"; auto src_reorder_key = key_tid + "@src_mem_preorder_p";
auto bias_key = key + key_tid + "@bias_mem_p"; auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p";
auto user_src_key = key + key_tid + "@user_src_mem_p";
auto user_residual_key = key + key_tid + "@user_residual_data_mem_p";
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>( conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key)); dev_ctx.GetBlob(prim_key));
...@@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "input" and "Filter" variable // Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created // as well as attributes of primitive to be created
// This name will be used as key when saving info into device context // This name will be used as key when saving info into device context
const std::string key = platform::CreateKey( std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@fwd_pd"; const std::string key_conv_pd = key + "@fwd_pd";
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
// Create user memory descriptors // Create user memory descriptors
...@@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format_tag out_format = mkldnn::memory::format_tag out_format =
weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::goihw; : mkldnn::memory::format_tag::goihw;
const std::string key = std::string key = platform::CreateKey(dev_ctx, weights_tz, filter_fmt,
platform::CreateKey(weights_tz, filter_fmt, out_format, in_type); out_format, in_type);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(), platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(),
in_type, dev_ctx, mkldnn_engine, in_type, dev_ctx, mkldnn_engine,
......
...@@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims()); auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = const std::string key =
platform::CreateKey(src_tz, ctx.OutputName("Output")); platform::CreateKey(dev_ctx, src_tz, ctx.OutputName("Output"));
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
......
...@@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt,
src_tz, ctx.OutputName("Output")); std::string key =
platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r"; const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d"; const std::string key_dst_mem = key + "@d";
......
...@@ -370,8 +370,9 @@ class FCPrimitiveFactory { ...@@ -370,8 +370,9 @@ class FCPrimitiveFactory {
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx, void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) { const ExecutionContext& ctx) {
const std::string key = std::string key = platform::CreateKey(dev_ctx);
platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix()); key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string weights_key = key + ctx.InputName("W"); const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias"); const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_); dev_ctx.SetBlob(weights_key, weights_);
...@@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input, ...@@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, const Tensor* bias, LoDTensor* output, const Tensor* w, const Tensor* bias, LoDTensor* output,
bool fuse_relu, bool force_fp32_output) { bool fuse_relu, bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey( std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(), dev_ctx, input->format(), input->dims()[0],
input->dims()[0], framework::vectorize<int>(w->dims()), framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
ctx.OutputName("Out")); prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key);
constexpr bool is_int8 = constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value; std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value; bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
......
...@@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler ...@@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler
const std::string& uniq_name) const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>( : platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) { platform::CreateKey(dev_ctx, dims, uniq_name)) {
if (!this->isCached()) { if (!this->isCached()) {
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
if (!is_test) { if (!is_test) {
......
...@@ -336,9 +336,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -336,9 +336,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const auto& out_name = ctx.OutputName("Out"); const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0]; const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
std::string key = platform::CreateKey(dev_ctx, batch_size, out_name);
const std::string key = platform::CreateKey( key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name);
auto factory = auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key)); std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......
...@@ -305,9 +305,11 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory( ...@@ -305,9 +305,11 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx, const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx,
const Tensor *input_x, const Tensor *input_y, const Tensor *input_x, const Tensor *input_y,
const mkldnn::engine &mkldnn_engine) { const mkldnn::engine &mkldnn_engine) {
const std::string key = platform::CreateKey( std::string key = platform::CreateKey(
input_x->type(), framework::vectorize(input_x->dims()), input_y->type(), dev_ctx, input_x->type(), framework::vectorize(input_x->dims()),
framework::vectorize(input_y->dims()), ctx.OutputName("Out")); input_y->type(), framework::vectorize(input_y->dims()),
ctx.OutputName("Out"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>( auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>(
dev_ctx.GetBlob(key)); dev_ctx.GetBlob(key));
......
...@@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable // Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context // This name will be used as key when referring info from device context
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
diff_src_tz, pooling_type, ksize, strides, paddings, dev_ctx, diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, in_x->format(), ctx.InputName("Out")); memory::data_type::f32, in_x->format(), ctx.InputName("Out"));
platform::PoolingMKLDNNHandler<T> handler( platform::PoolingMKLDNNHandler<T> handler(
......
...@@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input"); bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16"); bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key = platform::CreateKey( std::string key =
platform::ThreadIDasStr(), src_tz, scale_data, scale_shift, platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output")); is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r"; const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d"; const std::string key_dst_mem = key + "@d";
......
...@@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
float reorder_scale = scale_out / scale_in; float reorder_scale = scale_out / scale_in;
std::string key = std::string key = platform::CreateKey(dev_ctx, src_tz, scale_in, scale_out,
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in, ctx.OutputName("Output"));
scale_out, ctx.OutputName("Output")); key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r"; const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s"; const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d"; const std::string key_dst_mem = key + "@d";
......
...@@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler ...@@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique // Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(framework::vectorize(input->dims()), axis, platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
uniq_name)) { axis, uniq_name)) {
if (!this->isCached()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims(), output->dims(), input->dims(), output->dims(),
...@@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler ...@@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward, : platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>( mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, axis, uniq_name)) { platform::CreateKey(dev_ctx, dims, axis, uniq_name)) {
auto data_softmax_md = auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md = auto diff_softmax_md =
......
...@@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> { ...@@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
: platform::MKLDNNHandlerT<T, dnnl::sum>( : platform::MKLDNNHandlerT<T, dnnl::sum>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(z->dims()), uniq_name)), platform::CreateKey(dev_ctx, framework::vectorize(z->dims()),
uniq_name)),
num_inputs_(0) { num_inputs_(0) {
for (size_t i = 0; i < in_vars.size(); i++) { for (size_t i = 0; i < in_vars.size(); i++) {
srcs_suffix_.push_back(std::string("-") + std::to_string(i)); srcs_suffix_.push_back(std::string("-") + std::to_string(i));
...@@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// For in-place execution which sum does not have we need to fake it // For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input // so from oneDNN dst memory we reorder data into input
if (in_place) { if (in_place) {
const std::string reorder_key = platform::CreateKey( const std::string reorder_key =
framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I"); platform::CreateKey(dev_ctx, framework::vectorize(output->dims()),
ctx.OutputName("Out") + "-I");
auto& in_out = in_vars[0]->Get<framework::LoDTensor>(); auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
auto output_tz = framework::vectorize<int64_t>(output->dims()); auto output_tz = framework::vectorize<int64_t>(output->dims());
......
...@@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims()); auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims());
const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out")); const std::string key =
platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out"));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, axis, dev_ctx, platform::TransposeMKLDNNHandler<T> handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key); mkldnn_engine, key);
...@@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims()); auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
nchw_tz, ctx.OutputName(framework::GradVarName("X"))); dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, dev_ctx, platform::TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, dev_ctx,
mkldnn_engine, key); mkldnn_engine, key);
......
...@@ -532,6 +532,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -532,6 +532,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; } void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; } const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Disable adding thread ID to the key
void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; };
bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; };
// Prevent next ResetBlobMap() // Prevent next ResetBlobMap()
void BlockNextCacheClearing(); void BlockNextCacheClearing();
...@@ -554,6 +558,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext { ...@@ -554,6 +558,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<std::mutex> p_mutex_; std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false; bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor std::string key_suffix_; // Key identifying current Executor
bool key_attach_thread_id_ = true;
}; };
#endif #endif
......
...@@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) { ...@@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
} }
} }
inline unsigned int HashPointer(uintptr_t ptr) {
// Get four less meaningful digits in decimal numerals
return ptr % 1000;
}
// If MKLDNN build and CPU place then register suffix in DeviceContext // If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr, inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) { const platform::Place& place) {
...@@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr, ...@@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance(); platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx = platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place); (platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer( dev_ctx->SetKeySuffix("E" +
reinterpret_cast<uintptr_t>(ptr)))); std::to_string(reinterpret_cast<uintptr_t>(ptr)));
// When NaiveExecutor/Executor is used no info on thread id is needed in a
// key
dev_ctx->DisableThreadInfoInKey();
} }
} }
template <typename... ArgTypes> template <typename... ArgTypes>
inline std::string CreateKey(ArgTypes&&... args) { inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
ArgTypes&&... args) {
std::string key; std::string key;
key.reserve(64); key.reserve(64);
using expand_type = int[]; using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...}; expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += dev_ctx.GetKeySuffix();
return key; return key;
} }
inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return ((dev_ctx.IsThreadIdUsedInKey() == true) &&
(platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
? key + "-t:" + ThreadIDasStr()
: key;
}
inline std::vector<std::vector<int64_t>> ToMkldnnPadding( inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
const std::vector<int64_t>& paddings) { const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) { if (paddings.size() == 6) {
......
...@@ -43,16 +43,9 @@ class MKLDNNHandlerT { ...@@ -43,16 +43,9 @@ class MKLDNNHandlerT {
engine_(engine), engine_(engine),
place_(cpu_place), place_(cpu_place),
key_common_(base_key), key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)),
fwd_pd_(nullptr), fwd_pd_(nullptr),
bwd_pd_(nullptr) { bwd_pd_(nullptr) {}
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
key_ += dev_ctx.GetKeySuffix();
}
std::shared_ptr<TForward> AcquireForwardPrimitive() { std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@fwd_p"; const std::string key_p = key_ + "@fwd_p";
...@@ -306,8 +299,8 @@ class MKLDNNHandlerT { ...@@ -306,8 +299,8 @@ class MKLDNNHandlerT {
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_; mkldnn::engine engine_;
platform::Place place_; platform::Place place_;
std::string key_;
std::string key_common_; std::string key_common_;
std::string key_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_; std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_; std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
}; };
...@@ -317,15 +310,10 @@ class MKLDNNHandler { ...@@ -317,15 +310,10 @@ class MKLDNNHandler {
public: public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key) const std::string& base_key)
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) { : dev_ctx_(dev_ctx),
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() != engine_(engine),
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) { key_common_(base_key),
key_ = key_common_; key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {}
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
key_ += dev_ctx.GetKeySuffix();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
...@@ -508,8 +496,8 @@ class MKLDNNHandler { ...@@ -508,8 +496,8 @@ class MKLDNNHandler {
protected: protected:
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_; mkldnn::engine engine_;
std::string key_;
std::string key_common_; std::string key_common_;
std::string key_;
}; };
template <typename T> template <typename T>
...@@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> { ...@@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
: platform::MKLDNNHandlerT<T, dnnl::binary>( : platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place, dev_ctx, engine, cpu_place,
platform::CreateKey( platform::CreateKey(
framework::vectorize(x->dims()), dev_ctx, framework::vectorize(x->dims()),
uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) { uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) {
// bradcasting combined with in-place may require // bradcasting combined with in-place may require
auto rankdiff = x->dims().size() - y->dims().size(); auto rankdiff = x->dims().size() - y->dims().size();
...@@ -627,7 +615,7 @@ class ActivationMKLDNNHandler ...@@ -627,7 +615,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, "a", algorithm, unique_name)) { platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training, this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
...@@ -645,7 +633,7 @@ class ActivationMKLDNNHandler ...@@ -645,7 +633,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward, : platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>( mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, "a", algorithm, unique_name)) { platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt); dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md = auto src_md =
...@@ -676,7 +664,7 @@ class LRNMKLDNNHandler ...@@ -676,7 +664,7 @@ class LRNMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, mkldnn_engine, cpu_place, dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) { unique_name)) {
if (!this->isCached()) { if (!this->isCached()) {
const int n = ctx.Attr<int>("n"); const int n = ctx.Attr<int>("n");
...@@ -712,7 +700,7 @@ class LRNMKLDNNHandler ...@@ -712,7 +700,7 @@ class LRNMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>( : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) { platform::CreateKey(dev_ctx, dims, unique_name)) {
auto src_md = auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_md = auto diff_md =
...@@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward>( mkldnn::pooling_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(input->dims()), platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
framework::ToMKLDNNDataType(input->type()), framework::ToMKLDNNDataType(input->type()),
unique_name)) { unique_name)) {
if (!this->isCached()) { if (!this->isCached()) {
...@@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward, ...@@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward, : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward>( mkldnn::pooling_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place, dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(diff_src_dims, dt, unique_name)) { platform::CreateKey(dev_ctx, diff_src_dims, dt, unique_name)) {
auto diff_dst_md = mkldnn::memory::desc( auto diff_dst_md = mkldnn::memory::desc(
diff_dst_dims, platform::MKLDNNGetDataType<T>(), diff_dst_fmt); diff_dst_dims, platform::MKLDNNGetDataType<T>(), diff_dst_fmt);
auto diff_src_md = auto diff_src_md =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册