未验证 提交 b7a23adb 编写于 作者: S Sławomir Siwek 提交者: GitHub

FC + activation fuse passes (#45183)

* git

* style

* leave default relu in kernel

* style

* cleanup FCMKLDNN pattern

* merge conflicts

* update develop

* update develop

* add const

* rename to oneDNN and adjust attributes

* whitespace
上级 da051350
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
......@@ -14,9 +14,8 @@
#include "paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
#include "paddle/fluid/platform/mkldnn_reuse.h"
#include "paddle/fluid/string/pretty_log.h"
namespace paddle {
......@@ -26,20 +25,20 @@ namespace ir {
using string::PrettyLogDetail;
void FuseFCActOneDNNPass::ApplyImpl(Graph *graph) const {
std::vector<std::string> act_types = {
"gelu", "tanh", "sigmoid", "mish", "hard_swish"};
auto act_types = paddle::platform::GetSupportedActivations();
for (std::string act_type : act_types) FuseFCAct(graph, act_type);
for (auto act_type : act_types) FuseFCAct(graph, act_type);
}
void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
const std::string &act_type) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
FusePassBase::Init("fc_act", graph);
FusePassBase::Init("fc_" + act_type + "_mkldnn_fuse_pass", graph);
GraphPatternDetector gpd;
patterns::OperatorActivation fc_act_pattern(gpd.mutable_pattern(), "fc_act");
patterns::OperatorActivation fc_act_pattern(
gpd.mutable_pattern(), "fc_" + act_type + "_mkldnn_fuse_pass");
fc_act_pattern("fc", act_type);
int found_fc_act_count = 0;
......@@ -62,15 +61,23 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
"is used."));
}
auto attr_map = paddle::platform::GetAttributeMap(act_type);
for (const auto &attr : attr_map) {
if (act_op->HasAttr(attr.first)) {
fc_op->SetAttr(attr.second, act_op->GetAttr(attr.first));
}
}
if (act_type == "gelu" && act_op->HasAttr("approximate")) {
bool approximate = PADDLE_GET_CONST(bool, act_op->GetAttr("approximate"));
std::string type = approximate ? "_tanh" : "_erf";
fc_op->SetAttr("activation_type", act_type + type);
std::string gelu_act_type =
PADDLE_GET_CONST(bool, act_op->GetAttr("approximate")) ? "gelu_tanh"
: "gelu_erf";
fc_op->SetAttr("fuse_activation", gelu_act_type);
} else {
fc_op->SetAttr("activation_type", act_type);
fc_op->SetAttr("fuse_activation", act_type);
}
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetAttr("use_mkldnn", true);
fc_op->SetOutput("Out", {act_out->Name()});
IR_OP_VAR_LINK(fc, act_out);
......@@ -80,7 +87,8 @@ void FuseFCActOneDNNPass::FuseFCAct(Graph *graph,
gpd(graph, handler);
AddStatis(found_fc_act_count);
if (!Has("disable_logs") || !Get<bool>("disable_logs"))
if ((!Has("disable_logs") || !Get<bool>("disable_logs")) &&
found_fc_act_count > 0)
PrettyLogDetail(
"--- fused %d fc with %s activation", found_fc_act_count, act_type);
}
......@@ -95,8 +103,16 @@ REGISTER_PASS_CAPABILITY(fc_act_mkldnn_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("fc", 0)
.LE("gelu", 0)
.LE("sigmoid", 0)
.LE("mish", 1)
.EQ("abs", 0)
.LE("clip", 1)
.EQ("gelu", 0)
.EQ("hard_sigmoid", 0)
.LE("hard_swish", 0)
.LE("tanh", 0));
.LE("leaky_relu", 1)
.LE("mish", 1)
.EQ("relu", 0)
.EQ("relu6", 0)
.EQ("sigmoid", 0)
.EQ("sqrt", 0)
.EQ("swish", 0)
.EQ("tanh", 0));
......@@ -23,21 +23,14 @@ namespace paddle {
namespace framework {
namespace ir {
/*
* \brief Fuse the FC and activation operators into single OneDNN's
* FC with post-op.
*
* \note Currently only GeLU, hardswish, sigmoid, mish and tanh are supported
* as an activation function.
*/
class FuseFCActOneDNNPass : public FusePassBase {
public:
virtual ~FuseFCActOneDNNPass() {}
protected:
void ApplyImpl(ir::Graph *graph) const override;
void ApplyImpl(Graph *graph) const override;
void FuseFCAct(ir::Graph *graph, const std::string &act_types) const;
void FuseFCAct(Graph *graph, const std::string &act_types) const;
};
} // namespace ir
......
......@@ -34,12 +34,12 @@ TEST(FuseFCActOneDNNPass, ThrowUseMkldnn) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}},
false);
test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
// No fusion in this attribute configuration
......@@ -58,12 +58,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
auto* act_op = test::CreateOp(
&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
auto* act_op =
test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", true);
Graph graph(prog);
......@@ -78,9 +78,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluTanh) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_tanh"), 0);
}
}
......@@ -93,12 +93,12 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
auto* act_op = test::CreateOp(
&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
auto* act_op =
test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
act_op->SetAttr("approximate", false);
Graph graph(prog);
......@@ -113,9 +113,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluErf) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu_erf"), 0);
}
}
......@@ -128,11 +128,11 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(&prog, "gelu", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "gelu", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
......@@ -146,9 +146,9 @@ TEST(FuseFCActOneDNNPass, FuseWithGeluAuto) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("gelu"), 0);
}
}
......@@ -161,11 +161,11 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(&prog, "tanh", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "tanh", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
......@@ -179,9 +179,9 @@ TEST(FuseFCActOneDNNPass, FuseWithTanh) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("tanh"), 0);
}
}
......@@ -194,12 +194,11 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(
&prog, "sigmoid", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "sigmoid", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
......@@ -213,9 +212,9 @@ TEST(FuseFCActOneDNNPass, FuseWithSigmoid) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("sigmoid"), 0);
}
}
......@@ -228,11 +227,11 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(&prog, "mish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
test::CreateOp(&prog, "mish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
......@@ -246,9 +245,9 @@ TEST(FuseFCActOneDNNPass, FuseWithMish) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("mish"), 0);
}
}
......@@ -261,12 +260,12 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
"fc",
{
{"Input", "x"},
{"Weights", "weights"},
{"W", "weights"},
{"Bias", "bias"},
},
{{"Out", "fc_y"}});
test::CreateOp(
&prog, "hard_swish", {{"Input", "fc_y"}}, {{"Out", "act_y"}}, false);
&prog, "hard_swish", {{"X", "fc_y"}}, {{"Out", "act_y"}}, false);
Graph graph(prog);
constexpr int removed_nodes_count = 2;
......@@ -280,9 +279,9 @@ TEST(FuseFCActOneDNNPass, FuseWithHardSwish) {
const auto* op = node->Op();
ASSERT_TRUE(op->HasAttr("use_mkldnn"));
EXPECT_TRUE(PADDLE_GET_CONST(bool, op->GetAttr("use_mkldnn")));
ASSERT_TRUE(op->HasAttr("activation_type"));
ASSERT_TRUE(op->HasAttr("fuse_activation"));
auto act_type =
PADDLE_GET_CONST(std::string, op->GetAttr("activation_type"));
PADDLE_GET_CONST(std::string, op->GetAttr("fuse_activation"));
EXPECT_EQ(act_type.compare("hard_swish"), 0);
}
}
......
......@@ -242,7 +242,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) {
if (input_compats_.find(input_desc.first) == input_compats_.end()) {
if (!input_desc.second.empty()) {
LOG(WARNING) << "The Input (" << input_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
<< op_name_ << ") not registered in OpCompat!";
return false;
}
}
......@@ -269,7 +269,7 @@ bool OpCompat::Judge(const OpDesc& op_desc, const std::string& pass_name) {
if (output_compats_.find(output_desc.first) == output_compats_.end()) {
if (!output_desc.second.empty()) {
LOG(WARNING) << "The Output (" << output_desc.first << ") of Operator ("
<< op_name_ << ") not reigistered in OpCompat!";
<< op_name_ << ") not registered in OpCompat!";
return false;
}
}
......
......@@ -87,8 +87,7 @@ class FCMKLDNNHandler
dnnl::memory::format_tag::a);
}
dnnl::primitive_attr attrs;
HandlePostOps(ctx, &attrs);
const auto attrs = CreateFCAttrs(ctx);
this->AcquireForwardPrimitiveDescriptor(attrs,
prop_kind::forward_inference,
......@@ -99,44 +98,33 @@ class FCMKLDNNHandler
}
private:
void HandlePostOps(const paddle::framework::ExecutionContext& ctx,
dnnl::primitive_attr* attrs) {
static std::unordered_map<std::string, dnnl::algorithm> algo_map = {
{"relu", dnnl::algorithm::eltwise_relu},
{"gelu", dnnl::algorithm::eltwise_gelu},
{"gelu_tanh", dnnl::algorithm::eltwise_gelu_tanh},
{"gelu_erf", dnnl::algorithm::eltwise_gelu_erf},
{"tanh", dnnl::algorithm::eltwise_tanh},
{"sigmoid", dnnl::algorithm::eltwise_logistic},
{"hard_swish", dnnl::algorithm::eltwise_hardswish},
{"mish", dnnl::algorithm::eltwise_mish}};
dnnl::primitive_attr CreateFCAttrs(const ExecutionContext& ctx) {
dnnl::primitive_attr attributes;
dnnl::post_ops post_operations;
std::vector<float> output_shift_scale;
float scale = 1.0f;
if (IsInt8<T_w>()) {
std::tie(output_shift_scale, scale) = ComputeOutputShiftScale(ctx);
int mask = CreateMask(1, output_shift_scale.size() > 1);
attrs->set_output_scales(mask, output_shift_scale);
attributes.set_output_scales(mask, output_shift_scale);
}
dnnl::post_ops post_ops;
constexpr float sum_scale = 1.0f;
float sum_scale = 1.0f;
if (ctx.HasAttr("fuse_residual_connection") &&
ctx.Attr<bool>("fuse_residual_connection")) {
post_ops.append_sum(sum_scale);
post_operations.append_sum(sum_scale);
}
std::string activation_type = ctx.Attr<std::string>("activation_type");
if (activation_type.empty() == false) {
constexpr float alpha = 0.0f;
constexpr float beta = 0.0f;
post_ops.append_eltwise(scale, algo_map[activation_type], alpha, beta);
// ReLU from "fc_fuse_pass"
if (ctx.Attr<std::string>("activation_type") == "relu") {
post_operations.append_eltwise(
scale, dnnl::algorithm::eltwise_relu, 0.0f, 0.0f);
}
platform::AppendActivation(ctx, post_operations, scale);
attrs->set_post_ops(post_ops);
attributes.set_post_ops(post_operations);
return attributes;
}
// Compute the bias scales so that its values correspond to the
......
......@@ -226,7 +226,8 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT
300)
set_tests_properties(test_mkldnn_conv_mish_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_fc_mish_fuse_pass PROPERTIES TIMEOUT 300)
set_tests_properties(test_onednn_fc_activation_fuse_pass PROPERTIES TIMEOUT
300)
set_tests_properties(test_mkldnn_fc_elementwise_add_fuse_pass
PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_conv_affine_channel_fuse_pass
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig
import numpy as np
import unittest
import hypothesis.strategies as st
class TestFCMishMkldnnFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
x_shape = draw(
st.lists(st.integers(min_value=1, max_value=128),
min_size=2,
max_size=3))
in_num_col_dims = len(x_shape) - 1
w_shape = draw(
st.lists(st.integers(min_value=1, max_value=128),
min_size=2,
max_size=2))
w_shape[0] = int(np.prod(x_shape[in_num_col_dims:]))
fc_bias_shape = [w_shape[1]]
ops_config = [{
"op_type": "fc",
"op_inputs": {
"Input": ["fc_x"],
"W": ["fc_w"],
"Bias": ["fc_bias"]
},
"op_outputs": {
"Out": ["fc_out"]
},
"op_attrs": {
"activation_type": "",
"padding_weights": False,
"in_num_col_dims": in_num_col_dims,
"use_mkldnn": True
}
}, {
"op_type": "mish",
"op_inputs": {
"X": ["fc_out"]
},
"op_outputs": {
"Out": ["mish_output"]
},
"op_attrs": {},
}]
ops = self.generate_op_config(ops_config)
program_config = ProgramConfig(ops=ops,
weights={
"fc_w":
TensorConfig(shape=w_shape),
"fc_bias":
TensorConfig(shape=fc_bias_shape),
},
inputs={
"fc_x": TensorConfig(shape=x_shape),
},
outputs=["mish_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"])
yield config, ["fc"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
from functools import partial
import unittest
import hypothesis.strategies as st
class TestFCActivationOneDNNFusePass(PassAutoScanTest):
def sample_program_config(self, draw):
fc_in = draw(st.sampled_from([32, 64]))
fc_wei = draw(st.sampled_from([64]))
activation_type = draw(
st.sampled_from([
'relu', 'gelu', 'swish', 'mish', 'sqrt', 'hard_swish',
'sigmoid', 'abs', 'relu6', 'clip', 'tanh', 'hard_sigmoid',
'leaky_relu'
]))
def generate_input(shape):
return np.random.random(shape).astype(np.float32)
fc_op = OpConfig(type="fc",
inputs={
"Input": ["fc_input"],
"W": ["fc_weight"],
"Bias": ["fc_bias"]
},
outputs={"Out": ["fc_output"]},
attrs={
"use_mkldnn": True,
"padding_weights": False,
"in_num_col_dims": 1,
})
if activation_type == "clip":
activation_op = OpConfig(
activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
min=draw(st.floats(min_value=0.1, max_value=0.49)),
max=draw(st.floats(min_value=0.5, max_value=1.0)))
elif activation_type == "gelu":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
approximate=draw(st.booleans()))
elif activation_type == "leaky_relu":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
alpha=draw(
st.floats(min_value=0.1,
max_value=1.0)))
elif activation_type == "relu6":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
threshold=6)
elif activation_type == "swish":
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]},
beta=draw(
st.floats(min_value=0.1,
max_value=10.0)))
else:
activation_op = OpConfig(activation_type,
inputs={"X": ["fc_output"]},
outputs={"Out": ["activation_output"]})
model_net = [fc_op, activation_op]
program_config = ProgramConfig(
ops=model_net,
weights={
"fc_weight":
TensorConfig(
data_gen=partial(generate_input, [fc_wei, fc_wei])),
"fc_bias":
TensorConfig(data_gen=partial(generate_input, [fc_wei])),
},
inputs={
"fc_input":
TensorConfig(data_gen=partial(generate_input, [fc_in, fc_wei]))
},
outputs=["activation_output"])
return program_config
def sample_predictor_configs(self, program_config):
config = self.create_inference_config(
use_mkldnn=True, passes=["fc_act_mkldnn_fuse_pass"])
yield config, ["fc"], (1e-5, 1e-5)
def test(self):
self.run_and_statis(quant=False, passes=["fc_act_mkldnn_fuse_pass"])
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册