提交 51e5950e 编写于 作者: X xiaolil1

modify prop_kind with is_test

上级 aac5ee72
......@@ -480,12 +480,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0]);
output_shift_scale, sum_scale[0], is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn,
output_shift_scale, sum_scale[0]);
output_shift_scale, sum_scale[0], is_test);
}
} else{
auto src_md = platform::MKLDNNMemDesc(
......@@ -501,11 +501,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
bias_tz, platform::MKLDNNGetDataType<float>(), memory::format::x);
conv_pd = ConvFwdPrimitiveDesc(src_md, weights_md, bias_md, dst_md,
strides, paddings, mkldnn_engine,
fuse_relu, fuse_residual_conn);
fuse_relu, fuse_residual_conn, is_test);
} else {
conv_pd =
ConvFwdPrimitiveDesc(src_md, weights_md, dst_md, strides, paddings,
mkldnn_engine, fuse_relu, fuse_residual_conn);
mkldnn_engine, fuse_relu, fuse_residual_conn, is_test);
}
}
// Save conv_pd/src_memory/weights_memory for backward pass
......@@ -743,12 +743,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const std::vector<float> output_shift_scale, const float sum_scale) const {
const std::vector<float> output_shift_scale, const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights,
propagation, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
......@@ -767,12 +769,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const memory::desc& dst, const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn) const{
const bool fuse_residual_conn, bool is_test) const{
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights,
propagation, mkldnn::convolution_direct, src, weights,
dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
......@@ -792,12 +796,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn,
const std::vector<float> output_shift_scale, const float sum_scale) const {
const std::vector<float> output_shift_scale, const float sum_scale, bool is_test) const {
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights,
propagation, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
......@@ -817,12 +823,14 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const std::vector<int>& strides,
const std::vector<int>& paddings,
const mkldnn::engine& engine, const bool fuse_relu,
const bool fuse_residual_conn) const{
const bool fuse_residual_conn, bool is_test) const{
memory::dims stride_dims = {strides[0], strides[1]};
memory::dims padding_dims = {paddings[0], paddings[1]};
auto propagation = is_test ? mkldnn::prop_kind::forward_scoring : mkldnn::prop_kind::forward_training;
auto conv_desc = mkldnn::convolution_forward::desc(
mkldnn::prop_kind::forward_scoring, mkldnn::convolution_direct, src, weights,
propagation, mkldnn::convolution_direct, src, weights,
bias, dst, stride_dims, padding_dims, padding_dims,
mkldnn::padding_kind::zero);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册