未验证 提交 bd0b38e6 编写于 作者: A Adam 提交者: GitHub

Refactor of conv fp32 oneDNN operator (#25137)

* Refactor of conv fp32 oneDNN operator
test=develop

* Formatting fix
test=develop

* Return Enforces
test=develop

* GetWeights improvements
test=develop
上级 b2f5a149
......@@ -26,42 +26,24 @@ using mkldnn::memory;
using mkldnn::primitive;
using mkldnn::reorder;
using mkldnn::stream;
using platform::to_void_cast;
using platform::GetMKLDNNFormat;
using platform::to_void_cast;
inline void GetWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
int groups, bool is_conv3d) {
const int groups) {
if (groups > 1) {
if (is_conv3d) {
int output = weights_tz[0];
int input = weights_tz[1];
int dimension = weights_tz[2];
int height = weights_tz[3];
int width = weights_tz[4];
weights_tz.resize(6);
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = dimension;
weights_tz[4] = height;
weights_tz[5] = width;
} else {
int output = weights_tz[0];
int input = weights_tz[1];
int height = weights_tz[2];
int width = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = groups;
weights_tz[1] = output / groups;
weights_tz[2] = input;
weights_tz[3] = height;
weights_tz[4] = width;
}
weights_tz[1] = weights_tz[1] / groups;
}
}
inline MKLDNNMemoryFormat GetWeightsFormat(MKLDNNMemoryFormat format,
int groups, bool is_conv3d) {
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
const int groups,
const bool is_conv3d) {
if (is_conv3d) {
return (groups == 1) ? format : MKLDNNMemoryFormat::goidhw;
} else {
......@@ -90,53 +72,29 @@ static mkldnn::memory::data_type GetDstType(bool is_int8,
return dst_dt;
}
template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
template <typename T>
class ConvMKLDNNHandlerT
: public platform::MKLDNNHandlerT<T, mkldnn::convolution_forward> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
if (!is_INT8) {
ComputeFP32(ctx);
} else {
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) {
ComputeINT8<uint8_t>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::s8) {
ComputeINT8<int8_t>(ctx);
}
}
}
void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_EQ(input->layout(), DataLayout::kMKLDNN,
ConvMKLDNNHandlerT(const paddle::framework::ExecutionContext& ctx,
const platform::MKLDNNDeviceContext& dev_ctx,
const mkldnn::engine mkldnn_engine,
platform::Place cpu_place, const Tensor* input,
const Tensor* filter, const Tensor* bias, Tensor* output,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::convolution_forward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The input tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, input->layout()));
PADDLE_ENFORCE_NE(
input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument("Wrong format set for Input tensor"));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
"Wrong format set for Input tensor"));
PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN,
......@@ -147,23 +105,27 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::errors::InvalidArgument(
"Wrong format set for Filter tensor"));
PADDLE_ENFORCE_GE(input->dims().size(), 4,
PADDLE_ENFORCE_GE(
input->dims().size(), 4,
platform::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .",
input->dims().size()));
PADDLE_ENFORCE_LE(input->dims().size(), 5,
PADDLE_ENFORCE_LE(
input->dims().size(), 5,
platform::errors::InvalidArgument(
"Input must be with 4 or 5 dimensions, i.e. NCHW or "
"NCDHW, but got dimension = %d .",
input->dims().size()));
PADDLE_ENFORCE_GE(filter->dims().size(), 4,
PADDLE_ENFORCE_GE(
filter->dims().size(), 4,
platform::errors::InvalidArgument(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d .",
filter->dims().size()));
PADDLE_ENFORCE_LE(filter->dims().size(), 5,
PADDLE_ENFORCE_LE(
filter->dims().size(), 5,
platform::errors::InvalidArgument(
"Filter must be with 4 or 5 dimensions, i.e. OIHW or "
"OIDHW, but got dimension = %d .",
......@@ -179,73 +141,64 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::errors::InvalidArgument(
"Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ(
bias->dims().size(), 1,
platform::errors::InvalidArgument("Bias must only have 1 dimension, "
PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
platform::errors::InvalidArgument(
"Bias must only have 1 dimension, "
"i.e. X, but got dimension = %d .",
bias->dims().size()));
}
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
const bool fuse_residual_conn =
ctx.Attr<bool>("fuse_residual_connection");
const int groups = ctx.Attr<int>("groups");
const std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
const auto input_dims = input->dims();
const auto data_dims =
framework::slice_ddim(input_dims, 2, input_dims.size());
const auto filter_dims = filter->dims();
const auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp));
const auto ksize = framework::vectorize(filter_data_dims);
const bool is_test = ctx.Attr<bool>("is_test");
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm");
bool is_conv3d = strides.size() == 3U;
auto strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp));
auto input_dims = input->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_dims = filter->dims();
auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp));
auto ksize = framework::vectorize(filter_data_dims);
auto dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp),
end(dilations_temp));
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize);
const bool is_conv3d = strides.size() == 3U;
std::vector<primitive> pipeline;
PADDLE_ENFORCE(
PADDLE_ENFORCE_EQ(
is_conv3d
? dilations.size() == 3 && dilations[0] == 1 && dilations[1] == 1 &&
dilations[2] == 1
? dilations.size() == 3 && dilations[0] == 1 &&
dilations[1] == 1 && dilations[2] == 1
: dilations.size() == 2 && dilations[0] == 1 && dilations[1] == 1,
"dilation in convolution is not implemented yet");
true, platform::errors::Unimplemented(
"Dilation in oneDNN convolution is not implemented yet"));
const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>();
const auto src_tz = paddle::framework::vectorize(input->dims());
auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
auto dst_tz = paddle::framework::vectorize(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
GetWeightsTz(weights_tz, groups);
auto src_format = input->format();
MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d);
const auto dst_tz = paddle::framework::vectorize(output->dims());
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), src_format);
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(), weights_format);
const mkldnn::memory::dims stride_dims = strides;
const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
......@@ -255,139 +208,250 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// gradient computation proper as this op is called directly without
// fetch op following it , so numercial grad is computed (in python)
// using block formats which will give wrong results
std::string data_format = ctx.Attr<std::string>("data_format");
const std::string data_format = ctx.Attr<std::string>("data_format");
auto chosen_memory_format =
is_test ? MKLDNNMemoryFormat::any
: platform::data_format_to_memory_format(data_format);
weights_format = MKLDNNMemoryFormat::any;
// Check the format for user's special output
if (chosen_memory_format != MKLDNNMemoryFormat::any) {
if (is_conv3d) {
chosen_memory_format =
platform::MKLDNNFormatForSize(src_tz.size(), chosen_memory_format);
chosen_memory_format = platform::MKLDNNFormatForSize(
src_tz.size(), chosen_memory_format);
}
}
auto src_md = platform::MKLDNNMemDesc(
const auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), weights_format);
std::vector<int64_t> bias_tz;
auto dst_md = platform::MKLDNNMemDesc(
const auto weights_md =
platform::MKLDNNMemDesc(weights_tz, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
const auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
const mkldnn::primitive_attr conv_attr = CreatePostOps(
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn);
if (bias) {
bias_tz = paddle::framework::vectorize(bias->dims());
auto bias_tz = framework::vectorize(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn,
fwd_prop_kind);
this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, bias_md, dst_md, stride_dims,
mkldnn_paddings[0], mkldnn_paddings[1]);
} else {
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, boost::none, dst_md, strides, paddings,
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta,
fuse_residual_conn, fwd_prop_kind);
this->AcquireForwardPrimitiveDescriptor(
conv_attr, fwd_prop_kind, dnnl::algorithm::convolution_direct,
src_md, weights_md, dst_md, stride_dims, mkldnn_paddings[0],
mkldnn_paddings[1]);
}
}
}
// create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p =
handler.AcquireSrcMemory(user_src_md, to_void_cast<T>(input_data));
auto user_weights_memory_p = handler.AcquireWeightsMemory(
user_weights_md, to_void_cast<T>(filter_data));
mkldnn::primitive_attr CreatePostOps(
std::string fuse_activation, float fuse_alpha, float fuse_beta,
bool fuse_residual_conn, const std::vector<float> output_shift_scale = {},
float sum_scale = 1.0f) {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
if (output_shift_scale.size() > 0) {
int mask = output_shift_scale.size() > 1 ? 1 << 1 : 0;
conv_attr.set_output_scales(mask, output_shift_scale);
}
// create reorder primitive if the input format is not the preferred one
auto src_memory_p =
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline);
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive(
user_weights_memory_p, pipeline, is_test);
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(sum_scale);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::shared_ptr<mkldnn::memory> dst_memory_p, user_residual_memory_p;
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) {
const T* input_data = input->data<T>();
auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
if (fuse_residual_conn) {
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto residual_param_data = residual_param->data<T>();
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(), to_void_cast<T>(input_data),
"@src_mem_p");
}
PADDLE_ENFORCE_NE(
residual_param_data, nullptr,
platform::errors::InvalidArgument(
"Provide data if you want MKLDNN conv+elementwise_add fusion"));
PADDLE_ENFORCE_EQ(
output->dims(), residual_param->dims(),
platform::errors::InvalidArgument(
"Output and elementwise parameter need to have the "
"same dimension sizes, "
"but got output's dimension = %d and residual param's dimension "
"= %d .",
output->dims().size(), residual_param->dims().size()));
std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
const framework::Tensor* filter, const int groups, const bool is_conv3d,
const bool is_test) {
// This is workaround to make execution faster, delete
// if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) {
return weights_mem_p;
} else {
const T* filter_data = filter->data<T>();
auto weights_tz = framework::vectorize(filter->dims());
GetWeightsTz(weights_tz, groups);
if (residual_param->format() != handler.GetDstFormat()) {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto residual_data_tz =
paddle::framework::vectorize(residual_param->dims());
auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type());
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(),
GetWeightsFormat(filter->format(), groups, is_conv3d));
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->weights_desc(),
to_void_cast<T>(filter_data), "@weights_mem_p", is_test);
}
}
std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
const framework::Tensor* bias, const bool is_test) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(), to_void_cast<T>(bias_data),
"@bias_mem_p", is_test);
}
std::shared_ptr<mkldnn::memory> AcquireResidualMemory(
const framework::Tensor* residual_param) {
const T* residual_data = residual_param->data<T>();
auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_data_type, residual_param->format());
user_residual_memory_p = handler.AcquireResidualDataMemory(
user_residual_md, to_void_cast<T>(residual_param_data));
framework::vectorize(residual_param->dims()),
framework::ToMKLDNNDataType(residual_param->type()),
residual_param->format());
return this->AcquireMemoryFromPrimitive(user_residual_md,
to_void_cast<T>(residual_data),
"@user_residual_data_mem_p");
}
dst_memory_p = handler.AcquireDstMemoryFromResidualDataMemory(
user_residual_memory_p, to_void_cast<T>(output_data), pipeline);
std::shared_ptr<mkldnn::memory> AcquireDstMemoryWithResidual(
framework::Tensor* output, const framework::Tensor* residual_param) {
std::shared_ptr<dnnl::memory> dst_memory_p;
if (residual_param->format() !=
platform::GetMKLDNNFormat(this->fwd_pd_->dst_desc())) {
auto residual_memory_p = this->AcquireResidualMemory(residual_param);
dst_memory_p = this->AcquireDstMemory(output);
this->AcquireReorder(residual_memory_p, dst_memory_p, "@residual_dst");
} else {
// Changing ShareDataWith to TensorCopy results in performance drop
// on ResNet architectures
// (https://github.com/PaddlePaddle/Paddle/issues/22964)
output->ShareDataWith(*residual_param);
auto output_data = output->mutable_data<T>(ctx.GetPlace());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
dst_memory_p = this->AcquireDstMemory(output);
}
return dst_memory_p;
}
};
template <typename T, typename K>
class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
paddle::platform::errors::PreconditionNotMet(
"Operator DNNL Conv must use CPUPlace"));
bool is_INT8 =
std::is_same<T, int8_t>::value || std::is_same<T, uint8_t>::value;
if (!is_INT8) {
ComputeFP32(ctx);
} else {
auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
dst_memory_p =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
auto residual_param = ctx.Input<Tensor>("ResidualData");
auto dst_dt = GetDstType(true, force_fp32_output, fuse_activation,
fuse_residual_conn, residual_param);
if (dst_dt == mkldnn::memory::data_type::f32) {
ComputeINT8<float>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::u8) {
ComputeINT8<uint8_t>(ctx);
} else if (dst_dt == mkldnn::memory::data_type::s8) {
ComputeINT8<int8_t>(ctx);
}
}
}
void ComputeFP32(const paddle::framework::ExecutionContext& ctx) const {
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto conv_p = handler.AcquireConvolution();
const bool is_test = ctx.Attr<bool>("is_test");
const bool is_conv3d = ctx.Attr<std::vector<int>>("strides").size() == 3U;
const bool fuse_residual_conn = ctx.Attr<bool>("fuse_residual_connection");
mkldnn::stream astream(mkldnn_engine);
if (bias) {
const T* bias_data = bias->data<T>();
auto user_bias_md = platform::MKLDNNMemDesc(
{bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
auto user_bias_memory_p =
handler.AcquireBiasMemory(user_bias_md, to_void_cast<T>(bias_data));
const auto* input = ctx.Input<Tensor>("Input");
const auto* filter = ctx.Input<Tensor>("Filter");
const auto* bias =
ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
ConvMKLDNNHandlerT<T> handler(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
output, ctx.InputName("Input") + ctx.InputName("Filter"));
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_conv3d, is_test);
std::shared_ptr<dnnl::memory> dst_memory_p;
if (fuse_residual_conn) {
auto* residual_param = ctx.Input<Tensor>("ResidualData");
dst_memory_p =
handler.AcquireDstMemoryWithResidual(output, residual_param);
} else {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
dst_memory_p = handler.AcquireDstMemory(output);
}
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
{MKLDNN_ARG_DST, *dst_memory_p}};
if (bias) {
auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
}
mkldnn::stream astream(mkldnn_engine);
conv_p->execute(astream, args);
astream.wait();
output->set_layout(DataLayout::kMKLDNN);
output->set_format(GetMKLDNNFormat(*dst_memory_p));
}
template <typename T_out>
void ComputeINT8(const paddle::framework::ExecutionContext& ctx) const {
const bool is_test = ctx.Attr<bool>("is_test");
......@@ -552,7 +616,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
GetWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output->dims());
PADDLE_ENFORCE_EQ(
......@@ -866,7 +930,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d);
GetWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output_grad->dims());
auto src_format = input->format();
......@@ -879,7 +943,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::CreateKey(
src_tz, ctx.InputName("Input") + ctx.InputName("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
const std::string key_conv_pd = key + "@forward_pd";
std::vector<primitive> pipeline;
// Create user memory descriptors
......
......@@ -210,6 +210,73 @@ class MKLDNNHandlerT {
return mem_p;
}
void AcquireReorder(const std::shared_ptr<mkldnn::memory>& user_memory_p,
const std::shared_ptr<mkldnn::memory>& target_memory_p,
const std::string& suffix) {
const auto key_reorder_p = key_ + suffix + "reorder_p";
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p == nullptr) {
reorder_p =
std::make_shared<mkldnn::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
}
mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false) {
const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user";
auto target_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) {
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) {
target_memory_p = std::make_shared<mkldnn::memory>(target_md, engine_);
auto reorder_p =
std::make_shared<dnnl::reorder>(*user_memory_p, *target_memory_p);
dev_ctx_.SetBlob(key_reorder_p, reorder_p);
mkldnn::stream astream(engine_);
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
} else {
target_memory_p = user_memory_p;
}
dev_ctx_.SetBlob(user_key, user_memory_p);
dev_ctx_.SetBlob(target_key, target_memory_p);
} else if (!is_persistent) {
mkldnn::stream astream(engine_);
auto user_memory_p =
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(user_key));
user_memory_p->set_data_handle(ptr);
auto reorder_p = std::static_pointer_cast<mkldnn::reorder>(
dev_ctx_.GetBlob(key_reorder_p));
if (reorder_p != nullptr) {
reorder_p->execute(astream, {{MKLDNN_ARG_FROM, *user_memory_p},
{MKLDNN_ARG_TO, *target_memory_p}});
astream.wait();
}
}
return target_memory_p;
}
std::shared_ptr<mkldnn::memory> AcquireMemory(const std::string& suffix) {
const auto local_key = key_ + suffix;
return std::static_pointer_cast<mkldnn::memory>(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册