未验证 提交 c11d9b30 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN ] conv2d fwd&bwd optimization (#27871)

上级 d932b561
......@@ -211,22 +211,8 @@ class ConvMKLDNNHandlerT
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
// TODO(jczaja): This is workaround to make grad op UT's numerical
// gradient computation proper as this op is called directly without
// fetch op following it , so numercial grad is computed (in python)
// using block formats which will give wrong results
const std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
is_test ? MKLDNNMemoryFormat::any
: platform::data_format_to_memory_format(data_format);
// Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) {
chosen_memory_format = platform::MKLDNNFormatForSize(
src_tz.size(), chosen_memory_format);
}
}
auto chosen_memory_format = MKLDNNMemoryFormat::any;
auto data_type = mkldnn::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value)
......@@ -351,14 +337,16 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
const framework::Tensor* residual_param) {
const T* residual_data = residual_param->data<T>();
void* residual_data =
residual_param->type() == framework::DataTypeTrait<T_out>::DataType()
? to_void_cast<T_out>(residual_param->data<T_out>())
: to_void_cast<T>(residual_param->data<T>());
auto user_residual_md = platform::MKLDNNMemDesc(
framework::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(residual_param->type()),
residual_param->format());
return this->AcquireMemoryFromPrimitive(user_residual_md,
to_void_cast<T>(residual_data),
return this->AcquireMemoryFromPrimitive(user_residual_md, residual_data,
"@user_residual_data_mem_p");
}
......@@ -973,22 +961,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
* the memory format preferred for best performance
*/
// TODO(jczaja): Once GRAD NHWC is working then format 'any'
// should be used exclusively. But till forward pass enforce
// NCHW for training we need to have NCHW here as well
// to avoid performance degradation in relu_grad and pool2d_grad
std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
platform::data_format_to_memory_format(data_format);
auto chosen_memory_format = MKLDNNMemoryFormat::any;
weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
}
}
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
......@@ -1055,9 +1029,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const size_t size = handler.GetDiffWeightsMemorySize();
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
// For convoluition with groups write filter grad into
// oneDNN buffer and then we reorder it into filter_grad tensor
auto diff_weights_memory_p =
handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data));
g > 1 ? handler.AcquireDiffWeightsMemoryFromWeightsPrimitive()
: handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data));
auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights();
......@@ -1072,8 +1049,43 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle
auto filter_fmt = GetMKLDNNFormat(*diff_weights_memory_p);
filter_grad->set_format(platform::MKLDNNFormatForSize(
g > 1 ? weights_tz.size() - 1 : weights_tz.size(), filter_fmt));
// For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data
if (g > 1) {
memory::data_type in_type =
framework::ToMKLDNNDataType(filter_grad->type());
// for 3d conv with groups (six dimensional data reorder to goidhw)
// for 2d conv with groups (five dimensional data reorder to goihw)
mkldnn::memory::format_tag out_format =
weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::goihw;
const std::string key =
platform::CreateKey(weights_tz, filter_fmt, out_format, in_type);
platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(),
in_type, dev_ctx, mkldnn_engine,
key);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace());
auto reorder_p =
handler.AcquireReorder(reorder_dst_memory_p, diff_weights_memory_p);
reorder_p->execute(astream, *diff_weights_memory_p,
*reorder_dst_memory_p);
astream.wait();
// So here we have a data in goihw , which can be interpreted as OIHW
// (OIDHW for conv3d)
// because filter_grad shape is set for OIHW (OIDHW for conv3d)
mkldnn::memory::format_tag target_format =
weights_tz.size() == 6 ? mkldnn::memory::format_tag::oidhw
: mkldnn::memory::format_tag::oihw;
filter_grad->set_format(target_format);
} else {
filter_grad->set_format(filter_fmt);
}
}
if (input_grad) {
auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
......
......@@ -289,6 +289,10 @@ inline mkldnn::memory::format_tag GetMKLDNNFormat(
strides[3] >= strides[4] && strides[4] >= strides[1]) {
return mkldnn::memory::format_tag::Acdeb16a;
}
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
return mkldnn::memory::format_tag::Abcde16a;
}
} else if (inner_blks[0] == 16 && inner_idxs[0] == 1) {
if (strides[0] >= strides[1] && strides[1] >= strides[2] &&
strides[2] >= strides[3] && strides[3] >= strides[4]) {
......
......@@ -346,6 +346,18 @@ class MKLDNNHandler {
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
// This incarnation of AcquireMemory can call user function eg. custom reorder
// or preprocessing routine if needed
std::shared_ptr<mkldnn::memory> AcquireMemory(
......@@ -1199,6 +1211,12 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
conv_bwd_weights_pd_->diff_weights_desc(), ptr, "@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_desc(), "@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
......
......@@ -216,4 +216,6 @@ class TestWithInput1x1Filter1x1(TestConv2DBf16Op):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
......@@ -233,4 +233,6 @@ class TestMKLDNNDilations(TestConv2DMKLDNNOp):
if __name__ == '__main__':
from paddle import enable_static
enable_static()
unittest.main()
......@@ -110,4 +110,6 @@ class TestFusionGRUINT8MKLDNNOp3(TestFusionGRUBF16MKLDNNOp):
if __name__ == "__main__":
from paddle import enable_static
enable_static()
unittest.main()
......@@ -1320,6 +1320,13 @@ class OpTest(unittest.TestCase):
cache_list = None
if hasattr(self, "cache_name_list"):
cache_list = self.cache_name_list
# oneDNN numeric gradient should use CPU kernel
use_onednn = False
if "use_mkldnn" in op_attrs and op_attrs["use_mkldnn"] == True:
op_attrs["use_mkldnn"] = False
use_onednn = True
self.op = create_op(
self.scope,
self.op_type,
......@@ -1328,6 +1335,9 @@ class OpTest(unittest.TestCase):
op_attrs,
cache_list=cache_list)
if use_onednn:
op_attrs["use_mkldnn"] = True
if no_grad_set is None:
no_grad_set = set()
else:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册