未验证 提交 d414af94 编写于 作者: J Jacek Czaja 提交者: GitHub

[Internal reviewing] NHWC fix to am_vocoder model for oneDNN 2.6 (#42729)

* - prototype of reimplemented fixes

* - compilation fixes

* - compilation fix

* - cosmetic info

* - hopefully fix

* - compilation fix

* - supported for nested blocking of cache clearing

* - fix

* - Unit test to changes

* - Compilation fix to windows (hopefully)

* - Moved resetting layout to ResetBlob

* - fixes after review
上级 0211a833
......@@ -1908,7 +1908,8 @@ Scope* OperatorWithKernel::PrepareData(
(var->IsType<LoDTensor>() == true) &&
(expected_kernel_key.data_layout_ != DataLayout::kMKLDNN) &&
(paddle::platform::MKLDNNDeviceContext::tls()
.get_cur_paddle_data_layout() == DataLayout::kNHWC)) {
.get_cur_paddle_data_layout() == DataLayout::kNHWC) &&
(tensor_in->dims().size() >= 3)) {
// Mixed execution : MKL-DNN and GPU is not supported!
if (!new_scope) {
new_scope = &scope.NewScope();
......
......@@ -17,6 +17,10 @@ limitations under the License. */
#include "paddle/fluid/operators/assign_op.h"
#include "paddle/phi/kernels/funcs/math_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace operators {
......@@ -65,6 +69,12 @@ class ConditionalBlockOp : public ConditionalOp {
scopes->resize(1);
scopes->front() = &scope.NewScope();
auto &cur_scope = *scopes->front();
#ifdef PADDLE_WITH_MKLDNN
// (jczaja) Executor on being destroyed clears oneDNN cache and
// reset registered model data layout. This is unwanted for nested
// Executors (executors declared inside control ops)
platform::DontClearMKLDNNCache(dev_place);
#endif
framework::Executor exec(dev_place);
auto *block = Attr<framework::BlockDesc *>("sub_block");
VLOG(3) << "Conditional block.idx = " << block->ID()
......
......@@ -17,6 +17,9 @@
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
#endif
namespace paddle {
namespace framework {
class InferShapeContext;
......@@ -66,6 +69,12 @@ class WhileOp : public framework::OperatorBase {
"the Condition's shape is ",
cond.dims().to_str(), ".\n"));
#ifdef PADDLE_WITH_MKLDNN
// (jczaja) Executor on being destroyed clears oneDNN cache and
// resets registered model data layout. This is unwanted for nested
// Executors (executors declared inside control ops)
platform::DontClearMKLDNNCache(dev_place);
#endif
framework::Executor executor(dev_place);
auto *block = Attr<framework::BlockDesc *>(kStepBlock);
......
......@@ -97,7 +97,7 @@ Crop Operator.
Crop input into output, as specified by offsets and shape.
There are two ways to set the offsets:
1. In runtime: Using the input 'Offsets', which is a Vairbale and can be
1. In runtime: Using the input 'Offsets', which is a Variable and can be
output of other operators. This way is suitable for
dynamic offsets.
2. In network configuration: Using the attribute 'offsets', which will be
......
cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op shape_op activation_op pooling transpose_op scope device_context enforce executor)
cc_test(test_mkldnn_op_nhwc SRCS mkldnn/test_mkldnn_op_nhwc.cc DEPS op_registry pool_op shape_op crop_op activation_op pooling transpose_op scope device_context enforce executor)
......@@ -34,6 +34,8 @@ USE_OP_ITSELF(transpose);
USE_OP_DEVICE_KERNEL(transpose, MKLDNN);
USE_OP_ITSELF(shape);
USE_OP_DEVICE_KERNEL(shape, MKLDNN);
USE_OP_ITSELF(crop);
USE_OP_DEVICE_KERNEL(crop, CPU);
PD_DECLARE_KERNEL(pool2d, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(relu, CPU, ALL_LAYOUT);
......@@ -211,5 +213,68 @@ TEST(test_pool2d_shape_nhwc, cpu_place) {
"Computed shape does not match expected shape"));
}
TEST(test_pool2d_crop_nhwc, cpu_place) {
framework::DDim dims({1, 4, 8, 512}); // NHWC shape
framework::DDim expected_dims({1, 3, 7, 512}); // NCHW expected shape
platform::CPUPlace p;
framework::Scope scope;
InputVars input_name = {"x",
scope.Var("x")->GetMutable<framework::LoDTensor>()};
InputVars second_crop_input_name = {
"v", scope.Var("v")->GetMutable<framework::LoDTensor>()};
// Initialize input data
std::uniform_real_distribution<float> dist(10.0f, 20.0f);
std::mt19937 engine;
size_t numel = static_cast<size_t>(phi::product(dims));
input_name.tensor->Resize(dims);
auto data_ptr = input_name.tensor->mutable_data<float>(p);
for (size_t i = 0; i < numel; ++i) {
data_ptr[i] = dist(engine);
}
// Second input (Y) to crop is having no buffer
// but as it is MKLDNN then its shape order should be NCHW
auto expected_dims_nchw = phi::vectorize<int64_t>(expected_dims);
std::rotate(expected_dims_nchw.begin() + 1, expected_dims_nchw.end() - 1,
expected_dims_nchw.end());
second_crop_input_name.tensor->Resize(phi::make_ddim(expected_dims_nchw));
const auto second_crop_input_md =
dnnl::memory::desc(expected_dims_nchw, dnnl::memory::data_type::f32,
dnnl::memory::format_tag::nhwc);
second_crop_input_name.tensor->set_mem_desc(second_crop_input_md);
scope.Var("y")->GetMutable<framework::LoDTensor>();
auto *z = scope.Var("z")->GetMutable<framework::LoDTensor>();
auto &pool = platform::DeviceContextPool::Instance();
// Make pool2d followed by crop. crop may have Y input as
// non buffered so the path to be executed is handling oneDNN kernel
// that is followed by CPU kernel with non-buffered Input
auto ksize = std::vector<int>(2, 2);
auto op_pool = framework::OpRegistry::CreateOp(
"pool2d", {{"X", {"x"}}}, {{"Out", {"y"}}},
{{"pooling_type", {std::string("max")}},
{"ksize", {ksize}},
{"data_format", {std::string("NHWC")}},
{"use_mkldnn", {true}}});
std::vector<int> offsets{0, 0, 0, 0};
auto op_crop = framework::OpRegistry::CreateOp(
"crop", {{"X", {"y"}}, {"Y", {"v"}}}, {{"Out", {"z"}}},
{{"offsets", {offsets}}});
op_pool->Run(scope, p);
op_crop->Run(scope, p);
pool.Get(p)->Wait();
// Verify shape of output
PADDLE_ENFORCE_EQ(z->dims(), expected_dims,
platform::errors::InvalidArgument(
"Output shape does not match expected output shape"));
}
} // namespace operators
} // namespace paddle
......@@ -750,7 +750,7 @@ dnnl::stream& MKLDNNDeviceContextThreadLocals::Body::get_stream(void) {
void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
VLOG(4) << tls().get_curr_exec() << " " << ptr;
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
if (!block_next_cache_clearing_) {
if (block_next_cache_clearing_ == 0) {
VLOG(3) << "Clearing DNNL cache.";
// If no specific executor pointer then clear
// everything. For executor pointer then clear only
......@@ -768,9 +768,20 @@ void MKLDNNDeviceContext::ResetBlobMap(void* ptr) {
s.second->erase(ptr);
}
}
// Reset paddle layout to NCHW
VLOG(3) << "Resetting Paddle data layout to NCHW.";
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
} else {
VLOG(3) << "Prevented Clearing DNNL cache.";
block_next_cache_clearing_ = false;
--block_next_cache_clearing_;
VLOG(3) << "Prevented Clearing DNNL cache. Updated "
"block_next_cache_clearing_ : "
<< block_next_cache_clearing_;
PADDLE_ENFORCE_GE(block_next_cache_clearing_, 0,
platform::errors::InvalidArgument(
"Cache clearing mark should be non-negative "
". But received %d.",
block_next_cache_clearing_));
}
}
......@@ -796,8 +807,10 @@ void MKLDNNDeviceContext::LinkEntryWithExecutor(BlobPtr_t<KeyBlob> pblob,
void MKLDNNDeviceContext::BlockNextCacheClearing() {
std::lock_guard<decltype(*p_mutex_)> lock(*p_mutex_);
VLOG(3) << "Next DNNL cache clearing has been blocked.";
block_next_cache_clearing_ = true;
++block_next_cache_clearing_;
VLOG(3) << "Next DNNL cache clearing has been blocked. Updated "
"block_next_cache_clearing_ : "
<< block_next_cache_clearing_;
}
size_t MKLDNNDeviceContext::GetShapeBlobSize() const {
......
......@@ -850,7 +850,8 @@ class MKLDNNDeviceContext : public CPUDeviceContext {
// to erase
std::shared_ptr<ExecShape> p_exec_items_;
std::shared_ptr<std::mutex> p_mutex_;
bool block_next_cache_clearing_ = false;
// 0 - clearing is allowed. x > 0 do not clear.
unsigned int block_next_cache_clearing_ = 0;
};
#endif
......
......@@ -148,8 +148,6 @@ inline void ClearMKLDNNCache(const platform::Place& place,
platform::MKLDNNDeviceContext* dev_ctx =
(platform::MKLDNNDeviceContext*)pool.Get(place);
dev_ctx->ResetBlobMap(ptr);
platform::MKLDNNDeviceContext::tls().set_cur_paddle_data_layout(
paddle::framework::DataLayout::kNCHW);
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册