未验证 提交 bd0b38e6 编写于 作者: A Adam 提交者: GitHub

Refactor of conv fp32 oneDNN operator (#25137)

* Refactor of conv fp32 oneDNN operator
test=develop

* Formatting fix
test=develop

* Return Enforces
test=develop

* GetWeights improvements
test=develop
上级 b2f5a149
...@@ -26,42 +26,24 @@ using mkldnn::memory; ...@@ -26,42 +26,24 @@ using mkldnn::memory;
using mkldnn::primitive; using mkldnn::primitive;
using mkldnn::reorder; using mkldnn::reorder;
using mkldnn::stream; using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
using platform::to_void_cast;
inline void GetWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT inline void GetWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
int groups, bool is_conv3d) { const int groups) {
if (groups > 1) { if (groups > 1) {
if (is_conv3d) { // if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
int output = weights_tz[0]; // else [o, i, h, w] -> [g, o/g, i, h, w]
int input = weights_tz[1]; weights_tz.push_back(0);
int dimension = weights_tz[2]; std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
int height = weights_tz[3];
int width = weights_tz[4];
weights_tz.resize(6);
weights_tz[0] = groups; weights_tz[0] = groups;
weights_tz[1] = output / groups; weights_tz[1] = weights_tz[1] / groups;
weights_tz[2] = input;
weights_tz[3] = dimension;
weights_tz[4] = height;
weights_tz[5] = width;
} else {
int output = weights_tz[0];
int input = weights_tz[1];
int height = weights_tz[2];
int width = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = height;
weights_tz[4] = width;
}
} }
} }
inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format, inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
int groups, bool is_conv3d) { const int groups,
const bool is_conv3d) {
if (is_conv3d) { if (is_conv3d) {
return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw; return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
} else { } else {
...@@ -90,53 +72,29 @@ static mkldnn::memory::data_type GetDstType(bool is_int8, ...@@ -90,53 +72,29 @@ static mkldnn::memory::data_type GetDstType(bool is_int8,
return dst_dt; return dst_dt;
} }
template <typename T, typename K> template <typename T>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { class ConvMKLDNNHandlerT
: public platform::MKLDNNHandlerT<T, mkldnn::convolution_forward> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { ConvMKLDNNHandlerT(const paddle::framework::ExecutionContext& ctx,
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, const platform::MKLDNNDeviceContext& dev_ctx,
paddle::platform::errors::PreconditionNotMet( const mkldnn::engine mkldnn_engine,
"Operator DNNL Conv must use CPUPlace")); platform::Place cpu_place, const Tensor* input,
bool is_INT8 = const Tensor* filter, const Tensor* bias, Tensor* output,
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value; const std::string& unique_name)
if (!is_INT8) { : platform::MKLDNNHandlerT<T, mkldnn::convolution_forward>(
ComputeFP32(ctx); dev_ctx, mkldnn_engine, cpu_place,
} else { platform::CreateKey(framework::vectorize(input->dims()),
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation"); unique_name)) {
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection"); if (!this->isCached()) {
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output"); PADDLE_ENFORCE_EQ(
auto residual_param = ctx.Input<Tensor>("ResidualData"); input->layout(), DataLayout::kMKLDNN,
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) {
ComputeINT8<uint8_t>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::s8) {
ComputeINT8<int8_t>(ctx);
}
}
}
void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.", "The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, input->layout())); DataLayout::kMKLDNN, input->layout()));
PADDLE_ENFORCE_NE( PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
input->format(), MKLDNNMemoryFormat::undef, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Wrong format set for Input tensor")); "Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN, filter->layout(), DataLayout::kMKLDNN,
...@@ -147,23 +105,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -147,23 +105,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Wrong format set for Filter tensor")); "Wrong format set for Filter tensor"));
PADDLE_ENFORCE_GE(input->dims().size(), 4, PADDLE_ENFORCE_GE(
input->dims().size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or " "Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .", "NCDHW, but got dimension = %d .",
input->dims().size())); input->dims().size()));
PADDLE_ENFORCE_LE(input->dims().size(), 5, PADDLE_ENFORCE_LE(
input->dims().size(), 5,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or " "Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .", "NCDHW, but got dimension = %d .",
input->dims().size())); input->dims().size()));
PADDLE_ENFORCE_GE(filter->dims().size(), 4, PADDLE_ENFORCE_GE(
filter->dims().size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or " "Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d .", "OIDHW, but got dimension = %d .",
filter->dims().size())); filter->dims().size()));
PADDLE_ENFORCE_LE(filter->dims().size(), 5, PADDLE_ENFORCE_LE(
filter->dims().size(), 5,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or " "Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d .", "OIDHW, but got dimension = %d .",
...@@ -179,73 +141,64 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -179,73 +141,64 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong format for Bias tensor.")); "Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
bias->dims().size(), 1, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Bias must only have 1 dimension, " "Bias must only have 1 dimension, "
"i.e. X, but got dimension = %d .", "i.e. X, but got dimension = %d .",
bias->dims().size())); bias->dims().size()));
} }
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides"); const std::string fuse_activation =
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp)); ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn =
ctx.Attr<bool>("fuse_residual_connection");
const int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings"); const auto input_dims = input->dims();
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp)); const auto data_dims =
framework::slice_ddim(input_dims, 2, input_dims.size());
const auto filter_dims = filter->dims();
const auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations"); const auto ksize = framework::vectorize(filter_data_dims);
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp)); const bool is_test = ctx.Attr<bool>("is_test");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation"); auto strides_temp = ctx.Attr<std::vector<int>>("strides");
float fuse_alpha = ctx.Attr<float>("fuse_alpha"); std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
float fuse_beta = ctx.Attr<float>("fuse_beta");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
bool is_conv3d = strides.size() == 3U;
auto input_dims = input->dims(); auto paddings_temp = ctx.Attr<std::vector<int>>("paddings");
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size()); std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
auto filter_dims = filter->dims();
auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize(filter_data_dims); auto dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp),
end(dilations_temp));
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
const bool is_conv3d = strides.size() == 3U;
std::vector<primitive> pipeline; PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE(
is_conv3d is_conv3d
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 && ? dilations.size() == 3 && dilations[0] == 1 &&
dilations[2] == 1 dilations[1] == 1 && dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1, : dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet"); true, platform::errors::Unimplemented(
"Dilation in oneDNN convolution is not implemented yet"));
const T* input_data = input->data<T>(); const auto src_tz = paddle::framework::vectorize(input->dims());
const T* filter_data = filter->data<T>();
auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); GetWeightsTz(weights_tz, groups);
GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
auto src_format = input->format(); const auto dst_tz = paddle::framework::vectorize(output->dims());
MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
auto user_src_md = platform::MKLDNNMemDesc( const mkldnn::memory::dims stride_dims = strides;
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format); const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
/* create memory descriptor for convolution without specified format /* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose * ('any') which lets a primitive (convolution in this case) choose
...@@ -255,139 +208,250 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -255,139 +208,250 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// gradient computation proper as this op is called directly without // gradient computation proper as this op is called directly without
// fetch op following it , so numercial grad is computed (in python) // fetch op following it , so numercial grad is computed (in python)
// using block formats which will give wrong results // using block formats which will give wrong results
std::string data_format = ctx.Attr<std::string>("data_format"); const std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format = auto chosen_memory_format =
is_test ? MKLDNNMemoryFormat::any is_test ? MKLDNNMemoryFormat::any
: platform::data_format_to_memory_format(data_format); : platform::data_format_to_memory_format(data_format);
weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output // Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) { if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) { if (is_conv3d) {
chosen_memory_format = chosen_memory_format = platform::MKLDNNFormatForSize(
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format); src_tz.size(), chosen_memory_format);
} }
} }
auto src_md = platform::MKLDNNMemDesc( const auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc( const auto weights_md =
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format); platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType<T>(),
std::vector<int64_t> bias_tz; MKLDNNMemoryFormat::any);
auto dst_md = platform::MKLDNNMemDesc( const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key); const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
const mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn);
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize(bias->dims()); auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, this->AcquireForwardPrimitiveDescriptor(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
fwd_prop_kind); src_md, weights_md, bias_md, dst_md, stride_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
} else { } else {
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( this->AcquireForwardPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings, conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, src_md, weights_md, dst_md, stride_dims, mkldnn_paddings[0],
fuse_residual_conn, fwd_prop_kind); mkldnn_paddings[1]);
}
}
} }
// create mkldnn memory from input tensors (data/weights) mkldnn::primitive_attr CreatePostOps(
auto user_src_memory_p = std::string fuse_activation, float fuse_alpha, float fuse_beta,
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data)); bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
auto user_weights_memory_p = handler.AcquireWeightsMemory( float sum_scale = 1.0f) {
user_weights_md, to_void_cast<T>(filter_data)); mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
// create reorder primitive if the input format is not the preferred one // Fusion with Elementwise layer relies on adding a sum post-operation with
auto src_memory_p = // the scale parameter. It is assumed that when fuse_residual_connection is
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); // true, the output tensor contains the data coming from residual
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( // connection. The result of this post_op is:
user_weights_memory_p, pipeline, is_test); // Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p; std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
if (fuse_residual_conn) { return this->AcquireMemoryWithReorder(
auto residual_param = ctx.Input<Tensor>("ResidualData"); user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
auto residual_param_data = residual_param->data<T>(); "@src_mem_p");
}
PADDLE_ENFORCE_NE( std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
residual_param_data, nullptr, const framework::Tensor* filter, const int groups, const bool is_conv3d,
platform::errors::InvalidArgument( const bool is_test) {
"Provide data if you want MKLDNN conv+elementwise_add fusion")); // This is workaround to make execution faster, delete
PADDLE_ENFORCE_EQ( // if statement after including md inside Tensor
output->dims(), residual_param->dims(), auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
platform::errors::InvalidArgument( if (is_test && weights_mem_p) {
"Output and elementwise parameter need to have the " return weights_mem_p;
"same dimension sizes, " } else {
"but got output's dimension = %d and residual param's dimension " const T* filter_data = filter->data<T>();
"= %d .", auto weights_tz = framework::vectorize(filter->dims());
output->dims().size(), residual_param->dims().size())); GetWeightsTz(weights_tz, groups);
if (residual_param->format() != handler.GetDstFormat()) { auto user_src_md = platform::MKLDNNMemDesc(
auto output_data = weights_tz, platform::MKLDNNGetDataType<T>(),
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); GetWeightsFormat(filter->format(), groups, is_conv3d));
auto residual_data_tz =
paddle::framework::vectorize(residual_param->dims()); return this->AcquireMemoryWithReorder(
auto residual_data_type = user_src_md, this->fwd_pd_->weights_desc(),
paddle::framework::ToMKLDNNDataType(residual_param->type()); to_void_cast<T>(filter_data), "@weights_mem_p", is_test);
}
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<T>(bias_data),
"@bias_mem_p", is_test);
}
std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
const framework::Tensor* residual_param) {
const T* residual_data = residual_param->data<T>();
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format()); framework::vectorize(residual_param->dims()),
user_residual_memory_p = handler.AcquireResidualDataMemory( framework::ToMKLDNNDataType(residual_param->type()),
user_residual_md, to_void_cast<T>(residual_param_data)); residual_param->format());
return this->AcquireMemoryFromPrimitive(user_residual_md,
to_void_cast<T>(residual_data),
"@user_residual_data_mem_p");
}
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory( std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline); framework::Tensor* output, const framework::Tensor* residual_param) {
std::shared_ptr<dnnl::memory> dst_memory_p;
if (residual_param->format() !=
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->AcquireDstMemory(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst");
} else { } else {
// Changing ShareDataWith to TensorCopy results in performance drop // Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures // on ResNet architectures
// (https://github.com/PaddlePaddle/Paddle/issues/22964) // (https://github.com/PaddlePaddle/Paddle/issues/22964)
output->ShareDataWith(*residual_param); output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T>(ctx.GetPlace()); dst_memory_p = this->AcquireDstMemory(output);
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
} }
return dst_memory_p;
}
};
template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
if (!is_INT8) {
ComputeFP32(ctx);
} else { } else {
auto output_data = std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
dst_memory_p = bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data)); auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) {
ComputeINT8<uint8_t>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::s8) {
ComputeINT8<int8_t>(ctx);
}
} }
}
void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto conv_p = handler.AcquireConvolution(); const bool is_test = ctx.Attr<bool>("is_test");
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
mkldnn::stream astream(mkldnn_engine); const auto* input = ctx.Input<Tensor>("Input");
if (bias) { const auto* filter = ctx.Input<Tensor>("Filter");
const T* bias_data = bias->data<T>(); const auto* bias =
auto user_bias_md = platform::MKLDNNMemDesc( ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
{bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); auto* output = ctx.Output<Tensor>("Output");
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
auto bias_memory_p = ConvMKLDNNHandlerT<T> handler(
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline); ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
output, ctx.InputName("Input") + ctx.InputName("Filter"));
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p}, auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_conv3d, is_test);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
auto* residual_param = ctx.Input<Tensor>("ResidualData");
dst_memory_p =
handler.AcquireDstMemoryWithResidual(output, residual_param);
} else { } else {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p}, dst_memory_p = handler.AcquireDstMemory(output);
}
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p}, {MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}}); {MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
} }
mkldnn::stream astream(mkldnn_engine);
conv_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p)); output->set_format(GetMKLDNNFormat(*dst_memory_p));
} }
template <typename T_out> template <typename T_out>
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const { void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
...@@ -552,7 +616,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -552,7 +616,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output->dims()); auto dst_tz = paddle::framework::vectorize(output->dims());
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -866,7 +930,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -866,7 +930,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output_grad->dims()); auto dst_tz = paddle::framework::vectorize(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
...@@ -879,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -879,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::CreateKey( const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter")); src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@forward_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
// Create user memory descriptors // Create user memory descriptors
......
...@@ -210,6 +210,73 @@ class MKLDNNHandlerT { ...@@ -210,6 +210,73 @@ class MKLDNNHandlerT {
return mem_p; return mem_p;
} }
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix) {
const auto key_reorder_p = key_ + suffix + "reorder_p";
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false) {
const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user";
auto target_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) {
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
dev_ctx_.SetBlob(user_key, user_memory_p);
dev_ctx_.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
mkldnn::stream astream(engine_);
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) { std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix; const auto local_key = key_ + suffix;
return std::static_pointer_cast<mkldnn::memory>( return std::static_pointer_cast<mkldnn::memory>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册