未验证 提交 d9d47dc6 编写于 作者: P Paulina Gacek 提交者: GitHub

Rewrite mat reshape transpose testers (#49580)

* reshape_transpose_matmul_pass_tester rewritten

* matmul_transpose_reshape_pass_tester rewritten

* mkldnn to onednn
上级 f287b1e9
...@@ -460,14 +460,6 @@ if(WITH_MKLDNN) ...@@ -460,14 +460,6 @@ if(WITH_MKLDNN)
test_cpu_quantize_squash_pass test_cpu_quantize_squash_pass
SRCS mkldnn/cpu_quantize_squash_pass_tester.cc SRCS mkldnn/cpu_quantize_squash_pass_tester.cc
DEPS cpu_quantize_squash_pass naive_executor) DEPS cpu_quantize_squash_pass naive_executor)
cc_test(
test_reshape_transpose_matmul_mkldnn_fuse_pass
SRCS mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass_tester.cc
DEPS reshape_transpose_matmul_mkldnn_fuse_pass)
cc_test(
test_matmul_transpose_reshape_fuse_pass
SRCS mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass_tester.cc
DEPS matmul_transpose_reshape_mkldnn_fuse_pass)
cc_test( cc_test(
test_shuffle_channel_mkldnn_detect_pass test_shuffle_channel_mkldnn_detect_pass
SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc SRCS mkldnn/shuffle_channel_mkldnn_detect_pass_tester.cc
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/matmul_transpose_reshape_mkldnn_fuse_pass.h"
namespace paddle {
namespace framework {
namespace ir {
void SetOp(ProgramDesc *prog,
const std::string &type,
const std::vector<std::string> &inputs,
const std::vector<std::string> &outputs) {
auto *op = prog->MutableBlock(0)->AppendOp();
op->SetType(type);
op->SetInput("X", {inputs[0]});
op->SetOutput("Out", {outputs[0]});
if (type == "transpose2") {
op->SetAttr("axis", std::vector<int>({0, 2, 1, 3}));
op->SetOutput("XShape", {outputs[1]});
}
if (type == "reshape2") {
op->SetAttr("shape", std::vector<int>({4, 5, 6}));
op->SetOutput("XShape", {outputs[1]});
}
if (type == "matmul") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("alpha", 1.0f);
op->SetAttr("transpose_X", true);
op->SetAttr("transpose_Y", true);
}
if (type == "matmul_v2") {
op->SetInput("Y", {inputs[1]});
op->SetAttr("use_mkldnn", true);
op->SetAttr("trans_x", true);
op->SetAttr("trans_y", true);
}
}
ProgramDesc BuildProgramDesc(const std::string &op_name) {
ProgramDesc prog;
for (auto &v : std::initializer_list<std::string>(
{"a1", "a2", "b", "c", "cx", "d", "dx", "e"})) {
auto *var = prog.MutableBlock(0)->Var(v);
var->SetType(proto::VarType::SELECTED_ROWS);
}
SetOp(&prog, op_name, {"a1", "a2"}, {"b"});
SetOp(&prog, "transpose2", {"b"}, {"c", "cx"});
SetOp(&prog, "reshape2", {"c"}, {"d", "dx"});
SetOp(&prog, "fc", {"d"}, {"e"});
return prog;
}
void MainTest(const ProgramDesc &prog, const std::string &op_name) {
std::unique_ptr<ir::Graph> graph(new ir::Graph(prog));
int original_nodes_num = graph->Nodes().size();
auto pass =
PassRegistry::Instance().Get("matmul_transpose_reshape_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int current_nodes_num = graph->Nodes().size();
EXPECT_EQ(original_nodes_num - 6, current_nodes_num);
for (auto *node : graph->Nodes()) {
if (node->IsOp()) {
auto *op = node->Op();
if (op->Type() == op_name) {
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_reshape_Out"),
std::vector<int>({4, 5, 6}));
EXPECT_EQ(op->GetAttrIfExists<std::vector<int>>("fused_transpose_Out"),
std::vector<int>({0, 2, 1, 3}));
}
}
}
}
TEST(MatmulTransposeReshapeFusePass, matmul_fuse_pass) {
auto prog = BuildProgramDesc("matmul");
MainTest(prog, "matmul");
}
TEST(MatmulTransposeReshapeFusePass, matmul_v2_fuse_pass) {
auto prog = BuildProgramDesc("matmul_v2");
MainTest(prog, "matmul_v2");
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(matmul_transpose_reshape_mkldnn_fuse_pass);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/ir/mkldnn/reshape_transpose_matmul_mkldnn_fuse_pass.h"
#include "paddle/fluid/framework/ir/pass_tester_helper.h"
namespace paddle {
namespace framework {
namespace ir {
void AddVarToScope(Scope* param_scope,
const std::string& name,
const DDim& dims) {
auto* tensor = param_scope->Var(name)->GetMutable<phi::DenseTensor>();
tensor->Resize(dims);
tensor->mutable_data<float>(phi::CPUPlace());
}
Scope* CreateParamScope() {
auto param_scope = new Scope();
AddVarToScope(param_scope, "w1", {768, 768});
AddVarToScope(param_scope, "bias1", {768});
AddVarToScope(param_scope, "w2", {768, 768});
AddVarToScope(param_scope, "bias2", {768});
return param_scope;
}
void TestMain(const std::string& op_name, bool with_xshapes) {
// inputs operator output
// -----------------------------------------------
// a1,w1,bias1 fc -> b1
// b1 reshape -> c1
// c1 transpose -> d1
// a2,w2,bias2 fc -> b2
// b2 reshape -> c2
// c2 transpose -> d2
// (d1, d2) matmul(_v2) -> (...)
Layers layers;
auto* a1 = layers.data("a1", {-1, 128, 768});
auto* w1 = layers.data("w1", {768, 768}, true);
auto* bias1 = layers.data("bias1", {768}, true);
auto* b1 = layers.fc(a1, w1, bias1, 2);
b1->SetShape({-1, 128, 768});
auto* c1 = layers.reshape2(b1, {0, 0, 12, 64}, with_xshapes);
c1->SetShape({-1, 128, 12, 64});
auto* d1 = layers.transpose2(c1, {0, 2, 1, 3}, with_xshapes);
d1->SetShape({-1, 12, 128, 64});
auto* a2 = layers.data("a2", {-1, 128, 768});
auto* w2 = layers.data("w2", {768, 768}, true);
auto* bias2 = layers.data("bias2", {768}, true);
auto* b2 = layers.fc(a2, w2, bias2, 2);
b2->SetShape({-1, 128, 768});
auto* c2 = layers.reshape2(b2, {0, 0, 12, 64});
c2->SetShape({-1, 128, 12, 64});
auto* d2 = layers.transpose2(c2, {0, 2, 1, 3});
d2->SetShape({-1, 12, 128, 64});
if (op_name == "matmul_v2") {
layers.matmul_v2(d1, d2);
} else {
layers.matmul(d1, d2);
}
std::unique_ptr<ir::Graph> graph(new ir::Graph(layers.main_program()));
graph->Set("__param_scope__", CreateParamScope());
int num_reshape_nodes_before = GetNumOpNodes(graph, "reshape2");
int num_transpose_nodes_before = GetNumOpNodes(graph, "transpose2");
int total_nodes_before = graph->Nodes().size();
VLOG(3) << DebugString(graph);
auto pass =
PassRegistry::Instance().Get("reshape_transpose_matmul_mkldnn_fuse_pass");
graph.reset(pass->Apply(graph.release()));
int num_reshape_nodes_after = GetNumOpNodes(graph, "reshape2");
int num_transpose_nodes_after = GetNumOpNodes(graph, "transpose2");
int total_nodes_after = graph->Nodes().size();
VLOG(3) << DebugString(graph);
EXPECT_EQ(num_reshape_nodes_before, 2);
EXPECT_EQ(num_reshape_nodes_after, 0);
EXPECT_EQ(num_transpose_nodes_before, 2);
EXPECT_EQ(num_transpose_nodes_after, 0);
int removed = 8; // 2* reshape, reshape_out, transpose, transpose_out
if (with_xshapes) removed += 2; // transpose_xshape, reshape_xshape
EXPECT_EQ(total_nodes_before - removed, total_nodes_after);
auto* matmul_op_desc = GetOpNodes(graph, op_name).at(0)->Op();
auto check = [&matmul_op_desc](std::string a) {
std::string shape_str = "fused_reshape_" + a;
auto shape = matmul_op_desc->GetAttrIfExists<std::vector<int>>(shape_str);
EXPECT_EQ(shape, (std::vector<int>{0, 0, 12, 64}));
std::string axis_str = "fused_transpose_" + a;
auto axis = matmul_op_desc->GetAttrIfExists<std::vector<int>>(axis_str);
EXPECT_EQ(axis, (std::vector<int>{0, 2, 1, 3}));
};
check("X");
check("Y");
}
TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose) {
TestMain("matmul", false);
}
TEST(ReshapeTransposeMatmulMkldnnFusePass,
both_matmul_inputs_reshape_transpose_one_with_xshapes) {
TestMain("matmul", true);
}
TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose) {
TestMain("matmul_v2", false);
}
TEST(ReshapeTransposeMatmulV2MkldnnFusePass,
both_matmulv2_inputs_reshape_transpose_one_with_xshapes) {
TestMain("matmul_v2", true);
}
} // namespace ir
} // namespace framework
} // namespace paddle
USE_PASS(reshape_transpose_matmul_mkldnn_fuse_pass);
...@@ -232,7 +232,7 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -232,7 +232,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_mkldnn_conv_elementwise_add_fuse_pass set_tests_properties(test_mkldnn_conv_elementwise_add_fuse_pass
PROPERTIES TIMEOUT 120) PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_depthwise_conv_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_mkldnn_depthwise_conv_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_mkldnn_reshape_transpose_matmul_fuse_pass set_tests_properties(test_onednn_reshape_transpose_matmul_fuse_pass
PROPERTIES TIMEOUT 100) PROPERTIES TIMEOUT 100)
set_tests_properties(test_mkldnn_mish_op PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_mish_op PROPERTIES TIMEOUT 300)
set_tests_properties(test_mkldnn_conv3d_op PROPERTIES TIMEOUT 300) set_tests_properties(test_mkldnn_conv3d_op PROPERTIES TIMEOUT 300)
...@@ -240,7 +240,7 @@ if(WITH_GPU AND TENSORRT_FOUND) ...@@ -240,7 +240,7 @@ if(WITH_GPU AND TENSORRT_FOUND)
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120) set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass set_tests_properties(test_conv_transpose_eltwiseadd_bn_fuse_pass
PROPERTIES TIMEOUT 250) PROPERTIES TIMEOUT 250)
set_tests_properties(test_mkldnn_matmul_transpose_reshape_fuse_pass set_tests_properties(test_onednn_matmul_transpose_reshape_fuse_pass
PROPERTIES TIMEOUT 100) PROPERTIES TIMEOUT 100)
set_tests_properties(test_conv_transpose_bn_fuse_pass PROPERTIES TIMEOUT set_tests_properties(test_conv_transpose_bn_fuse_pass PROPERTIES TIMEOUT
300) 300)
......
...@@ -21,7 +21,7 @@ from auto_scan_test import PassAutoScanTest ...@@ -21,7 +21,7 @@ from auto_scan_test import PassAutoScanTest
from program_config import OpConfig, ProgramConfig, TensorConfig from program_config import OpConfig, ProgramConfig, TensorConfig
class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest): class TestOneDNNMatmulTransposeReshapeFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
attrs = [ attrs = [
program_config.ops[i].attrs for i in range(len(program_config.ops)) program_config.ops[i].attrs for i in range(len(program_config.ops))
...@@ -57,42 +57,42 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -57,42 +57,42 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest):
shape_x = [batch_size, channel, 32, input_dim] shape_x = [batch_size, channel, 32, input_dim]
shape_y = [batch_size, channel, input_dim, 16] shape_y = [batch_size, channel, input_dim, 16]
if type == "x": if type == 'x':
return np.random.random(shape_x).astype(np.float32) return np.random.random(shape_x).astype(np.float32)
else: else:
return np.random.random(shape_y).astype(np.float32) return np.random.random(shape_y).astype(np.float32)
matmul_op = OpConfig( matmul_op = OpConfig(
type="matmul", type='matmul',
inputs={"X": ["input_data1"], "Y": ["input_data2"]}, inputs={'X': ['input_data1'], 'Y': ['input_data2']},
outputs={"Out": ["matmul_output"]}, outputs={'Out': ['matmul_output']},
attrs={ attrs={
"transpose_X": transpose_X, 'transpose_X': transpose_X,
"transpose_Y": transpose_Y, 'transpose_Y': transpose_Y,
"alpha": alpha, 'alpha': alpha,
"fused_reshape_X": [], 'fused_reshape_X': [],
"fused_reshape_Y": [], 'fused_reshape_Y': [],
"fused_transpose_X": [], 'fused_transpose_X': [],
"fused_transpose_Y": [], 'fused_transpose_Y': [],
"fused_reshape_Out": [], 'fused_reshape_Out': [],
"fused_transpose_Out": [], 'fused_transpose_Out': [],
}, },
) )
transpose2_op = OpConfig( transpose2_op = OpConfig(
type="transpose2", type='transpose2',
inputs={"X": ["matmul_output"]}, inputs={'X': ['matmul_output']},
outputs={ outputs={
"Out": ["transpose2_output"], 'Out': ['transpose2_output'],
"XShape": ["transpose2_xshape"], 'XShape': ['transpose2_xshape'],
}, },
attrs={'axis': axis}, attrs={'axis': axis},
) )
reshape2_op = OpConfig( reshape2_op = OpConfig(
type="reshape2", type='reshape2',
inputs={"X": ["transpose2_output"]}, inputs={'X': ['transpose2_output']},
outputs={"Out": ["reshape2_output"], "XShape": ["reshape2_xshape"]}, outputs={'Out': ['reshape2_output'], 'XShape': ['reshape2_xshape']},
attrs={'shape': shape}, attrs={'shape': shape},
) )
...@@ -102,27 +102,27 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest): ...@@ -102,27 +102,27 @@ class TestMatmulTransposeReshapeMkldnnFusePass(PassAutoScanTest):
ops=model_net, ops=model_net,
weights={}, weights={},
inputs={ inputs={
"input_data1": TensorConfig( 'input_data1': TensorConfig(
data_gen=partial(generate_input, "x") data_gen=partial(generate_input, 'x')
), ),
"input_data2": TensorConfig( 'input_data2': TensorConfig(
data_gen=partial(generate_input, "y") data_gen=partial(generate_input, 'y')
), ),
}, },
outputs=["reshape2_output"], outputs=['reshape2_output'],
) )
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["matmul"], (1e-5, 1e-5) yield config, ['matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, passes=["matmul_transpose_reshape_mkldnn_fuse_pass"] quant=False, passes=['matmul_transpose_reshape_mkldnn_fuse_pass']
) )
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -20,10 +20,11 @@ import numpy as np ...@@ -20,10 +20,11 @@ import numpy as np
from auto_scan_test import PassAutoScanTest from auto_scan_test import PassAutoScanTest
from program_config import ProgramConfig, TensorConfig from program_config import ProgramConfig, TensorConfig
num = 32 * 64
class TestOneDNNReshapeTransposeMatmulFusePass(PassAutoScanTest):
def setUp(self):
self.num = 32 * 64
class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
def is_program_valid(self, program_config: ProgramConfig) -> bool: def is_program_valid(self, program_config: ProgramConfig) -> bool:
return True return True
...@@ -40,11 +41,11 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -40,11 +41,11 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
input_dim = draw(st.sampled_from([32, 64])) input_dim = draw(st.sampled_from([32, 64]))
def generate_input1(attrs): def generate_input1(attrs):
shape_x = [attrs[3]['batch_size'], attrs[3]['channel'], num] shape_x = [attrs[3]['batch_size'], attrs[3]['channel'], self.num]
return np.random.random(shape_x).astype(np.float32) return np.random.random(shape_x).astype(np.float32)
def generate_input2(attrs): def generate_input2(attrs):
shape_x = [attrs[3]['batch_size'], attrs[3]['channel'], num] shape_x = [attrs[3]['batch_size'], attrs[3]['channel'], self.num]
input_volume = reduce(lambda x, y: x * y, shape_x) input_volume = reduce(lambda x, y: x * y, shape_x)
matmul_shape = [i for i in attrs[0]['shape']] matmul_shape = [i for i in attrs[0]['shape']]
if 0 in matmul_shape: if 0 in matmul_shape:
...@@ -66,7 +67,7 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -66,7 +67,7 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
matmul_shape[0], matmul_shape[0],
matmul_shape[1], matmul_shape[1],
matmul_shape[-1], matmul_shape[-1],
int(num / matmul_shape[-1]), int(self.num / matmul_shape[-1]),
] ]
elif attrs[2]['transpose_X']: elif attrs[2]['transpose_X']:
shape_y = matmul_shape shape_y = matmul_shape
...@@ -77,17 +78,17 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -77,17 +78,17 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
matmul_shape[0], matmul_shape[0],
matmul_shape[1], matmul_shape[1],
matmul_shape[-1], matmul_shape[-1],
int(num / matmul_shape[-1]), int(self.num / matmul_shape[-1]),
] ]
return np.random.random(shape_y).astype(np.float32) return np.random.random(shape_y).astype(np.float32)
attrs = [ attrs = [
{"shape": shape}, {'shape': shape},
{"axis": axis}, {'axis': axis},
{ {
"transpose_X": transpose_X, 'transpose_X': transpose_X,
"transpose_Y": transpose_Y, 'transpose_Y': transpose_Y,
"alpha": alpha, 'alpha': alpha,
}, },
{ {
'batch_size': batch_size, 'batch_size': batch_size,
...@@ -98,37 +99,37 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -98,37 +99,37 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
ops_config = [ ops_config = [
{ {
"op_type": "reshape2", 'op_type': 'reshape2',
"op_inputs": {"X": ["input_data1"]}, 'op_inputs': {'X': ['input_data1']},
"op_outputs": { 'op_outputs': {
"Out": ["reshape2_output"], 'Out': ['reshape2_output'],
"XShape": ["reshape2_xshape"], 'XShape': ['reshape2_xshape'],
}, },
"op_attrs": {'shape': attrs[0]['shape']}, 'op_attrs': {'shape': attrs[0]['shape']},
}, },
{ {
"op_type": "transpose2", 'op_type': 'transpose2',
"op_inputs": {"X": ["reshape2_output"]}, 'op_inputs': {'X': ['reshape2_output']},
"op_outputs": { 'op_outputs': {
"Out": ["transpose2_output"], 'Out': ['transpose2_output'],
"XShape": ["transpose2_xshape"], 'XShape': ['transpose2_xshape'],
}, },
"op_attrs": {'axis': attrs[1]['axis']}, 'op_attrs': {'axis': attrs[1]['axis']},
}, },
{ {
"op_type": "matmul", 'op_type': 'matmul',
"op_inputs": {"X": ["transpose2_output"], "Y": ["input_data2"]}, 'op_inputs': {'X': ['transpose2_output'], 'Y': ['input_data2']},
"op_outputs": {"Out": ["matmul_output"]}, 'op_outputs': {'Out': ['matmul_output']},
"op_attrs": { 'op_attrs': {
'transpose_X': attrs[2]['transpose_X'], 'transpose_X': attrs[2]['transpose_X'],
'transpose_Y': attrs[2]['transpose_Y'], 'transpose_Y': attrs[2]['transpose_Y'],
'alpha': attrs[2]['alpha'], 'alpha': attrs[2]['alpha'],
"fused_reshape_X": [], 'fused_reshape_X': [],
"fused_reshape_Y": [], 'fused_reshape_Y': [],
"fused_transpose_X": [], 'fused_transpose_X': [],
"fused_transpose_Y": [], 'fused_transpose_Y': [],
"fused_reshape_Out": [], 'fused_reshape_Out': [],
"fused_transpose_Out": [], 'fused_transpose_Out': [],
}, },
}, },
] ]
...@@ -139,27 +140,27 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest): ...@@ -139,27 +140,27 @@ class TestReshapeTransposeMatmulMkldnnFusePass(PassAutoScanTest):
ops=ops, ops=ops,
weights={}, weights={},
inputs={ inputs={
"input_data1": TensorConfig( 'input_data1': TensorConfig(
data_gen=partial(generate_input1, attrs) data_gen=partial(generate_input1, attrs)
), ),
"input_data2": TensorConfig( 'input_data2': TensorConfig(
data_gen=partial(generate_input2, attrs) data_gen=partial(generate_input2, attrs)
), ),
}, },
outputs=["matmul_output"], outputs=['matmul_output'],
) )
return program_config return program_config
def sample_predictor_configs(self, program_config): def sample_predictor_configs(self, program_config):
config = self.create_inference_config(use_mkldnn=True) config = self.create_inference_config(use_mkldnn=True)
yield config, ["matmul"], (1e-5, 1e-5) yield config, ['matmul'], (1e-5, 1e-5)
def test(self): def test(self):
self.run_and_statis( self.run_and_statis(
quant=False, passes=["reshape_transpose_matmul_mkldnn_fuse_pass"] quant=False, passes=['reshape_transpose_matmul_mkldnn_fuse_pass']
) )
if __name__ == "__main__": if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册