You need to sign in or sign up before continuing.
提交 428b2b9e 编写于 作者: A Adam 提交者: Tao Luo

MKLDNN handler cleanup (#19713)

* MKLDNN handler cleanup

* MKLDNN handler cleanup
test=develop
上级 2c30e64b
......@@ -136,7 +136,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
std::vector<memory::primitive_desc> srcs_pd;
std::vector<float> scales = {1.0f, 1.0f};
const std::string key = platform::MKLDNNHandler::GetHash(
const std::string key = platform::GetHash(
src_x_tz, ctx.op().Output("Out") + std::to_string(x->format()) +
std::to_string(y->format()));
......
......@@ -72,19 +72,17 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
for (size_t i = 0; i < multi_input.size(); i++) {
platform::MKLDNNHandler::AppendKeyDims(
platform::AppendKeyDims(
&key, paddle::framework::vectorize<int>(multi_input[i]->dims()));
}
platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out"));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key,
std::to_string(multi_input[0]->format()));
platform::AppendKey(&key, std::to_string(concat_axis));
platform::AppendKey(&key, ctx.op().Output("Out"));
platform::AppendKey(&key, std::to_string(dt));
platform::AppendKey(&key, std::to_string(multi_input[0]->format()));
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
platform::MKLDNNHandler::AppendKey(&key, "-t:");
platform::MKLDNNHandler::AppendKey(
&key, platform::MKLDNNHandler::ThreadIDasStr());
platform::AppendKey(&key, "-t:");
platform::AppendKey(&key, platform::ThreadIDasStr());
}
return key;
}
......
......@@ -417,7 +417,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// Get unique name for storing MKLDNN primitives
std::string key;
key.reserve(MaxKeyLength);
platform::ConvMKLDNNHandler::AppendKey(
platform::ConvMKLDNNHandler::CreateKey(
&key, src_tz, weights_tz, strides, paddings, dilations, groups, src_dt,
input->format(), fuse_activation, fuse_residual_conn,
ctx.op().Input("Input") + ctx.op().Input("Filter"));
......@@ -439,7 +439,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::string key_tid = "";
if (platform::get_cur_mkldnn_session_id() ==
platform::kMKLDNNSessionID_Default) {
key_tid = "-t:" + platform::MKLDNNHandler::ThreadIDasStr();
key_tid = "-t:" + platform::ThreadIDasStr();
}
auto prim_key = key + key_tid + "@conv_p";
......
......@@ -36,10 +36,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const std::vector<int>& src_tz, const float scale_data) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(src_dt));
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output"));
platform::AppendKey(&key, std::to_string(src_dt));
platform::AppendKeyDims(&key, src_tz);
platform::AppendKey(&key, std::to_string(scale_data));
platform::AppendKey(&key, ctx.op().Output("Output"));
return key;
}
......
......@@ -35,10 +35,10 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
const bool is_negative) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, src_tz);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(scale_data));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(is_negative));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Output"));
platform::AppendKeyDims(&key, src_tz);
platform::AppendKey(&key, std::to_string(scale_data));
platform::AppendKey(&key, std::to_string(is_negative));
platform::AppendKey(&key, ctx.op().Output("Output"));
return key;
}
......
......@@ -205,7 +205,7 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
platform::GetHash(softmax_tz, ctx.op().Output("Out"));
SoftmaxMKLDNNHandler<T> handler(softmax_tz, MKLDNNMemoryFormat::nc, dev_ctx,
mkldnn_engine, key);
......@@ -276,7 +276,7 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
// Currently only supports NC data format
// retrieve eltwise primitive desc from device context
const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Input("Out"));
platform::GetHash(softmax_tz, ctx.op().Input("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";
auto softmax_pd =
......
......@@ -179,5 +179,33 @@ inline MKLDNNMemoryFormat StringToMKLDNNFormat(std::string* format) {
}
}
inline std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
inline std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
inline void AppendKey(std::string* key, const std::string& s) {
key->append(s);
}
inline std::string GetHash(const mkldnn::memory::dims& operand_dims,
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
inline void AppendKeyDims(std::string* key, const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
} // namespace platform
} // namespace paddle
......@@ -38,7 +38,7 @@ class MKLDNNHandler {
platform::kMKLDNNSessionID_Default) {
key_ = key_common_;
} else {
key_ = key_common_ + "-t:" + MKLDNNHandler::ThreadIDasStr();
key_ = key_common_ + "-t:" + ThreadIDasStr();
}
}
......@@ -47,35 +47,19 @@ class MKLDNNHandler {
return this->AcquireMemory(md, ptr, "@user_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p", custom_func);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p");
return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p");
return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
......@@ -138,18 +122,6 @@ class MKLDNNHandler {
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const mkldnn::memory::primitive_desc& mpd, const std::string& suffix) {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(mpd);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
......@@ -221,67 +193,6 @@ class MKLDNNHandler {
return target_memory_p;
}
static std::string ThreadIDasStr(void) {
return std::to_string(
std::hash<std::thread::id>()(std::this_thread::get_id()));
}
static std::string GetHash(mkldnn::memory::dims& operand_dims, // NOLINT
const std::string& suffix) {
return dims2str(operand_dims) + suffix;
}
static void AppendKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyVec(key, strides);
AppendKeyVec(key, paddings);
AppendKeyVec(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, fuse_activation);
AppendKey(key, std::to_string(residual));
AppendKey(key, suffix);
}
static void AppendKeyDims(std::string* key,
const mkldnn::memory::dims& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
static void AppendKeyVec(std::string* key, const std::vector<int>& dims) {
for (unsigned int i = 0; i < dims.size(); i++) {
AppendKey(key, std::to_string(dims[i]));
}
}
static void AppendKey(std::string* key, const std::string& s) {
key->append(s);
}
protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = "";
for (size_t i = 0; i < operand_dims.size(); ++i) {
dstr += std::to_string(operand_dims[i]) + "-";
}
return dstr;
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
......@@ -324,6 +235,11 @@ class SumMKLDNNHandler : public MKLDNNHandler {
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSecondSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src2_mem_p");
}
std::shared_ptr<mkldnn::sum> AcquireSum(
std::shared_ptr<mkldnn::memory> dst_memory,
std::vector<mkldnn::primitive::at>* inputs) {
......@@ -458,12 +374,12 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
const float beta, const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(algorithm));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta));
platform::MKLDNNHandler::AppendKey(&key, suffix);
platform::AppendKeyDims(&key, input_dims);
platform::AppendKey(&key, std::to_string(algorithm));
platform::AppendKey(&key, std::to_string(fmt));
platform::AppendKey(&key, std::to_string(alpha));
platform::AppendKey(&key, std::to_string(beta));
platform::AppendKey(&key, suffix);
return key;
}
......@@ -609,13 +525,13 @@ class LRNMKLDNNHandler : public MKLDNNHandler {
const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(n));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(k));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
platform::AppendKeyDims(&key, input_dims);
platform::AppendKey(&key, std::to_string(n));
platform::AppendKey(&key, std::to_string(alpha));
platform::AppendKey(&key, std::to_string(beta));
platform::AppendKey(&key, std::to_string(k));
platform::AppendKey(&key, std::to_string(fmt));
platform::AppendKey(&key, suffix);
return key;
}
......@@ -803,14 +719,14 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
const MKLDNNMemoryFormat& fmt, const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, pooling_type);
platform::MKLDNNHandler::AppendKeyVec(&key, ksize);
platform::MKLDNNHandler::AppendKeyVec(&key, strides);
platform::MKLDNNHandler::AppendKeyVec(&key, paddings);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(dt));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
platform::AppendKeyDims(&key, input_dims);
platform::AppendKey(&key, pooling_type);
platform::AppendKeyDims(&key, ksize);
platform::AppendKeyDims(&key, strides);
platform::AppendKeyDims(&key, paddings);
platform::AppendKey(&key, std::to_string(dt));
platform::AppendKey(&key, std::to_string(fmt));
platform::AppendKey(&key, suffix);
return key;
}
......@@ -1160,6 +1076,17 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p", custom_func);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
......@@ -1368,6 +1295,31 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
suffix;
}
static void CreateKey(
std::string* key, const mkldnn::memory::dims& input_dims,
const mkldnn::memory::dims& weights_dims, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& dilations,
const int& groups, const mkldnn::memory::data_type& srcdt,
const MKLDNNMemoryFormat& format, const std::string& fuse_activation,
const bool& residual, const std::string& suffix) {
AppendKeyDims(key, input_dims);
AppendKeyDims(key, weights_dims);
AppendKeyDims(key, strides);
AppendKeyDims(key, paddings);
AppendKeyDims(key, dilations);
AppendKey(key, std::to_string(groups));
AppendKey(key, std::to_string(srcdt));
AppendKey(key, std::to_string(format));
AppendKey(key, fuse_activation);
AppendKey(key, std::to_string(residual));
AppendKey(key, suffix);
}
private:
std::shared_ptr<typename forward_t::primitive_desc> conv_pd_;
std::shared_ptr<typename backward_weights_t::primitive_desc>
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册