提交 87a44b11 编写于 作者: J Jacek Czaja 提交者: tensor-tang

[MKL-DNN] Added reusing of primitive descriptors (fp32) (#16667)

* - Reuse of conv PD

- conv transpose pd reused

- Added PD reusing of softmax and Batch Norm

- Refactoring and removal of not needed routines of mkl-dnn ops

test=develop

- Fix to reusing conv

test=develop

- Lint fixes

test=develop

- Further lint fixes

test=develop

- Lint  fixes

test=develop

- lint fixes

test=develop

- Lint workaround

test=develop

* - Fix after review on including boost as third party header

test=develop

* - Fix after review. Name change to something more descriptive

test=develop
上级 072db093
...@@ -39,13 +39,9 @@ struct bn_type_traits { ...@@ -39,13 +39,9 @@ struct bn_type_traits {
class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
public: public:
BatchNormMKLDNNHandler( BatchNormMKLDNNHandler(const platform::MKLDNNDeviceContext &dev_ctx,
std::shared_ptr<batch_norm_fwd::primitive_desc> batch_norm_pd, mkldnn::engine engine, const std::string &base_key)
const platform::MKLDNNDeviceContext &dev_ctx, mkldnn::engine engine, : platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
const std::string &base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {
batch_norm_pd_ = batch_norm_pd;
}
std::shared_ptr<memory> AcquireScaleshiftMemoryFromPrimitive(void *ptr) { std::shared_ptr<memory> AcquireScaleshiftMemoryFromPrimitive(void *ptr) {
return this->AcquireMemoryFromPrimitive( return this->AcquireMemoryFromPrimitive(
...@@ -62,6 +58,26 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -62,6 +58,26 @@ class BatchNormMKLDNNHandler : public platform::MKLDNNHandler {
batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p"); batch_norm_pd_->variance_primitive_desc(), ptr, "@variance_mem_p");
} }
std::shared_ptr<batch_norm_fwd::primitive_desc>
AcquireBatchNormPrimitiveDescriptor(const batch_norm_fwd::desc &bn_fwd_desc,
const mkldnn::engine &engine) {
const std::string key_batch_norm_fwd_pd = key_ + "@bn_fwd_pd";
auto batch_norm_pd =
std::static_pointer_cast<batch_norm_fwd::primitive_desc>(
dev_ctx_.GetBlob(key_batch_norm_fwd_pd));
if (batch_norm_pd == nullptr) {
batch_norm_pd_.reset(
new batch_norm_fwd::primitive_desc(bn_fwd_desc, engine));
dev_ctx_.SetBlob(key_batch_norm_fwd_pd, batch_norm_pd_);
} else {
batch_norm_pd_ = batch_norm_pd;
is_reusing_ = true;
}
return batch_norm_pd_;
}
std::shared_ptr<batch_norm_fwd> AcquireTestTrainingBatchNormFwd( std::shared_ptr<batch_norm_fwd> AcquireTestTrainingBatchNormFwd(
std::shared_ptr<memory> src_memory, std::shared_ptr<memory> src_memory,
std::shared_ptr<memory> scaleshift_memory, std::shared_ptr<memory> scaleshift_memory,
...@@ -213,7 +229,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -213,7 +229,7 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = BatchNormMKLDNNHandler::GetHash( const std::string key = BatchNormMKLDNNHandler::GetHash(
src_tz, epsilon, flags, global_stats, input_format, src_tz, epsilon, flags, global_stats, input_format,
ctx.op().Output("SavedMean")); ctx.op().Output("SavedMean"));
const std::string key_batch_norm_fwd_pd = key + "@bn_fwd_pd"; BatchNormMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
auto user_src_md = platform::MKLDNNMemDesc( auto user_src_md = platform::MKLDNNMemDesc(
{src_tz}, platform::MKLDNNGetDataType<T>(), input_format); {src_tz}, platform::MKLDNNGetDataType<T>(), input_format);
...@@ -222,13 +238,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -222,13 +238,9 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>; using bn_fwd_types = bn_type_traits<mkldnn::batch_normalization_forward>;
auto batch_norm_fwd_desc = auto batch_norm_fwd_desc =
bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags}; bn_fwd_types::op_desc{propagation, user_src_md, epsilon, flags};
auto batch_norm_fwd_pd = std::make_shared<batch_norm_fwd::primitive_desc>(
batch_norm_fwd_desc, mkldnn_engine);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_batch_norm_fwd_pd, batch_norm_fwd_pd);
BatchNormMKLDNNHandler handler(batch_norm_fwd_pd, dev_ctx, mkldnn_engine, auto batch_norm_fwd_pd = handler.AcquireBatchNormPrimitiveDescriptor(
key); batch_norm_fwd_desc, mkldnn_engine);
auto src_memory = auto src_memory =
handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data)); handler.AcquireSrcMemory(user_src_md, to_void_cast(x_data));
......
...@@ -144,7 +144,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -144,7 +144,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Input("Input") + ctx.op().Input("Filter")); ctx.op().Input("Input") + ctx.op().Input("Filter"));
const std::string key_conv_pd = key + "@conv_pd";
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
...@@ -183,6 +182,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -183,6 +182,8 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
platform::ConvMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// create a conv primitive descriptor and save it for usage in backward // create a conv primitive descriptor and save it for usage in backward
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd; std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
...@@ -191,18 +192,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -191,18 +192,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc( conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn, fwd_prop_kind); fuse_relu, fuse_residual_conn, fwd_prop_kind);
} else { } else {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
paddings, mkldnn_engine, fuse_relu, src_md, weights_md, boost::none, dst_md, strides, paddings,
fuse_residual_conn, fwd_prop_kind); mkldnn_engine, fuse_relu, fuse_residual_conn, fwd_prop_kind);
} }
// Save conv_pd/src_memory/weights_memory for backward pass
if (!is_test) dev_ctx.SetBlob(key_conv_pd, conv_pd);
platform::ConvMKLDNNHandler handler(conv_pd, dev_ctx, mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = auto user_src_memory_p =
...@@ -633,31 +630,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -633,31 +630,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
} }
private: private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_residual_conn) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn, bool fuse_relu, bool fuse_residual_conn,
const std::vector<float> output_shift_scale, float sum_scale) const { const std::vector<float> output_shift_scale, float sum_scale) const {
...@@ -679,30 +651,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -679,30 +651,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return conv_attr; return conv_attr;
} }
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = strides;
memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc> std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
...@@ -731,31 +679,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -731,31 +679,6 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
p_conv_pd); p_conv_pd);
} }
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) const {
memory::dims stride_dims = strides;
memory::dims padding_dims = paddings;
auto conv_desc = mkldnn::convolution_forward::desc(
fwd_prop_kind, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc> std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst, const memory::desc& bias, const memory::desc& dst,
......
...@@ -12,6 +12,7 @@ ...@@ -12,6 +12,7 @@
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "boost/optional.hpp"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/memory/malloc.h" #include "paddle/fluid/memory/malloc.h"
...@@ -124,7 +125,6 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -124,7 +125,6 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash( const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash(
src_tz, weights_tz, strides, paddings, dilations, groups, src_tz, weights_tz, strides, paddings, dilations, groups,
ctx.op().Output("Output")); ctx.op().Output("Output"));
const std::string key_conv_transpose_pd = key + "@conv_transpose_pd";
std::vector<mkldnn::primitive> pipeline; std::vector<mkldnn::primitive> pipeline;
...@@ -153,6 +153,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -153,6 +153,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto dst_md = platform::MKLDNNMemDesc( auto dst_md = platform::MKLDNNMemDesc(
dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format); dst_tz, platform::MKLDNNGetDataType<T>(), chosen_memory_format);
platform::ConvTransposeMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// create a deconv(conv transpose) primitive descriptor and save it for // create a deconv(conv transpose) primitive descriptor and save it for
// usage in backward // usage in backward
std::shared_ptr<mkldnn::deconvolution_forward::primitive_desc> std::shared_ptr<mkldnn::deconvolution_forward::primitive_desc>
...@@ -163,19 +164,14 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -163,19 +164,14 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x); bias_tz, platform::MKLDNNGetDataType<T>(), mkldnn::memory::format::x);
conv_transpose_pd = ConvTransposeFwdPrimitiveDesc( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, bias_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu, fwd_prop_kind); fuse_relu, false, fwd_prop_kind);
} else { } else {
conv_transpose_pd = ConvTransposeFwdPrimitiveDesc( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine, src_md, weights_md, boost::none, dst_md, strides, paddings,
fuse_relu, fwd_prop_kind); mkldnn_engine, fuse_relu, false, fwd_prop_kind);
} }
// Save conv_pd/src_memory/weights_memory for backward pass
if (!is_test) dev_ctx.SetBlob(key_conv_transpose_pd, conv_transpose_pd);
platform::ConvTransposeMKLDNNHandler handler(conv_transpose_pd, dev_ctx,
mkldnn_engine, key);
// create mkldnn memory from input tensors (data/weights) // create mkldnn memory from input tensors (data/weights)
auto user_src_memory_p = handler.AcquireSrcMemory( auto user_src_memory_p = handler.AcquireSrcMemory(
...@@ -224,70 +220,6 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -224,70 +220,6 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
output->set_layout(DataLayout::kMKLDNN); output->set_layout(DataLayout::kMKLDNN);
output->set_format(platform::GetMKLDNNFormat(*dst_memory_p)); output->set_format(platform::GetMKLDNNFormat(*dst_memory_p));
} }
private:
mkldnn::primitive_attr CreatePostOps(bool fuse_relu) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>
ConvTransposeFwdPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, mkldnn::prop_kind fwd_prop_kind) const {
mkldnn::memory::dims stride_dims = {strides[0], strides[1]};
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]};
auto deconv_desc = mkldnn::deconvolution_forward::desc(
fwd_prop_kind, mkldnn::deconvolution_direct, src, weights, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr deconv_attr = CreatePostOps(fuse_relu);
auto p_conv_transpose_pd =
new mkldnn::deconvolution_forward::primitive_desc(deconv_desc,
deconv_attr, engine);
return std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>(
p_conv_transpose_pd);
}
std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>
ConvTransposeFwdPrimitiveDesc(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
const mkldnn::memory::desc& bias, const mkldnn::memory::desc& dst,
const std::vector<int>& strides, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
mkldnn::prop_kind fwd_prop_kind) const {
mkldnn::memory::dims stride_dims = {strides[0], strides[1]};
mkldnn::memory::dims padding_dims = {paddings[0], paddings[1]};
auto deconv_desc = mkldnn::deconvolution_forward::desc(
fwd_prop_kind, mkldnn::deconvolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr deconv_attr = CreatePostOps(fuse_relu);
auto p_conv_transpose_pd =
new mkldnn::deconvolution_forward::primitive_desc(deconv_desc,
deconv_attr, engine);
return std::unique_ptr<mkldnn::deconvolution_forward::primitive_desc>(
p_conv_transpose_pd);
}
}; };
} // namespace operators } // namespace operators
......
...@@ -34,12 +34,9 @@ using platform::to_void_cast; ...@@ -34,12 +34,9 @@ using platform::to_void_cast;
class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
public: public:
SoftmaxMKLDNNHandler( SoftmaxMKLDNNHandler(const platform::MKLDNNDeviceContext& dev_ctx,
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd, mkldnn::engine engine, const std::string& base_key)
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, : platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key),
softmax_pd_(softmax_pd) {}
SoftmaxMKLDNNHandler( SoftmaxMKLDNNHandler(
std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd, std::shared_ptr<mkldnn::softmax_forward::primitive_desc> softmax_pd,
...@@ -54,6 +51,26 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler { ...@@ -54,6 +51,26 @@ class SoftmaxMKLDNNHandler : public platform::MKLDNNHandler {
key_ += "-BWD"; key_ += "-BWD";
} }
std::shared_ptr<softmax_forward::primitive_desc>
AcquireSoftmaxPrimitiveDescriptor(const softmax_forward::desc& softmax_desc,
const mkldnn::engine& engine) {
const std::string key_softmax_pd = key_ + "@softmax_pd";
auto softmax_pd = std::static_pointer_cast<softmax_forward::primitive_desc>(
dev_ctx_.GetBlob(key_softmax_pd));
if (softmax_pd == nullptr) {
softmax_pd_.reset(
new softmax_forward::primitive_desc(softmax_desc, engine));
dev_ctx_.SetBlob(key_softmax_pd, softmax_pd_);
} else {
softmax_pd_ = softmax_pd;
is_reusing_ = true;
}
return softmax_pd_;
}
std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax( std::shared_ptr<mkldnn::softmax_forward> AcquireSoftmax(
std::shared_ptr<mkldnn::memory> dst_memory_p, std::shared_ptr<mkldnn::memory> dst_memory_p,
std::shared_ptr<mkldnn::memory> src_memory_p) { std::shared_ptr<mkldnn::memory> src_memory_p) {
...@@ -138,19 +155,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -138,19 +155,18 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
const std::string key = const std::string key =
platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out")); platform::MKLDNNHandler::GetHash(softmax_tz, ctx.op().Output("Out"));
const std::string key_softmax_pd = key + "@softmax_pd";
SoftmaxMKLDNNHandler handler(dev_ctx, mkldnn_engine, key);
// Currently only NC data format is supported // Currently only NC data format is supported
auto softmax_md = MKLDNNMemDesc( auto softmax_md = MKLDNNMemDesc(
{softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc); {softmax_tz}, platform::MKLDNNGetDataType<T>(), memory::format::nc);
// Normalization is made after innermost dimension eg. C out of NC // Normalization is made after innermost dimension eg. C out of NC
auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring, auto softmax_desc = softmax_forward::desc(prop_kind::forward_scoring,
softmax_md, 1 /*dim: C*/); softmax_md, 1 /*dim: C*/);
auto softmax_pd = std::make_shared<mkldnn::softmax_forward::primitive_desc>(
softmax_desc, mkldnn_engine);
dev_ctx.SetBlob(key_softmax_pd, softmax_pd);
SoftmaxMKLDNNHandler handler(softmax_pd, dev_ctx, mkldnn_engine, key); auto softmax_pd =
handler.AcquireSoftmaxPrimitiveDescriptor(softmax_desc, mkldnn_engine);
auto softmax_src_memory_p = auto softmax_src_memory_p =
handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data)); handler.AcquireSrcMemory(softmax_md, to_void_cast<T>(input_data));
auto softmax_dst_memory_p = auto softmax_dst_memory_p =
......
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "boost/optional.hpp"
#include "paddle/fluid/framework/data_layout_transform.h" #include "paddle/fluid/framework/data_layout_transform.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
...@@ -395,9 +396,28 @@ class TransposeMKLDNNHandler : public MKLDNNHandler { ...@@ -395,9 +396,28 @@ class TransposeMKLDNNHandler : public MKLDNNHandler {
std::vector<int> logical_axis_; std::vector<int> logical_axis_;
}; };
template <typename T>
struct convolutional_algorithm;
template <>
struct convolutional_algorithm<mkldnn::convolution_forward> {
static constexpr mkldnn::algorithm T = mkldnn::algorithm::convolution_direct;
};
template <>
struct convolutional_algorithm<mkldnn::deconvolution_forward> {
static constexpr mkldnn::algorithm T =
mkldnn::algorithm::deconvolution_direct;
};
template <class forward_t, class backward_data_t, class backward_weights_t> template <class forward_t, class backward_data_t, class backward_weights_t>
class ConvMKLDNNTemplateHandler : public MKLDNNHandler { class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
public: public:
ConvMKLDNNTemplateHandler(const platform::MKLDNNDeviceContext& dev_ctx,
mkldnn::engine engine, const std::string& base_key)
: platform::MKLDNNHandler(dev_ctx, engine, base_key) {}
// TODO(jczaja): remove after conv int8 is adapted
ConvMKLDNNTemplateHandler( ConvMKLDNNTemplateHandler(
std::shared_ptr<typename forward_t::primitive_desc> conv_pd, std::shared_ptr<typename forward_t::primitive_desc> conv_pd,
const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine, const platform::MKLDNNDeviceContext& dev_ctx, mkldnn::engine engine,
...@@ -542,6 +562,73 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler { ...@@ -542,6 +562,73 @@ class ConvMKLDNNTemplateHandler : public MKLDNNHandler {
scale_data, mask); scale_data, mask);
} }
mkldnn::primitive_attr CreatePostOps(bool fuse_relu,
bool fuse_residual_conn = false) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
// Fusion with Elementwise layer relies on adding a sum post-operation with
// the scale parameter. It is assumed that when fuse_residual_connection is
// true, the output tensor contains the data coming from residual
// connection. The result of this post_op is:
// Output = scale * Output + Conv_Out.
if (fuse_residual_conn) {
post_operations.append_sum(1.0f);
}
// Fusion with ReLU layer is executed through the PostOps feature. Create a
// PostOps object and configure it to execute an eltwise relu operation.
if (fuse_relu) {
constexpr float scale = 1.0f;
constexpr float negative_slope = 0.0f;
constexpr float placeholder = 0.0f;
post_operations.append_eltwise(scale, mkldnn::algorithm::eltwise_relu,
negative_slope, placeholder);
}
conv_attr.set_post_ops(post_operations);
return conv_attr;
}
std::shared_ptr<typename forward_t::primitive_desc>
AcquireConvolutionPrimitiveDescriptor(
const mkldnn::memory::desc& src, const mkldnn::memory::desc& weights,
boost::optional<const mkldnn::memory::desc&> bias,
const mkldnn::memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const mkldnn::engine& engine,
const bool fuse_relu, const bool fuse_residual_conn,
mkldnn::prop_kind fwd_prop_kind) {
const std::string key_conv_pd = key_ + "@conv_pd";
auto conv_pd = std::static_pointer_cast<typename forward_t::primitive_desc>(
dev_ctx_.GetBlob(key_conv_pd));
if (conv_pd == nullptr) {
mkldnn::memory::dims stride_dims = strides;
mkldnn::memory::dims padding_dims = paddings;
auto conv_desc =
bias ? typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, *bias, dst, stride_dims, padding_dims,
padding_dims, mkldnn::padding_kind::zero)
: typename forward_t::desc(
fwd_prop_kind, convolutional_algorithm<forward_t>::T, src,
weights, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn);
conv_pd_.reset(
new typename forward_t::primitive_desc(conv_desc, conv_attr, engine));
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx_.SetBlob(key_conv_pd, conv_pd_);
} else {
conv_pd_ = conv_pd;
is_reusing_ = true;
}
return conv_pd_;
}
std::shared_ptr<forward_t> AcquireConvolution( std::shared_ptr<forward_t> AcquireConvolution(
std::shared_ptr<mkldnn::memory> src_memory_p, std::shared_ptr<mkldnn::memory> src_memory_p,
std::shared_ptr<mkldnn::memory> weights_memory_p, std::shared_ptr<mkldnn::memory> weights_memory_p,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册