提交 95c1816e 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] Extended LRN with reusing via Acquire API (#18675)

test=develop

- compileation fix

- Yet another compilation fix

- Even yet another compilation fix

- Surprise! Again compilation fix

- lint fixes

test=develop

- Fix to workspace acquire of LRN

test=develop

- Fix to hash of BWD LRN

test=develop

- fix to lrn BWD PD acquire

test=develop

- Fixing LRN PD creation

test=develop

- cosmetic fix in comment

test=develop

- Fixes after review

test=develop
上级 0ae45f0b
......@@ -14,7 +14,7 @@ limitations under the License. */
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/operators/lrn_op.h"
#include "paddle/fluid/platform/mkldnn_helper.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
namespace paddle {
namespace operators {
......@@ -22,30 +22,6 @@ namespace operators {
using paddle::framework::Tensor;
using paddle::platform::MKLDNNDeviceContext;
namespace {
template <typename T, typename... Args>
std::shared_ptr<T> insert_to_context(const std::string& key,
const MKLDNNDeviceContext& dev_ctx,
Args&&... args) {
auto p = std::static_pointer_cast<T, void>(dev_ctx.GetBlob(key));
if (!p) {
p = std::make_shared<T>(args...);
dev_ctx.SetBlob(key, std::static_pointer_cast<void, T>(p));
}
return p;
}
template <typename... Args>
void run_primitive(Args&&... args) {
auto forward_op = mkldnn::lrn_forward{args...};
std::vector<mkldnn::primitive> pipeline = {forward_op};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
}
} // namespace
template <typename T>
class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
public:
......@@ -76,66 +52,42 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
const bool is_test = ctx.Attr<bool>("is_test");
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, x->format());
auto forward_desc = mkldnn::lrn_forward::desc{mkldnn::prop_kind::forward,
mkldnn::lrn_across_channels,
src_md,
n,
alpha,
beta,
k};
auto src_memory_pd = mkldnn::memory::primitive_desc{src_md, mkldnn_engine};
if (!is_test) {
const std::string key = ctx.op().Output("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
const std::string key_pd = key + "@lrn_pd";
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
auto forward_pd = insert_to_context<mkldnn::lrn_forward::primitive_desc>(
key_pd, dev_ctx, forward_desc, mkldnn_engine);
auto src_memory = insert_to_context<mkldnn::memory>(
key_src_memory, dev_ctx, src_memory_pd);
src_memory->set_data_handle(
static_cast<void*>(const_cast<T*>(input_data)));
auto dst_memory = mkldnn::memory(forward_pd->dst_primitive_desc(),
static_cast<void*>(output_data));
auto workspace_memory = insert_to_context<mkldnn::memory>(
key_workspace_memory, dev_ctx,
forward_pd->workspace_primitive_desc());
run_primitive(*forward_pd, *src_memory, *workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
} else {
auto forward_pd =
mkldnn::lrn_forward::primitive_desc{forward_desc, mkldnn_engine};
auto src_memory = mkldnn::memory{
src_memory_pd, static_cast<void*>(const_cast<T*>(input_data))};
auto workspace_memory =
mkldnn::memory{forward_pd.workspace_primitive_desc()};
auto dst_memory = mkldnn::memory(forward_pd.dst_primitive_desc(),
static_cast<void*>(output_data));
run_primitive(forward_pd, src_memory, workspace_memory, dst_memory);
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(platform::GetMKLDNNFormat(dst_memory));
}
// Format and dims are assumed to be the same for dst and src
auto md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), x->format());
const std::string key = platform::LRNMKLDNNHandler::GetHash(
dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out"));
platform::LRNMKLDNNHandler handler(ctx.Attr<bool>("is_test"), dev_ctx,
mkldnn_engine, key);
auto src_memory =
handler.AcquireSrcMemory(md, platform::to_void_cast<T>(input_data));
// TODO(jczaja): Hide getting PD inside of handler for all Acquire API
handler.AcquireLRNPrimitiveDescriptor(md, n, alpha, beta, k);
auto dst_memory =
handler.AcquireDstMemory(md, platform::to_void_cast<T>(output_data));
auto lrn_p = handler.AcquireLRN(dst_memory, src_memory);
std::vector<mkldnn::primitive> pipeline = {*lrn_p};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto output_format =
(mkldnn::memory::format)dst_memory->get_primitive_desc()
.desc()
.data.format;
out->set_layout(framework::DataLayout::kMKLDNN);
out->set_format(output_format);
}
};
......@@ -156,11 +108,6 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
const std::string key = ctx.op().Input("Out");
const std::string key_src_memory = key + "@lrn_src_memory";
const std::string key_pd = key + "@lrn_pd";
const std::string key_workspace_memory = key + "@lrn_workspace_memory";
const int n = ctx.Attr<int>("n");
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
......@@ -174,42 +121,46 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto dims = paddle::framework::vectorize2int(x->dims());
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
const std::string key = platform::LRNMKLDNNHandler::GetHash(
dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out"));
auto diff_src_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key);
auto diff_dst_md = paddle::platform::MKLDNNMemDesc(
dims, mkldnn::memory::data_type::f32, mkldnn::memory::format::nchw);
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), x->format());
auto diff_dst_memory =
mkldnn::memory{{diff_dst_md, mkldnn_engine},
static_cast<void*>(const_cast<float*>(out_grad_data))};
// diff_dst and diff_src layouts are assumed to be the same
auto diff_md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto diff_src_memory = mkldnn::memory{{diff_src_md, mkldnn_engine},
static_cast<void*>(x_grad_data)};
auto workspace = handler.AcquireWorkspaceMemory();
auto backward_desc = mkldnn::lrn_backward::desc{
mkldnn::lrn_across_channels, src_md, diff_src_md, n, alpha, beta, k};
auto diff_dst_memory = handler.AcquireDiffDstMemory(
diff_md, platform::to_void_cast<T>(out_grad_data));
auto forward_pd = dev_ctx.GetBlob(key_pd);
auto diff_src_memory = handler.AcquireDiffSrcMemory(
diff_md, platform::to_void_cast<T>(x_grad_data));
auto backward_pd = mkldnn::lrn_backward::primitive_desc{
backward_desc, mkldnn_engine,
*static_cast<mkldnn::lrn_forward::primitive_desc*>(forward_pd.get())};
auto src_memory = handler.AcquireSrcMemory(
src_md, platform::to_void_cast<T>(x->data<T>()));
std::shared_ptr<void> workspace_memory =
dev_ctx.GetBlob(key_workspace_memory);
// TODO(jczaja): Hide this call inside Handler
handler.AcquireLRNBackwardPrimitiveDescriptor(src_md, diff_md, n, alpha,
beta, k);
auto src_memory = dev_ctx.GetBlob(key_src_memory);
auto backward_op = mkldnn::lrn_backward{
backward_pd, *static_cast<mkldnn::memory*>(src_memory.get()),
diff_dst_memory, *static_cast<mkldnn::memory*>(workspace_memory.get()),
diff_src_memory};
auto lrn_bwd = handler.AcquireLRNBackward(src_memory, diff_dst_memory,
workspace, diff_src_memory);
std::vector<mkldnn::primitive> pipeline = {backward_op};
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
auto output_format =
(mkldnn::memory::format)diff_src_memory->get_primitive_desc()
.desc()
.data.format;
x_grad->set_layout(framework::DataLayout::kMKLDNN);
x_grad->set_format(output_format);
}
};
} // namespace operators
......
......@@ -436,6 +436,159 @@ class ActivationMKLDNNHandler : public MKLDNNHandler {
std::shared_ptr<mkldnn::eltwise_backward::primitive_desc> activation_bwd_pd_;
};
class LRNMKLDNNHandler : public MKLDNNHandler {
public:
LRNMKLDNNHandler(bool is_test, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), is_test_(is_test) {}
std::shared_ptr<mkldnn::lrn_forward::primitive_desc>
AcquireLRNPrimitiveDescriptor(const mkldnn::memory::desc& src_md, const int n,
const float alpha, const float beta,
const float k) {
// LRN PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
if (fwd_pd_ == nullptr) {
auto forward_desc = mkldnn::lrn_forward::desc{
is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k};
fwd_pd_.reset(
new mkldnn::lrn_forward::primitive_desc(forward_desc, engine_));
dev_ctx_.SetBlob(key_lrn_pd, fwd_pd_);
}
}
return fwd_pd_;
}
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) {
// workspace has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
auto local_key = key_common_ + "@workspace";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr,
"LRN PD MKL-DNN not found in cache!");
mkldnn::memory::primitive_desc workspace_mpd =
fwd_pd_->workspace_primitive_desc();
mem_p = std::make_shared<mkldnn::memory>(workspace_mpd);
dev_ctx_.SetBlob(local_key, mem_p);
}
}
return mem_p;
}
std::shared_ptr<mkldnn::lrn_forward> AcquireLRN(
std::shared_ptr<mkldnn::memory> dst_memory,
std::shared_ptr<mkldnn::memory> src_memory) {
auto prim_key = key_ + "@lrn_p";
auto lrn_p = std::static_pointer_cast<mkldnn::lrn_forward>(
dev_ctx_.GetBlob(prim_key));
if (lrn_p == nullptr) {
if (is_test_) {
lrn_p = std::make_shared<mkldnn::lrn_forward>(*fwd_pd_, *(src_memory),
*(dst_memory));
} else {
// For training we need to create workspace
// to store indices from backward
auto workspace_memory = this->AcquireWorkspaceMemory();
lrn_p = std::make_shared<mkldnn::lrn_forward>(
*fwd_pd_, *src_memory, *workspace_memory, *dst_memory);
}
dev_ctx_.SetBlob(prim_key, lrn_p);
}
return lrn_p;
}
std::shared_ptr<mkldnn::lrn_backward::primitive_desc>
AcquireLRNBackwardPrimitiveDescriptor(const mkldnn::memory::desc& src_md,
const mkldnn::memory::desc& diff_md,
const int n, const float alpha,
const float beta, const float k) {
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
const std::string key_lrn_bwd_pd = key_ + "@lrn_bwd_pd";
bwd_pd_ = std::static_pointer_cast<mkldnn::lrn_backward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr, "LRN MKL-DNN not found in cache!");
auto backward_desc = mkldnn::lrn_backward::desc{
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k};
bwd_pd_.reset(new mkldnn::lrn_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_lrn_bwd_pd, bwd_pd_);
}
return bwd_pd_;
}
std::shared_ptr<mkldnn::lrn_backward> AcquireLRNBackward(
std::shared_ptr<mkldnn::memory> src_memory,
std::shared_ptr<mkldnn::memory> diff_dst_memory,
std::shared_ptr<mkldnn::memory> workspace,
std::shared_ptr<mkldnn::memory> diff_src_memory) {
auto prim_key = key_ + "@lrn_bwd_p";
auto lrn_bwd_p = std::static_pointer_cast<mkldnn::lrn_backward>(
dev_ctx_.GetBlob(prim_key));
if (lrn_bwd_p == nullptr) {
lrn_bwd_p = std::make_shared<mkldnn::lrn_backward>(
*bwd_pd_, *src_memory, *diff_dst_memory, *workspace,
*diff_src_memory);
dev_ctx_.SetBlob(prim_key, lrn_bwd_p);
}
return lrn_bwd_p;
}
static std::string GetHash(const memory::dims& input_dims, const int n,
const float alpha, const float beta, const float k,
const memory::format& fmt,
const std::string& suffix) {
std::string key;
key.reserve(platform::MKLDNNHandler::MaxKeyLength);
platform::MKLDNNHandler::AppendKeyDims(&key, input_dims);
platform::MKLDNNHandler::AppendKey(&key, std::to_string(n));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(alpha));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(beta));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(k));
platform::MKLDNNHandler::AppendKey(&key, std::to_string(fmt));
platform::MKLDNNHandler::AppendKey(&key, suffix);
return key;
}
private:
bool is_test_;
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::lrn_backward::primitive_desc> bwd_pd_;
};
class PoolingMKLDNNHandler : public MKLDNNHandler {
public:
PoolingMKLDNNHandler(const std::string& pooling_type,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册