未验证 提交 caf9d398 编写于 作者: J joanna.wozna.intel 提交者: GitHub

Add Conv Transpose BF16 (#30877)

* Add conv transpose BF16

* Share function GetWeightsTz

* Adjust to review and fix op compatibility

* Add bias to unique handler name

* Remove errors related to paddle enforce

* Add conv2d_transpose to bf16 list and kernel refator
上级 cbbe1274
...@@ -2192,9 +2192,9 @@ PDNode *patterns::Bfloat16Placement::operator()( ...@@ -2192,9 +2192,9 @@ PDNode *patterns::Bfloat16Placement::operator()(
const std::unordered_set<std::string> &bfloat16_enabled_op_types) { const std::unordered_set<std::string> &bfloat16_enabled_op_types) {
std::unordered_set<std::string> supported_op_types = std::unordered_set<std::string> supported_op_types =
std::unordered_set<std::string>( std::unordered_set<std::string>(
{"concat", "conv2d", "elementwise_add", "elementwise_mul", "fc", {"concat", "conv2d", "conv2d_transpose", "elementwise_add",
"fusion_gru", "gelu", "layer_norm", "matmul", "pool2d", "reshape2", "elementwise_mul", "fc", "fusion_gru", "gelu", "layer_norm",
"softmax", "sum", "transpose2"}); "matmul", "pool2d", "reshape2", "softmax", "sum", "transpose2"});
if (!bfloat16_enabled_op_types.empty()) { if (!bfloat16_enabled_op_types.empty()) {
supported_op_types = bfloat16_enabled_op_types; supported_op_types = bfloat16_enabled_op_types;
} }
......
...@@ -160,7 +160,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass, ...@@ -160,7 +160,7 @@ REGISTER_PASS(conv_transpose_bias_mkldnn_fuse_pass,
REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass) REGISTER_PASS_CAPABILITY(conv_transpose_bias_mkldnn_fuse_pass)
.AddCombination( .AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d_transpose", 1) .LE("conv2d_transpose", 2)
.LE("elementwise_add", 1)); .LE("elementwise_add", 1));
REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass, REGISTER_PASS(conv3d_bias_mkldnn_fuse_pass,
......
...@@ -329,7 +329,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass) ...@@ -329,7 +329,7 @@ REGISTER_PASS_CAPABILITY(quant_conv2d_dequant_fuse_pass)
paddle::framework::compatible::OpVersionComparatorCombination() paddle::framework::compatible::OpVersionComparatorCombination()
.LE("conv2d", 1) .LE("conv2d", 1)
.EQ("fc", 0) .EQ("fc", 0)
.LE("conv2d_transpose", 1) .LE("conv2d_transpose", 2)
.EQ("fake_quantize_abs_max", 0) .EQ("fake_quantize_abs_max", 0)
.EQ("fake_quantize_range_abs_max", 0) .EQ("fake_quantize_range_abs_max", 0)
.EQ("fake_quantize_moving_average_abs_max", 0) .EQ("fake_quantize_moving_average_abs_max", 0)
......
...@@ -390,7 +390,7 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass) ...@@ -390,7 +390,7 @@ REGISTER_PASS_CAPABILITY(tensorrt_subgraph_pass)
.LE("elementwise_add", 1) .LE("elementwise_add", 1)
.LE("elementwise_mul", 1) .LE("elementwise_mul", 1)
.EQ("prelu", 0) .EQ("prelu", 0)
.LE("conv2d_transpose", 1) .LE("conv2d_transpose", 2)
.LE("leaky_relu", 1) .LE("leaky_relu", 1)
.EQ("fc", 0) .EQ("fc", 0)
.EQ("shuffle_channel", 0) .EQ("shuffle_channel", 0)
......
...@@ -290,6 +290,15 @@ void Conv2DTransposeOpMaker::Make() { ...@@ -290,6 +290,15 @@ void Conv2DTransposeOpMaker::Make() {
AddAttr<bool>("use_mkldnn", AddAttr<bool>("use_mkldnn",
"(bool, default false) Only used in mkldnn kernel") "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<bool>("force_fp32_output",
"(bool, default false) Force BF16 kernel output FP32, only "
"used in MKL-DNN BF16")
.SetDefault(false);
AddAttr<std::string>(
"mkldnn_data_type",
"(string, default \"float32\"). Data type of mkldnn kernel")
.SetDefault("float32")
.InEnum({"float32", "bfloat16"});
AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel") AddAttr<bool>("fuse_relu", "(bool, default false) Only used in mkldnn kernel")
.SetDefault(false); .SetDefault(false);
AddAttr<std::string>("fuse_activation", AddAttr<std::string>("fuse_activation",
...@@ -671,7 +680,17 @@ REGISTER_OP_VERSION(conv2d_transpose) ...@@ -671,7 +680,17 @@ REGISTER_OP_VERSION(conv2d_transpose)
"output_padding", "output_padding",
"In order to add additional size to one side of each dimension " "In order to add additional size to one side of each dimension "
"in the output", "in the output",
std::vector<int>{})); std::vector<int>{}))
.AddCheckpoint(
R"ROC(
Upgrade conv2d transpose to add a new attributes [force_fp32_output, mkldnn_data_type].
)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("force_fp32_output",
"Force BF16 kernel output FP32, only used in MKL-DNN BF16",
false)
.NewAttr("mkldnn_data_type", "Data type of mkldnn kernel",
"float32"));
REGISTER_OP_VERSION(conv3d_transpose) REGISTER_OP_VERSION(conv3d_transpose)
.AddCheckpoint( .AddCheckpoint(
......
...@@ -33,18 +33,6 @@ using mkldnn::stream; ...@@ -33,18 +33,6 @@ using mkldnn::stream;
using platform::GetMKLDNNFormat; using platform::GetMKLDNNFormat;
using platform::to_void_cast; using platform::to_void_cast;
inline void GetWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
if (groups > 1) {
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups;
}
}
inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format, inline MKLDNNMemoryFormat GetWeightsFormat(const MKLDNNMemoryFormat format,
const int groups, const int groups,
const bool is_conv3d) { const bool is_conv3d) {
...@@ -198,7 +186,7 @@ class ConvMKLDNNHandlerT ...@@ -198,7 +186,7 @@ class ConvMKLDNNHandlerT
const auto src_tz = paddle::framework::vectorize(input->dims()); const auto src_tz = paddle::framework::vectorize(input->dims());
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
GetWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
const auto dst_tz = paddle::framework::vectorize(output->dims()); const auto dst_tz = paddle::framework::vectorize(output->dims());
...@@ -322,7 +310,7 @@ class ConvMKLDNNHandlerT ...@@ -322,7 +310,7 @@ class ConvMKLDNNHandlerT
} else { } else {
const K* filter_data = filter->data<K>(); const K* filter_data = filter->data<K>();
auto weights_tz = framework::vectorize(filter->dims()); auto weights_tz = framework::vectorize(filter->dims());
GetWeightsTz(weights_tz, groups); platform::GetGroupConvWeightsTz(weights_tz, groups);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(), weights_tz, platform::MKLDNNGetDataType<K>(),
...@@ -640,7 +628,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -640,7 +628,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g); platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output->dims()); auto dst_tz = paddle::framework::vectorize(output->dims());
std::transform(dilations.begin(), dilations.end(), dilations.begin(), std::transform(dilations.begin(), dilations.end(), dilations.begin(),
...@@ -959,7 +947,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -959,7 +947,7 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto weights_tz = paddle::framework::vectorize(filter->dims()); auto weights_tz = paddle::framework::vectorize(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g); platform::GetGroupConvWeightsTz(weights_tz, g);
auto dst_tz = paddle::framework::vectorize(output_grad->dims()); auto dst_tz = paddle::framework::vectorize(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
......
...@@ -25,27 +25,40 @@ namespace operators { ...@@ -25,27 +25,40 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
using framework::DataLayout; using framework::DataLayout;
template <typename T> inline mkldnn::memory::dims GetWeightsTz(const Tensor* filter,
class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { const int groups) {
auto iohw_weights_tz = framework::vectorize(filter->dims());
auto weights_tz = iohw_weights_tz;
// IOHW -> OIHW
weights_tz[0] = iohw_weights_tz[1];
weights_tz[1] = iohw_weights_tz[0];
int g = std::max(groups, 1);
platform::GetGroupConvWeightsTz(weights_tz, g);
return weights_tz;
}
template <typename T, typename K, typename T_out>
class ConvTransposeMKLDNNHandlerT
: public platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { ConvTransposeMKLDNNHandlerT(const framework::ExecutionContext& ctx,
PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true, const platform::MKLDNNDeviceContext& dev_ctx,
paddle::platform::errors::PreconditionNotMet( const mkldnn::engine mkldnn_engine,
"Operator DNNL ConvTranspose must use CPUPlace")); platform::Place cpu_place, const Tensor* input,
const Tensor* filter, const Tensor* bias,
Tensor* output, const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::deconvolution_forward>(
dev_ctx, mkldnn_engine, cpu_place,
platform::CreateKey(dev_ctx, framework::vectorize(input->dims()),
unique_name)) {
if (!this->isCached()) {
const bool is_test = ctx.Attr<bool>("is_test"); const bool is_test = ctx.Attr<bool>("is_test");
PADDLE_ENFORCE_EQ(is_test, true, PADDLE_ENFORCE_EQ(is_test, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"ConvTransposeMKLDNN works only for inference. " "ConvTransposeMKLDNN works only for inference. "
"Set is_test = True. but got is_test=False .")); "The attribute \'is_test\' value should be set to "
"True, but got is_test=False."));
auto& dev_ctx =
ctx.template device_context<paddle::platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto* input = ctx.Input<Tensor>("Input");
auto* filter = ctx.Input<Tensor>("Filter");
auto* bias = ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->layout(), DataLayout::kMKLDNN, input->layout(), DataLayout::kMKLDNN,
...@@ -53,7 +66,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -53,7 +66,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"Got wrong layout = %d for Input tensor.", input->layout())); "Got wrong layout = %d for Input tensor.", input->layout()));
PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(input->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong format for Input tensor.")); "Got wrong format for Input tensor. The input "
"format is undefined."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN, filter->layout(), DataLayout::kMKLDNN,
...@@ -66,8 +80,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -66,8 +80,8 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
input->dims().size(), 4, input->dims().size(), 4,
platform::errors::InvalidArgument( platform::errors::InvalidArgument("Input must be with 4 dimensions, "
"Input must be with 4 dimensions, i.e. NCHW. but got dimension =%d", "i.e. NCHW. but got dimension =%d",
input->dims().size())); input->dims().size()));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->dims().size(), 4, filter->dims().size(), 4,
...@@ -85,37 +99,40 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -85,37 +99,40 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Got wrong format for Bias tensor.")); "Got wrong format for Bias tensor."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(bias->dims().size(), 1,
bias->dims().size(), 1, platform::errors::InvalidArgument(
platform::errors::InvalidArgument("Bias must only have 1 dimension, " "Bias must only have 1 dimension, "
"i.e. X, but got dimension = %d .", "i.e. X, but got dimension = %d .",
bias->dims().size())); bias->dims().size()));
} }
std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides"); std::vector<int> strides_temp = ctx.Attr<std::vector<int>>("strides");
std::vector<int64_t> strides(begin(strides_temp), end(strides_temp)); mkldnn::memory::dims strides(begin(strides_temp), end(strides_temp));
std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings"); std::vector<int> paddings_temp = ctx.Attr<std::vector<int>>("paddings");
std::vector<int64_t> paddings(begin(paddings_temp), end(paddings_temp)); mkldnn::memory::dims paddings(begin(paddings_temp), end(paddings_temp));
std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations"); std::vector<int> dilations_temp = ctx.Attr<std::vector<int>>("dilations");
std::vector<int64_t> dilations(begin(dilations_temp), end(dilations_temp)); mkldnn::memory::dims dilations(begin(dilations_temp),
end(dilations_temp));
int groups = ctx.Attr<int>("groups"); int groups = ctx.Attr<int>("groups");
std::string padding_algorithm = ctx.Attr<std::string>("padding_algorithm"); std::string padding_algorithm =
ctx.Attr<std::string>("padding_algorithm");
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
strides.size(), 2, strides.size(), 2,
platform::errors::Unimplemented( platform::errors::Unimplemented(
"Now we only support 2d oneDNN convolution transpose op")); "Now we only support 2d oneDNN convolution transpose op"));
auto input_dims = input->dims(); const auto& input_dims = input->dims();
auto data_dims = framework::slice_ddim(input_dims, 2, input_dims.size()); const auto data_dims =
auto filter_dims = filter->dims(); framework::slice_ddim(input_dims, 2, input_dims.size());
auto filter_data_dims = const auto& filter_dims = filter->dims();
const auto filter_data_dims =
framework::slice_ddim(filter_dims, 2, filter_dims.size()); framework::slice_ddim(filter_dims, 2, filter_dims.size());
auto ksize = framework::vectorize(filter_data_dims); const auto ksize = framework::vectorize(filter_data_dims);
UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm, UpdatePaddingAndDilation(&paddings, &dilations, padding_algorithm,
data_dims, strides, ksize); data_dims, strides, ksize);
...@@ -123,147 +140,224 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -123,147 +140,224 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::transform(dilations.begin(), dilations.end(), dilations.begin(), std::transform(dilations.begin(), dilations.end(), dilations.begin(),
[](int64_t i) { return i - 1; }); [](int64_t i) { return i - 1; });
const auto src_tz = framework::vectorize(input->dims());
const auto weights_tz = GetWeightsTz(filter, groups);
const auto dst_tz = framework::vectorize(output->dims());
const auto mkldnn_paddings = platform::ToMkldnnPadding(paddings);
/* create memory descriptor for convolution without specified format
* ('any') which lets a primitive (convolution in this case) choose
* the memory format preferred for best performance
*/
const auto chosen_memory_format = MKLDNNMemoryFormat::any;
const std::string fuse_activation =
ctx.Attr<std::string>("fuse_activation");
const float fuse_alpha = ctx.Attr<float>("fuse_alpha");
const float fuse_beta = ctx.Attr<float>("fuse_beta");
auto data_type = mkldnn::memory::data_type::f32;
if (ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16" ||
std::is_same<T_out, platform::bfloat16>::value)
data_type = mkldnn::memory::data_type::bf16;
const auto src_md =
platform::MKLDNNMemDesc(src_tz, data_type, chosen_memory_format);
const auto weights_md =
platform::MKLDNNMemDesc(weights_tz, data_type, chosen_memory_format);
const auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T_out>(), chosen_memory_format);
const mkldnn::primitive_attr conv_trans_attr =
CreatePostOps(fuse_activation, fuse_alpha, fuse_beta);
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
if (bias) {
std::vector<int64_t> bias_tz = framework::vectorize(bias->dims());
const auto bias_md =
platform::MKLDNNMemDesc(bias_tz, data_type, MKLDNNMemoryFormat::x);
this->AcquireForwardPrimitiveDescriptor(
conv_trans_attr, fwd_prop_kind,
dnnl::algorithm::deconvolution_direct, src_md, weights_md, bias_md,
dst_md, strides, dilations, mkldnn_paddings[0], mkldnn_paddings[1]);
} else {
this->AcquireForwardPrimitiveDescriptor(
conv_trans_attr, fwd_prop_kind,
dnnl::algorithm::deconvolution_direct, src_md, weights_md, dst_md,
strides, dilations, mkldnn_paddings[0], mkldnn_paddings[1]);
}
}
}
mkldnn::primitive_attr CreatePostOps(const std::string& fuse_activation,
const float& fuse_alpha,
const float& fuse_beta) {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_activation == "relu" || fuse_activation == "leaky_relu") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "relu6") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale,
mkldnn::algorithm::eltwise_bounded_relu,
fuse_alpha, fuse_beta);
} else if (fuse_activation == "swish") {
constexpr float scale = 1.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_swish,
fuse_alpha, fuse_beta);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::shared_ptr<mkldnn::memory> AcquireSrcMemoryWithReorder(
const framework::Tensor* input) {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const std::string user_key_suffix{"@src_mem_p_user"};
auto user_src_mem_p = this->AcquireMemory(user_key_suffix);
if (!user_src_mem_p) {
auto user_src_md = platform::MKLDNNMemDesc(
framework::vectorize(input->dims()), platform::MKLDNNGetDataType<T>(),
input->format());
return this->AcquireMemoryWithReorder(
user_src_md, this->fwd_pd_->src_desc(),
platform::to_void_cast<T>(input_data), "@src_mem_p");
} else {
const std::string target_key_suffix{"@src_mem_p_target"};
const auto target_src_mem_p = this->AcquireMemory(target_key_suffix);
user_src_mem_p->set_data_handle(platform::to_void_cast<T>(input_data));
if (user_src_mem_p != target_src_mem_p) {
this->AcquireReorder(user_src_mem_p, target_src_mem_p, "@src_mem_p");
}
return target_src_mem_p;
}
}
auto src_tz = paddle::framework::vectorize<int64_t>(input->dims()); std::shared_ptr<mkldnn::memory> AcquireWeightsMemoryWithReorder(
auto iohw_weights_tz = const framework::Tensor* filter, const int& groups, const bool& is_test) {
paddle::framework::vectorize<int64_t>(filter->dims()); // This is workaround to make execution faster, delete
auto weights_tz = iohw_weights_tz; // if statement after including md inside Tensor
auto weights_mem_p = this->AcquireMemory("@weights_mem_p_target");
if (is_test && weights_mem_p) {
return weights_mem_p;
} else {
const K* filter_data = filter->data<K>();
auto weights_tz = GetWeightsTz(filter, groups);
int g = std::max(groups, 1);
// IOHW -> OIHW auto user_src_md = platform::MKLDNNMemDesc(
weights_tz[0] = iohw_weights_tz[1]; weights_tz, platform::MKLDNNGetDataType<K>(),
weights_tz[1] = iohw_weights_tz[0]; (g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw);
auto iohw_weights_tz = framework::vectorize(filter->dims());
// Custom Reorder from IOHW to OIHW // Custom Reorder from IOHW to OIHW
auto iohw2oihw_reorder = auto iohw2oihw_reorder =
[&iohw_weights_tz](const T* filter_data) -> std::shared_ptr<T> { [&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> {
int o = iohw_weights_tz[1]; int o = iohw_weights_tz[1];
int c = iohw_weights_tz[0]; int c = iohw_weights_tz[0];
int h = iohw_weights_tz[2]; int h = iohw_weights_tz[2];
int w = iohw_weights_tz[3]; int w = iohw_weights_tz[3];
std::shared_ptr<T> reordered_filter_data(new T[o * c * h * w](), std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](),
std::default_delete<T[]>()); std::default_delete<K[]>());
for (int i = 0; i < c; ++i) { for (int i = 0; i < c; ++i) {
for (int j = 0; j < o; ++j) { for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w; int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w; int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset], std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(T)); &filter_data[in_offset], h * w * sizeof(K));
} }
} }
return reordered_filter_data; return reordered_filter_data;
}; };
int g = std::max(groups, 1); return this->template AcquireMemoryWithReorder<K>(
if (g > 1) { user_src_md, this->fwd_pd_->weights_desc(),
int o = weights_tz[0]; platform::to_void_cast<K>(filter_data), "@weights_mem_p", is_test,
int i = weights_tz[1]; iohw2oihw_reorder);
int h = weights_tz[2]; }
int w = weights_tz[3];
weights_tz.resize(5);
weights_tz[0] = g;
weights_tz[1] = o / g;
weights_tz[2] = i;
weights_tz[3] = h;
weights_tz[4] = w;
} }
auto dst_tz = paddle::framework::vectorize<int64_t>(output->dims());
// Get unique name for storing MKLDNN primitives
const std::string key =
platform::CreateKey(dev_ctx, src_tz, ctx.OutputName("Output"));
std::vector<mkldnn::primitive> pipeline;
auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input->format());
auto user_weights_md = platform::MKLDNNMemDesc(
{weights_tz}, platform::MKLDNNGetDataType<T>(),
(g == 1) ? MKLDNNMemoryFormat::oihw : MKLDNNMemoryFormat::goihw);
/* create memory descriptor for convolution without specified format std::shared_ptr<mkldnn::memory> AcquireBiasMemoryWithReorder(
* ('any') which lets a primitive (convolution in this case) choose const framework::Tensor* bias, const bool& is_test) {
* the memory format preferred for best performance auto bias_mem_p = this->AcquireMemory("@bias_mem_p_target");
*/ if (is_test && bias_mem_p) {
auto chosen_memory_format = MKLDNNMemoryFormat::any; return bias_mem_p;
std::string fuse_activation = ctx.Attr<std::string>("fuse_activation");
float fuse_alpha = ctx.Attr<float>("fuse_alpha");
float fuse_beta = ctx.Attr<float>("fuse_beta");
auto src_md = platform::MKLDNNMemDesc(
src_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
auto weights_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
std::vector<int64_t> bias_tz;
auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
platform::ConvTransposeMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// create a deconv(conv transpose) primitive descriptor and save it for
// usage in backward
std::shared_ptr<mkldnn::deconvolution_forward::primitive_desc>
conv_transpose_pd;
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training;
if (bias) {
bias_tz = paddle::framework::vectorize<int64_t>(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, dilations, paddings,
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, false,
fwd_prop_kind);
} else { } else {
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( const K* bias_data = bias->data<K>();
src_md, weights_md, boost::none, dst_md, strides, dilations, paddings, auto user_bias_md = platform::MKLDNNMemDesc(
mkldnn_engine, fuse_activation, fuse_alpha, fuse_beta, false, framework::vectorize(bias->dims()), platform::MKLDNNGetDataType<K>(),
fwd_prop_kind); MKLDNNMemoryFormat::x);
return this->AcquireMemoryWithReorder(
user_bias_md, this->fwd_pd_->bias_desc(),
platform::to_void_cast<K>(bias_data), "@bias_mem_p", is_test);
} }
}
};
// create mkldnn memory from input tensors (data/weights) template <typename T, typename K>
auto user_src_memory_p = handler.AcquireSrcMemory( class ConvTransposeMKLDNNOpKernel : public framework::OpKernel<T> {
user_src_md, platform::to_void_cast<T>(input_data)); public:
auto user_weights_memory_p = handler.AcquireWeightsMemory( void Compute(const framework::ExecutionContext& ctx) const override {
user_weights_md, platform::to_void_cast<T>(filter_data), PADDLE_ENFORCE_EQ(platform::is_cpu_place(ctx.GetPlace()), true,
is_test ? iohw2oihw_reorder : platform::user_function()); platform::errors::PreconditionNotMet(
"Operator DNNL ConvTranspose must use CPUPlace"));
const bool is_bfloat16 =
ctx.Attr<std::string>("mkldnn_data_type") == "bfloat16";
const bool force_fp32_output = ctx.Attr<bool>("force_fp32_output");
if (is_bfloat16) {
if (force_fp32_output)
Execute<float>(ctx);
else
Execute<platform::bfloat16>(ctx);
} else {
Execute<float>(ctx);
}
}
// create reorder primitive if the input format is not the preferred one template <typename T_out>
auto src_memory_p = void Execute(const framework::ExecutionContext& ctx) const {
handler.AcquireSrcMemoryFromPrimitive(user_src_memory_p, pipeline); auto& dev_ctx =
auto weights_memory_p = handler.AcquireWeightsMemoryFromPrimitive( ctx.template device_context<platform::MKLDNNDeviceContext>();
user_weights_memory_p, pipeline, is_test); const auto& mkldnn_engine = dev_ctx.GetEngine();
auto output_data = const bool is_test = ctx.Attr<bool>("is_test");
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto dst_memory_p = handler.AcquireDstMemoryFromPrimitive(
platform::to_void_cast<T>(output_data));
auto conv_p = handler.AcquireConvolution(); const auto* input = ctx.Input<Tensor>("Input");
const auto* filter = ctx.Input<Tensor>("Filter");
const auto* bias =
ctx.HasInput("Bias") ? ctx.Input<Tensor>("Bias") : nullptr;
auto* output = ctx.Output<Tensor>("Output");
const std::string unique_name = ctx.InputName("Input") +
ctx.InputName("Filter") +
(bias ? ctx.InputName("Bias") : "");
ConvTransposeMKLDNNHandlerT<T, K, T_out> handler(
ctx, dev_ctx, mkldnn_engine, ctx.GetPlace(), input, filter, bias,
output, unique_name);
auto src_memory_p = handler.AcquireSrcMemoryWithReorder(input);
auto weights_memory_p = handler.AcquireWeightsMemoryWithReorder(
filter, ctx.Attr<int>("groups"), is_test);
std::shared_ptr<dnnl::memory> dst_memory_p =
handler.template AcquireDstMemory<T_out>(output);
auto conv_p = handler.AcquireForwardPrimitive();
std::unordered_map<int, dnnl::memory> args = {
{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}};
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
if (bias) { if (bias) {
const T* bias_data = bias->data<T>(); auto bias_memory_p = handler.AcquireBiasMemoryWithReorder(bias, is_test);
auto user_bias_md = platform::MKLDNNMemDesc( args.insert({MKLDNN_ARG_BIAS, *bias_memory_p});
{bias_tz}, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
auto user_bias_memory_p = handler.AcquireBiasMemory(
user_bias_md, platform::to_void_cast<T>(bias_data));
auto bias_memory_p =
handler.AcquireBiasMemoryFromPrimitive(user_bias_memory_p, pipeline);
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_BIAS, *bias_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} else {
conv_p->execute(astream, {{MKLDNN_ARG_SRC, *src_memory_p},
{MKLDNN_ARG_WEIGHTS, *weights_memory_p},
{MKLDNN_ARG_DST, *dst_memory_p}});
} }
auto& astream = platform::MKLDNNDeviceContext::tls().get_stream();
conv_p->execute(astream, args);
astream.wait(); astream.wait();
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
...@@ -274,5 +368,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -274,5 +368,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_KERNEL(conv2d_transpose, MKLDNN, ::paddle::platform::CPUPlace, REGISTER_OP_KERNEL(
ops::ConvTransposeMKLDNNOpKernel<float>); conv2d_transpose, MKLDNN, ::paddle::platform::CPUPlace,
ops::ConvTransposeMKLDNNOpKernel<float, float>,
ops::ConvTransposeMKLDNNOpKernel<paddle::platform::bfloat16, float>);
...@@ -492,6 +492,19 @@ inline std::vector<std::vector<int64_t>> ToMkldnnPadding( ...@@ -492,6 +492,19 @@ inline std::vector<std::vector<int64_t>> ToMkldnnPadding(
} }
} }
// The function adjusts the vector of weight dimensions for group convolutions
inline void GetGroupConvWeightsTz(std::vector<int64_t>& weights_tz, // NOLINT
const int groups) {
if (groups > 1) {
// if (is_conv3d) [o, i, d, h, w]->[g, o/g, i, d, h, w]
// else [o, i, h, w] -> [g, o/g, i, h, w]
weights_tz.push_back(0);
std::rotate(weights_tz.begin(), weights_tz.end() - 1, weights_tz.end());
weights_tz[0] = groups;
weights_tz[1] = weights_tz[1] / groups;
}
}
inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) { inline bool HasOpINT8DataType(const paddle::framework::OpDesc* op) {
return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" || return (op->GetAttrIfExists<std::string>("mkldnn_data_type") == "int8" ||
op->GetAttrIfExists<bool>("use_quantizer")); op->GetAttrIfExists<bool>("use_quantizer"));
......
...@@ -250,10 +250,12 @@ class MKLDNNHandlerT { ...@@ -250,10 +250,12 @@ class MKLDNNHandlerT {
astream.wait(); astream.wait();
} }
template <typename F = T>
std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder( std::shared_ptr<mkldnn::memory> AcquireMemoryWithReorder(
const mkldnn::memory::desc& user_md, const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const mkldnn::memory::desc& target_md, void* ptr,
const std::string& suffix, bool is_persistent = false) { const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {}) {
const auto target_key = key_ + suffix + "_target"; const auto target_key = key_ + suffix + "_target";
const auto key_reorder_p = key_ + suffix + "reorder_p"; const auto key_reorder_p = key_ + suffix + "reorder_p";
const auto user_key = key_ + suffix + "_user"; const auto user_key = key_ + suffix + "_user";
...@@ -262,6 +264,12 @@ class MKLDNNHandlerT { ...@@ -262,6 +264,12 @@ class MKLDNNHandlerT {
std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx_.GetBlob(target_key));
if (target_memory_p == nullptr) { if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx_.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p = auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, engine_, ptr); std::make_shared<dnnl::memory>(user_md, engine_, ptr);
if (user_md != target_md) { if (user_md != target_md) {
...@@ -1288,6 +1296,5 @@ static void SetDstMemoryQuantized( ...@@ -1288,6 +1296,5 @@ static void SetDstMemoryQuantized(
dst_memory.reset( dst_memory.reset(
new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data))); new mkldnn::memory(*dst_md, engine, to_void_cast<T>(output_data)));
} }
} // namespace platform } // namespace platform
} // namespace paddle } // namespace paddle
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle.fluid.core as core
from paddle.fluid.tests.unittests.op_test import OpTest, convert_float_to_uint16
from paddle.fluid.tests.unittests.test_conv2d_transpose_op import conv2dtranspose_forward_naive
from paddle import enable_static
def conv2d_bias_naive(out, bias):
_, out_c, _, _ = out.shape
for l in range(out_c):
out[:, l, :, :] = out[:, l, :, :] + bias[l]
return out
@unittest.skipIf(not core.supports_bfloat16(),
"place does not support BF16 evaluation")
class TestConv2DTransposeBF16MKLDNNOp(OpTest):
def test_check_output(self):
self.check_output_with_place(core.CPUPlace())
def test_check_grad(self):
pass
def test_check_grad_no_input(self):
pass
def test_check_grad_no_filter(self):
pass
def init_op_type(self):
self.data_format = "NCHW"
self.op_type = 'conv2d_transpose'
self._cpu_only = True
def init_test_case(self):
self.pad = [0, 0]
self.fuse_bias = False
self.use_mkldnn = True
self.is_test = True
self.bias_size = None
self.fuse_activation = ""
self.fuse_alpha = 0.0
self.fuse_beta = 0.0
self.stride = [1, 1]
self.dilations = [1, 1]
self.input_size = [2, 3, 5, 5] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 3, 3]
self.groups = 1
self.output_size = None
self.output_padding = []
self.data_format = "NCHW"
self.pad = [0, 0]
self.padding_algorithm = "EXPLICIT"
self.force_fp32_output = False
def setUp(self):
self.input_type = np.uint16
self.dtype = np.uint16
self.mkldnn_data_type = "bfloat16"
self.init_op_type()
self.init_test_case()
input = np.random.random(self.input_size).astype(np.float32)
filter = np.random.random(self.filter_size).astype(np.float32)
self.attrs = {
'strides': self.stride,
'paddings': self.pad,
'padding_algorithm': self.padding_algorithm,
'groups': self.groups,
'dilations': self.dilations,
'is_test': self.is_test,
'use_mkldnn': self.use_mkldnn,
'mkldnn_data_type': self.mkldnn_data_type,
'force_fp32_output': self.force_fp32_output,
'data_format': self.data_format,
'fuse_activation': self.fuse_activation,
'fuse_alpha': self.fuse_alpha,
'fuse_beta': self.fuse_beta
}
if self.output_size is not None:
self.attrs['output_size'] = self.output_size
if len(self.output_padding) > 0:
self.attrs['output_padding'] = self.output_padding
output = conv2dtranspose_forward_naive(input, filter,
self.attrs).astype(np.float32)
if self.input_type is not np.float32:
input = convert_float_to_uint16(input)
self.inputs = {
'Input': input.view(self.input_type),
'Filter': OpTest.np_dtype_to_fluid_dtype(filter)
}
if self.fuse_bias and self.bias_size is not None:
bias = np.random.random(self.bias_size).astype(np.float32)
output = conv2d_bias_naive(output, bias)
output = output.astype(np.float32)
self.attrs['fuse_bias'] = self.fuse_bias
self.inputs['Bias'] = OpTest.np_dtype_to_fluid_dtype(bias)
if self.fuse_activation == "relu":
output = np.maximum(output, 0).astype(np.float32)
output = output.astype(np.float32)
if not self.force_fp32_output:
output = convert_float_to_uint16(output, self.attrs['data_format'])
self.outputs['Output'] = output
class TestMKLDNNFuseBias(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNFuseBias, self).init_test_case()
self.pad = [1, 1]
self.fuse_bias = True
self.bias_size = [6]
class TestMKLDNNWithPad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithPad, self).init_test_case()
self.pad = [1, 1]
self.input_size = [2, 3, 10, 10]
class TestMKLDNNWithStride(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithStride, self).init_test_case()
self.pad = [1, 1]
self.stride = [2, 2]
self.input_size = [2, 3, 6, 6] # NCHW
class TestMKLDNNWithAsymPad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithAsymPad, self).init_test_case()
self.pad = [0, 0, 1, 2]
self.padding_algorithm = "EXPLICIT"
class TestMKLDNNWithSamePad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithSamePad, self).init_test_case()
self.pad = [0, 0]
self.padding_algorithm = "SAME"
class TestMKLDNNWithValidPad(TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestMKLDNNWithValidPad, self).init_test_case()
self.pad = [1, 1]
self.padding_algorithm = "VALID"
class TestMKLDNNWithValidPad_NHWC(TestMKLDNNWithValidPad):
def init_test_case(self):
super(TestMKLDNNWithValidPad_NHWC, self).init_test_case()
self.data_format = 'NHWC'
N, C, H, W = self.input_size
self.input_size = [N, H, W, C]
class TestConv2DTransposeMKLDNNWithDilationsExplicitPad(
TestConv2DTransposeBF16MKLDNNOp):
def init_test_case(self):
super(TestConv2DTransposeMKLDNNWithDilationsExplicitPad,
self).init_test_case()
self.stride = [2, 1]
self.dilations = [1, 2]
self.groups = 1
self.input_size = [4, 3, 8, 7] # NCHW
f_c = self.input_size[1]
self.filter_size = [f_c, 6, 4, 3]
self.pad = [1, 3, 2, 1]
self.padding_algorithm = "EXPLICIT"
if __name__ == '__main__':
enable_static()
unittest.main()
...@@ -82,6 +82,8 @@ class TestConv2DTransposeMKLDNNOp(TestConv2DTransposeOp): ...@@ -82,6 +82,8 @@ class TestConv2DTransposeMKLDNNOp(TestConv2DTransposeOp):
self.attrs['fuse_activation'] = self.fuse_activation self.attrs['fuse_activation'] = self.fuse_activation
self.attrs['fuse_alpha'] = self.fuse_alpha self.attrs['fuse_alpha'] = self.fuse_alpha
self.attrs['fuse_beta'] = self.fuse_beta self.attrs['fuse_beta'] = self.fuse_beta
self.attrs['mkldnn_data_type'] = 'float32'
self.attrs['force_fp32_output'] = False
self.outputs['Output'] = output self.outputs['Output'] = output
...@@ -150,3 +152,8 @@ class TestConv2DTransposeMKLDNNWithDilationsExplicitPad( ...@@ -150,3 +152,8 @@ class TestConv2DTransposeMKLDNNWithDilationsExplicitPad(
self.filter_size = [f_c, 6, 4, 3] self.filter_size = [f_c, 6, 4, 3]
self.pad = [1, 3, 2, 1] self.pad = [1, 3, 2, 1]
self.padding_algorithm = "EXPLICIT" self.padding_algorithm = "EXPLICIT"
if __name__ == '__main__':
enable_static()
unittest.main()
...@@ -221,12 +221,18 @@ def copy_bits_from_float_to_uint16(f): ...@@ -221,12 +221,18 @@ def copy_bits_from_float_to_uint16(f):
return struct.unpack('<I', struct.pack('<f', f))[0] >> 16 return struct.unpack('<I', struct.pack('<f', f))[0] >> 16
def convert_float_to_uint16(float_list): def convert_float_to_uint16(float_list, data_format="NCHW"):
if data_format == "NHWC":
float_list = np.transpose(float_list, [0, 3, 1, 2])
new_output = [] new_output = []
for x in np.nditer(float_list): for x in np.nditer(float_list):
new_output.append(np.uint16(copy_bits_from_float_to_uint16(x))) new_output.append(np.uint16(copy_bits_from_float_to_uint16(x)))
new_output = np.reshape(new_output, float_list.shape).view(np.uint16)
return np.reshape(new_output, float_list.shape).view(np.uint16) if data_format == "NHWC":
new_output = np.transpose(new_output, [0, 2, 3, 1])
return new_output
class OpTest(unittest.TestCase): class OpTest(unittest.TestCase):
......
...@@ -590,6 +590,7 @@ STATIC_MODE_TESTING_LIST = [ ...@@ -590,6 +590,7 @@ STATIC_MODE_TESTING_LIST = [
'test_conv2d_int8_mkldnn_op', 'test_conv2d_int8_mkldnn_op',
'test_conv2d_mkldnn_op', 'test_conv2d_mkldnn_op',
'test_conv2d_transpose_mkldnn_op', 'test_conv2d_transpose_mkldnn_op',
'test_conv2d_transpose_bf16_mkldnn_op',
'test_conv3d_mkldnn_op', 'test_conv3d_mkldnn_op',
'test_dequantize_mkldnn_op', 'test_dequantize_mkldnn_op',
'test_elementwise_add_mkldnn_op', 'test_elementwise_add_mkldnn_op',
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册