未验证 提交 e9288340 编写于 作者: A Adam Osewski 提交者: GitHub

[OneDNN] Conv op refactor. (#36252)

* Remove unused header.

* Use ConvMKLDNNHandlerT for conv2d INT8.

* Use absolute module path to import.
上级 dc4d5719
...@@ -23,7 +23,6 @@ limitations under the License. */ ...@@ -23,7 +23,6 @@ limitations under the License. */
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/platform/profiler.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
......
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. /* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License"); Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License. you may not use this file except in compliance with the License.
...@@ -12,27 +12,16 @@ ...@@ -12,27 +12,16 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/data_layout_transform.h" #include <tuple>
#include "paddle/fluid/operators/conv_op.h" #include "paddle/fluid/operators/conv_op.h"
#include "paddle/fluid/platform/cpu_info.h" #include "paddle/fluid/platform/cpu_info.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h" #include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace platform {
class MKLDNNDeviceContext;
} // namespace platform
} // namespace paddle
namespace paddle { namespace paddle {
namespace operators { namespace operators {
namespace {
using framework::DataLayout;
using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format, inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
const int groups, const int groups,
...@@ -78,7 +67,7 @@ class ConvMKLDNNHandlerT ...@@ -78,7 +67,7 @@ class ConvMKLDNNHandlerT
mkldnn::convolution_backward_data, mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights> { mkldnn::convolution_backward_weights> {
public: public:
ConvMKLDNNHandlerT(const paddle::framework::ExecutionContext& ctx, ConvMKLDNNHandlerT(const framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx, const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine, const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input, platform::Place cpu_place, const Tensor* input,
...@@ -92,19 +81,19 @@ class ConvMKLDNNHandlerT ...@@ -92,19 +81,19 @@ class ConvMKLDNNHandlerT
unique_name)) { unique_name)) {
if (!this->isCached()) { if (!this->isCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN, input->layout(), framework::DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.", "The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, input->layout())); framework::DataLayout::kMKLDNN, input->layout()));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong format set for Input tensor")); "Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN, filter->layout(), framework::DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The Filter tensor's layout should be %d, but got %d.", "The Filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout())); framework::DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong format set for Filter tensor")); "Wrong format set for Filter tensor"));
...@@ -137,10 +126,10 @@ class ConvMKLDNNHandlerT ...@@ -137,10 +126,10 @@ class ConvMKLDNNHandlerT
if (bias) { if (bias) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
bias->layout(), DataLayout::kMKLDNN, bias->layout(), framework::DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The Bias tensor's layout should be %d, but got %d.", "The Bias tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, bias->layout())); framework::DataLayout::kMKLDNN, bias->layout()));
PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong format for Bias tensor.")); "Got wrong format for Bias tensor."));
...@@ -188,12 +177,12 @@ class ConvMKLDNNHandlerT ...@@ -188,12 +177,12 @@ class ConvMKLDNNHandlerT
std::transform(dilations.begin(), dilations.end(), dilations.begin(), std::transform(dilations.begin(), dilations.end(), dilations.begin(),
[](int64_t i) { return i - 1; }); [](int64_t i) { return i - 1; });
const auto src_tz = paddle::framework::vectorize(input->dims()); const auto src_tz = framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = framework::vectorize(filter->dims());
platform::GetGroupConvWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
const auto dst_tz = paddle::framework::vectorize(output->dims()); const auto dst_tz = framework::vectorize(output->dims());
const mkldnn::memory::dims stride_dims = strides; const mkldnn::memory::dims stride_dims = strides;
const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings); const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
...@@ -204,29 +193,48 @@ class ConvMKLDNNHandlerT ...@@ -204,29 +193,48 @@ class ConvMKLDNNHandlerT
* the memory format preferred for best performance * the memory format preferred for best performance
*/ */
auto chosen_memory_format = MKLDNNMemoryFormat::any; auto chosen_memory_format = MKLDNNMemoryFormat::any;
auto data_type = mkldnn::memory::data_type::f32; auto data_type = mkldnn::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" || if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value) std::is_same<T_out, platform::bfloat16>::value)
data_type = mkldnn::memory::data_type::bf16; data_type = mkldnn::memory::data_type::bf16;
const auto src_md = mkldnn::memory::desc src_md, weights_md;
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format); if (platform::is_int8<T>()) {
const auto weights_md = platform::MKLDNNMemDesc(weights_tz, data_type, src_md = platform::MKLDNNMemDesc(
MKLDNNMemoryFormat::any); src_tz, framework::ToMKLDNNDataType(input->type()),
chosen_memory_format);
weights_md = platform::MKLDNNMemDesc(
weights_tz, mkldnn::memory::data_type::s8, chosen_memory_format);
} else {
src_md =
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
weights_md = platform::MKLDNNMemDesc(weights_tz, data_type,
MKLDNNMemoryFormat::any);
}
const auto dst_md = platform::MKLDNNMemDesc( const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
float sum_scale;
std::vector<float> output_shift_scale;
std::tie(sum_scale, output_shift_scale) = get_int8_scales(ctx);
const mkldnn::primitive_attr conv_attr = CreatePostOps( const mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn); fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
output_shift_scale, sum_scale); // for INT8 only!
if (bias) { if (bias) {
auto bias_tz = framework::vectorize(bias->dims()); auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = mkldnn::memory::desc bias_md;
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x); if (platform::is_int8<T>()) {
bias_md = platform::MKLDNNMemDesc(
bias_tz, mkldnn::memory::data_type::s32, MKLDNNMemoryFormat::x);
} else {
bias_md = platform::MKLDNNMemDesc(bias_tz, data_type,
MKLDNNMemoryFormat::x);
}
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct, conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
...@@ -255,28 +263,28 @@ class ConvMKLDNNHandlerT ...@@ -255,28 +263,28 @@ class ConvMKLDNNHandlerT
unique_name)) { unique_name)) {
if (!this->isBwdCached()) { if (!this->isBwdCached()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
in->layout(), DataLayout::kMKLDNN, in->layout(), framework::DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.", "The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, in->layout())); framework::DataLayout::kMKLDNN, in->layout()));
PADDLE_ENFORCE_NE(in->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(in->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong format for Input tensor.")); "Got wrong format for Input tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN, filter->layout(), framework::DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The filter tensor's layout should be %d, but got %d.", "The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout())); framework::DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong format for Filter tensor.")); "Got wrong format for Filter tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
out_grad->layout(), DataLayout::kMKLDNN, out_grad->layout(), framework::DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The output_grad tensor's layout should be %d, but got %d.", "The output_grad tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, out_grad->layout())); framework::DataLayout::kMKLDNN, out_grad->layout()));
PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(out_grad->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong format set for output_grad tensor")); "Wrong format set for output_grad tensor"));
...@@ -296,28 +304,25 @@ class ConvMKLDNNHandlerT ...@@ -296,28 +304,25 @@ class ConvMKLDNNHandlerT
std::vector<int64_t> dilations(begin(dilations_temp), std::vector<int64_t> dilations(begin(dilations_temp),
end(dilations_temp)); end(dilations_temp));
std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
int groups = ctx.Attr<int>("groups");
auto input_dims = in->dims(); auto input_dims = in->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size()); auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims(); auto filter_dims = filter->dims();
auto filter_data_dims = auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize(filter_data_dims); auto ksize = framework::vectorize(filter_data_dims);
std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
auto src_tz = framework::vectorize(in->dims()); auto src_tz = framework::vectorize(in->dims());
auto weights_tz = framework::vectorize(filter->dims()); auto weights_tz = framework::vectorize(filter->dims());
int groups = ctx.Attr<int>("groups");
int g = std::max(groups, 1); int g = std::max(groups, 1);
platform::GetGroupConvWeightsTz(weights_tz, g); platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(out_grad->dims()); auto dst_tz = framework::vectorize(out_grad->dims());
/* create memory descriptor for conv backward without specified format /* create memory descriptor for conv backward without specified format
* ('any') which lets a primitive (conv backward in this case) choose * ('any') which lets a primitive (conv backward in this case) choose
...@@ -349,8 +354,14 @@ class ConvMKLDNNHandlerT ...@@ -349,8 +354,14 @@ class ConvMKLDNNHandlerT
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
if (bias) { if (bias) {
auto bias_tz = framework::vectorize(bias->dims()); auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( mkldnn::memory::desc bias_md;
bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x); if (platform::is_int8<T>()) {
bias_md = platform::MKLDNNMemDesc(
bias_tz, mkldnn::memory::data_type::s32, MKLDNNMemoryFormat::x);
} else {
bias_md = platform::MKLDNNMemDesc(
bias_tz, mkldnn::memory::data_type::f32, MKLDNNMemoryFormat::x);
}
this->AcquireForwardPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
conv_attr, mkldnn::prop_kind::forward_training, conv_attr, mkldnn::prop_kind::forward_training,
...@@ -377,6 +388,71 @@ class ConvMKLDNNHandlerT ...@@ -377,6 +388,71 @@ class ConvMKLDNNHandlerT
} }
} }
std::tuple<float, std::vector<float>> get_int8_scales(
const framework::ExecutionContext& ctx) const {
const auto* filter = ctx.Input<Tensor>("Filter");
const auto& weights_tz = framework::vectorize(filter->dims());
const bool& force_fp32_output = ctx.Attr<bool>("force_fp32_output");
const bool& fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
const int groups = std::max(ctx.Attr<int>("groups"), 1);
const auto& scale_in_data = ctx.Attr<float>("Scale_in");
const auto& scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
bool is_multi_channel = scale_weights_data.size() > 1;
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
int count =
is_multi_channel
? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
: 1;
std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 50)
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0)
// weights data will contain 0 in some models, then weights
// scale couldn't be calculated
output_shift_scale[i] = scale_out_data;
else
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
return std::make_tuple(sum_scale, output_shift_scale);
}
std::tuple<float, std::vector<float>> get_int8_bias_scales(
const framework::ExecutionContext& ctx) const {
const auto* filter = ctx.Input<Tensor>("Filter");
const auto& weights_tz = framework::vectorize(filter->dims());
const int groups = std::max(ctx.Attr<int>("groups"), 1);
const auto& scale_weights_data =
ctx.Attr<std::vector<float>>("Scale_weights");
const auto& scale_in_data = ctx.Attr<float>("Scale_in");
bool is_multi_channel = scale_weights_data.size() > 1;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count =
is_multi_channel
? (groups > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
: 1;
std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 50)
for (int i = 0; i < count; i++) {
scale_bias_data[i] = scale_in_data * scale_weights_data[i];
}
return std::make_tuple(mask_reorder, scale_bias_data);
}
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta, std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {}, bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
...@@ -433,7 +509,7 @@ class ConvMKLDNNHandlerT ...@@ -433,7 +509,7 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, this->bwd_pd_->weights_desc(), user_src_md, this->bwd_pd_->weights_desc(),
to_void_cast<K>(filter_data), "@weights_mem_d_p", false); platform::to_void_cast<K>(filter_data), "@weights_mem_d_p", false);
} }
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
...@@ -480,11 +556,11 @@ class ConvMKLDNNHandlerT ...@@ -480,11 +556,11 @@ class ConvMKLDNNHandlerT
framework::vectorize(in_mem->dims()), framework::vectorize(in_mem->dims()),
platform::MKLDNNGetDataType<T>(), in_mem->format()); platform::MKLDNNGetDataType<T>(), in_mem->format());
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_mem_md, mem_md, to_void_cast<T>(in_mem_data), key_mem); user_mem_md, mem_md, platform::to_void_cast<T>(in_mem_data), key_mem);
} else { } else {
const std::string target_key_suffix{key_mem_target}; const std::string target_key_suffix{key_mem_target};
const auto target_mem_p = this->AcquireMemory(target_key_suffix); const auto target_mem_p = this->AcquireMemory(target_key_suffix);
user_mem_p->set_data_handle(to_void_cast<T>(in_mem_data)); user_mem_p->set_data_handle(platform::to_void_cast<T>(in_mem_data));
if (user_mem_p != target_mem_p) { if (user_mem_p != target_mem_p) {
this->AcquireReorder(user_mem_p, target_mem_p, key_mem); this->AcquireReorder(user_mem_p, target_mem_p, key_mem);
} }
...@@ -494,7 +570,8 @@ class ConvMKLDNNHandlerT ...@@ -494,7 +570,8 @@ class ConvMKLDNNHandlerT
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int groups, const bool is_conv3d, const framework::Tensor* filter, const int groups, const bool is_conv3d,
const bool is_test) { const bool is_test, const std::vector<float>& scale_data = {1.0f},
int mask = 0) {
// This is workaround to make execution faster, delete // This is workaround to make execution faster, delete
// if statement after including md inside Tensor // if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target"); auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
...@@ -511,12 +588,14 @@ class ConvMKLDNNHandlerT ...@@ -511,12 +588,14 @@ class ConvMKLDNNHandlerT
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->weights_desc(), user_src_md, this->fwd_pd_->weights_desc(),
to_void_cast<K>(filter_data), "@weights_mem_p", is_test); platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test, {},
scale_data, mask);
} }
} }
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) { const framework::Tensor* bias, const bool is_test,
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target"); auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
if (is_test && bias_mem_p) { if (is_test && bias_mem_p) {
return bias_mem_p; return bias_mem_p;
...@@ -527,8 +606,9 @@ class ConvMKLDNNHandlerT ...@@ -527,8 +606,9 @@ class ConvMKLDNNHandlerT
MKLDNNMemoryFormat::x); MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder( return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<K>(bias_data), user_bias_md, this->fwd_pd_->bias_desc(),
"@bias_mem_p", is_test); platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test, {},
scale_data, mask);
} }
} }
...@@ -536,8 +616,8 @@ class ConvMKLDNNHandlerT ...@@ -536,8 +616,8 @@ class ConvMKLDNNHandlerT
const framework::Tensor* residual_param) { const framework::Tensor* residual_param) {
void* residual_data = void* residual_data =
residual_param->type() == framework::DataTypeTrait<T_out>::DataType() residual_param->type() == framework::DataTypeTrait<T_out>::DataType()
? to_void_cast<T_out>(residual_param->data<T_out>()) ? platform::to_void_cast<T_out>(residual_param->data<T_out>())
: to_void_cast<T>(residual_param->data<T>()); : platform::to_void_cast<T>(residual_param->data<T>());
auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p"); auto residual_mem_p = this->AcquireMemory("@user_residual_data_mem_p");
if (residual_mem_p) { if (residual_mem_p) {
residual_mem_p->set_data_handle(residual_data); residual_mem_p->set_data_handle(residual_data);
...@@ -572,12 +652,14 @@ class ConvMKLDNNHandlerT ...@@ -572,12 +652,14 @@ class ConvMKLDNNHandlerT
} }
}; };
} // anonymous namespace
template <typename T, typename K> template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Operator DNNL Conv must use CPUPlace")); "Operator DNNL Conv must use CPUPlace"));
bool is_INT8 = bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
...@@ -607,9 +689,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -607,9 +689,9 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
template <typename T_out> template <typename T_out>
void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const { void ComputeFP32(const framework::ExecutionContext& ctx) const {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -656,407 +738,112 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -656,407 +738,112 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_p->execute(astream, args); conv_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T_out> template <typename T_out>
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine(); const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input"); const std::string& fuse_activation =
auto* output = ctx.Output<Tensor>("Output"); ctx.Attr<std::string>("fuse_activation");
const bool& fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN, const bool& force_fp32_output = ctx.Attr<bool>("force_fp32_output");
platform::errors::InvalidArgument( const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
"The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, input->layout()));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Input tensor."));
PADDLE_ENFORCE_GE(input->dims().size(), 4,
platform::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .",
input->dims().size()));
PADDLE_ENFORCE_LE(input->dims().size(), 5,
platform::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .",
input->dims().size()));
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool unsigned_output = bool unsigned_output =
(fuse_activation == "relu" || fuse_activation == "relu6"); (fuse_activation == "relu" || fuse_activation == "relu6");
const T* input_data = input->data<T>();
auto src_tz = paddle::framework::vectorize(input->dims());
mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type());
std::string key =
platform::CreateKey(dev_ctx, src_tz, src_dt,
ctx.InputName("Input") + ctx.InputName("Filter"));
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
// This is workaround for hacky implementation
// of conv int8 mkl-dnn. Once conv fp32 and conv int8
// are merged/unified, this will disappear
auto key_tid = platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, key);
const std::string key_conv_pd = key_tid + "@conv_pd";
auto prim_key = key_tid + "@conv_p";
auto dst_key = key_tid + "@dst_mem_p";
auto src_key = key_tid + "@src_mem_p";
auto weights_key = key_tid + "@weights_mem_p";
auto bias_key = key_tid + "@bias_mem_p";
auto user_src_key = key_tid + "@user_src_mem_p";
auto user_residual_key = key_tid + "@user_residual_data_mem_p";
auto src_reorder_key = key_tid + "@src_mem_preorder_p";
auto residual_reorder_key = key_tid + "@residual_data_mem_preorder_p";
conv_pd =
std::static_pointer_cast<mkldnn::convolution_forward::primitive_desc>(
dev_ctx.GetBlob(key_conv_pd));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); PADDLE_ENFORCE_NE(
is_conv3d, true,
platform::errors::Unimplemented(
"OneDNN int8 convolution does not support 3D inputs currently"));
PADDLE_ENFORCE_EQ(
fuse_residual_conn && force_fp32_output, false,
platform::errors::Unimplemented(
"residual fusion does not support force output with fp32"));
if (conv_pd == nullptr || !is_test) { auto* input = ctx.Input<Tensor>("Input");
float fuse_alpha = ctx.Attr<float>("fuse_alpha"); auto* filter = ctx.Input<Tensor>("Filter");
float fuse_beta = ctx.Attr<float>("fuse_beta"); auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); auto* output = ctx.Output<Tensor>("Output");
auto* filter = ctx.Input<Tensor>("Filter"); ConvMKLDNNHandlerT<T, K, T_out> handler(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
output, ctx.InputName("Input") + ctx.InputName("Filter"));
PADDLE_ENFORCE_EQ( auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Filter tensor."));
PADDLE_ENFORCE_GE(filter->dims().size(), 4, const auto& scale_weights_data =
platform::errors::InvalidArgument( ctx.Attr<std::vector<float>>("Scale_weights");
"Filter must be with 4 or 5 dimensions, i.e. OIHW " const bool is_multi_channel = scale_weights_data.size() > 1;
"or OIDHW, but got dimensions = %d .", const int& groups = ctx.Attr<int>("groups");
filter->dims().size())); const bool& is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_LE(filter->dims().size(), 5, int mask_reorder =
platform::errors::InvalidArgument( is_multi_channel ? ((groups != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
"Filter must be with 4 or 5 dimensions, i.e. OIHW " auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
"or OIDHW, but got dimensions = %d .", filter, groups, false, is_test, scale_weights_data, mask_reorder);
filter->dims().size()));
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
auto* residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
!fuse_residual_conn || !force_fp32_output, true, output->dims(), residual_param->dims(),
platform::errors::Unimplemented( platform::errors::InvalidArgument(
"residual fusion does not support force output with fp32")); "Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d"
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr; " and residual param's dimension =%d .",
output->dims().size(), residual_param->dims().size()));
if (bias) {
PADDLE_ENFORCE_EQ(
bias->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The bias tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, bias->layout()));
PADDLE_ENFORCE_NE(bias->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
platform::errors::InvalidArgument(
"Bias must only have 1 dimension, i.e. X, but "
"got dimension = %d .",
bias->dims().size()));
}
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp),
end(dilations_temp));
std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
bool is_conv3d = strides.size() == 3U;
PADDLE_ENFORCE_NE(is_conv3d, true,
platform::errors::Unimplemented(
"int8 does not support conv3d currently"));
auto input_dims = input->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims();
auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize);
int groups = ctx.Attr<int>("groups");
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output->dims());
std::transform(dilations.begin(), dilations.end(), dilations.begin(),
[](int64_t i) { return i - 1; });
const K* filter_data = filter->data<K>();
auto scale_in_data = ctx.Attr<float>("Scale_in");
auto scale_in_eltwise_data = ctx.Attr<float>("Scale_in_eltwise");
auto scale_weights_data = ctx.Attr<std::vector<float>>("Scale_weights");
auto scale_out_data =
force_fp32_output ? 1.0f : ctx.Attr<float>("Scale_out");
float sum_scale =
fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f;
bool is_multi_channel = scale_weights_data.size() > 1;
int count = is_multi_channel ? (g > 1 ? (weights_tz)[1] * (weights_tz)[0]
: (weights_tz)[0])
: 1;
std::vector<float> output_shift_scale(count);
#pragma omp parallel for if (count > 1)
for (int i = 0; i < count; i++) {
if (scale_weights_data[i] == 0.0)
output_shift_scale[i] =
scale_out_data; // weights data will contain 0
// in some models, then weights
// scale couldn't be calculated
else
output_shift_scale[i] =
static_cast<float>(static_cast<double>(scale_out_data) /
(static_cast<double>(scale_in_data) *
static_cast<double>(scale_weights_data[i])));
}
auto user_src_md =
platform::MKLDNNMemDesc({src_tz}, src_dt, input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<K>(),
((g) == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
auto chosen_memory_format = MKLDNNMemoryFormat::any;
std::vector<int64_t> bias_tz;
auto src_md =
platform::MKLDNNMemDesc(src_tz, src_dt, chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, memory::data_type::s8, chosen_memory_format);
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
handler.reset(
new platform::ConvMKLDNNHandler(dev_ctx, mkldnn_engine, key));
// create a conv primitive descriptor and save it for usage in backward
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
if (bias) {
bias_tz = paddle::framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
MKLDNNMemoryFormat::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, dilations, paddings,
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, propagation, output_shift_scale, sum_scale);
} else {
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, paddle::none, dst_md, strides, dilations,
paddings, mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, propagation, output_shift_scale, sum_scale);
}
// create mkldnn memory from input tensors (data/weights)
user_src_memory_p =
handler->AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler->AcquireWeightsMemory(
user_weights_md, to_void_cast<K>(filter_data));
// create reorder primitive if the input format is not the preferred one
src_memory_p =
handler->AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
std::shared_ptr<mkldnn::memory> weights_memory_p;
int mask_reorder =
is_multi_channel ? ((g != 1) ? (1 << 1) + (1 << 0) : 1 << 0) : 0;
weights_memory_p = handler->AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test, true, scale_weights_data,
mask_reorder);
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
PADDLE_ENFORCE_EQ(
output->dims(), residual_param->dims(),
platform::errors::InvalidArgument(
"Output and elementwise parameter need to have the "
"same dimension sizes, but got output's dimension = %d"
" and residual param's dimension =%d .",
output->dims().size(), residual_param->dims().size()));
auto residual_dt =
paddle::framework::ToMKLDNNDataType(residual_param->type());
if (residual_param->format() != handler->GetDstFormat()) {
auto residual_data_tz =
paddle::framework::vectorize(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_dt, residual_param->format());
dst_memory_p = platform::SetDstMemory<T_out>(
ctx, output, residual_param, user_residual_md, handler,
&pipeline);
} else {
output->ShareDataWith(*residual_param);
dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
}
need_s8_to_u8 =
(platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
unsigned_output;
} else {
dst_memory_p = platform::SetDstMemory<T_out>(ctx, output, handler);
}
// create convolution op primitive
conv_p = handler->AcquireConvolution();
if (bias) {
const K* bias_data = bias->data<K>();
auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<K>(), MKLDNNMemoryFormat::x);
auto user_bias_memory_p = handler->AcquireBiasMemory(
user_bias_md, to_void_cast<K>(bias_data));
std::shared_ptr<mkldnn::memory> bias_memory_p;
int mask_reorder = is_multi_channel ? 1 << 0 : 1;
int count =
is_multi_channel
? (g > 1 ? (weights_tz)[1] * (weights_tz)[0] : (weights_tz)[0])
: 1;
std::vector<float> scale_bias_data(count);
#pragma omp parallel for if (count > 1)
for (int i = 0; i < count; i++) {
scale_bias_data[i] = scale_in_data * scale_weights_data[i];
}
bias_memory_p = handler->AcquireBiasMemoryFromPrimitive(
user_bias_memory_p, pipeline, is_test, true, scale_bias_data,
mask_reorder);
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
}
} else {
auto src_memory_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx.GetBlob(src_reorder_key));
src_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(src_key));
if (src_memory_reorder_p) {
user_src_memory_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(user_src_key));
user_src_memory_p->set_data_handle(to_void_cast<T>(input_data));
{
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
src_memory_reorder_p->execute(astream, *user_src_memory_p,
*src_memory_p);
astream.wait();
}
} else if (src_memory_p) {
src_memory_p->set_data_handle(to_void_cast<T>(input_data));
}
auto weights_memory_p = std::static_pointer_cast<mkldnn::memory>(
dev_ctx.GetBlob(weights_key));
dst_memory_p = dst_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(dst_key)); handler.AcquireDstMemoryWithResidual(output, residual_param);
conv_p = std::static_pointer_cast<mkldnn::convolution_forward>( need_s8_to_u8 = (platform::MKLDNNGetDataType<T_out>() ==
dev_ctx.GetBlob(prim_key)); mkldnn::memory::data_type::s8) &&
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, unsigned_output;
mkldnn_engine, key)); } else {
dst_memory_p = handler.template AcquireDstMemory<T_out>(output);
if (fuse_residual_conn) { }
auto residual_param = ctx.Input<Tensor>("ResidualData");
output->ShareDataWith(*residual_param);
need_s8_to_u8 =
(platform::MKLDNNGetDataType<T_out>() == memory::data_type::s8) &&
unsigned_output;
}
platform::SetDstMemoryHandler<T_out>(ctx, output, handler, dst_memory_p);
auto residual_reorder_p = std::static_pointer_cast<mkldnn::reorder>( auto conv_p = handler.AcquireForwardPrimitive();
dev_ctx.GetBlob(residual_reorder_key));
if (residual_reorder_p) { std::unordered_map<int, dnnl::memory> args = {
auto user_residual_data_p = std::static_pointer_cast<mkldnn::memory>( {MKLDNN_ARG_SRC, *src_memory_p},
dev_ctx.GetBlob(user_residual_key)); {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{ {MKLDNN_ARG_DST, *dst_memory_p}};
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
residual_reorder_p->execute(astream, *user_residual_data_p,
*dst_memory_p);
astream.wait();
}
}
auto bias_memory_p = if (bias) {
std::static_pointer_cast<mkldnn::memory>(dev_ctx.GetBlob(bias_key)); float mask_reorder;
std::vector<float> scale_bias_data;
std::tie(mask_reorder, scale_bias_data) =
handler.get_int8_bias_scales(ctx);
if (bias_memory_p) { auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p}, bias, is_test, scale_bias_data, mask_reorder);
{MKLDNN_ARG_WEIGHTS, *weights_memory_p}, args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
}
} }
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait(); astream.wait();
if (need_s8_to_u8) { if (need_s8_to_u8) {
output->mutable_data<uint8_t>(ctx.GetPlace()); output->mutable_data<uint8_t>(ctx.GetPlace());
} }
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_layout(framework::DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
}; };
template <typename T, typename K> template <typename T, typename K>
class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNGradOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"Operator DNNL ConvGrad must use CPUPlace")); "Operator DNNL ConvGrad must use CPUPlace"));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
...@@ -1105,18 +892,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -1105,18 +892,19 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}}); {MKLDNN_ARG_DIFF_WEIGHTS, *diff_weights_memory_p}});
astream.wait(); astream.wait();
filter_grad->set_layout(DataLayout::kMKLDNN); filter_grad->set_layout(framework::DataLayout::kMKLDNN);
// in OneDNN groups in convolution are treated as separate dimension // in OneDNN groups in convolution are treated as separate dimension
// which is not the case in paddlepaddle // which is not the case in paddlepaddle
auto filter_fmt = GetMKLDNNFormat(*diff_weights_memory_p); auto filter_fmt = platform::GetMKLDNNFormat(*diff_weights_memory_p);
// For convolution with groups convert from blocked to NCHW // For convolution with groups convert from blocked to NCHW
// otherwise there will be problems in next operators working on this data // otherwise there will be problems in next operators working on this data
if (g > 1) { if (g > 1) {
memory::data_type in_type = framework::ToMKLDNNDataType(filter->type()); mkldnn::memory::data_type in_type =
framework::ToMKLDNNDataType(filter->type());
// for 3d conv with groups (six dimensional data reorder to goidhw) // for 3d conv with groups (six dimensional data reorder to goidhw)
// for 2d conv with groups (five dimensional data reorder to goihw) // for 2d conv with groups (five dimensional data reorder to goihw)
// auto weights_tz = paddle::framework::vectorize(filter->dims()); // auto weights_tz = framework::vectorize(filter->dims());
auto weights_tz = diff_weights_memory_p->get_desc().dims(); auto weights_tz = diff_weights_memory_p->get_desc().dims();
mkldnn::memory::format_tag out_format = mkldnn::memory::format_tag out_format =
...@@ -1168,8 +956,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -1168,8 +956,8 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
{MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}}); {MKLDNN_ARG_DIFF_SRC, *diff_src_memory_p}});
astream.wait(); astream.wait();
input_grad->set_layout(DataLayout::kMKLDNN); input_grad->set_layout(framework::DataLayout::kMKLDNN);
input_grad->set_format(GetMKLDNNFormat(*diff_src_memory_p)); input_grad->set_format(platform::GetMKLDNNFormat(*diff_src_memory_p));
} }
} }
}; };
......
...@@ -531,7 +531,13 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) { ...@@ -531,7 +531,13 @@ inline bool HasOpBFLOAT16DataType(const paddle::framework::OpDesc* op) {
inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) { inline bool HasOpFLOAT32DataType(const paddle::framework::OpDesc* op) {
return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32"; return op->GetAttrIfExists<std::string>("mkldnn_data_type") == "float32";
} }
enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP }; enum class RNNReorderType { PP_NTC, PP_TNC, NTC_PP, TNC_PP };
template <typename T>
bool constexpr is_int8() {
return std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
}
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -527,7 +527,8 @@ class MKLDNNHandlerT { ...@@ -527,7 +527,8 @@ class MKLDNNHandlerT {
const mkldnn::memory::desc& user_md, const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false, const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) { std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key_ + suffix + "_target"; const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p"; const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user"; const auto user_key = key_ + suffix + "_user";
...@@ -546,8 +547,17 @@ class MKLDNNHandlerT { ...@@ -546,8 +547,17 @@ class MKLDNNHandlerT {
std::make_shared<dnnl::memory>(user_md, engine_, ptr); std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) { if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_); target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
auto reorder_p = dnnl::reorder::primitive_desc reorder_pdesc;
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p); if (is_int8<T>()) {
dnnl::primitive_attr attr;
attr.set_output_scales(mask, scale_data);
reorder_pdesc = dnnl::reorder::primitive_desc(*user_memory_p,
*target_memory_p, attr);
} else {
reorder_pdesc =
dnnl::reorder::primitive_desc(*user_memory_p, *target_memory_p);
}
auto reorder_p = std::make_shared<dnnl::reorder>(reorder_pdesc);
dev_ctx_.SetBlob(key_reorder_p, reorder_p); dev_ctx_.SetBlob(key_reorder_p, reorder_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream(); auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
...@@ -597,201 +607,6 @@ class MKLDNNHandlerT { ...@@ -597,201 +607,6 @@ class MKLDNNHandlerT {
std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_; std::shared_ptr<typename TBackward_params::primitive_desc> bwd_w_pd_;
}; };
// TODO(grygielski) this class will be deleted later.
class MKLDNNHandler {
public:
MKLDNNHandler(const MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: dev_ctx_(dev_ctx),
engine_(engine),
key_(platform::ExtendKeyWithThreadInfoIfNeeded(dev_ctx, base_key)) {
platform::MKLDNNDeviceContext::tls().log_lib_version();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_diff_dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, void* ptr, const std::string& suffix) {
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemoryFromPrimitive(
mkldnn::memory::desc md, const std::string& suffix) {
const auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(md, engine_);
dev_ctx_.SetBlob(local_key, mem_p);
}
return mem_p;
}
// This incarnation of AcquireMemory can call user function eg. custom reorder
// or preprocessing routine if needed
std::shared_ptr<mkldnn::memory> AcquireMemory(
const mkldnn::memory::desc& md, void* ptr, const std::string& suffix,
user_function custom_func = {}) {
/*Generate key*/
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
// Call custom reorder/preprocessing func if available
if (custom_func) {
auto reordered_data = custom_func(reinterpret_cast<const float*>(ptr));
dev_ctx_.SetBlob(local_key + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::vector<int64_t>& dims, const mkldnn::memory::data_type dtype,
const MKLDNNMemoryFormat& fmt, void* ptr, const std::string& suffix) {
/*Generate key*/
auto local_key = key_ + suffix;
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
auto md = mkldnn::memory::desc(dims, dtype, fmt);
mem_p = std::make_shared<mkldnn::memory>(md, engine_, ptr);
dev_ctx_.SetBlob(local_key, mem_p);
} else {
mem_p->set_data_handle(ptr);
}
return mem_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto stored_reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (stored_reorder_p) {
pipeline.push_back(*stored_reorder_p);
} else {
auto reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(
mkldnn::memory::desc& md, // NOLINT
mkldnn::memory::desc& user_md, // NOLINT
const std::shared_ptr<mkldnn::memory> user_memory_p,
const std::string& suffix,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false, bool is_INT8 = false,
std::vector<float> scale_data = {1.0f}, int mask = 0) {
// create reorder primitive if the input format is not the preferred one
auto local_key = key_ + suffix;
auto key_reorder_p = key_ + suffix + "reorder_p";
auto target_memory_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (target_memory_p == nullptr) {
target_memory_p = user_memory_p;
if (md != user_md) {
target_memory_p = std::make_shared<mkldnn::memory>(md, engine_);
std::shared_ptr<mkldnn::reorder::primitive_desc> reorder_pd;
if (is_INT8) {
mkldnn::primitive_attr
attri; // attribute for int8 weights and bias data reorder.
attri.set_output_scales(mask, scale_data);
reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
new mkldnn::reorder::primitive_desc(*user_memory_p,
*target_memory_p, attri));
} else {
reorder_pd = std::shared_ptr<mkldnn::reorder::primitive_desc>(
new mkldnn::reorder::primitive_desc(*user_memory_p,
*target_memory_p));
}
auto reorder_p =
std::shared_ptr<mkldnn::reorder>(new mkldnn::reorder(*reorder_pd));
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
dev_ctx_.SetBlob(local_key, target_memory_p);
} else if (!is_persistent) {
// Make reorder if needed
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
platform::RecordEvent record_reorder("int_reorder",
platform::EventRole::kUniqueOp);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
}
return target_memory_p;
}
protected:
const MKLDNNDeviceContext& dev_ctx_;
mkldnn::engine engine_;
std::string key_;
};
template <typename T> template <typename T>
class BinaryMKLDNNHandler class BinaryMKLDNNHandler
: public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> { : public platform::MKLDNNHandlerNoCachingT<T, dnnl::binary> {
...@@ -1143,362 +958,6 @@ class ReorderMKLDNNHandler { ...@@ -1143,362 +958,6 @@ class ReorderMKLDNNHandler {
mkldnn::engine engine_; mkldnn::engine engine_;
}; };
template <typename T>
struct convolutional_algorithm;
template <>
struct convolutional_algorithm<mkldnn::convolution_forward> {
static constexpr mkldnn::algorithm T = mkldnn::algorithm::convolution_direct;
};
template <>
struct convolutional_algorithm<mkldnn::deconvolution_forward> {
static constexpr mkldnn::algorithm T =
mkldnn::algorithm::deconvolution_direct;
};
template <class forward_t, class backward_data_t, class backward_weights_t>
class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
public:
ConvMKLDNNTemplateHandler(const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
// TODO(jczaja): remove after conv int8 is adapted
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
conv_pd_ = conv_pd;
}
ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
std::shared_ptr<typename backward_data_t::primitive_desc>
conv_bwd_data_pd,
std::shared_ptr<typename backward_weights_t::primitive_desc>
conv_bwd_weights_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
conv_pd_(conv_pd),
conv_bwd_weights_pd_(conv_bwd_weights_pd),
conv_bwd_data_pd_(conv_bwd_data_pd) {
// If we are in Grad operatgor then update a key with BWD suffix to
// distinguish from FWD memory primitives
key_ += "-BWD";
}
size_t GetDstMemorySize() const { return conv_pd_->dst_desc().get_size(); }
MKLDNNMemoryFormat GetDstFormat() const {
return paddle::platform::GetMKLDNNFormat(conv_pd_->dst_desc());
}
size_t GetDiffWeightsMemorySize() const {
return conv_bwd_weights_pd_->diff_weights_desc().get_size();
}
size_t GetDiffSourceMemorySize() const {
return conv_bwd_data_pd_->diff_src_desc().get_size();
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_bwd_weights_pd_->src_desc();
auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p,
"@weights-src_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromWeightsPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_weights_pd_->diff_dst_desc();
auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@weights-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_desc(), ptr, "@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffWeightsMemoryFromWeightsPrimitive(
void) {
return this->AcquireMemoryFromPrimitive(
conv_bwd_weights_pd_->diff_weights_desc(), "@diff_weights_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = conv_bwd_data_pd_->diff_dst_desc();
auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@data-diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto weights_pd = conv_bwd_data_pd_->weights_desc();
auto user_pd = user_weights_memory_p->get_desc();
return this->AcquireMemory(weights_pd, user_pd, user_weights_memory_p,
"@data-weights_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireResidualDataMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_residual_data_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromResidualDataMemory(
const std::shared_ptr<mkldnn::memory>& user_residual_memory_p,
void* dst_ptr,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
return this->AcquireMemory(user_residual_memory_p,
this->AcquireDstMemoryFromPrimitive(dst_ptr),
"@residual_data_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromDataPrimitive(
void* ptr) {
return this->AcquireMemoryFromPrimitive(conv_bwd_data_pd_->diff_src_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(conv_pd_->dst_desc(), ptr,
"@dst_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto src_pd = conv_pd_->src_desc();
auto user_pd = user_memory_p->get_desc();
return this->AcquireMemory(src_pd, user_pd, user_memory_p, "@src_mem_p",
pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemory(
const mkldnn::memory::desc& md, void* ptr,
user_function custom_func = {}) {
return this->AcquireMemory(md, ptr, "@user_weights_mem_p", custom_func);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemory(
const mkldnn::memory::desc& md, void* ptr) {
return this->AcquireMemory(md, ptr, "@user_bias_mem_p");
}
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_weights_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false, bool is_INT8 = false,
std::vector<float> scale_data = {1.0f}, int mask = 0) {
auto user_weights_pd = user_weights_memory_p->get_desc();
auto weights_pd = conv_pd_->weights_desc();
return this->AcquireMemory(
weights_pd, user_weights_pd, user_weights_memory_p, "@weights_mem_p",
pipeline, is_persistent, is_INT8, scale_data, mask);
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryFromPrimitive(
const std::shared_ptr<mkldnn::memory> user_bias_memory_p,
std::vector<mkldnn::primitive>& pipeline, // NOLINT
bool is_persistent = false, bool is_INT8 = false,
std::vector<float> scale_data = {1.0f},
int mask = 0) { // NOLINT
auto user_bias_pd = user_bias_memory_p->get_desc();
auto bias_pd = conv_pd_->bias_desc();
return this->AcquireMemory(bias_pd, user_bias_pd, user_bias_memory_p,
"@bias_mem_p", pipeline, is_persistent, is_INT8,
scale_data, mask);
}
mkldnn::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::shared_ptr<typename forward_t::primitive_desc>
AcquireConvolutionPrimitiveDescriptor(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
paddle::optional<const mkldnn::memory::desc&> bias,
const mkldnn::memory::desc& dst, const std::vector<int64_t>& strides,
const std::vector<int64_t>& dilations,
const std::vector<int64_t>& paddings, const mkldnn::engine& engine,
const std::string& fuse_activation, float fuse_alpha, float fuse_beta,
const bool fuse_residual_conn, mkldnn::prop_kind fwd_prop_kind,
const std::vector<float> output_shift_scale = {},
const float sum_scale = 1.0f) {
// Conv PD has to be passed to Grad op that
// may be exxecuted by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_conv_pd = key_ + "@conv_pd";
conv_pd_ = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd_ == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims dilations_dims = dilations;
auto mkldnn_paddings = ToMkldnnPadding(paddings);
auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, *bias, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1])
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, dst, stride_dims, dilations_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, output_shift_scale, sum_scale);
conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
}
return conv_pd_;
}
std::shared_ptr<forward_t> AcquireConvolution() {
auto prim_key = key_ + "@conv_p";
auto conv_p =
std::static_pointer_cast<forward_t>(dev_ctx_.GetBlob(prim_key));
if (conv_p == nullptr) {
conv_p = std::make_shared<forward_t>(*conv_pd_);
dev_ctx_.SetBlob(prim_key, conv_p);
}
return conv_p;
}
std::shared_ptr<backward_weights_t> AcquireConvolutionBackwardWeights() {
auto prim_key = key_ + "@conv_bwd_weights_p";
auto conv_bwd_weights_p = std::static_pointer_cast<backward_weights_t>(
dev_ctx_.GetBlob(prim_key));
if (conv_bwd_weights_p == nullptr) {
// create backward conv primitive for weights
conv_bwd_weights_p =
std::make_shared<backward_weights_t>(*conv_bwd_weights_pd_);
dev_ctx_.SetBlob(prim_key, conv_bwd_weights_p);
}
return conv_bwd_weights_p;
}
std::shared_ptr<backward_data_t> AcquireConvolutionBackwardData() {
auto prim_key = key_ + "@conv_bwd_data_p";
auto conv_bwd_data_p =
std::static_pointer_cast<backward_data_t>(dev_ctx_.GetBlob(prim_key));
if (conv_bwd_data_p == nullptr) {
conv_bwd_data_p = std::make_shared<backward_data_t>(*conv_bwd_data_pd_);
dev_ctx_.SetBlob(prim_key, conv_bwd_data_p);
}
return conv_bwd_data_p;
}
private:
std::shared_ptr<typename forward_t::primitive_desc> conv_pd_;
std::shared_ptr<typename backward_weights_t::primitive_desc>
conv_bwd_weights_pd_;
std::shared_ptr<typename backward_data_t::primitive_desc> conv_bwd_data_pd_;
};
using ConvMKLDNNHandler =
ConvMKLDNNTemplateHandler<mkldnn::convolution_forward,
mkldnn::convolution_backward_data,
mkldnn::convolution_backward_weights>;
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
return dst_memory_p;
}
template <typename T>
static std::shared_ptr<mkldnn::memory> SetDstMemory(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const framework::Tensor* residual_param,
const mkldnn::memory::desc& user_residual_md,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::vector<mkldnn::primitive>* pipeline) {
const T* residual_param_data = residual_param->data<T>();
PADDLE_ENFORCE_NOT_NULL(
residual_param_data,
platform::errors::PreconditionNotMet("Residual parameter is required for "
"the DNNL conv+elementwise_add "
"fusion, but now it is missing."));
std::shared_ptr<mkldnn::memory> user_residual_memory_p =
handler->AcquireResidualDataMemory(user_residual_md,
to_void_cast<T>(residual_param_data));
T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::shared_ptr<mkldnn::memory> dst_memory_p =
handler->AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), *pipeline);
return dst_memory_p;
}
template <typename T>
static void SetDstMemoryHandler(
const framework::ExecutionContext& ctx, framework::Tensor* output,
const std::shared_ptr<ConvMKLDNNHandler>& handler,
std::shared_ptr<mkldnn::memory> dst_memory_p) {
T* output_data =
output->mutable_data<T>(ctx.GetPlace(), handler->GetDstMemorySize());
dst_memory_p->set_data_handle(to_void_cast<T>(output_data));
}
template <typename T> template <typename T>
static void SetDstMemoryQuantized( static void SetDstMemoryQuantized(
const framework::ExecutionContext& ctx, framework::Tensor* output, const framework::ExecutionContext& ctx, framework::Tensor* output,
...@@ -1524,5 +983,6 @@ static void SetDstMemoryQuantized( ...@@ -1524,5 +983,6 @@ static void SetDstMemoryQuantized(
dst_memory.reset( dst_memory.reset(
new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data))); new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data)));
} }
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
...@@ -20,7 +20,8 @@ import numpy as np ...@@ -20,7 +20,8 @@ import numpy as np
import paddle import paddle
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest, convert_float_to_uint16, get_numeric_gradient from paddle.fluid.tests.unittests.op_test import (
OpTest, convert_float_to_uint16, get_numeric_gradient)
from paddle.fluid.tests.unittests.testsuite import create_op from paddle.fluid.tests.unittests.testsuite import create_op
from paddle.fluid import Program, program_guard from paddle.fluid import Program, program_guard
......
...@@ -22,7 +22,7 @@ import paddle.nn as nn ...@@ -22,7 +22,7 @@ import paddle.nn as nn
paddle.enable_static() paddle.enable_static()
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle.fluid as fluid import paddle.fluid as fluid
from op_test import OpTest from paddle.fluid.tests.unittests.op_test import OpTest
def conv2dtranspose_forward_naive(input_, filter_, attrs): def conv2dtranspose_forward_naive(input_, filter_, attrs):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册