未验证 提交 162ac048 编写于 作者: S Sławomir Siwek 提交者: GitHub

Replace custom IOHW -> OIHW reorder with build-in oneDNN reorder (#37175)

* Use oneDNN reorder instead of custom one

* Fix whitespace typo

* Fix Code format error

* Incorporating feedback

* Remove unncessary reorder

* Support GIOHW format

* Fix code format error
上级 6d6642c8
......@@ -68,7 +68,7 @@ class ConvTransposeMKLDNNHandlerT
PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument(
"The filter tensor's laytout should be %d, but got %d.",
"The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument(
......@@ -227,34 +227,12 @@ class ConvTransposeMKLDNNHandlerT
auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(),
(g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw);
auto iohw_weights_tz = framework::vectorize(filter->dims());
// Custom Reorder from IOHW to OIHW
auto iohw2oihw_reorder =
[&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> {
int o = iohw_weights_tz[1];
int c = iohw_weights_tz[0];
int h = iohw_weights_tz[2];
int w = iohw_weights_tz[3];
std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](),
std::default_delete<K[]>());
for (int i = 0; i < c; ++i) {
for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(K));
}
}
return reordered_filter_data;
};
(g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
return this->template AcquireMemoryWithReorder<K>(
dev_ctx, user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), key, "@weights_mem_p", is_test_,
iohw2oihw_reorder);
platform::to_void_cast<K>(filter_data), key, "@weights_mem_p",
is_test_);
}
template <typename F = T>
......@@ -263,7 +241,6 @@ class ConvTransposeMKLDNNHandlerT
const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const std::string& key,
const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key + suffix + "_target";
const auto key_reorder_p = key + suffix + "reorder_p";
......@@ -273,12 +250,6 @@ class ConvTransposeMKLDNNHandlerT
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));
if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
if (user_md != target_md) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册