提交 fa164241 编写于 作者: Z Zhang, Guoming

merge PR-14185

上级 03df18c9
...@@ -60,6 +60,11 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -60,6 +60,11 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
return conv_pd_->dst_primitive_desc().get_size(); return conv_pd_->dst_primitive_desc().get_size();
} }
mkldnn::memory::format GetDstFormat() const {
return static_cast<mkldnn::memory::format>(
conv_pd_->dst_primitive_desc().desc().data.format);
}
size_t GetDiffWeightsMemorySize() const { size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size(); return conv_bwd_weights_pd_->diff_weights_primitive_desc().get_size();
} }
...@@ -111,6 +116,21 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -111,6 +116,21 @@ class ConvMKLDNNHandler : public platform::MKLDNNHandler {
"@data-weights_mem_p", pipeline); "@data-weights_mem_p", pipeline);
} }
std::shared_ptr<mkldnn::memory> AcquireResidualDataMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_residual_data_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromResidualDataMemory(
const std::shared_ptr<mkldnn::memory>& user_residual_memory_p,
void* dst_ptr,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
return this->AcquireMemory(user_residual_memory_p,
this->AcquireDstMemoryFromPrimitive(dst_ptr),
"@residual_data_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive( std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) { void* ptr) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
...@@ -505,6 +525,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -505,6 +525,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// create reorder primitive if the input format is not the preferred one // create reorder primitive if the input format is not the preferred one
auto src_memory_p = auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> weights_memory_p; std::shared_ptr<mkldnn::memory> weights_memory_p;
if(is_INT8){ if(is_INT8){
int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0; int mask_reorder = is_multi_channel? ((g!= 1) ? (1<<1)+(1<<0) : 1<<0) : 0;
...@@ -549,7 +570,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -549,7 +570,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
} }
} else{ } else{
T* output_data = nullptr; // create reorder primitive if the input format is not the preferred one
// auto src_memory_p =
// handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
// auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
// user_weights_memory_p, pipeline, is_test);
// std::shared_ptr<mkldnn::memory> dst_memory_p;
if (fuse_residual_conn) { if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData"); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>(); auto residual_param_data = residual_param->data<T>();
...@@ -561,15 +588,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -561,15 +588,35 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Output and elementwise parameter need to have the " "Output and elementwise parameter need to have the "
"same dimension sizes"); "same dimension sizes");
if (residual_param->format() != handler.GetDstFormat()) {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else {
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
output_data = output->mutable_data<T>(ctx.GetPlace()); auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
} else { } else {
output_data =
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
}
dst_memory_p = dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
// dst_memory_p =
// handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
}
// create convolution op primitive // create convolution op primitive
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
...@@ -912,7 +959,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -912,7 +959,6 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
user_weights_md, to_void_cast<T>(filter_data)); user_weights_md, to_void_cast<T>(filter_data));
auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory( auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
user_diff_dst_md, to_void_cast<T>(output_grad_data)); user_diff_dst_md, to_void_cast<T>(output_grad_data));
// create backward conv primitive for weights // create backward conv primitive for weights
if (filter_grad) { if (filter_grad) {
auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive( auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
......
...@@ -188,6 +188,26 @@ class MKLDNNHandler { ...@@ -188,6 +188,26 @@ class MKLDNNHandler {
return mem_p; return mem_p;
} }
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto stored_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (stored_reorder_p) {
pipeline.push_back(*stored_reorder_p);
} else {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
pipeline.push_back(*reorder_p);
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory( std::shared_ptr<mkldnn::memory> AcquireMemory(
mkldnn::memory::primitive_desc& mpd, // NOLINT mkldnn::memory::primitive_desc& mpd, // NOLINT
mkldnn::memory::primitive_desc& user_mpd, // NOLINT mkldnn::memory::primitive_desc& user_mpd, // NOLINT
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册