未验证 提交 8bcb1f29 编写于 作者: A Adam 提交者: GitHub

Add conv+affine_channel fuse pass to MKLDNN pass strategy and fix it (#26779)

上级 2675cae7
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h" #include "paddle/fluid/framework/ir/conv_affine_channel_fuse_pass.h"
#include <cmath>
#include <functional> #include <functional>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -74,12 +75,17 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight, ...@@ -74,12 +75,17 @@ void recompute_bias_and_weights(const Scope* scope, ir::Node* conv_weight,
auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>(); auto* weights = scope->FindVar(conv_weight->Name())->GetMutable<LoDTensor>();
auto weights_shape = weights->dims(); auto weights_shape = weights->dims();
auto weights_shape_2d = flatten_to_2d(weights_shape, 1); auto weights_shape_2d = flatten_to_2d(weights_shape, 1);
auto* weights_data = weights->mutable_data<float>(platform::CPUPlace());
EigenMatrixArrayMap weights_array_2d( EigenMatrixArrayMap weights_array_2d(weights_data, weights_shape_2d[0],
weights->mutable_data<float>(platform::CPUPlace()), weights_shape_2d[0],
weights_shape_2d[1]); weights_shape_2d[1]);
weights_array_2d.colwise() *= scale_array; weights_array_2d.colwise() *= scale_array;
// Check for subnormal values that slows down convolution execution
for (int i = 0; i < weights->numel(); ++i) {
if (std::fpclassify(weights_data[i]) == FP_SUBNORMAL) weights_data[i] = 0;
}
} }
void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
...@@ -108,13 +114,6 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -108,13 +114,6 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
GET_CONV_BN_NODES(conv_ac_pattern); GET_CONV_BN_NODES(conv_ac_pattern);
// check if fuse can be done and if MKL-DNN should be used
FuseOptions fuse_option = FindFuseOption(*conv, *affine_channel);
if (fuse_option == DO_NOT_FUSE) {
VLOG(3) << "do not perform conv+affinechannel fuse";
return;
}
// Create eltwise_y (conv bias) variable // Create eltwise_y (conv bias) variable
VarDesc eltwise_y_in_desc( VarDesc eltwise_y_in_desc(
patterns::PDNodeName(name_scope_, "eltwise_y_in")); patterns::PDNodeName(name_scope_, "eltwise_y_in"));
...@@ -143,6 +142,7 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const { ...@@ -143,6 +142,7 @@ void ConvAffineChannelFusePass::ApplyImpl(ir::Graph* graph) const {
desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()})); desc.SetOutput("Out", std::vector<std::string>({ac_out->Name()}));
desc.SetType("elementwise_add"); desc.SetType("elementwise_add");
desc.SetAttr("axis", 1); desc.SetAttr("axis", 1);
desc.SetAttr("use_mkldnn", conv->Op()->GetAttrIfExists<bool>("use_mkldnn"));
auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied. auto eltwise_op = g->CreateOpNode(&desc); // OpDesc will be copied.
GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel}); GraphSafeRemoveNodes(graph, {ac_scale, ac_bias, affine_channel});
......
...@@ -188,6 +188,8 @@ void CpuPassStrategy::EnableMKLDNN() { ...@@ -188,6 +188,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"depthwise_conv_mkldnn_pass", // "depthwise_conv_mkldnn_pass", //
"conv_bn_fuse_pass", // Execute BN passes again to "conv_bn_fuse_pass", // Execute BN passes again to
"conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order "conv_eltwiseadd_bn_fuse_pass", // preserve correct pass order
"conv_affine_channel_fuse_pass", //
"conv_eltwiseadd_affine_channel_fuse_pass", //
"conv_transpose_bn_fuse_pass", // "conv_transpose_bn_fuse_pass", //
"conv_transpose_eltwiseadd_bn_fuse_pass", // "conv_transpose_eltwiseadd_bn_fuse_pass", //
"conv_bias_mkldnn_fuse_pass", // "conv_bias_mkldnn_fuse_pass", //
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册