提交 5b07ca9c 编写于 作者: J Jacek Czaja 提交者: Tao Luo

- ReImplemented pooling fwd mkldnn (#19911)

- First implementation of BWD and FWD of pooling mkl-dnn

- Compilation fix

- Fix

- Fix

 - Fix

- Fix to crash

- Compilation fix

- Combined AcquireBacward with Fwd

test=develop
上级 790d5226
...@@ -37,7 +37,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -37,7 +37,6 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
"It must use CPUPlace."); "It must use CPUPlace.");
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const auto& mkldnn_engine = dev_ctx.GetEngine();
const Tensor* input = ctx.Input<Tensor>("X"); const Tensor* input = ctx.Input<Tensor>("X");
Tensor* output = ctx.Output<Tensor>("Out"); Tensor* output = ctx.Output<Tensor>("Out");
...@@ -66,52 +65,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -66,52 +65,37 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
PADDLE_ENFORCE(input->dims().size() == 4, PADDLE_ENFORCE(input->dims().size() == 4,
"Input dim must be with 4, i.e. NCHW"); "Input dim must be with 4, i.e. NCHW");
const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize<int>(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
auto dst_tz = paddle::framework::vectorize<int>(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
auto input_format = input->format(); auto is_test = ctx.Attr<bool>("is_test");
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
platform::PoolingMKLDNNHandler<T> handler(
mkldnn::memory::data_type dt = src_tz, dst_tz, ksize, strides, paddings, pooling_type,
paddle::framework::ToMKLDNNDataType(input->type()); ctx.Attr<bool>("ceil_mode"), input->format(),
auto fmt = input->format(); paddle::framework::ToMKLDNNDataType(input->type()), is_test, dev_ctx,
ctx.GetPlace(), ctx.op().Output("Out"));
const std::string key =
platform::CreateKey(src_tz, pooling_type, ksize, strides, paddings, dt, auto src_memory = handler.AcquireSrcMemory(input);
fmt, ctx.op().Output("Out")); auto dst_memory = handler.AcquireDstMemory(output);
platform::PoolingMKLDNNHandler handler(pooling_type, dt, std::shared_ptr<mkldnn::pooling_forward> pool_p;
ctx.Attr<bool>("is_test"), dev_ctx, std::shared_ptr<mkldnn::memory> workspace_memory;
mkldnn_engine, key); if ((is_test == false) && (pooling_type == "max")) {
// Training
auto src_md = platform::MKLDNNMemDesc(src_tz, dt, input_format); workspace_memory = handler.AcquireWorkspaceMemory();
pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory,
auto src_memory = *workspace_memory);
handler.AcquireSrcMemory(src_md, to_void_cast<T>(input_data)); } else {
// Inference
/* create memory descriptor for pooling without specified format pool_p = handler.AcquireForwardPrimitive(*src_memory, *dst_memory);
* ('any') which lets a primitive (pooling in this case) choose }
* the memory format preferred for best performance
*/
auto dst_md = platform::MKLDNNMemDesc(dst_tz, dt, MKLDNNMemoryFormat::any);
auto pooling_pd = handler.AcquirePoolingPrimitiveDescriptor(
src_tz, dst_tz, src_md, dst_md, ksize, strides, paddings,
ctx.Attr<bool>("ceil_mode"));
auto dst_memory =
handler.AcquireDstMemoryFromPrimitive(to_void_cast<T>(output_data));
auto pool_p = handler.AcquirePooling(dst_memory, src_memory);
// push primitive to stream and wait until it's executed // push primitive to stream and wait until it's executed
std::vector<mkldnn::primitive> pipeline{*pool_p}; std::vector<mkldnn::primitive> pipeline{*pool_p};
stream(stream::kind::eager).submit(pipeline).wait(); stream(stream::kind::eager).submit(pipeline).wait();
output_format = auto output_format =
(MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format; (MKLDNNMemoryFormat)dst_memory->get_primitive_desc().desc().data.format;
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
...@@ -158,14 +142,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -158,14 +142,9 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<platform::MKLDNNDeviceContext>(); ctx.template device_context<platform::MKLDNNDeviceContext>();
const mkldnn::engine& mkldnn_engine = dev_ctx.GetEngine();
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
const T* out_grad_data = out_grad->data<T>();
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef};
auto diff_src_tz = paddle::framework::vectorize<int>(in_x_grad->dims()); auto diff_src_tz = paddle::framework::vectorize<int>(in_x_grad->dims());
auto diff_dst_tz = paddle::framework::vectorize<int>(out_grad->dims()); auto diff_dst_tz = paddle::framework::vectorize<int>(out_grad->dims());
...@@ -175,36 +154,35 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -175,36 +154,35 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
diff_src_tz, pooling_type, ksize, strides, paddings, diff_src_tz, pooling_type, ksize, strides, paddings,
memory::data_type::f32, in_x->format(), ctx.op().Input("Out")); memory::data_type::f32, in_x->format(), ctx.op().Input("Out"));
platform::PoolingMKLDNNHandler handler( platform::PoolingMKLDNNHandler<T> handler(
pooling_type, paddle::framework::ToMKLDNNDataType(in_x_grad->type()), diff_dst_tz, diff_src_tz, ksize, strides, paddings, pooling_type,
false, dev_ctx, mkldnn_engine, key); ctx.Attr<bool>("ceil_mode"), in_x->format(), out_grad->format(),
paddle::framework::ToMKLDNNDataType(out_grad->type()), dev_ctx,
auto workspace = handler.AcquireWorkspaceMemory(); ctx.GetPlace(), ctx.op().Input("Out"));
auto diff_dst_md = platform::MKLDNNMemDesc( auto diff_dst_memory = handler.AcquireDiffDstMemory(out_grad);
{diff_dst_tz}, platform::MKLDNNGetDataType<T>(), out_grad->format()); auto diff_src_memory = handler.AcquireDiffSrcMemory(in_x_grad);
auto diff_dst_memory = handler.AcquireDiffDstMemory( std::shared_ptr<mkldnn::pooling_backward> pool_bwd_p;
diff_dst_md, to_void_cast<T>(out_grad_data)); std::shared_ptr<mkldnn::memory> workspace_memory;
if (pooling_type == "max") {
auto diff_src_md = platform::MKLDNNMemDesc( // Max - pooling needs Workspace
diff_src_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::any); workspace_memory = handler.AcquireWorkspaceMemory();
pool_bwd_p = handler.AcquireBackwardPrimitive(
auto bwd_pd = handler.AcquirePoolingBackwardPrimitiveDescriptor( *diff_dst_memory, *workspace_memory, *diff_src_memory);
diff_dst_md, diff_src_md, ksize, strides, paddings); } else {
// Average Pooling
auto diff_src_memory = handler.AcquireDiffSrcMemoryFromPrimitive( pool_bwd_p =
reinterpret_cast<void*>(in_x_grad_data)); handler.AcquireBackwardPrimitive(*diff_dst_memory, *diff_src_memory);
}
auto pool_bwd_p = handler.AcquirePoolingBackward(diff_dst_memory, workspace,
diff_src_memory);
pipeline.push_back(*pool_bwd_p); pipeline.push_back(*pool_bwd_p);
mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait(); mkldnn::stream(mkldnn::stream::kind::eager).submit(pipeline).wait();
in_x_grad_format = (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc() auto in_x_grad_format =
.desc() (MKLDNNMemoryFormat)diff_src_memory->get_primitive_desc()
.data.format; .desc()
.data.format;
in_x_grad->set_layout(DataLayout::kMKLDNN); in_x_grad->set_layout(DataLayout::kMKLDNN);
in_x_grad->set_format(in_x_grad_format); in_x_grad->set_format(in_x_grad_format);
} // Compute() } // Compute()
......
...@@ -66,8 +66,6 @@ class SoftmaxMKLDNNHandler ...@@ -66,8 +66,6 @@ class SoftmaxMKLDNNHandler
auto diff_softmax_md = auto diff_softmax_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireForwardPrimitiveDescriptor(prop_kind::forward_scoring,
data_softmax_md, axis);
this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md, this->AcquireBackwardPrimitiveDescriptor(diff_softmax_md, data_softmax_md,
axis); axis);
} }
......
...@@ -140,6 +140,9 @@ class MKLDNNHandlerT { ...@@ -140,6 +140,9 @@ class MKLDNNHandlerT {
template <typename... Args> template <typename... Args>
void AcquireBackwardPrimitiveDescriptor(Args&&... args) { void AcquireBackwardPrimitiveDescriptor(Args&&... args) {
const std::string key_fwd_pd = key_common_ + "@forward_pd";
fwd_pd_ = std::static_pointer_cast<typename TForward::primitive_desc>(
dev_ctx_.GetBlob(key_fwd_pd));
PADDLE_ENFORCE_NOT_NULL(fwd_pd_); PADDLE_ENFORCE_NOT_NULL(fwd_pd_);
const std::string key_pd = key_ + "@backward_pd"; const std::string key_pd = key_ + "@backward_pd";
bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>( bwd_pd_ = std::static_pointer_cast<typename TBackward::primitive_desc>(
...@@ -445,8 +448,6 @@ class ActivationMKLDNNHandler ...@@ -445,8 +448,6 @@ class ActivationMKLDNNHandler
auto src_md = auto src_md =
platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt); platform::MKLDNNMemDesc(dims, platform::MKLDNNGetDataType<T>(), fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
algorithm, src_md, alpha, beta);
this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md, this->AcquireBackwardPrimitiveDescriptor(algorithm, diff_dst_md, src_md,
alpha, beta); alpha, beta);
} }
...@@ -496,9 +497,6 @@ class LRNMKLDNNHandler ...@@ -496,9 +497,6 @@ class LRNMKLDNNHandler
auto diff_md = auto diff_md =
mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt); mkldnn::memory::desc(dims, platform::MKLDNNGetDataType<T>(), diff_fmt);
this->AcquireForwardPrimitiveDescriptor(mkldnn::prop_kind::forward_training,
mkldnn::lrn_across_channels, src_md,
n, alpha, beta, k);
this->AcquireBackwardPrimitiveDescriptor( this->AcquireBackwardPrimitiveDescriptor(
mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k); mkldnn::lrn_across_channels, src_md, diff_md, n, alpha, beta, k);
} }
...@@ -520,177 +518,97 @@ class LRNMKLDNNHandler ...@@ -520,177 +518,97 @@ class LRNMKLDNNHandler
} }
}; };
class PoolingMKLDNNHandler : public MKLDNNHandler { template <typename T>
class PoolingMKLDNNHandler : public MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward> {
public: public:
PoolingMKLDNNHandler(const std::string& pooling_type, PoolingMKLDNNHandler(
mkldnn::memory::data_type dt, bool is_test, const std::vector<int>& src_dims, const std::vector<int>& dst_dims,
const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
dt_(dt),
pooling_type_(pooling_type),
is_test_(is_test) {}
std::shared_ptr<mkldnn::pooling_forward::primitive_desc>
AcquirePoolingPrimitiveDescriptor(
const std::vector<int>& src_tz, const std::vector<int>& dst_tz,
const mkldnn::memory::desc& src_md, const mkldnn::memory::desc& dst_md,
const std::vector<int>& ksize, const std::vector<int>& strides, const std::vector<int>& ksize, const std::vector<int>& strides,
const std::vector<int>& paddings, bool ceil_mode) { const std::vector<int>& paddings, const std::string& pooling_type,
// Pooling PD has to be passed to Grad op that bool ceil_mode, const MKLDNNMemoryFormat fmt,
// may be executed by diffrent thread, hence mkldnn::memory::data_type dt, bool is_test,
// for that one we use key that does not contain TID const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place,
const std::string key_pooling_pd = key_common_ + "@pooling_pd"; const std::string& unique_name)
fwd_pd_ = std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>( : platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
dev_ctx_.GetBlob(key_pooling_pd)); mkldnn::pooling_backward>(
if (fwd_pd_ == nullptr) { dev_ctx, dev_ctx.GetEngine(), cpu_place,
static std::mutex acquire_barrier; platform::CreateKey(src_dims, pooling_type, ksize, strides,
std::lock_guard<std::mutex> block_threads_until_finish_this_job( paddings, dt, fmt, unique_name)) {
acquire_barrier); auto src_md = mkldnn::memory::desc(src_dims, dt, fmt);
fwd_pd_ = /* create memory descriptor for pooling without specified format
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>( * ('any') which lets a primitive (pooling in this case) choose
dev_ctx_.GetBlob(key_pooling_pd)); * the memory format preferred for best performance
if (fwd_pd_ == nullptr) { */
std::vector<int> padding_left_top(paddings); auto dst_md =
std::vector<int> padding_right_bottom(paddings); platform::MKLDNNMemDesc(dst_dims, dt, MKLDNNMemoryFormat::any);
if (ceil_mode) {
CorrectOutputSize(src_tz, dst_tz, ksize, paddings, strides, std::vector<int> padding_left_top(paddings);
padding_right_bottom); std::vector<int> padding_right_bottom(paddings);
} if (ceil_mode) {
auto mkldnn_forward_prop_kind = CorrectOutputSize(src_dims, dst_dims, ksize, paddings, strides,
is_test_ ? mkldnn::prop_kind::forward_inference padding_right_bottom);
: mkldnn::prop_kind::forward_training;
auto pooling_desc = mkldnn::pooling_forward::desc(
mkldnn_forward_prop_kind,
pooling_type_ == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
src_md, dst_md, strides, ksize, padding_left_top,
padding_right_bottom, mkldnn::padding_kind::zero);
fwd_pd_.reset(
new mkldnn::pooling_forward::primitive_desc(pooling_desc, engine_));
dev_ctx_.SetBlob(key_pooling_pd, fwd_pd_);
}
} }
return fwd_pd_;
}
std::shared_ptr<mkldnn::memory> AcquireDstMemoryFromPrimitive(void* ptr) { this->AcquireForwardPrimitiveDescriptor(
return this->AcquireMemoryFromPrimitive(fwd_pd_->dst_primitive_desc(), ptr, is_test ? mkldnn::prop_kind::forward_inference
"@dst_mem_p"); : mkldnn::prop_kind::forward_training,
pooling_type == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
src_md, dst_md, strides, ksize, padding_left_top, padding_right_bottom,
mkldnn::padding_kind::zero);
}
PoolingMKLDNNHandler(
const std::vector<int>& diff_dst_dims,
const std::vector<int>& diff_src_dims, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& paddings,
const std::string& pooling_type, bool ceil_mode,
const MKLDNNMemoryFormat fmt, const MKLDNNMemoryFormat diff_dst_fmt,
mkldnn::memory::data_type dt,
const platform::MKLDNNDeviceContext& dev_ctx, platform::Place cpu_place,
const std::string& unique_name)
: platform::MKLDNNHandlerT<T, mkldnn::pooling_forward,
mkldnn::pooling_backward>(
dev_ctx, dev_ctx.GetEngine(), cpu_place,
platform::CreateKey(diff_src_dims, pooling_type, ksize, strides,
paddings, dt, fmt, unique_name)) {
auto diff_dst_md = mkldnn::memory::desc(
diff_dst_dims, platform::MKLDNNGetDataType<T>(), diff_dst_fmt);
auto diff_src_md =
mkldnn::memory::desc(diff_src_dims, platform::MKLDNNGetDataType<T>(),
MKLDNNMemoryFormat::any);
this->AcquireBackwardPrimitiveDescriptor(
pooling_type == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
diff_src_md, diff_dst_md, strides, ksize, paddings, paddings,
mkldnn::padding_kind::zero);
} }
std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) { std::shared_ptr<mkldnn::memory> AcquireWorkspaceMemory(void) {
mkldnn::memory::primitive_desc workspace_mpd = mkldnn::memory::primitive_desc workspace_mpd =
pooling_type_ == "max" this->fwd_pd_->workspace_primitive_desc();
? fwd_pd_->workspace_primitive_desc()
: mkldnn::memory::primitive_desc(
{{}, dt_, MKLDNNMemoryFormat::nchw}, engine_);
// Pooling PD has to be passed to Grad op that // Pooling PD has to be passed to Grad op that
// may be executed by diffrent thread, hence // may be executed by diffrent thread, hence
// for that one we use key that does not contain TID // for that one we use key that does not contain TID
auto local_key = key_common_ + "@workspace"; auto local_key = this->key_common_ + "@workspace";
auto mem_p = auto mem_p = std::static_pointer_cast<mkldnn::memory>(
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); this->dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
static std::mutex acquire_barrier; static std::mutex acquire_barrier;
std::lock_guard<std::mutex> block_threads_until_finish_this_job( std::lock_guard<std::mutex> block_threads_until_finish_this_job(
acquire_barrier); acquire_barrier);
mem_p = mem_p = std::static_pointer_cast<mkldnn::memory>(
std::static_pointer_cast<mkldnn::memory>(dev_ctx_.GetBlob(local_key)); this->dev_ctx_.GetBlob(local_key));
if (mem_p == nullptr) { if (mem_p == nullptr) {
mem_p = std::make_shared<mkldnn::memory>(workspace_mpd); mem_p = std::make_shared<mkldnn::memory>(workspace_mpd);
dev_ctx_.SetBlob(local_key, mem_p); this->dev_ctx_.SetBlob(local_key, mem_p);
} }
} }
return mem_p; return mem_p;
} }
std::shared_ptr<mkldnn::pooling_forward> AcquirePooling(
std::shared_ptr<mkldnn::memory> dst_memory,
std::shared_ptr<mkldnn::memory> src_memory) {
auto prim_key = key_ + "@pooling_p";
auto pooling_p = std::static_pointer_cast<mkldnn::pooling_forward>(
dev_ctx_.GetBlob(prim_key));
if (pooling_p == nullptr) {
if (is_test_) {
pooling_p = std::make_shared<mkldnn::pooling_forward>(
*fwd_pd_, *(src_memory), *(dst_memory));
} else {
// For training we need to create workspace
// to store indices from backward
auto workspace_memory = this->AcquireWorkspaceMemory();
pooling_p = std::make_shared<mkldnn::pooling_forward>(
*fwd_pd_, *src_memory, *dst_memory, *workspace_memory);
}
dev_ctx_.SetBlob(prim_key, pooling_p);
}
return pooling_p;
}
std::shared_ptr<mkldnn::pooling_backward::primitive_desc>
AcquirePoolingBackwardPrimitiveDescriptor(
const mkldnn::memory::desc& diff_dst_md,
const mkldnn::memory::desc& diff_src_md, const std::vector<int>& ksize,
const std::vector<int>& strides, const std::vector<int>& paddings) {
const std::string key_pooling_pd = key_common_ + "@pooling_pd";
const std::string key_pooling_bwd_pd = key_ + "@pooling_bwd_pd";
bwd_pd_ =
std::static_pointer_cast<mkldnn::pooling_backward::primitive_desc>(
dev_ctx_.GetBlob(key_pooling_bwd_pd));
if (bwd_pd_ == nullptr) {
fwd_pd_ =
std::static_pointer_cast<mkldnn::pooling_forward::primitive_desc>(
dev_ctx_.GetBlob(key_pooling_pd));
// PD from FWD op has to exist.
PADDLE_ENFORCE(fwd_pd_ != nullptr, "Pooling MKL-DNN not found in cache!");
auto backward_desc = mkldnn::pooling_backward::desc(
pooling_type_ == "max" ? mkldnn::algorithm::pooling_max
: mkldnn::algorithm::pooling_avg,
diff_src_md, diff_dst_md, strides, ksize, paddings, paddings,
mkldnn::padding_kind::zero);
bwd_pd_.reset(new mkldnn::pooling_backward::primitive_desc(
backward_desc, engine_, *fwd_pd_));
dev_ctx_.SetBlob(key_pooling_bwd_pd, bwd_pd_);
}
return bwd_pd_;
}
std::shared_ptr<mkldnn::memory> AcquireDiffDstMemoryFromDataPrimitive(
const std::shared_ptr<mkldnn::memory> user_memory_p,
std::vector<mkldnn::primitive>& pipeline) { // NOLINT
auto diff_dst_pd = bwd_pd_->diff_dst_primitive_desc();
auto user_pd = user_memory_p->get_primitive_desc();
return this->AcquireMemory(diff_dst_pd, user_pd, user_memory_p,
"@diff_dst_mem_p", pipeline);
}
std::shared_ptr<mkldnn::memory> AcquireDiffSrcMemoryFromPrimitive(void* ptr) {
return this->AcquireMemoryFromPrimitive(bwd_pd_->diff_src_primitive_desc(),
ptr, "@diff_src_mem_p");
}
std::shared_ptr<mkldnn::pooling_backward> AcquirePoolingBackward(
std::shared_ptr<mkldnn::memory> diff_dst_memory,
std::shared_ptr<mkldnn::memory> workspace,
std::shared_ptr<mkldnn::memory> diff_src_memory) {
auto prim_key = key_ + "@pooling_bwd_p";
auto pooling_bwd_p = std::static_pointer_cast<mkldnn::pooling_backward>(
dev_ctx_.GetBlob(prim_key));
if (pooling_bwd_p == nullptr) {
pooling_bwd_p = std::make_shared<mkldnn::pooling_backward>(
*bwd_pd_, *diff_dst_memory, *workspace, *diff_src_memory);
dev_ctx_.SetBlob(prim_key, pooling_bwd_p);
}
return pooling_bwd_p;
}
private: private:
static inline int ComputeCeiledOutput(int input_size, int kernel_size, static inline int ComputeCeiledOutput(int input_size, int kernel_size,
int padding, int stride) { int padding, int stride) {
...@@ -710,13 +628,6 @@ class PoolingMKLDNNHandler : public MKLDNNHandler { ...@@ -710,13 +628,6 @@ class PoolingMKLDNNHandler : public MKLDNNHandler {
} }
} }
} }
private:
mkldnn::memory::data_type dt_;
std::string pooling_type_;
bool is_test_;
std::shared_ptr<mkldnn::pooling_forward::primitive_desc> fwd_pd_;
std::shared_ptr<mkldnn::pooling_backward::primitive_desc> bwd_pd_;
}; };
class TransposeMKLDNNHandler : public MKLDNNHandler { class TransposeMKLDNNHandler : public MKLDNNHandler {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册