未验证 提交 f6cca625 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Making ThreadID info in caching key optional (#29272)

上级 08f24a31
......@@ -181,8 +181,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
if (in_format != out_format) {
void* in_data = GetDataFromTensor(in, in_type);
const std::string key =
platform::CreateKey(in_tz, in_format, out_format, in_type);
std::string key =
platform::CreateKey(*dev_ctx, in_tz, in_format, out_format, in_type);
platform::ReorderMKLDNNHandler handler(in_tz, in.type(), in_type, *dev_ctx,
cpu_engine, key);
......
......@@ -39,20 +39,15 @@ class GRUMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::gru_forward> {
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, dnnl::gru_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
CreateKey(unique_name, MKLDNNGetDataType<T>(), Ti)),
CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>(), Ti)),
N(N),
Ti(Ti),
IC(IC),
OC(OC) {
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>());
} else {
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>(), "-t:",
platform::ThreadIDasStr());
}
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
// Is it int8 kernel
const bool is_INT8 = std::is_same<T, uint8_t>::value;
......
......@@ -109,13 +109,8 @@ class MultiGRUHandler {
const std::string unique_name = ctx.OutputName("Hidden");
// Create memory key without Ti because weights, bias and h0 memories
// do not depend on Ti size but primitive and input/output memory do
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>());
} else {
memory_key_ = CreateKey(unique_name, MKLDNNGetDataType<T>(), "-t:",
platform::ThreadIDasStr());
}
memory_key_ = platform::ExtendKeyWithThreadInfoIfNeeded(
dev_ctx, CreateKey(dev_ctx, unique_name, MKLDNNGetDataType<T>()));
key_ = memory_key_;
key_.append("T").append(std::to_string(Ti_));
......
......@@ -48,7 +48,8 @@ class BatchNormMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(x->dims()), unique_name)) {
platform::CreateKey(dev_ctx, framework::vectorize(x->dims()),
unique_name)) {
if (!this->isCached()) {
const float epsilon = ctx.Attr<float>("epsilon");
const bool fuse_with_relu = ctx.Attr<bool>("fuse_with_relu");
......@@ -89,7 +90,7 @@ class BatchNormMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::batch_normalization_forward,
mkldnn::batch_normalization_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
platform::CreateKey(dev_ctx, dims, uniq_name)) {
auto diff_dst_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
......
......@@ -158,9 +158,10 @@ class ConcatMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// If one of the multiple inputs of concat has an input size of 0, the
// actual size of the multi_input will change
std::string key = platform::CreateKey(
paddle::framework::vectorize<int>(multi_input[0]->dims()),
dev_ctx, paddle::framework::vectorize<int>(multi_input[0]->dims()),
multi_input.size(), ctx.OutputName("Out"), dt,
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
platform::ThreadIDasStr());
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@concat_p";
const std::string key_concat_pd = key + "@concat_pd";
......
......@@ -95,7 +95,7 @@ class ConvMKLDNNHandlerT
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::convolution_forward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(framework::vectorize(input->dims()),
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
......@@ -521,8 +521,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
std::string key = platform::CreateKey(
src_tz, src_dt, ctx.InputName("Input") + ctx.InputName("Filter"));
std::string key =
platform::CreateKey(dev_ctx, src_tz, src_dt,
ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
......@@ -537,21 +538,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// This is workaround for hacky implementation
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
std::string key_tid = "";
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::ThreadIDasStr();
}
auto prim_key = key + key_tid + "@conv_p";
auto dst_key = key + key_tid + "@dst_mem_p";
auto src_key = key + key_tid + "@src_mem_p";
auto weights_key = key + key_tid + "@weights_mem_p";
auto bias_key = key + key_tid + "@bias_mem_p";
auto user_src_key = key + key_tid + "@user_src_mem_p";
auto user_residual_key = key + key_tid + "@user_residual_data_mem_p";
auto src_reorder_key = key + key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key + key_tid + "@residual_data_mem_preorder_p";
auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto prim_key = key_tid + "@conv_p";
auto dst_key = key_tid + "@dst_mem_p";
auto src_key = key_tid + "@src_mem_p";
auto weights_key = key_tid + "@weights_mem_p";
auto bias_key = key_tid + "@bias_mem_p";
auto user_src_key = key_tid + "@user_src_mem_p";
auto user_residual_key = key_tid + "@user_residual_data_mem_p";
auto src_reorder_key = key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(
dev_ctx.GetBlob(prim_key));
......@@ -972,10 +969,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
std::string key = platform::CreateKey(
dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@fwd_pd";
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
std::vector<primitive> pipeline;
// Create user memory descriptors
......@@ -1090,8 +1088,9 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
mkldnn::memory::format_tag out_format =
weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::goihw;
const std::string key =
platform::CreateKey(weights_tz, filter_fmt, out_format, in_type);
std::string key = platform::CreateKey(dev_ctx, weights_tz, filter_fmt,
out_format, in_type);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(),
in_type, dev_ctx, mkldnn_engine,
......
......@@ -172,9 +172,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key =
platform::CreateKey(src_tz, ctx.OutputName("Output"));
platform::CreateKey(dev_ctx, src_tz, ctx.OutputName("Output"));
std::vector<mkldnn::primitive> pipeline;
......
......@@ -67,8 +67,11 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format();
std::string key = platform::CreateKey(platform::ThreadIDasStr(), src_dt,
src_tz, ctx.OutputName("Output"));
std::string key =
platform::CreateKey(dev_ctx, src_dt, src_tz, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
......
......@@ -370,8 +370,9 @@ class FCPrimitiveFactory {
void CacheWeightsAndBias(const MKLDNNDeviceContext& dev_ctx,
const ExecutionContext& ctx) {
const std::string key =
platform::CreateKey(platform::ThreadIDasStr(), dev_ctx.GetKeySuffix());
std::string key = platform::CreateKey(dev_ctx);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string weights_key = key + ctx.InputName("W");
const std::string bias_key = key + ctx.InputName("Bias");
dev_ctx.SetBlob(weights_key, weights_);
......@@ -541,10 +542,11 @@ static void ExecuteFc(const ExecutionContext& ctx, const LoDTensor* input,
const Tensor* w, const Tensor* bias, LoDTensor* output,
bool fuse_relu, bool force_fp32_output) {
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const std::string prim_key = platform::CreateKey(
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), input->format(),
input->dims()[0], framework::vectorize<int>(w->dims()),
ctx.OutputName("Out"));
std::string prim_key = platform::CreateKey(
dev_ctx, input->format(), input->dims()[0],
framework::vectorize<int>(w->dims()), ctx.OutputName("Out"));
prim_key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, prim_key);
constexpr bool is_int8 =
std::is_same<T_in, int8_t>::value || std::is_same<T_in, uint8_t>::value;
bool is_bfloat16 = std::is_same<T_in, paddle::platform::bfloat16>::value;
......
......@@ -30,7 +30,7 @@ class LayerNormMKLDNNHandler
const std::string& uniq_name)
: platform::MKLDNNHandlerT<T, dnnl::layer_normalization_forward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, uniq_name)) {
platform::CreateKey(dev_ctx, dims, uniq_name)) {
if (!this->isCached()) {
auto md = dnnl::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
if (!is_test) {
......
......@@ -336,9 +336,8 @@ static std::shared_ptr<MatMulFactory<XT, YT, OT>> GetPrimitiveFactory(
const auto& out_name = ctx.OutputName("Out");
const auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto batch_size = ctx.Input<Tensor>("X")->dims()[0];
const std::string key = platform::CreateKey(
platform::ThreadIDasStr(), dev_ctx.GetKeySuffix(), batch_size, out_name);
std::string key = platform::CreateKey(dev_ctx, batch_size, out_name);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto factory =
std::static_pointer_cast<MatMulFactory<XT, YT, OT>>(dev_ctx.GetBlob(key));
......
......@@ -305,9 +305,11 @@ std::shared_ptr<MulPrimitiveFactory<XT, YT, OT>> GetPrimitiveFactory(
const MKLDNNDeviceContext &dev_ctx, const ExecutionContext &ctx,
const Tensor *input_x, const Tensor *input_y,
const mkldnn::engine &mkldnn_engine) {
const std::string key = platform::CreateKey(
input_x->type(), framework::vectorize(input_x->dims()), input_y->type(),
framework::vectorize(input_y->dims()), ctx.OutputName("Out"));
std::string key = platform::CreateKey(
dev_ctx, input_x->type(), framework::vectorize(input_x->dims()),
input_y->type(), framework::vectorize(input_y->dims()),
ctx.OutputName("Out"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
auto prim_creator = std::static_pointer_cast<MulPrimitiveFactory<XT, YT, OT>>(
dev_ctx.GetBlob(key));
......
......@@ -140,7 +140,7 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context
const std::string key = platform::CreateKey(
diff_src_tz, pooling_type, ksize, strides, paddings,
dev_ctx, diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, in_x->format(), ctx.InputName("Out"));
platform::PoolingMKLDNNHandler<T> handler(
......
......@@ -64,9 +64,11 @@ class QuantOpKernel : public framework::OpKernel<T> {
bool is_negative_input = ctx.Attr<bool>("is_negative_input");
bool bfloat16 = ctx.Attr<bool>("bfloat16");
std::string key = platform::CreateKey(
platform::ThreadIDasStr(), src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
std::string key =
platform::CreateKey(dev_ctx, src_tz, scale_data, scale_shift,
is_negative_input, ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
......
......@@ -65,9 +65,9 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
float reorder_scale = scale_out / scale_in;
std::string key =
platform::CreateKey(platform::ThreadIDasStr(), src_tz, scale_in,
scale_out, ctx.OutputName("Output"));
std::string key = platform::CreateKey(dev_ctx, src_tz, scale_in, scale_out,
ctx.OutputName("Output"));
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_prim = key + "@r";
const std::string key_src_mem = key + "@s";
const std::string key_dst_mem = key + "@d";
......
......@@ -53,8 +53,8 @@ class SoftmaxMKLDNNHandler
mkldnn::softmax_backward>(
dev_ctx, mkldnn_engine, cpu_place,
// Softmax may be inplace then uniq_name is no longer unique
platform::CreateKey(framework::vectorize(input->dims()), axis,
uniq_name)) {
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
axis, uniq_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->dims(), output->dims(),
......@@ -78,7 +78,7 @@ class SoftmaxMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::softmax_forward,
mkldnn::softmax_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, axis, uniq_name)) {
platform::CreateKey(dev_ctx, dims, axis, uniq_name)) {
auto data_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_softmax_md =
......
......@@ -54,7 +54,8 @@ class SumMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::sum> {
: platform::MKLDNNHandlerT<T, dnnl::sum>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(z->dims()), uniq_name)),
platform::CreateKey(dev_ctx, framework::vectorize(z->dims()),
uniq_name)),
num_inputs_(0) {
for (size_t i = 0; i < in_vars.size(); i++) {
srcs_suffix_.push_back(std::string("-") + std::to_string(i));
......@@ -184,8 +185,9 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// For in-place execution which sum does not have we need to fake it
// so from oneDNN dst memory we reorder data into input
if (in_place) {
const std::string reorder_key = platform::CreateKey(
framework::vectorize(output->dims()), ctx.OutputName("Out") + "-I");
const std::string reorder_key =
platform::CreateKey(dev_ctx, framework::vectorize(output->dims()),
ctx.OutputName("Out") + "-I");
auto& in_out = in_vars[0]->Get<framework::LoDTensor>();
auto output_tz = framework::vectorize<int64_t>(output->dims());
......
......@@ -48,7 +48,8 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(input->dims());
const std::string key = platform::CreateKey(nchw_tz, ctx.OutputName("Out"));
const std::string key =
platform::CreateKey(dev_ctx, nchw_tz, ctx.OutputName("Out"));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, axis, dev_ctx,
mkldnn_engine, key);
......@@ -103,7 +104,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto nchw_tz = paddle::framework::vectorize<int64_t>(out_grad->dims());
const std::string key = platform::CreateKey(
nchw_tz, ctx.OutputName(framework::GradVarName("X")));
dev_ctx, nchw_tz, ctx.OutputName(framework::GradVarName("X")));
platform::TransposeMKLDNNHandler<T> handler(nchw_tz, reversed_axis, dev_ctx,
mkldnn_engine, key);
......
......@@ -532,6 +532,10 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
void SetKeySuffix(const std::string& suffix) { key_suffix_ = suffix; }
const std::string& GetKeySuffix(void) const { return key_suffix_; }
// Disable adding thread ID to the key
void DisableThreadInfoInKey(void) { key_attach_thread_id_ = false; };
bool IsThreadIdUsedInKey(void) const { return key_attach_thread_id_; };
// Prevent next ResetBlobMap()
void BlockNextCacheClearing();
......@@ -554,6 +558,7 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
std::string key_suffix_; // Key identifying current Executor
bool key_attach_thread_id_ = true;
};
#endif
......
......@@ -431,11 +431,6 @@ inline void AppendKey(std::string* key, const std::vector<T>& dims) {
}
}
inline unsigned int HashPointer(uintptr_t ptr) {
// Get four less meaningful digits in decimal numerals
return ptr % 1000;
}
// If MKLDNN build and CPU place then register suffix in DeviceContext
inline void AttachPointerHashToMKLDNNKey(void* ptr,
const platform::Place& place) {
......@@ -443,20 +438,34 @@ inline void AttachPointerHashToMKLDNNKey(void* ptr,
platform::DeviceContextPool& pool = platform::DeviceContextPool::Instance();
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->SetKeySuffix("E" + std::to_string(platform::HashPointer(
reinterpret_cast<uintptr_t>(ptr))));
dev_ctx->SetKeySuffix("E" +
std::to_string(reinterpret_cast<uintptr_t>(ptr)));
// When NaiveExecutor/Executor is used no info on thread id is needed in a
// key
dev_ctx->DisableThreadInfoInKey();
}
}
template <typename... ArgTypes>
inline std::string CreateKey(ArgTypes&&... args) {
inline std::string CreateKey(const platform::MKLDNNDeviceContext& dev_ctx,
ArgTypes&&... args) {
std::string key;
key.reserve(64);
using expand_type = int[];
expand_type{0, (AppendKey(&key, std::forward<ArgTypes>(args)), 0)...};
key += dev_ctx.GetKeySuffix();
return key;
}
inline std::string ExtendKeyWithThreadInfoIfNeeded(
const platform::MKLDNNDeviceContext& dev_ctx, const std::string& key) {
return ((dev_ctx.IsThreadIdUsedInKey() == true) &&
(platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() ==
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default))
? key + "-t:" + ThreadIDasStr()
: key;
}
inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
const std::vector<int64_t>& paddings) {
if (paddings.size() == 6) {
......
......@@ -43,16 +43,9 @@ class MKLDNNHandlerT {
engine_(engine),
place_(cpu_place),
key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)),
fwd_pd_(nullptr),
bwd_pd_(nullptr) {
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
key_ += dev_ctx.GetKeySuffix();
}
bwd_pd_(nullptr) {}
std::shared_ptr<TForward> AcquireForwardPrimitive() {
const std::string key_p = key_ + "@fwd_p";
......@@ -306,8 +299,8 @@ class MKLDNNHandlerT {
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
platform::Place place_;
std::string key_;
std::string key_common_;
std::string key_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
};
......@@ -317,15 +310,10 @@ class MKLDNNHandler {
public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx), engine_(engine), key_common_(base_key) {
if (platform::MKLDNNDeviceContext::tls().get_cur_mkldnn_session_id() !=
platform::MKLDNNDeviceContextThreadLocals::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
key_ += dev_ctx.GetKeySuffix();
}
: dev_ctx_(dev_ctx),
engine_(engine),
key_common_(base_key),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
......@@ -508,8 +496,8 @@ class MKLDNNHandler {
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_;
std::string key_common_;
std::string key_;
};
template <typename T>
......@@ -524,7 +512,7 @@ class BinaryMKLDNNHandler : public platform::MKLDNNHandlerT<T, dnnl::binary> {
: platform::MKLDNNHandlerT<T, dnnl::binary>(
dev_ctx, engine, cpu_place,
platform::CreateKey(
framework::vectorize(x->dims()),
dev_ctx, framework::vectorize(x->dims()),
uniq_name + (algo == dnnl::algorithm::binary_mul ? "M" : ""))) {
// bradcasting combined with in-place may require
auto rankdiff = x->dims().size() - y->dims().size();
......@@ -627,7 +615,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, "a", algorithm, unique_name)) {
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
auto md = mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
......@@ -645,7 +633,7 @@ class ActivationMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::eltwise_forward,
mkldnn::eltwise_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, "a", algorithm, unique_name)) {
platform::CreateKey(dev_ctx, dims, "a", algorithm, unique_name)) {
auto diff_dst_md = platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto src_md =
......@@ -676,7 +664,7 @@ class LRNMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(framework::vectorize(input->dims()),
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
const int n = ctx.Attr<int>("n");
......@@ -712,7 +700,7 @@ class LRNMKLDNNHandler
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, unique_name)) {
platform::CreateKey(dev_ctx, dims, unique_name)) {
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_md =
......@@ -752,7 +740,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(framework::vectorize(input->dims()),
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
framework::ToMKLDNNDataType(input->type()),
unique_name)) {
if (!this->isCached()) {
......@@ -861,7 +849,7 @@ class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(diff_src_dims, dt, unique_name)) {
platform::CreateKey(dev_ctx, diff_src_dims, dt, unique_name)) {
auto diff_dst_md = mkldnn::memory::desc(
diff_dst_dims, platform::MKLDNNGetDataType<T>(), diff_dst_fmt);
auto diff_src_md =
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册