提交 ff7c0a23 编写于 作者: X xiaolil1

enable key reuse

上级 ed1be6a8
...@@ -293,6 +293,102 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -293,6 +293,102 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
conv_bwd_data_pd_; conv_bwd_data_pd_;
}; };
struct key_desc{
struct Hash{
std::size_t operator()(const key_desc &key) const{
int input_dim = 0;
int weights_dim = 0;
int stride_value = 0;
int padding_value = 0;
int dilation_value = 0;
for(size_t i=0; i<key.input_tz.size(); i++){
input_dim += key.input_tz[i];
}
for(size_t i=0; i<key.weights_tz.size(); i++){
weights_dim += key.weights_tz[i];
}
for(size_t i=0; i<key.strides.size(); i++){
stride_value += key.strides[i];
}
for(size_t i=0; i<key.paddings.size(); i++){
padding_value += key.paddings[i];
}
for(size_t i=0; i<key.dilations.size(); i++){
dilation_value += key.dilations[i];
}
std::hash<int> hasher;
return hasher( (input_dim << 8) +
(weights_dim << 8 * 2) +
(stride_value << 8 * 3) +
(padding_value << 8) +
(dilation_value << 8 * 2) +
(key.groups << 8 * 3));
}
};
std::vector<int> input_tz;
std::vector<int> weights_tz;
std::vector<int> strides;
std::vector<int> paddings;
std::vector<int> dilations;
int groups;
const std::string suffix;
key_desc(std::vector<int> input_tz, std::vector<int> weights_tz, std::vector<int> strides, std::vector<int> paddings, std::vector<int> dilations,int groups,const std::string suffix): input_tz(input_tz), weights_tz(weights_tz), strides(strides), paddings(paddings), dilations(dilations), groups(groups), suffix(suffix) {}
bool operator==(const key_desc o) const{
for(size_t i=0; i<input_tz.size(); i++){
if(input_tz[i] != o.input_tz[i])
return false;
}
for(size_t i=0; i<weights_tz.size(); i++){
if(weights_tz[i] != o.weights_tz[i])
return false;
}
for(size_t i=0; i<strides.size(); i++){
if(strides[i] != o.strides[i])
return false;
}
for(size_t i=0; i<paddings.size(); i++){
if(paddings[i] != o.paddings[i])
return false;
}
for(size_t i=0; i<dilations.size(); i++){
if(dilations[i] != o.dilations[i])
return false;
}
if(groups != o.groups) return false;
if(suffix != o.suffix) return false;
return true;
}
bool operator!=(const key_desc& o) const { return !(*this == o); }
};
class handle_key{
public:
void SetKeyMap(std::unordered_map<key_desc, std::string, key_desc::Hash> &key_map, key_desc key_dsr, std::string key){
auto it = key_map.find(key_dsr);
if (it == key_map.end()) {
key_map[key_dsr] = key; // create new blob
} else {
(*it).second = key; // set data to existing blob
}
return;
}
std::string GetKeyMap(std::unordered_map<key_desc, std::string, key_desc::Hash> &key_map, key_desc key_dsr){
auto it = key_map.find(key_dsr);
if (it != key_map.end()) {
return (*it).second;
}
return "";
}
};
template <typename T> template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public: public:
...@@ -353,7 +449,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -353,7 +449,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float* filter_data = filter->data<float>(); const float* filter_data = filter->data<float>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims());
std::vector<int> weights_tz = std::vector<int> weights_tz =
paddle::framework::vectorize2int(filter->dims()); paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
if (g > 1) { if (g > 1) {
...@@ -371,20 +467,28 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -371,20 +467,28 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = ConvMKLDNNHandler::GetHash( handle_key keyhandler;
src_tz, weights_tz, strides, paddings, dilations, groups, key_desc key_dsr = {src_tz, weights_tz, strides, paddings, dilations, groups, ctx.op().Output("Output")};
ctx.op().Output("Output"));
static std::unordered_map<key_desc, std::string, key_desc::Hash> key_map;
static std::shared_ptr<std::unordered_map<ConvMKLDNNHandler::key_suffix_desc, std::string, ConvMKLDNNHandler::key_suffix_desc::Hash>> key_suffix_map(new std::unordered_map<ConvMKLDNNHandler::key_suffix_desc, std::string, ConvMKLDNNHandler::key_suffix_desc::Hash>({}));
bool key_reuse = true;
std::string none_key = "";
if(keyhandler.GetKeyMap(key_map, key_dsr) == none_key){
key_reuse = false;
}
std::string key;
if(!key_reuse){
key = ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output"));
keyhandler.SetKeyMap(key_map, key_dsr, key);
} else{
key = keyhandler.GetKeyMap(key_map, key_dsr);
}
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
static std::unordered_map<std::string, std::vector<std::vector<float>>> scale_map; static std::unordered_map<std::string, std::vector<std::vector<float>>> scale_map;
//scale_map.insert({key_conv_pd,{1.0f}});
//scale_map[key_conv_pd]={0.1f};
bool scale_reuse = true; bool scale_reuse = true;
//auto scale_in_key = key + "@scale_in";
//auto scale_weights_key = key + "@scale_weights";
//auto scale_out_key = key + "@scale_out";
//auto output_shift_scale_key = key + "@output_shift_scale";
//auto sum_scale_key = key + "@sum_scale";
//auto scale_in_eltwise_key = key + "@scale_in_eltwise";
std::vector<float> scale_in_data; std::vector<float> scale_in_data;
std::vector<float> scale_out_data; std::vector<float> scale_out_data;
std::vector<float> scale_weights_data; std::vector<float> scale_weights_data;
...@@ -610,6 +714,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -610,6 +714,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key); ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
handler.key_suffix_map_ = key_suffix_map;
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
......
...@@ -115,6 +115,28 @@ class MKLDNNHandler { ...@@ -115,6 +115,28 @@ class MKLDNNHandler {
key_(base_key), key_(base_key),
is_reusing_(false) {} is_reusing_(false) {}
struct key_suffix_desc{
struct Hash{
std::size_t operator()(const key_suffix_desc &dsr) const{
int a = std::atoi(dsr.key.c_str());
int b = std::atoi(dsr.suffix.c_str());
std::hash<int> hasher;
return hasher( (a << 8) +
(b << 8 * 2));
}
};
std::string key;
std::string suffix;
key_suffix_desc(std::string key, std::string suffix): key(key), suffix(suffix) {}
bool operator==(const key_suffix_desc o) const{
return(key == o.key && suffix == o.suffix);
}
bool operator!=(const key_suffix_desc& o) const { return !(*this == o); }
};
std::shared_ptr<mkldnn::memory> AcquireSrcMemory( std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) { const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src_mem_p"); return this->AcquireMemory(md, ptr, "@user_src_mem_p");
...@@ -148,7 +170,20 @@ class MKLDNNHandler { ...@@ -148,7 +170,20 @@ class MKLDNNHandler {
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive( std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::primitive_desc mdp, void* ptr, mkldnn::memory::primitive_desc mdp, void* ptr,
const std::string& suffix) { const std::string& suffix) {
auto local_key = key_ + suffix; std::string local_key;
if(key_suffix_map_) {
key_suffix_desc dsr = {key_ , suffix};
if(GetKeySuffixMap(key_suffix_map_, dsr) == ""){
//std::cout<<"create key!!!!!!!"<<std::endl;
local_key = key_ + suffix;
SetKeySuffixMap(key_suffix_map_, dsr, local_key);
} else{
local_key = GetKeySuffixMap(key_suffix_map_, dsr);
}
} else{
local_key = key_ + suffix;
}
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
...@@ -170,7 +205,20 @@ class MKLDNNHandler { ...@@ -170,7 +205,20 @@ class MKLDNNHandler {
void* ptr, void* ptr,
const std::string& suffix) { const std::string& suffix) {
/*Generate key*/ /*Generate key*/
auto local_key = key_ + suffix; std::string local_key;
if(key_suffix_map_){
key_suffix_desc dsr = {key_ , suffix};
if(GetKeySuffixMap(key_suffix_map_, dsr) == ""){
//std::cout<<"create key!!!!!!!"<<std::endl;
local_key = key_ + suffix;
SetKeySuffixMap(key_suffix_map_, dsr, local_key);
} else{
local_key = GetKeySuffixMap(key_suffix_map_, dsr);
}
} else{
local_key = key_ + suffix;
}
auto mem_p = auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false), PADDLE_ENFORCE((mem_p != nullptr) || (is_reusing_ == false),
...@@ -193,7 +241,21 @@ class MKLDNNHandler { ...@@ -193,7 +241,21 @@ class MKLDNNHandler {
const std::shared_ptr<mkldnn::memory>& target_memory_p, const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix, const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto local_key = key_ + suffix;
std::string local_key;
if(key_suffix_map_){
key_suffix_desc dsr = {key_ , suffix};
if(GetKeySuffixMap(key_suffix_map_, dsr) == ""){
//std::cout<<"create key!!!!!!!"<<std::endl;
local_key = key_ + suffix;
SetKeySuffixMap(key_suffix_map_, dsr, local_key);
} else{
local_key = GetKeySuffixMap(key_suffix_map_, dsr);
}
} else{
local_key = key_ + suffix;
}
auto key_reorder_p = key_ + suffix + "reorder_p"; auto key_reorder_p = key_ + suffix + "reorder_p";
auto stored_reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto stored_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
...@@ -222,7 +284,20 @@ class MKLDNNHandler { ...@@ -222,7 +284,20 @@ class MKLDNNHandler {
std::vector<float> scale_data = {1.0f}, std::vector<float> scale_data = {1.0f},
int mask = 0) { int mask = 0) {
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix; std::string local_key;
if(key_suffix_map_){
key_suffix_desc dsr = {key_ , suffix};
if(GetKeySuffixMap(key_suffix_map_, dsr) == ""){
//std::cout<<"create key!!!!!!!"<<std::endl;
local_key = key_ + suffix;
SetKeySuffixMap(key_suffix_map_, dsr, local_key);
} else{
local_key = GetKeySuffixMap(key_suffix_map_, dsr);
}
} else{
local_key = key_ + suffix;
}
auto key_reorder_p = key_ + suffix + "reorder_p"; auto key_reorder_p = key_ + suffix + "reorder_p";
auto target_memory_p = auto target_memory_p =
...@@ -268,6 +343,26 @@ class MKLDNNHandler { ...@@ -268,6 +343,26 @@ class MKLDNNHandler {
return dims2str(operand_dims) + suffix; return dims2str(operand_dims) + suffix;
} }
void SetKeySuffixMap(std::shared_ptr<std::unordered_map<key_suffix_desc, std::string, key_suffix_desc::Hash>> key_suffix_map, key_suffix_desc key_suffix_dsr, std::string key){
auto it = (*key_suffix_map).find(key_suffix_dsr);
if (it == (*key_suffix_map).end()) {
(*key_suffix_map)[key_suffix_dsr] = key; // create new blob
} else {
(*it).second = key; // set data to existing blob
}
return;
}
std::string GetKeySuffixMap(std::shared_ptr<std::unordered_map<key_suffix_desc, std::string, key_suffix_desc::Hash>> key_suffix_map, key_suffix_desc key_suffix_dsr){
auto it = (*key_suffix_map).find(key_suffix_dsr);
if (it != (*key_suffix_map).end()) {
return (*it).second;
}
return "";
}
std::shared_ptr<std::unordered_map<key_suffix_desc, std::string, key_suffix_desc::Hash>> key_suffix_map_;
protected: protected:
static std::string dims2str(const mkldnn::memory::dims& operand_dims) { static std::string dims2str(const mkldnn::memory::dims& operand_dims) {
std::string dstr = ""; std::string dstr = "";
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册