提交 04804674 编写于 作者: X xiaolil1

fix fp32 init regression for se-resnext50 caused by dwc

上级 b5c44fd4
...@@ -378,20 +378,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -378,20 +378,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward> conv_p; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p; std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p; std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
auto prim_key = key + "@conv_p"; auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p"; auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p"; auto src_key = key + "@src_mem_p";
auto user_src_key = key + "@user_src_mem_p";
auto src_reorder_key = key + "@src_mem_p" + "reorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key)); conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_reorder_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key)); src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key)); if(src_memory_reorder_p){
user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(user_src_key));
if (src_memory_p) { user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} else if(src_memory_p){
src_memory_p->set_data_handle(to_void_cast<T>(input_data)); src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} }
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd)); conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd));
std::shared_ptr<ConvMKLDNNHandler> handler; std::shared_ptr<ConvMKLDNNHandler> handler;
...@@ -414,7 +421,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -414,7 +421,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory( auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else { } else {
...@@ -462,11 +468,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -462,11 +468,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md, conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn); fuse_relu, fuse_residual_conn, is_test);
} else { } else {
conv_pd = conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings, ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn); mkldnn_engine, fuse_relu, fuse_residual_conn, is_test);
} }
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
...@@ -474,7 +480,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -474,7 +480,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key)); handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler->AcquireWeightsMemory( auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data)); user_weights_md, to_void_cast<float>(filter_data));
...@@ -508,7 +514,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -508,7 +514,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
residual_data_tz, residual_data_type, residual_param->format()); residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory( auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data)); user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory( dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else { } else {
...@@ -546,10 +551,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -546,10 +551,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else { } else {
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p); pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册