提交 f8ecc3de 编写于 作者: L lidanqing 提交者: Tao Luo

refactor the function ConvFwdPrimitiveDesc (#17897)

* refractor the function ConvFwdPrimitiveDesc
test=develop

* change according to review
test=develop

* use pointer way without boost::optional
test=develop

* pass vector to function by reference instead of raw vector
test=develop

* change pointer to shared_ptr
test=develop
上级 8462e2b8
......@@ -383,14 +383,13 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::string key_conv_pd = key + "@conv_pd";
bool need_s8_to_u8 = false;
std::shared_ptr<mkldnn::convolution_forward> conv_p = nullptr;
std::shared_ptr<mkldnn::memory> src_memory_p = nullptr;
std::shared_ptr<mkldnn::memory> user_src_memory_p = nullptr;
std::shared_ptr<mkldnn::memory> dst_memory_p = nullptr;
std::shared_ptr<mkldnn::convolution_forward> conv_p;
std::shared_ptr<mkldnn::memory> src_memory_p;
std::shared_ptr<mkldnn::memory> user_src_memory_p;
std::shared_ptr<mkldnn::memory> dst_memory_p;
std::vector<primitive> pipeline;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd =
nullptr;
std::shared_ptr<platform::ConvMKLDNNHandler> handler = nullptr;
std::shared_ptr<mkldnn::convolution_forward::primitive_desc> conv_pd;
std::shared_ptr<platform::ConvMKLDNNHandler> handler;
auto prim_key = key + "@conv_p";
auto dst_key = key + "@dst_mem_p";
......@@ -460,24 +459,17 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
// TODO(lidanqing): We use relu post-op instead of brelu post-op cause
// mkldnn v0.18 does not support INT8 brelu post-op. Use code in /**/ when
// v0.20 is enabled
std::shared_ptr<memory::desc> bias_md_p;
if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
output_shift_scale, sum_scale, is_test);
} else {
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, dst_md, strides, paddings, mkldnn_engine,
fuse_relu || fuse_brelu /*fuse_relu*/, fuse_residual_conn,
false /*fuse_brelu*/, fuse_brelu_threshold, output_shift_scale,
sum_scale, is_test);
bias_md_p = std::make_shared<memory::desc>(platform::MKLDNNMemDesc(
bias_tz, memory::data_type::s32, memory::format::x));
}
conv_pd = ConvFwdPrimitiveDesc(
src_md, weights_md, bias_md_p, dst_md, strides, paddings,
mkldnn_engine, fuse_relu || fuse_brelu /*fuse_relu*/,
fuse_residual_conn, false /*fuse_brelu*/, fuse_brelu_threshold,
output_shift_scale, sum_scale, is_test);
// Save conv_pd/src_memory/weights_memory for backward pass
dev_ctx.SetBlob(key_conv_pd, conv_pd);
handler.reset(new platform::ConvMKLDNNHandler(conv_pd, dev_ctx,
......@@ -649,7 +641,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
private:
mkldnn::primitive_attr CreatePostOps(
bool fuse_relu, bool fuse_residual_conn,
const std::vector<float> output_shift_scale, float sum_scale,
const std::vector<float>& output_shift_scale, float sum_scale,
bool fuse_brelu, float fuse_brelu_threshold) const {
mkldnn::primitive_attr conv_attr;
mkldnn::post_ops post_operations;
......@@ -679,52 +671,29 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const std::shared_ptr<memory::desc> bias_md_p,
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale,
const std::vector<float>& output_shift_scale,
const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, dst, stride_dims,
padding_dims, padding_dims, mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
sum_scale, fuse_brelu, fuse_brelu_threshold);
auto p_conv_pd = new mkldnn::convolution_forward::primitive_desc(
conv_desc, conv_attr, engine);
return std::unique_ptr<mkldnn::convolution_forward::primitive_desc>(
p_conv_pd);
}
std::unique_ptr<mkldnn::convolution_forward::primitive_desc>
ConvFwdPrimitiveDesc(const memory::desc& src, const memory::desc& weights,
const memory::desc& bias, const memory::desc& dst,
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn, const bool fuse_brelu,
const float fuse_brelu_threshold,
const std::vector<float> output_shift_scale,
const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, bias, dst,
stride_dims, padding_dims, padding_dims, mkldnn::padding_kind::zero);
auto conv_desc =
(bias_md_p != nullptr)
? mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights,
(*bias_md_p), dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero)
: mkldnn::convolution_forward::desc(
propagation, mkldnn::convolution_direct, src, weights, dst,
stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
mkldnn::primitive_attr conv_attr =
CreatePostOps(fuse_relu, fuse_residual_conn, output_shift_scale,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册