提交 428b2b9e 编写于 作者: A Adam 提交者: Tao Luo

MKLDNN handler cleanup (#19713)

* MKLDNN handler cleanup

* MKLDNN handler cleanup
test=develop
上级 2c30e64b
...@@ -136,7 +136,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -136,7 +136,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
std::vector<memory::primitive_desc> srcs_pd; std::vector<memory::primitive_desc> srcs_pd;
std::vector<float> scales = {1.0f, 1.0f}; std::vector<float> scales = {1.0f, 1.0f};
const std::string key = platform::MKLDNNHandler::GetHash( const std::string key = platform::GetHash(
src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) + src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) +
std::to_string(y->format())); std::to_string(y->format()));
......
...@@ -72,19 +72,17 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -72,19 +72,17 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
for (size_t i = 0; i < multi_input.size(); i++) { for (size_t i = 0; i < multi_input.size(); i++) {
platform::MKLDNNHandler::AppendKeyDims( platform::AppendKeyDims(
&key, paddle::framework::vectorize<int>(multi_input[i]->dims())); &key, paddle::framework::vectorize<int>(multi_input[i]->dims()));
} }
platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis)); platform::AppendKey(&key, std::to_string(concat_axis));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out")); platform::AppendKey(&key, ctx.op().Output("Out"));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, platform::AppendKey(&key, std::to_string(multi_input[0]->format()));
std::to_string(multi_input[0]->format()));
if (platform::get_cur_mkldnn_session_id() == if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) { platform::kMKLDNNSessionID_Default) {
platform::MKLDNNHandler::AppendKey(&key, "-t:"); platform::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey( platform::AppendKey(&key, platform::ThreadIDasStr());
&key, platform::MKLDNNHandler::ThreadIDasStr());
} }
return key; return key;
} }
......
...@@ -417,7 +417,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -417,7 +417,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
std::string key; std::string key;
key.reserve(MaxKeyLength); key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey( platform::ConvMKLDNNHandler::CreateKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt, &key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_activation, fuse_residual_conn, input->format(), fuse_activation, fuse_residual_conn,
ctx.op().Input("Input") + ctx.op().Input("Filter")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
...@@ -439,7 +439,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -439,7 +439,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key_tid = ""; std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() == if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) { platform::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::MKLDNNHandler::ThreadIDasStr(); key_tid = "-t:" + platform::ThreadIDasStr();
} }
auto prim_key = key + key_tid + "@conv_p"; auto prim_key = key + key_tid + "@conv_p";
......
...@@ -36,10 +36,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -36,10 +36,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<int>& src_tz, const float scale_data) { const std::vector<int>& src_tz, const float scale_data) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(src_dt)); platform::AppendKey(&key, std::to_string(src_dt));
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz); platform::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data)); platform::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output")); platform::AppendKey(&key, ctx.op().Output("Output"));
return key; return key;
} }
......
...@@ -35,10 +35,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -35,10 +35,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const bool is_negative) { const bool is_negative) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz); platform::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data)); platform::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(is_negative)); platform::AppendKey(&key, std::to_string(is_negative));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output")); platform::AppendKey(&key, ctx.op().Output("Output"));
return key; return key;
} }
......
...@@ -205,7 +205,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -205,7 +205,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
const std::string key = const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out")); platform::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx, SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key); mkldnn_engine, key);
...@@ -276,7 +276,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -276,7 +276,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// Currently only supports NC data format // Currently only supports NC data format
// retrieve eltwise primitive desc from device context // retrieve eltwise primitive desc from device context
const std::string key = const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Input("Out")); platform::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd"; const std::string key_softmax_pd = key + "@softmax_pd";
auto softmax_pd = auto softmax_pd =
......
...@@ -179,5 +179,33 @@ inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) { ...@@ -179,5 +179,33 @@ inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
} }
} }
inline std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
inline std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
inline void AppendKey(std::string* key, const std::string& s) {
key->append(s);
}
inline std::string GetHash(const mkldnn::memory::dims& operand_dims,
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
inline void AppendKeyDims(std::string* key, const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -38,7 +38,7 @@ class MKLDNNHandler { ...@@ -38,7 +38,7 @@ class MKLDNNHandler {
platform::kMKLDNNSessionID_Default) { platform::kMKLDNNSessionID_Default) {
key_ = key_common_; key_ = key_common_;
} else { } else {
key_ = key_common_ + "-t:" + MKLDNNHandler::ThreadIDasStr(); key_ = key_common_ + "-t:" + ThreadIDasStr();
} }
} }
...@@ -47,35 +47,19 @@ class MKLDNNHandler { ...@@ -47,35 +47,19 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_src_mem_p"); return this->AcquireMemory(md, ptr, "@user_src_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p", custom_func);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p"); return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory( std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p"); return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory( std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p"); return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
...@@ -138,18 +122,6 @@ class MKLDNNHandler { ...@@ -138,18 +122,6 @@ class MKLDNNHandler {
return mem_p; return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemory(
const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mpd);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory( std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::shared_ptr<mkldnn::memory>& user_memory_p, const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p, const std::shared_ptr<mkldnn::memory>& target_memory_p,
...@@ -221,67 +193,6 @@ class MKLDNNHandler { ...@@ -221,67 +193,6 @@ class MKLDNNHandler {
return target_memory_p; return target_memory_p;
} }
static std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
static void AppendKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides);
AppendKeyVec(key, paddings);
AppendKeyVec(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, fuse_activation);
AppendKey(key, std::to_string(residual));
AppendKey(key, suffix);
}
static void AppendKeyDims(std::string* key,
const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
static void AppendKeyVec(std::string* key, const std::vector<int>& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
static void AppendKey(std::string* key, const std::string& s) {
key->append(s);
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
protected: protected:
const MKLDNNDeviceContext& dev_ctx_; const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_; mkldnn::engine engine_;
...@@ -324,6 +235,11 @@ class SumMKLDNNHandler : public MKLDNNHandler { ...@@ -324,6 +235,11 @@ class SumMKLDNNHandler : public MKLDNNHandler {
"@dst_mem_p"); "@dst_mem_p");
} }
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
}
std::shared_ptr<mkldnn::sum> AcquireSum( std::shared_ptr<mkldnn::sum> AcquireSum(
std::shared_ptr<mkldnn::memory> dst_memory, std::shared_ptr<mkldnn::memory> dst_memory,
std::vector<mkldnn::primitive::at>* inputs) { std::vector<mkldnn::primitive::at>* inputs) {
...@@ -458,12 +374,12 @@ class ActivationMKLDNNHandler : public MKLDNNHandler { ...@@ -458,12 +374,12 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
const float beta, const std::string& suffix) { const float beta, const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); platform::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(algorithm)); platform::AppendKey(&key, std::to_string(algorithm));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha)); platform::AppendKey(&key, std::to_string(alpha));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta)); platform::AppendKey(&key, std::to_string(beta));
platform::MKLDNNHandler::AppendKey(&key, suffix); platform::AppendKey(&key, suffix);
return key; return key;
} }
...@@ -609,13 +525,13 @@ class LRNMKLDNNHandler : public MKLDNNHandler { ...@@ -609,13 +525,13 @@ class LRNMKLDNNHandler : public MKLDNNHandler {
const std::string& suffix) { const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); platform::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(n)); platform::AppendKey(&key, std::to_string(n));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha)); platform::AppendKey(&key, std::to_string(alpha));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta)); platform::AppendKey(&key, std::to_string(beta));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(k)); platform::AppendKey(&key, std::to_string(k));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix); platform::AppendKey(&key, suffix);
return key; return key;
} }
...@@ -803,14 +719,14 @@ class PoolingMKLDNNHandler : public MKLDNNHandler { ...@@ -803,14 +719,14 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
const MKLDNNMemoryFormat& fmt, const std::string& suffix) { const MKLDNNMemoryFormat& fmt, const std::string& suffix) {
std::string key; std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims); platform::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, pooling_type); platform::AppendKey(&key, pooling_type);
platform::MKLDNNHandler::AppendKeyVec(&key, ksize); platform::AppendKeyDims(&key, ksize);
platform::MKLDNNHandler::AppendKeyVec(&key, strides); platform::AppendKeyDims(&key, strides);
platform::MKLDNNHandler::AppendKeyVec(&key, paddings); platform::AppendKeyDims(&key, paddings);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt)); platform::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt)); platform::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix); platform::AppendKey(&key, suffix);
return key; return key;
} }
...@@ -1160,6 +1076,17 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1160,6 +1076,17 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
pipeline); pipeline);
} }
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p", custom_func);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p, const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT std::vector<mkldnn::primitive>& pipeline, // NOLINT
...@@ -1368,6 +1295,31 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -1368,6 +1295,31 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
suffix; suffix;
} }
static void CreateKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyDims(key, strides);
AppendKeyDims(key, paddings);
AppendKeyDims(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, fuse_activation);
AppendKey(key, std::to_string(residual));
AppendKey(key, suffix);
}
private: private:
std::shared_ptr<typename forward_t::primitive_desc> conv_pd_; std::shared_ptr<typename forward_t::primitive_desc> conv_pd_;
std::shared_ptr<typename backward_weights_t::primitive_desc> std::shared_ptr<typename backward_weights_t::primitive_desc>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册