提交 619c797a 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] LRN refactoring (#19798)

- LRN mkl-dnn kernel refactor

test=develop

- compilation fix

- Another compilation fix

- Compilation fix

- another compilation fix

- compilation fix

- Crash fix

- optional LRN mkldnn workspace

- Added mid allocation

- Workaround for tests

- Removed gradient from is_test ut

- Removed mid for inference

- Reverted LRN mid removal for is_test

- PADDLE_ENFORCE adjusted

- Rebase to templatization commit

- Compilation fix

- compilation fix

test=develop

- lint

test=develop

- Fix to crash

- Rebase to recent codebase

 - lin

- lint

- compilation fix
上级 439d95e1
......@@ -32,16 +32,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"MKLDNN LRN must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x = ctx.Input<Tensor>("X");
auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut");
auto input_data = x->data<T>();
auto output_data = out->mutable_data<T>(ctx.GetPlace());
mid->mutable_data<T>(ctx.GetPlace());
const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
......@@ -52,31 +47,32 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k");
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
bool is_test = ctx.Attr<bool>("is_test");
auto dims = paddle::framework::vectorize<int>(x->dims());
// Format and dims are assumed to be the same for dst and src
auto md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), x->format());
const std::string key = platform::CreateKey(
dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out"));
platform::LRNMKLDNNHandler handler(ctx.Attr<bool>("is_test"), dev_ctx,
mkldnn_engine, key);
auto src_memory =
handler.AcquireSrcMemory(md, platform::to_void_cast<T>(input_data));
// TODO(jczaja): Hide getting PD inside of handler for all Acquire API
handler.AcquireLRNPrimitiveDescriptor(md, n, alpha, beta, k);
auto dst_memory =
handler.AcquireDstMemory(md, platform::to_void_cast<T>(output_data));
auto lrn_p = handler.AcquireLRN(dst_memory, src_memory);
platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
is_test, dev_ctx, ctx.GetPlace(),
ctx.op().Output("Out"));
auto src_memory = handler.AcquireSrcMemory(x);
auto dst_memory = handler.AcquireDstMemory(out);
std::shared_ptr<mkldnn::memory> workspace_memory;
std::shared_ptr<mkldnn::lrn_forward> lrn_p;
if (is_test == false) {
workspace_memory = handler.AcquireWorkspaceMemory(mid);
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory,
*dst_memory);
} else {
// mid has to be allocated and filled
// k to pass LRN unit tests
// TODO(jczaja): Disable checking mid in unit tests (Require API change)
mid->mutable_data<T>(ctx.GetPlace());
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory);
}
std::vector<mkldnn::primitive> pipeline = {*lrn_p};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
......@@ -104,6 +100,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
"is_test attribute should be set to False in training phase.");
auto x = ctx.Input<Tensor>("X");
auto mid = ctx.Input<Tensor>("MidOut");
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
......@@ -114,42 +111,20 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const float k = ctx.Attr<float>("k");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto out_grad_data = out_grad->data<T>();
auto dims = paddle::framework::vectorize<int>(x->dims());
const std::string key = platform::CreateKey(
dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out"));
platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key);
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), x->format());
// diff_dst and diff_src layouts are assumed to be the same
auto diff_md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto workspace = handler.AcquireWorkspaceMemory();
auto diff_dst_memory = handler.AcquireDiffDstMemory(
diff_md, platform::to_void_cast<T>(out_grad_data));
auto diff_src_memory = handler.AcquireDiffSrcMemory(
diff_md, platform::to_void_cast<T>(x_grad_data));
auto src_memory = handler.AcquireSrcMemory(
src_md, platform::to_void_cast<T>(x->data<T>()));
platform::LRNMKLDNNHandler<T> handler(
dims, n, alpha, beta, k, x->format(), out_grad->format(), dev_ctx,
ctx.GetPlace(), ctx.op().Input("Out"));
// TODO(jczaja): Hide this call inside Handler
handler.AcquireLRNBackwardPrimitiveDescriptor(src_md, diff_md, n, alpha,
beta, k);
auto src_memory = handler.AcquireSrcMemory(x);
auto workspace = handler.AcquireBackwardWorkspaceMemory(mid);
auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad);
auto lrn_bwd = handler.AcquireLRNBackward(src_memory, diff_dst_memory,
workspace, diff_src_memory);
auto lrn_bwd = handler.AcquireBackwardPrimitive(
*src_memory, *diff_dst_memory, *workspace, *diff_src_memory);
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
......
......@@ -460,141 +460,64 @@ class ActivationMKLDNNHandler
}
};
class LRNMKLDNNHandler : public MKLDNNHandler {
template <typename T>
class LRNMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
public:
LRNMKLDNNHandler(bool is_test, const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key), is_test_(is_test) {}
std::shared_ptr<mkldnn::lrn_forward::primitive_desc>
AcquireLRNPrimitiveDescriptor(const mkldnn::memory::desc& src_md, const int n,
const float alpha, const float beta,
const float k) {
// LRN PD has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
if (fwd_pd_ == nullptr) {
auto forward_desc = mkldnn::lrn_forward::desc{
is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k};
fwd_pd_.reset(
new mkldnn::lrn_forward::primitive_desc(forward_desc, engine_));
dev_ctx_.SetBlob(key_lrn_pd, fwd_pd_);
}
}
return fwd_pd_;
}
LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha,
const float beta, const float k,
const MKLDNNMemoryFormat fmt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name)
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) {
// workspace has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
auto local_key = key_common_ + "@workspace";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr,
"LRN PD MKL-DNN not found in cache!");
mkldnn::memory::primitive_desc workspace_mpd =
fwd_pd_->workspace_primitive_desc();
mem_p = std::make_shared<mkldnn::memory>(workspace_mpd);
dev_ctx_.SetBlob(local_key, mem_p);
}
}
return mem_p;
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) {
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(
is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k);
}
std::shared_ptr<mkldnn::lrn_forward> AcquireLRN(
std::shared_ptr<mkldnn::memory> dst_memory,
std::shared_ptr<mkldnn::memory> src_memory) {
auto prim_key = key_ + "@lrn_p";
auto lrn_p = std::static_pointer_cast<mkldnn::lrn_forward>(
dev_ctx_.GetBlob(prim_key));
if (lrn_p == nullptr) {
if (is_test_) {
lrn_p = std::make_shared<mkldnn::lrn_forward>(*fwd_pd_, *(src_memory),
*(dst_memory));
} else {
// For training we need to create workspace
// to store indices from backward
auto workspace_memory = this->AcquireWorkspaceMemory();
LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha,
const float beta, const float k,
const MKLDNNMemoryFormat fmt,
const MKLDNNMemoryFormat diff_fmt,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name)
lrn_p = std::make_shared<mkldnn::lrn_forward>(
*fwd_pd_, *src_memory, *workspace_memory, *dst_memory);
}
dev_ctx_.SetBlob(prim_key, lrn_p);
}
return lrn_p;
}
std::shared_ptr<mkldnn::lrn_backward::primitive_desc>
AcquireLRNBackwardPrimitiveDescriptor(const mkldnn::memory::desc& src_md,
const mkldnn::memory::desc& diff_md,
const int n, const float alpha,
const float beta, const float k) {
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
const std::string key_lrn_bwd_pd = key_ + "@lrn_bwd_pd";
bwd_pd_ = std::static_pointer_cast<mkldnn::lrn_backward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr, "LRN MKL-DNN not found in cache!");
: platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) {
auto src_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
auto diff_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
auto backward_desc = mkldnn::lrn_backward::desc{
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k};
bwd_pd_.reset(new mkldnn::lrn_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_lrn_bwd_pd, bwd_pd_);
}
return bwd_pd_;
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md,
n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptor(
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k);
}
std::shared_ptr<mkldnn::lrn_backward> AcquireLRNBackward(
std::shared_ptr<mkldnn::memory> src_memory,
std::shared_ptr<mkldnn::memory> diff_dst_memory,
std::shared_ptr<mkldnn::memory> workspace,
std::shared_ptr<mkldnn::memory> diff_src_memory) {
auto prim_key = key_ + "@lrn_bwd_p";
auto lrn_bwd_p = std::static_pointer_cast<mkldnn::lrn_backward>(
dev_ctx_.GetBlob(prim_key));
if (lrn_bwd_p == nullptr) {
lrn_bwd_p = std::make_shared<mkldnn::lrn_backward>(
*bwd_pd_, *src_memory, *diff_dst_memory, *workspace,
*diff_src_memory);
dev_ctx_.SetBlob(prim_key, lrn_bwd_p);
}
return lrn_bwd_p;
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
framework::Tensor* workspace) {
T* ptr = workspace->mutable_data<T>(
this->place_, this->fwd_pd_->dst_primitive_desc().get_size());
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p");
}
private:
bool is_test_;
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::lrn_backward::primitive_desc> bwd_pd_;
std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
const framework::Tensor* workspace) {
const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_primitive_desc(),
to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p");
}
};
class PoolingMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册