未验证 提交 162ac048 编写于 作者: S Sławomir Siwek 提交者: GitHub

Replace custom IOHW -> OIHW reorder with build-in oneDNN reorder (#37175)

* Use oneDNN reorder instead of custom one

* Fix whitespace typo

* Fix Code format error

* Incorporating feedback

* Remove unncessary reorder

* Support GIOHW format

* Fix code format error
上级 6d6642c8
...@@ -68,7 +68,7 @@ class ConvTransposeMKLDNNHandlerT ...@@ -68,7 +68,7 @@ class ConvTransposeMKLDNNHandlerT
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
filter->layout(), DataLayout::kMKLDNN, filter->layout(), DataLayout::kMKLDNN,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"The filter tensor's laytout should be %d, but got %d.", "The filter tensor's layout should be %d, but got %d.",
DataLayout::kMKLDNN, filter->layout())); DataLayout::kMKLDNN, filter->layout()));
PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef, PADDLE_ENFORCE_NE(filter->format(), MKLDNNMemoryFormat::undef,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -227,34 +227,12 @@ class ConvTransposeMKLDNNHandlerT ...@@ -227,34 +227,12 @@ class ConvTransposeMKLDNNHandlerT
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
weights_tz, platform::MKLDNNGetDataType<K>(), weights_tz, platform::MKLDNNGetDataType<K>(),
(g == 1) ? filter->format() : MKLDNNMemoryFormat::goihw); (g == 1) ? MKLDNNMemoryFormat::iohw : MKLDNNMemoryFormat::giohw);
auto iohw_weights_tz = framework::vectorize(filter->dims());
// Custom Reorder from IOHW to OIHW
auto iohw2oihw_reorder =
[&iohw_weights_tz](const K* filter_data) -> std::shared_ptr<K> {
int o = iohw_weights_tz[1];
int c = iohw_weights_tz[0];
int h = iohw_weights_tz[2];
int w = iohw_weights_tz[3];
std::shared_ptr<K> reordered_filter_data(new K[o * c * h * w](),
std::default_delete<K[]>());
for (int i = 0; i < c; ++i) {
for (int j = 0; j < o; ++j) {
int in_offset = j * h * w + i * o * h * w;
int out_offset = j * c * h * w + i * h * w;
std::memcpy(&(reordered_filter_data.get())[out_offset],
&filter_data[in_offset], h * w * sizeof(K));
}
}
return reordered_filter_data;
};
return this->template AcquireMemoryWithReorder<K>( return this->template AcquireMemoryWithReorder<K>(
dev_ctx, user_src_md, this->fwd_pd_->weights_desc(), dev_ctx, user_src_md, this->fwd_pd_->weights_desc(),
platform::to_void_cast<K>(filter_data), key, "@weights_mem_p", is_test_, platform::to_void_cast<K>(filter_data), key, "@weights_mem_p",
iohw2oihw_reorder); is_test_);
} }
template <typename F = T> template <typename F = T>
...@@ -263,7 +241,6 @@ class ConvTransposeMKLDNNHandlerT ...@@ -263,7 +241,6 @@ class ConvTransposeMKLDNNHandlerT
const mkldnn::memory::desc& user_md, const mkldnn::memory::desc& user_md,
const mkldnn::memory::desc& target_md, void* ptr, const std::string& key, const mkldnn::memory::desc& target_md, void* ptr, const std::string& key,
const std::string& suffix, bool is_persistent = false, const std::string& suffix, bool is_persistent = false,
std::function<std::shared_ptr<F>(const F*)> custom_reorder_func = {},
const std::vector<float>& scale_data = {1.0f}, int mask = 0) { const std::vector<float>& scale_data = {1.0f}, int mask = 0) {
const auto target_key = key + suffix + "_target"; const auto target_key = key + suffix + "_target";
const auto key_reorder_p = key + suffix + "reorder_p"; const auto key_reorder_p = key + suffix + "reorder_p";
...@@ -273,12 +250,6 @@ class ConvTransposeMKLDNNHandlerT ...@@ -273,12 +250,6 @@ class ConvTransposeMKLDNNHandlerT
std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key)); std::static_pointer_cast<dnnl::memory>(dev_ctx.GetBlob(target_key));
if (target_memory_p == nullptr) { if (target_memory_p == nullptr) {
if (custom_reorder_func) {
auto reordered_data =
custom_reorder_func(reinterpret_cast<const F*>(ptr));
dev_ctx.SetBlob(key_reorder_p + "-custom_reorder", reordered_data);
ptr = reinterpret_cast<void*>(reordered_data.get());
}
auto user_memory_p = auto user_memory_p =
std::make_shared<dnnl::memory>(user_md, this->engine_, ptr); std::make_shared<dnnl::memory>(user_md, this->engine_, ptr);
if (user_md != target_md) { if (user_md != target_md) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册