提交 04804674 编写于 作者: X xiaolil1

fix fp32 init regression for se-resnext50 caused by dwc

上级 b5c44fd4
......@@ -378,20 +378,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline;
auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p";
auto src_key = key + "@src_mem_p";
auto user_src_key = key + "@user_src_mem_p";
auto src_reorder_key = key + "@src_mem_p" + "reorder_p";
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>(dev_ctx.GetBlob(prim_key));
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_reorder_key));
src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
if (src_memory_p) {
if(src_memory_reorder_p){
user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(user_src_key));
user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
} else if(src_memory_p){
src_memory_p->set_data_handle(to_void_cast<T>(input_data));
}
dst_memory_p = std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key));
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
conv_pd = std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(dev_ctx.GetBlob(key_conv_pd));
std::shared_ptr<ConvMKLDNNHandler> handler;
......@@ -414,7 +421,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else {
......@@ -462,11 +468,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
fuse_relu, fuse_residual_conn, is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn);
mkldnn_engine, fuse_relu, fuse_residual_conn, is_test);
}
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
......@@ -474,7 +480,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
handler.reset(new ConvMKLDNNHandler(conv_pd, dev_ctx, mkldnn_engine, key));
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p =
user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<float>(filter_data));
......@@ -508,7 +514,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
residual_data_tz, residual_data_type, residual_param->format());
auto user_residual_memory_p = handler->AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
dst_memory_p = handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
} else {
......@@ -547,6 +552,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
} else {
if(src_memory_reorder_p){
pipeline.push_back(*src_memory_reorder_p);
}
pipeline.push_back(*conv_p);
stream(stream::kind::eager).submit(pipeline).wait();
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册