未验证 提交 8c6bbb48 编写于 作者: J Jacek Czaja 提交者: GitHub

[oneDNN] Accesses to oneDNN cache optimized for conv2d (#33048)

上级 9b203ef3
......@@ -74,7 +74,9 @@ static mkldnn::memory::data_type GetDstType(bool is_int8, bool is_bfloat16,
template <typename T, typename K, typename T_out>
class ConvMKLDNNHandlerT
: public platform::MKLDNNHandlerT<T, mkldnn::convolution_forward> {
: public platform::MKLDNNHandlerT<T, mkldnn::convolution_forward,
mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights> {
public:
ConvMKLDNNHandlerT(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
......@@ -82,11 +84,13 @@ class ConvMKLDNNHandlerT
platform::Place cpu_place, const Tensor* input,
const Tensor* filter, const Tensor* bias, Tensor* output,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::convolution_forward>(
: platform::MKLDNNHandlerT<T, mkldnn::convolution_forward,
mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
if (!this->isCachedNonBlocking()) {
PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
......@@ -224,12 +228,12 @@ class ConvMKLDNNHandlerT
auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptorNonBlocking(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, bias_md, dst_md, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptor(
this->AcquireForwardPrimitiveDescriptorNonBlocking(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, dst_md, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
......@@ -237,6 +241,142 @@ class ConvMKLDNNHandlerT
}
}
ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const Tensor* in,
const Tensor* filter, const Tensor* bias,
const Tensor* out_grad, Tensor* filter_grad,
Tensor* in_x_grad, const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::convolution_forward,
mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(in->dims()),
unique_name)) {
if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ(
in->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, in->layout()));
PADDLE_ENFORCE_NE(in->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor."));
PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Filter tensor."));
PADDLE_ENFORCE_EQ(
out_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The output_grad tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, out_grad->layout()));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::InvalidArgument(
"is_test attribute should be set to False in training phase."));
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp),
end(dilations_temp));
std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
int groups = ctx.Attr<int>("groups");
auto input_dims = in->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims();
auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize);
auto src_tz = framework::vectorize(in->dims());
auto weights_tz = framework::vectorize(filter->dims());
int g = std::max(groups, 1);
platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(out_grad->dims());
/* create memory descriptor for conv backward without specified format
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const auto weights_format = MKLDNNMemoryFormat::any;
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
std::transform(dilations.begin(), dilations.end(), dilations.begin(),
[](int64_t i) { return i - 1; });
const mkldnn::memory::dims dilations_dims = dilations;
const mkldnn::memory::dims stride_dims = strides;
// Recreating FWD PD. For training there are no post ops in convolution
mkldnn::primitive_attr conv_attr;
if (bias) {
auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptorNonBlocking(
conv_attr, mkldnn::prop_kind::forward_training,
dnnl::algorithm::convolution_direct, src_md, weights_md, bias_md,
dst_md, stride_dims, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptorNonBlocking(
conv_attr, mkldnn::prop_kind::forward_training,
dnnl::algorithm::convolution_direct, src_md, weights_md, dst_md,
stride_dims, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
}
this->AcquireBackwardPrimitiveDescriptorNonBlocking(
mkldnn::algorithm::convolution_direct, diff_src_md, weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
this->AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(
mkldnn::algorithm::convolution_direct, src_md, diff_weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
}
}
mkldnn::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
......@@ -280,27 +420,75 @@ class ConvMKLDNNHandlerT
return conv_attr;
}
std::shared_ptr<mkldnn::memory>
AcquireWeightsMemoryWithReorderFromDataPrimitive(
const framework::Tensor* filter, const int groups, const bool is_conv3d) {
const K* filter_data = filter->data<K>();
auto weights_tz = framework::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));
return this->AcquireMemoryWithReorder(
user_src_md, this->bwd_pd_->weights_desc(),
to_void_cast<K>(filter_data), "@weights_mem_d_p", false);
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
const std::string user_key_suffix{"@src_mem_p_user"};
auto user_src_mem_p = this->AcquireMemory(user_key_suffix);
return this->AcquireMemoryWithReorderPrimitive(
input, "@src_mem_p_user", "@src_mem_p_target", "@src_mem_p",
this->fwd_pd_->src_desc());
}
if (!user_src_mem_p) {
auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
std::shared_ptr<mkldnn::memory>
AcquireSrcMemoryWithReorderFromWeightsPrimitive(
const framework::Tensor* input) {
return this->AcquireMemoryWithReorderPrimitive(
input, "@src_mem_w_p_user", "@src_mem_w_p_target", "@src_mem_w_p",
this->bwd_w_pd_->src_desc());
}
std::shared_ptr<mkldnn::memory>
AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
const framework::Tensor* out_grad) {
return this->AcquireMemoryWithReorderPrimitive(
out_grad, "@diff_dst_mem_w_p_user", "@diff_dst_mem_w_p_target",
"@diff_dst_mem_w_p", this->bwd_w_pd_->diff_dst_desc());
}
std::shared_ptr<mkldnn::memory>
AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
const framework::Tensor* out_grad) {
return this->AcquireMemoryWithReorderPrimitive(
out_grad, "@diff_dst_mem_p_user", "@diff_dst_mem_p_target",
"@diff_dst_mem_p", this->bwd_pd_->diff_dst_desc());
}
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorderPrimitive(
const framework::Tensor* in_mem, const char* key_mem_user,
const char* key_mem_target, const char* key_mem,
const mkldnn::memory::desc& mem_md) {
const T* in_mem_data = in_mem->data<T>();
const std::string user_key_suffix{key_mem_user};
auto user_mem_p = this->AcquireMemory(user_key_suffix);
if (!user_mem_p) {
auto user_mem_md = platform::MKLDNNMemDesc(
framework::vectorize(in_mem->dims()),
platform::MKLDNNGetDataType<T>(), in_mem->format());
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
"@src_mem_p");
user_mem_md, mem_md, to_void_cast<T>(in_mem_data), key_mem);
} else {
const std::string target_key_suffix{"@src_mem_p_target"};
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem);
}
return target_src_mem_p;
return target_mem_p;
}
}
......@@ -866,7 +1054,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
}
};
template <typename T>
template <typename T, typename K>
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
......@@ -879,189 +1067,44 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const Tensor* input = ctx.Input<Tensor>("Input");
const Tensor* filter = ctx.Input<Tensor>("Filter");
const Tensor* bias =
ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
const Tensor* output_grad =
ctx.Input<Tensor>(framework::GradVarName("Output"));
Tensor* input_grad = ctx.Output<Tensor>(framework::GradVarName("Input"));
Tensor* filter_grad = ctx.Output<Tensor>(framework::GradVarName("Filter"));
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, input->layout()));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor."));
PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Filter tensor."));
PADDLE_ENFORCE_EQ(
output_grad->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The output_grad tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, output_grad->layout()));
PADDLE_ENFORCE_NE(output_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for output_grad tensor"));
PADDLE_ENFORCE_EQ(
ctx.Attr<bool>("is_test"), false,
platform::errors::InvalidArgument(
"is_test attribute should be set to False in training phase."));
if (!input_grad && !filter_grad) return;
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
int groups = ctx.Attr<int>("groups");
bool is_conv3d = strides.size() == 3U;
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const T* output_grad_data = output_grad->data<T>();
T* input_grad_data = nullptr;
T* filter_grad_data = nullptr;
auto input_dims = input->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims();
auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize);
auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output_grad->dims());
auto src_format = input->format();
MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
// Get an unique name from "argument" name of "input" and "Filter" variable
// as well as attributes of primitive to be created
// This name will be used as key when saving info into device context
std::string key = platform::CreateKey(
dev_ctx, src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@fwd_pd";
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
std::vector<primitive> pipeline;
// Create user memory descriptors
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
auto user_diff_dst_md = platform::MKLDNNMemDesc(
{dst_tz}, platform::MKLDNNGetDataType<T>(), output_grad->format());
/* create memory descriptor for conv backward without specified format
* ('any') which lets a primitive (conv backward in this case) choose
* the memory format preferred for best performance
*/
auto chosen_memory_format = MKLDNNMemoryFormat::any;
weights_format = MKLDNNMemoryFormat::any;
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto diff_src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
auto diff_dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
// Retrieve conv_pd from device context
auto conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
PADDLE_ENFORCE_NE(conv_pd, nullptr,
platform::errors::InvalidArgument(
"Fail to find conv_pd in device context"));
auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
std::transform(dilations.begin(), dilations.end(), dilations.begin(),
[](int64_t i) { return i - 1; });
const mkldnn::memory::dims dilations_dims = dilations;
// create backward convolution weights primitive descriptor
auto conv_bwd_weights_desc = mkldnn::convolution_backward_weights::desc(
mkldnn::algorithm::convolution_direct, src_md, diff_weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
auto conv_bwd_weights_pd =
std::make_shared<mkldnn::convolution_backward_weights::primitive_desc>(
conv_bwd_weights_desc, mkldnn_engine, *conv_pd);
// create backward convolution data primitive descriptor
auto conv_bwd_data_desc = mkldnn::convolution_backward_data::desc(
mkldnn::algorithm::convolution_direct, diff_src_md, weights_md,
diff_dst_md, strides, dilations_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
auto conv_bwd_data_pd =
std::make_shared<mkldnn::convolution_backward_data::primitive_desc>(
conv_bwd_data_desc, mkldnn_engine, *conv_pd);
platform::ConvMKLDNNHandler handler(conv_pd, conv_bwd_data_pd,
conv_bwd_weights_pd, dev_ctx,
mkldnn_engine, key);
// TODO(jczaja): Are all tensors really needed?
ConvMKLDNNHandlerT<T, K, T> handler(
ctx, dev_ctx, ctx.GetPlace(), input, filter, bias, output_grad,
filter_grad, input_grad,
ctx.InputName("Input") + ctx.InputName("Filter"));
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
auto user_diff_dst_memory_p = handler.AcquireDiffDstMemory(
user_diff_dst_md, to_void_cast<T>(output_grad_data));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (filter_grad) {
auto src_memory_p = handler.AcquireSrcMemoryFromWeightsPrimitive(
user_src_memory_p, pipeline);
auto diff_dst_memory_4filter_p =
handler.AcquireDiffDstMemoryFromWeightsPrimitive(
user_diff_dst_memory_p, pipeline);
const size_t size = handler.GetDiffWeightsMemorySize();
filter_grad_data = filter_grad->mutable_data<T>(ctx.GetPlace(), size);
if (filter_grad) {
auto src_memory_p =
handler.AcquireSrcMemoryWithReorderFromWeightsPrimitive(input);
auto diff_dst_memory_p =
handler.AcquireDiffDstMemoryWithReorderFromWeightsPrimitive(
output_grad);
// For convoluition with groups write filter grad into
// oneDNN buffer and then we reorder it into filter_grad tensor
int g = std::max(ctx.Attr<int>("groups"), 1);
auto diff_weights_memory_p =
g > 1 ? handler.AcquireDiffWeightsMemoryFromWeightsPrimitive()
: handler.AcquireDiffWeightsMemoryFromWeightsPrimitive(
reinterpret_cast<void*>(filter_grad_data));
g > 1 ? handler.AcquireDiffWeightsMemory()
: handler.AcquireDiffWeightsMemory(filter_grad);
auto conv_bwd_weights_p = handler.AcquireConvolutionBackwardWeights();
auto conv_bwd_weights_p = handler.AcquireBackwardWeightsPrimitive();
// TODO(grygielski) why no bias_diff?
conv_bwd_weights_p->execute(
astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4filter_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
{MKLDNN_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait();
......@@ -1073,10 +1116,12 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
// For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data
if (g > 1) {
memory::data_type in_type =
framework::ToMKLDNNDataType(filter_grad->type());
memory::data_type in_type = framework::ToMKLDNNDataType(filter->type());
// for 3d conv with groups (six dimensional data reorder to goidhw)
// for 2d conv with groups (five dimensional data reorder to goihw)
// auto weights_tz = paddle::framework::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims();
mkldnn::memory::format_tag out_format =
weights_tz.size() == 6 ? mkldnn::memory::format_tag::goidhw
: mkldnn::memory::format_tag::goihw;
......@@ -1084,9 +1129,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
out_format, in_type);
key = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
platform::ReorderMKLDNNHandler handler(weights_tz, filter_grad->type(),
in_type, dev_ctx, mkldnn_engine,
key);
platform::ReorderMKLDNNHandler handler(
weights_tz, filter->type(), in_type, dev_ctx, mkldnn_engine, key);
auto reorder_dst_memory_p =
handler.AcquireDstMemory(filter_grad, out_format, ctx.GetPlace());
......@@ -1113,24 +1157,21 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
}
}
if (input_grad) {
auto weights_memory_p = handler.AcquireWeightsMemoryFromDataPrimitive(
user_weights_memory_p, pipeline);
auto diff_dst_memory_4data_p =
handler.AcquireDiffDstMemoryFromDataPrimitive(user_diff_dst_memory_p,
pipeline);
const size_t size = handler.GetDiffSourceMemorySize();
input_grad_data = input_grad->mutable_data<T>(ctx.GetPlace(), size);
auto weights_memory_p =
handler.AcquireWeightsMemoryWithReorderFromDataPrimitive(
filter, ctx.Attr<int>("groups"),
ctx.Attr<std::vector<int>>("strides").size() == 3U);
auto diff_src_memory_p = handler.AcquireDiffSrcMemoryFromDataPrimitive(
reinterpret_cast<void*>(input_grad_data));
auto diff_dst_memory_p =
handler.AcquireDiffDstMemoryWithReorderMemoryFromDataPrimitive(
output_grad);
auto diff_src_memory_p = handler.AcquireDiffSrcMemory(input_grad);
auto conv_bwd_data_p = handler.AcquireConvolutionBackwardData();
auto conv_bwd_data_p = handler.AcquireBackwardPrimitive();
conv_bwd_data_p->execute(astream,
{{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_4data_p},
{MKLDNN_ARG_DIFF_DST, *diff_dst_memory_p},
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait();
......@@ -1167,7 +1208,7 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv2d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>);
ops::ConvMKLDNNGradOpKernel<float, float>);
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
::paddle::platform::CPUPlace, FP32,
......@@ -1177,4 +1218,4 @@ REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d, MKLDNN,
REGISTER_OP_KERNEL_WITH_CUSTOM_TYPE(conv3d_grad, MKLDNN,
::paddle::platform::CPUPlace, FP32,
ops::kConvMKLDNNFP32,
ops::ConvMKLDNNGradOpKernel<float>);
ops::ConvMKLDNNGradOpKernel<float, float>);
......@@ -35,7 +35,8 @@ using user_function = std::function<std::shared_ptr<float>(const float*)>;
using memory = mkldnn::memory;
template <typename T, typename TForward,
typename TBackward = mkldnn_dummy_primitive>
typename TBackward = mkldnn_dummy_primitive,
typename TBackward_params = mkldnn_dummy_primitive>
class MKLDNNHandlerT {
public:
MKLDNNHandlerT(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
......@@ -72,6 +73,21 @@ class MKLDNNHandlerT {
return backward_p;
}
std::shared_ptr<TBackward_params> AcquireBackwardWeightsPrimitive() {
const std::string key_p = key_ + "@bwd_w_p";
auto backward_p =
std::static_pointer_cast<TBackward_params>(dev_ctx_.GetBlob(key_p));
if (backward_p == nullptr) {
PADDLE_ENFORCE_NOT_NULL(bwd_w_pd_, platform::errors::Unavailable(
"Error: BWD_PD should be set when "
"getting BWD prim witk key: %s .",
key_p));
backward_p = std::make_shared<TBackward_params>(*bwd_w_pd_);
dev_ctx_.SetBlob(key_p, backward_p);
}
return backward_p;
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
......@@ -116,6 +132,29 @@ class MKLDNNHandlerT {
"@diff_src_mem_p");
}
// Buffer of given Tensor is used for oneDNN computation
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(
framework::Tensor* diff_weights) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
T* ptr = diff_weights->mutable_data<T>(
place_, bwd_w_pd_->diff_weights_desc().get_size());
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(), ptr,
"@diff_wei_mem_p");
}
// Buffer is allocated by oneDNN to store computation results
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemory(void) {
PADDLE_ENFORCE_NOT_NULL(
bwd_w_pd_,
platform::errors::Unavailable(
"Error: BWD_W_PD should be set when getting BWD grad of weights."));
return this->AcquireMemoryFromPrimitive(bwd_w_pd_->diff_weights_desc(),
"@diff_wei_mem_p");
}
protected:
bool isCached() {
const std::string key_pd = key_common_ + "@fwd_pd";
......@@ -243,6 +282,27 @@ class MKLDNNHandlerT {
}
}
template <typename... Args>
void AcquireBackwardWeightsPrimitiveDescriptorNonBlocking(Args&&... args) {
// fwd_pd_ is set during grad by calling
// AcquireForwardPrimitiveDescriptorNonBlocking
PADDLE_ENFORCE_NOT_NULL(
fwd_pd_,
platform::errors::Unavailable("Get MKLDNN Forward primitive %s failed.",
key_ + "@fwd_pd"));
const std::string key_pd = key_ + "@bwd_w_pd";
bwd_w_pd_ =
std::static_pointer_cast<typename TBackward_params::primitive_desc>(
dev_ctx_.GetBlob(key_pd));
if (bwd_w_pd_ == nullptr) {
auto bwd_desc =
typename TBackward_params::desc(std::forward<Args>(args)...);
bwd_w_pd_ = std::make_shared<typename TBackward_params::primitive_desc>(
bwd_desc, engine_, *fwd_pd_);
dev_ctx_.SetBlob(key_pd, bwd_w_pd_);
}
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
const std::string& suffix) {
return std::static_pointer_cast<mkldnn::memory>(
......@@ -370,6 +430,7 @@ class MKLDNNHandlerT {
std::string key_;
std::shared_ptr<typename TForward::primitive_desc> fwd_pd_;
std::shared_ptr<typename TBackward::primitive_desc> bwd_pd_;
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
};
// TODO(grygielski) this class will be deleted later.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册