提交 f8ecc3de 编写于 作者: L lidanqing 提交者: Tao Luo

refactor the function ConvFwdPrimitiveDesc (#17897)

* refractor the function ConvFwdPrimitiveDesc
test=develop

* change according to review
test=develop

* use pointer way without boost::optional
test=develop

* pass vector to function by reference instead of raw vector
test=develop

* change pointer to shared_ptr
test=develop
上级 8462e2b8
...@@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_conv_pd = key + "@conv_pd"; const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false; bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr; std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p = nullptr; std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr; std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p = nullptr; std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd = std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
nullptr; std::shared_ptr<platform::ConvMKLDNNHandler> handler;
std::shared_ptr<platform::ConvMKLDNNHandler> handler = nullptr;
auto prim_key = key + "@conv_p"; auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p"; auto dst_key = key + "@dst_mem_p";
...@@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause // TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when // mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled // v0.20 is enabled
std::shared_ptr<memory::desc> bias_md_p;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, bias_md_p = std::make_shared<memory::desc>(platform::MKLDNNMemDesc(
memory::format::x); bias_tz, memory::data_type::s32, memory::format::x));
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
output_shift_scale, sum_scale, is_test);
} else {
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn,
false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale,
sum_scale, is_test);
} }
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md_p, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
output_shift_scale, sum_scale, is_test);
// Save conv_pd/src_memory/weights_memory for backward pass // Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd); dev_ctx.SetBlob(key_conv_pd, conv_pd);
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx, handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
...@@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private: private:
mkldnn::primitive_attr CreatePostOps( mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn, bool fuse_relu, bool fuse_residual_conn,
const std::vector<float> output_shift_scale, float sum_scale, const std::vector<float>& output_shift_scale, float sum_scale,
bool fuse_brelu, float fuse_brelu_threshold) const { bool fuse_brelu, float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr; mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations; mkldnn::post_ops post_operations;
...@@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::unique_ptr<mkldnn::convolution_forward::primitive_desc> std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights, ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const std::shared_ptr<memory::desc> bias_md_p,
const memory::desc& dst, const std::vector<int>& strides, const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu, const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_brelu, const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold, const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale, const std::vector<float>& output_shift_scale,
const float sum_scale, bool is_test) const { const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]}; memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]}; memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
auto conv_desc =
auto conv_desc = mkldnn::convolution_forward::desc( (bias_md_p != nullptr)
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims, ? mkldnn::convolution_forward::desc(
padding_dims, padding_dims, mkldnn::padding_kind::zero); propagation, mkldnn::convolution_direct, src, weights,
mkldnn::primitive_attr conv_attr = (*bias_md_p), dst, stride_dims, padding_dims, padding_dims,
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, mkldnn::padding_kind::zero)
sum_scale, fuse_brelu, fuse_brelu_threshold); : mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, dst,
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc( stride_dims, padding_dims, padding_dims,
conv_desc, conv_attr, engine); mkldnn::padding_kind::zero);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale,
const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr = mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale, CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册