提交 619c797a 编写于 作者: J Jacek Czaja 提交者: Tao Luo

[MKL-DNN] LRN refactoring (#19798)

- LRN mkl-dnn kernel refactor

test=develop

- compilation fix

- Another compilation fix

- Compilation fix

- another compilation fix

- compilation fix

- Crash fix

- optional LRN mkldnn workspace

- Added mid allocation

- Workaround for tests

- Removed gradient from is_test ut

- Removed mid for inference

- Reverted LRN mid removal for is_test

- PADDLE_ENFORCE adjusted

- Rebase to templatization commit

- Compilation fix

- compilation fix

test=develop

- lint

test=develop

- Fix to crash

- Rebase to recent codebase

 - lin

- lint

- compilation fix
上级 439d95e1
...@@ -32,16 +32,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -32,16 +32,11 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"MKLDNN LRN must use CPUPlace."); "MKLDNN LRN must use CPUPlace.");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto out = ctx.Output<Tensor>("Out"); auto out = ctx.Output<Tensor>("Out");
auto mid = ctx.Output<Tensor>("MidOut"); auto mid = ctx.Output<Tensor>("MidOut");
auto input_data = x->data<T>();
auto output_data = out->mutable_data<T>(ctx.GetPlace());
mid->mutable_data<T>(ctx.GetPlace());
const int n = ctx.Attr<int>("n"); const int n = ctx.Attr<int>("n");
// MKL-DNN implements LRN in a caffe way: // MKL-DNN implements LRN in a caffe way:
// http://caffe.berkeleyvision.org/tutorial/layers/lrn.html // http://caffe.berkeleyvision.org/tutorial/layers/lrn.html
...@@ -52,31 +47,32 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -52,31 +47,32 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n); const float alpha = ctx.Attr<float>("alpha") * static_cast<float>(n);
const float beta = ctx.Attr<float>("beta"); const float beta = ctx.Attr<float>("beta");
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
bool is_test = ctx.Attr<bool>("is_test");
auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
auto dims = paddle::framework::vectorize<int>(x->dims()); auto dims = paddle::framework::vectorize<int>(x->dims());
// Format and dims are assumed to be the same for dst and src platform::LRNMKLDNNHandler<T> handler(dims, n, alpha, beta, k, x->format(),
auto md = paddle::platform::MKLDNNMemDesc( is_test, dev_ctx, ctx.GetPlace(),
dims, platform::MKLDNNGetDataType<T>(), x->format()); ctx.op().Output("Out"));
const std::string key = platform::CreateKey( auto src_memory = handler.AcquireSrcMemory(x);
dims, n, alpha, beta, k, x->format(), ctx.op().Output("Out")); auto dst_memory = handler.AcquireDstMemory(out);
platform::LRNMKLDNNHandler handler(ctx.Attr<bool>("is_test"), dev_ctx, std::shared_ptr<mkldnn::memory> workspace_memory;
mkldnn_engine, key); std::shared_ptr<mkldnn::lrn_forward> lrn_p;
auto src_memory = if (is_test == false) {
handler.AcquireSrcMemory(md, platform::to_void_cast<T>(input_data)); workspace_memory = handler.AcquireWorkspaceMemory(mid);
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *workspace_memory,
// TODO(jczaja): Hide getting PD inside of handler for all Acquire API *dst_memory);
handler.AcquireLRNPrimitiveDescriptor(md, n, alpha, beta, k); } else {
// mid has to be allocated and filled
auto dst_memory = // k to pass LRN unit tests
handler.AcquireDstMemory(md, platform::to_void_cast<T>(output_data)); // TODO(jczaja): Disable checking mid in unit tests (Require API change)
mid->mutable_data<T>(ctx.GetPlace());
auto lrn_p = handler.AcquireLRN(dst_memory, src_memory); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k);
lrn_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory);
}
std::vector<mkldnn::primitive> pipeline = {*lrn_p}; std::vector<mkldnn::primitive> pipeline = {*lrn_p};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
...@@ -104,6 +100,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -104,6 +100,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
"is_test attribute should be set to False in training phase."); "is_test attribute should be set to False in training phase.");
auto x = ctx.Input<Tensor>("X"); auto x = ctx.Input<Tensor>("X");
auto mid = ctx.Input<Tensor>("MidOut");
auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out")); auto out_grad = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X")); auto x_grad = ctx.Output<Tensor>(framework::GradVarName("X"));
...@@ -114,42 +111,20 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -114,42 +111,20 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const float k = ctx.Attr<float>("k"); const float k = ctx.Attr<float>("k");
auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>(); auto& dev_ctx = ctx.template device_context<MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
auto x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto out_grad_data = out_grad->data<T>();
auto dims = paddle::framework::vectorize<int>(x->dims()); auto dims = paddle::framework::vectorize<int>(x->dims());
const std::string key = platform::CreateKey( platform::LRNMKLDNNHandler<T> handler(
dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out")); dims, n, alpha, beta, k, x->format(), out_grad->format(), dev_ctx,
ctx.GetPlace(), ctx.op().Input("Out"));
platform::LRNMKLDNNHandler handler(false, dev_ctx, mkldnn_engine, key);
auto src_md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), x->format());
// diff_dst and diff_src layouts are assumed to be the same
auto diff_md = paddle::platform::MKLDNNMemDesc(
dims, platform::MKLDNNGetDataType<T>(), out_grad->format());
auto workspace = handler.AcquireWorkspaceMemory();
auto diff_dst_memory = handler.AcquireDiffDstMemory(
diff_md, platform::to_void_cast<T>(out_grad_data));
auto diff_src_memory = handler.AcquireDiffSrcMemory(
diff_md, platform::to_void_cast<T>(x_grad_data));
auto src_memory = handler.AcquireSrcMemory(
src_md, platform::to_void_cast<T>(x->data<T>()));
// TODO(jczaja): Hide this call inside Handler auto src_memory = handler.AcquireSrcMemory(x);
handler.AcquireLRNBackwardPrimitiveDescriptor(src_md, diff_md, n, alpha, auto workspace = handler.AcquireBackwardWorkspaceMemory(mid);
beta, k); auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
auto diff_src_memory = handler.AcquireDiffSrcMemory(x_grad);
auto lrn_bwd = handler.AcquireLRNBackward(src_memory, diff_dst_memory, auto lrn_bwd = handler.AcquireBackwardPrimitive(
workspace, diff_src_memory); *src_memory, *diff_dst_memory, *workspace, *diff_src_memory);
std::vector<mkldnn::primitive> pipeline = {*lrn_bwd}; std::vector<mkldnn::primitive> pipeline = {*lrn_bwd};
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
......
...@@ -460,141 +460,64 @@ class ActivationMKLDNNHandler ...@@ -460,141 +460,64 @@ class ActivationMKLDNNHandler
} }
}; };
class LRNMKLDNNHandler : public MKLDNNHandler { template <typename T>
class LRNMKLDNNHandler
: public MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward> {
public: public:
LRNMKLDNNHandler(bool is_test, const platform::MKLDNNDeviceContext& dev_ctx, LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha,
mkldnn::engine engine, const std::string& base_key) const float beta, const float k,
: platform::MKLDNNHandler(dev_ctx, engine, base_key), is_test_(is_test) {} const MKLDNNMemoryFormat fmt, bool is_test,
const platform::MKLDNNDeviceContext& dev_ctx,
platform::Place cpu_place, const std::string& unique_name)
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
AcquireLRNPrimitiveDescriptor(const mkldnn::memory::desc& src_md, const int n, dev_ctx, dev_ctx.GetEngine(), cpu_place,
const float alpha, const float beta, platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) {
const float k) { auto src_md =
// LRN PD has to be passed to Grad op that mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
// may be executed by diffrent thread, hence this->AcquireForwardPrimitiveDescriptor(
// for that one we use key that does not contain TID is_test ? mkldnn::prop_kind::forward_inference
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
if (fwd_pd_ == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
if (fwd_pd_ == nullptr) {
auto forward_desc = mkldnn::lrn_forward::desc{
is_test_ ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training, : mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md, n, alpha, beta, k}; mkldnn::lrn_across_channels, src_md, n, alpha, beta, k);
fwd_pd_.reset(
new mkldnn::lrn_forward::primitive_desc(forward_desc, engine_));
dev_ctx_.SetBlob(key_lrn_pd, fwd_pd_);
}
}
return fwd_pd_;
}
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) {
// workspace has to be passed to Grad op that
// may be executed by diffrent thread, hence
// for that one we use key that does not contain TID
auto local_key = key_common_ + "@workspace";
auto mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier);
mem_p =
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) {
const std::string key_lrn_pd = key_common_ + "@lrn_pd";
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr,
"LRN PD MKL-DNN not found in cache!");
mkldnn::memory::primitive_desc workspace_mpd =
fwd_pd_->workspace_primitive_desc();
mem_p = std::make_shared<mkldnn::memory>(workspace_mpd);
dev_ctx_.SetBlob(local_key, mem_p);
}
} }
return mem_p;
}
std::shared_ptr<mkldnn::lrn_forward> AcquireLRN(
std::shared_ptr<mkldnn::memory> dst_memory,
std::shared_ptr<mkldnn::memory> src_memory) {
auto prim_key = key_ + "@lrn_p";
auto lrn_p = std::static_pointer_cast<mkldnn::lrn_forward>( LRNMKLDNNHandler(const std::vector<int>& dims, const int n, const float alpha,
dev_ctx_.GetBlob(prim_key)); const float beta, const float k,
if (lrn_p == nullptr) { const MKLDNNMemoryFormat fmt,
if (is_test_) { const MKLDNNMemoryFormat diff_fmt,
lrn_p = std::make_shared<mkldnn::lrn_forward>(*fwd_pd_, *(src_memory), const platform::MKLDNNDeviceContext& dev_ctx,
*(dst_memory)); platform::Place cpu_place, const std::string& unique_name)
} else {
// For training we need to create workspace
// to store indices from backward
auto workspace_memory = this->AcquireWorkspaceMemory();
lrn_p = std::make_shared<mkldnn::lrn_forward>(
*fwd_pd_, *src_memory, *workspace_memory, *dst_memory);
}
dev_ctx_.SetBlob(prim_key, lrn_p);
}
return lrn_p;
}
std::shared_ptr<mkldnn::lrn_backward::primitive_desc> : platform::MKLDNNHandlerT<T, mkldnn::lrn_forward, mkldnn::lrn_backward>(
AcquireLRNBackwardPrimitiveDescriptor(const mkldnn::memory::desc& src_md, dev_ctx, dev_ctx.GetEngine(), cpu_place,
const mkldnn::memory::desc& diff_md, platform::CreateKey(dims, n, alpha, beta, k, fmt, unique_name)) {
const int n, const float alpha, auto src_md =
const float beta, const float k) { mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), fmt);
const std::string key_lrn_pd = key_common_ + "@lrn_pd"; auto diff_md =
const std::string key_lrn_bwd_pd = key_ + "@lrn_bwd_pd"; mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
bwd_pd_ = std::static_pointer_cast<mkldnn::lrn_backward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ = std::static_pointer_cast<mkldnn::lrn_forward::primitive_desc>(
dev_ctx_.GetBlob(key_lrn_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr, "LRN MKL-DNN not found in cache!");
auto backward_desc = mkldnn::lrn_backward::desc{ this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k}; mkldnn::lrn_across_channels, src_md,
bwd_pd_.reset(new mkldnn::lrn_backward::primitive_desc( n, alpha, beta, k);
backward_desc, engine_, *fwd_pd_)); this->AcquireBackwardPrimitiveDescriptor(
dev_ctx_.SetBlob(key_lrn_bwd_pd, bwd_pd_); mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k);
}
return bwd_pd_;
} }
std::shared_ptr<mkldnn::lrn_backward> AcquireLRNBackward( std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(
std::shared_ptr<mkldnn::memory> src_memory, framework::Tensor* workspace) {
std::shared_ptr<mkldnn::memory> diff_dst_memory, T* ptr = workspace->mutable_data<T>(
std::shared_ptr<mkldnn::memory> workspace, this->place_, this->fwd_pd_->dst_primitive_desc().get_size());
std::shared_ptr<mkldnn::memory> diff_src_memory) { return this->AcquireMemoryFromPrimitive(
auto prim_key = key_ + "@lrn_bwd_p"; this->fwd_pd_->workspace_primitive_desc(), ptr, "@wrk_mem_p");
auto lrn_bwd_p = std::static_pointer_cast<mkldnn::lrn_backward>(
dev_ctx_.GetBlob(prim_key));
if (lrn_bwd_p == nullptr) {
lrn_bwd_p = std::make_shared<mkldnn::lrn_backward>(
*bwd_pd_, *src_memory, *diff_dst_memory, *workspace,
*diff_src_memory);
dev_ctx_.SetBlob(prim_key, lrn_bwd_p);
} }
return lrn_bwd_p; std::shared_ptr<mkldnn::memory> AcquireBackwardWorkspaceMemory(
const framework::Tensor* workspace) {
const T* workspace_data = workspace->data<T>();
return this->AcquireMemoryFromPrimitive(
this->fwd_pd_->workspace_primitive_desc(),
to_void_cast<T>(workspace_data), "@bwd-wrk_mem_p");
} }
private:
bool is_test_;
std::shared_ptr<mkldnn::lrn_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::lrn_backward::primitive_desc> bwd_pd_;
}; };
class PoolingMKLDNNHandler : public MKLDNNHandler { class PoolingMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册