未验证 提交 ce335c23 编写于 作者: H heliqi 提交者: GitHub

add matmul_scale_fuse_pass (#37962)

* add matmul_scale matmul_v2_scale fuse pass

* add scaletensor judge

* modify var name

* add timeout notest;test=coverag

* fix error commit

* fix use_mkldnn attr

* fix use_mkldnn attr
上级 23d9e947
......@@ -97,6 +97,7 @@ pass_library(adaptive_pool2d_convert_global_pass inference)
pass_library(unsqueeze2_eltwise_fuse_pass inference)
pass_library(layer_norm_fuse_pass inference)
pass_library(add_support_int8_pass inference)
pass_library(matmul_scale_fuse_pass inference)
pass_library(generate_pass DEPS pass_desc_proto)
target_link_libraries(generate_pass pass_desc_proto)
if(WITH_GPU OR WITH_ROCM)
......
......@@ -1699,6 +1699,49 @@ PDNode *patterns::MatmulV2::operator()() {
return matmul_v2_out;
}
PDNode *patterns::MatmulScale::operator()() {
auto matmul_op = pattern->NewNode(matmul_op_repr())->assert_is_op("matmul");
auto matmul_in_x = pattern->NewNode(matmul_in_x_repr())
->AsInput()
->assert_is_op_input("matmul", "X");
auto matmul_in_y = pattern->NewNode(matmul_in_y_repr())
->AsInput()
->assert_is_op_input("matmul", "Y");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto scale_in_x = pattern->NewNode(scale_in_x_repr())
->assert_is_op_output("matmul", "Out")
->assert_is_op_input("scale", "X");
auto scale_out = pattern->NewNode(scale_out_repr())
->AsOutput()
->assert_is_op_output("scale", "Out");
matmul_op->LinksFrom({matmul_in_x, matmul_in_y}).LinksTo({scale_in_x});
scale_op->LinksFrom({scale_in_x}).LinksTo({scale_out});
return scale_out;
}
PDNode *patterns::MatmulV2Scale::operator()() {
auto matmul_v2_op =
pattern->NewNode(matmul_v2_op_repr())->assert_is_op("matmul_v2");
auto matmul_v2_in_x = pattern->NewNode(matmul_v2_in_x_repr())
->AsInput()
->assert_is_op_input("matmul_v2", "X");
auto matmul_v2_in_y = pattern->NewNode(matmul_v2_in_y_repr())
->AsInput()
->assert_is_persistable_var() // Y is weight
->assert_is_op_input("matmul_v2", "Y");
auto scale_op = pattern->NewNode(scale_op_repr())->assert_is_op("scale");
auto scale_in_x = pattern->NewNode(scale_in_x_repr())
->assert_is_op_output("matmul_v2", "Out")
->assert_is_op_input("scale", "X");
auto scale_out = pattern->NewNode(scale_out_repr())
->AsOutput()
->assert_is_op_output("scale", "Out");
matmul_v2_op->LinksFrom({matmul_v2_in_x, matmul_v2_in_y})
.LinksTo({scale_in_x});
scale_op->LinksFrom({scale_in_x}).LinksTo({scale_out});
return scale_out;
}
PDNode *patterns::Squeeze2Matmul::operator()() {
auto squeeze2_in_x = pattern->NewNode(squeeze2_in_x_repr())
->assert_is_op_input("squeeze2", "X")
......
......@@ -1032,6 +1032,36 @@ struct MatmulV2 : public PatternBase {
PATTERN_DECL_NODE(matmul_v2_out);
};
// Matmul + scale
// Forward pass.
struct MatmulScale : public PatternBase {
MatmulScale(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_scale") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_in_x);
PATTERN_DECL_NODE(matmul_in_y);
PATTERN_DECL_NODE(matmul_op);
PATTERN_DECL_NODE(scale_in_x);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};
// Matmul_v2 + scale
// Forward pass.
struct MatmulV2Scale : public PatternBase {
MatmulV2Scale(PDPattern* pattern, const std::string& name_scope)
: PatternBase(pattern, name_scope, "matmul_v2_scale") {}
PDNode* operator()();
PATTERN_DECL_NODE(matmul_v2_in_x);
PATTERN_DECL_NODE(matmul_v2_in_y);
PATTERN_DECL_NODE(matmul_v2_op);
PATTERN_DECL_NODE(scale_in_x);
PATTERN_DECL_NODE(scale_op);
PATTERN_DECL_NODE(scale_out);
};
// Squeeze2 + Matmul
// Forward pass.
struct Squeeze2Matmul : public PatternBase {
......
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/ir/matmul_scale_fuse_pass.h"
#include <cmath>
#include <string>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/platform/enforce.h"
namespace paddle {
namespace framework {
namespace ir {
class Node;
MatmulScaleFusePass::MatmulScaleFusePass() {
AddOpCompat(OpCompat("matmul"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("transpose_X")
.IsType<bool>()
.End()
.AddAttr("transpose_Y")
.IsType<bool>()
.End()
.AddAttr("alpha")
.IsType<float>()
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("ScaleTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("bias_after_scale")
.IsType<bool>()
.End()
.AddAttr("scale")
.End()
.AddAttr("bias")
.IsNumEQ(0.0f)
.End();
}
MatmulV2ScaleFusePass::MatmulV2ScaleFusePass() {
AddOpCompat(OpCompat("matmul_v2"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("Y")
.IsTensor()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("trans_x")
.IsType<bool>()
.End()
.AddAttr("trans_y")
.IsType<bool>()
.End();
AddOpCompat(OpCompat("scale"))
.AddInput("X")
.IsTensor()
.End()
.AddInput("ScaleTensor")
.IsTensor()
.IsOptional()
.End()
.AddOutput("Out")
.IsTensor()
.End()
.AddAttr("bias_after_scale")
.IsType<bool>()
.End()
.AddAttr("scale")
.End()
.AddAttr("bias")
.IsNumEQ(0.0f)
.End();
}
void MatmulScaleFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "matmul_scale_fuse";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::MatmulScale matmul_scale_pattern(gpd.mutable_pattern(), name_scope);
matmul_scale_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "matmul_scale_fuse pass";
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_x, matmul_in_x, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_in_y, matmul_in_y, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_op, matmul_op, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_in_x, scale_in_x, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, matmul_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, matmul_scale_pattern);
auto* scope = param_scope();
float bias = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"));
if (std::abs(bias) > 1e-5) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "matmul_scale_fuse_pass in op compat failed.";
return;
}
float scale = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
float matmul_alpha =
BOOST_GET_CONST(float, matmul_op->Op()->GetAttr("alpha"));
auto const& names = scale_op->Op()->InputNames();
bool has_scale_tensor =
std::find(names.begin(), names.end(), "ScaleTensor") != names.end();
if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) {
std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front();
auto* scale_var = scope->FindVar(scale_var_name);
// ScaleTensor must be weight
if (scale_var == nullptr) return;
auto* scale_tensor = scale_var->GetMutable<LoDTensor>();
scale = *(scale_tensor->data<float>());
}
OpDesc* matmul_desc = matmul_op->Op();
matmul_desc->SetAttr("alpha", scale * matmul_alpha);
matmul_desc->SetOutput("Out", {scale_out->Name()});
if (!IsCompat(*matmul_desc)) {
LOG(WARNING) << "matmul_scale_fuse_pass in out mul op compat failed.";
return;
}
IR_NODE_LINK_TO(matmul_op, scale_out);
GraphSafeRemoveNodes(graph, {scale_in_x, scale_op});
++found_count;
};
gpd(graph, handler);
AddStatis(found_count);
}
void MatmulV2ScaleFusePass::ApplyImpl(ir::Graph* graph) const {
PADDLE_ENFORCE_NOT_NULL(
graph, platform::errors::InvalidArgument("Graph cannot be nullptr."));
std::string name_scope = "matmul_v2_scale_fuse";
FusePassBase::Init(name_scope, graph);
GraphPatternDetector gpd;
patterns::MatmulV2Scale matmul_v2_scale_pattern(gpd.mutable_pattern(),
name_scope);
matmul_v2_scale_pattern();
int found_count = 0;
auto handler = [&](const GraphPatternDetector::subgraph_t& subgraph,
Graph* g) {
VLOG(4) << "matmul_v2_scale_fuse pass";
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_x, matmul_v2_in_x,
matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_in_y, matmul_v2_in_y,
matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(matmul_v2_op, matmul_v2_op,
matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_in_x, scale_in_x, matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_op, scale_op, matmul_v2_scale_pattern);
GET_IR_NODE_FROM_SUBGRAPH(scale_out, scale_out, matmul_v2_scale_pattern);
auto* scope = param_scope();
float bias = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("bias"));
if (std::abs(bias) > 1e-5) return;
if (!IsCompat(subgraph, g)) {
LOG(WARNING) << "matmul_v2_scale_fuse_pass in op compat failed.";
return;
}
float scale = BOOST_GET_CONST(float, scale_op->Op()->GetAttr("scale"));
auto const& names = scale_op->Op()->InputNames();
bool has_scale_tensor =
std::find(names.begin(), names.end(), "ScaleTensor") != names.end();
if (has_scale_tensor && scale_op->Op()->Input("ScaleTensor").size() > 0) {
std::string scale_var_name = scale_op->Op()->Input("ScaleTensor").front();
auto* scale_var = scope->FindVar(scale_var_name);
// ScaleTensor must be weight
if (scale_var == nullptr) return;
auto* scale_tensor = scale_var->GetMutable<LoDTensor>();
scale = *(scale_tensor->data<float>());
}
auto* matmul_y =
scope->FindVar(matmul_v2_in_y->Name())->GetMutable<LoDTensor>();
auto y_data = matmul_y->mutable_data<float>(platform::CPUPlace());
for (int i = 0; i < matmul_y->numel(); ++i) {
y_data[i] *= scale;
}
OpDesc* matmul_v2_desc = matmul_v2_op->Op();
matmul_v2_desc->SetOutput("Out", {scale_out->Name()});
if (!IsCompat(*matmul_v2_desc)) {
LOG(WARNING) << "matmul_v2_scale_fuse_pass in out mul op compat failed.";
return;
}
IR_NODE_LINK_TO(matmul_v2_op, scale_out);
GraphSafeRemoveNodes(graph, {scale_in_x, scale_op});
++found_count;
};
gpd(graph, handler);
AddStatis(found_count);
}
} // namespace ir
} // namespace framework
} // namespace paddle
REGISTER_PASS(matmul_scale_fuse_pass,
paddle::framework::ir::MatmulScaleFusePass);
REGISTER_PASS_CAPABILITY(matmul_scale_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.LE("matmul", 1)
.EQ("scale", 0));
REGISTER_PASS(matmul_v2_scale_fuse_pass,
paddle::framework::ir::MatmulV2ScaleFusePass);
REGISTER_PASS_CAPABILITY(matmul_v2_scale_fuse_pass)
.AddCombination(
paddle::framework::compatible::OpVersionComparatorCombination()
.EQ("matmul_v2", 0)
.EQ("scale", 0));
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/ir/fuse_pass_base.h"
#include "paddle/fluid/framework/ir/graph.h"
#include "paddle/fluid/framework/ir/graph_pattern_detector.h"
#include "paddle/fluid/framework/ir/pass.h"
namespace paddle {
namespace framework {
namespace ir {
class Graph;
/*
* Fuse the matmul and scale to a matmul.
*/
class MatmulScaleFusePass : public FusePassBase {
public:
MatmulScaleFusePass();
virtual ~MatmulScaleFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
/*
* Fuse the matmul_v2 and scale to a matmul_v2.
*/
class MatmulV2ScaleFusePass : public FusePassBase {
public:
MatmulV2ScaleFusePass();
virtual ~MatmulV2ScaleFusePass() {}
protected:
void ApplyImpl(Graph* graph) const override;
};
} // namespace ir
} // namespace framework
} // namespace paddle
......@@ -205,8 +205,10 @@ CpuPassStrategy::CpuPassStrategy() : PassStrategy({}) {
"squeeze2_matmul_fuse_pass", //
"reshape2_matmul_fuse_pass", //
"flatten2_matmul_fuse_pass", //
"matmul_v2_scale_fuse_pass", //
"map_matmul_v2_to_mul_pass", //
"map_matmul_v2_to_matmul_pass", //
"matmul_scale_fuse_pass", //
"map_matmul_to_mul_pass", //
"fc_fuse_pass", //
"repeated_fc_relu_fuse_pass", //
......@@ -258,8 +260,8 @@ void CpuPassStrategy::EnableMKLDNN() {
"matmul_transpose_reshape_fuse_pass", //
"matmul_v2_transpose_reshape_fuse_pass", //
// Disabled due to topology-dependent speed-up
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
// "fc_mkldnn_pass",
// "fc_act_mkldnn_fuse_pass",
"batch_norm_act_fuse_pass", //
"softplus_activation_mkldnn_fuse_pass", //
// TODO(intel): Please fix the bug on windows.
......
......@@ -83,7 +83,10 @@ if (WITH_MKLDNN AND TENSORRT_FOUND AND WITH_GPU)
set_tests_properties(test_conv_act_mkldnn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_eltwiseadd_bn_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add2_act_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 120)
set_tests_properties(test_conv_elementwise_add_act_fuse_pass PROPERTIES TIMEOUT 90)
set_tests_properties(test_matmul_scale_fuse_pass PROPERTIES TIMEOUT 60)
set_tests_properties(test_matmul_v2_scale_fuse_pass PROPERTIES TIMEOUT 60)
endif()
if (WITH_MKLDNN)
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestMatmulScaleFusePass(PassAutoScanTest):
"""
x_var y_var(persistable)
\ /
matmul
|
scale
"""
def sample_predictor_configs(self, program_config):
# cpu
config = self.create_inference_config(use_gpu=False)
yield config, ["matmul", ], (1e-5, 1e-5)
# mkldnn
config = self.create_inference_config(use_mkldnn=True)
yield config, ["matmul", ], (1e-5, 1e-5)
def sample_program_config(self, draw):
# 1. Generate shape and attr of matmul
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=5))
x_shape_rank = len(x_shape)
y_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8),
min_size=x_shape_rank,
max_size=x_shape_rank))
y_shape_rank = len(y_shape)
y_shape[-2] = x_shape[-1]
for i in range(y_shape_rank - 3, -1, -1):
j = x_shape_rank - (y_shape_rank - i)
if j < 0 or j >= x_shape_rank:
break
y_shape[i] = x_shape[j]
transpose_X = False
transpose_Y = False
alpha = draw(st.floats(min_value=-2.0, max_value=2.0, width=32))
# scale tensor
scale_shape = [1]
scale_value = draw(st.floats(min_value=-5.0, max_value=5.0, width=32))
matmul_op = OpConfig(
"matmul",
inputs={"X": ["matmul_x"],
"Y": ["matmul_y"]},
outputs={"Out": ["matmul_out"]},
transpose_X=transpose_X,
transpose_Y=transpose_Y,
alpha=alpha,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[],
head_number=1, )
is_scale_tensor = draw(st.booleans())
if is_scale_tensor:
scale_op = OpConfig(
"scale",
inputs={"X": ["matmul_out"],
"ScaleTensor": ["scale_tensor"]},
outputs={"Out": ["scale_out"]},
scale=scale_value,
bias=0.0,
bias_after_scale=draw(st.booleans()), )
else:
scale_op = OpConfig(
"scale",
inputs={"X": ["matmul_out"], },
outputs={"Out": ["scale_out"]},
scale=scale_value,
bias=0.0,
bias_after_scale=draw(st.booleans()), )
ops = [matmul_op, scale_op]
weights = {}
inputs = {}
if is_scale_tensor:
weights = {
"matmul_y": TensorConfig(shape=y_shape),
"scale_tensor": TensorConfig(shape=scale_shape)
}
inputs = {"matmul_x": TensorConfig(shape=x_shape), }
else:
inputs = {
"matmul_x": TensorConfig(shape=x_shape),
"matmul_y": TensorConfig(shape=y_shape),
}
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs=inputs,
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["matmul_scale_fuse_pass"], )
if __name__ == "__main__":
unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from auto_scan_test import PassAutoScanTest, IgnoreReasons
from program_config import TensorConfig, ProgramConfig, OpConfig
import numpy as np
import paddle.inference as paddle_infer
from functools import partial
from typing import Optional, List, Callable, Dict, Any, Set
import unittest
import hypothesis
from hypothesis import given, settings, seed, example, assume, reproduce_failure
import hypothesis.strategies as st
class TestMatmulV2ScaleFusePass(PassAutoScanTest):
"""
x_var y_var(persistable) x_var y_var*scale(persistable)
\ / \ /
matmul_v2 matmul_v2
| => |
scale scale_out
scale_out
"""
def sample_predictor_configs(self, program_config):
# for cpu
# config = self.create_inference_config(use_gpu=False)
# yield config, ["matmul_v2", ], (1e-5, 1e-5)
# mkldnn
config = self.create_inference_config(use_mkldnn=True)
yield config, ["matmul_v2", ], (1e-5, 1e-5)
def sample_program_config(self, draw):
# 1. Generate shape and attr of matmul
x_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8), min_size=2, max_size=5))
x_shape_rank = len(x_shape)
y_shape = draw(
st.lists(
st.integers(
min_value=1, max_value=8),
min_size=x_shape_rank,
max_size=x_shape_rank))
y_shape_rank = len(y_shape)
y_shape[-2] = x_shape[-1]
for i in range(y_shape_rank - 3, -1, -1):
j = x_shape_rank - (y_shape_rank - i)
if j < 0 or j >= x_shape_rank:
break
y_shape[i] = x_shape[j]
transpose_X = False
transpose_Y = False
# scale tensor
scale_shape = [1]
scale_value = draw(st.floats(min_value=-5.0, max_value=5.0, width=32))
matmul_v2_op = OpConfig(
"matmul_v2",
inputs={"X": ["matmul_x"],
"Y": ["matmul_y"]},
outputs={"Out": ["matmul_out"]},
trans_x=transpose_X,
trans_y=transpose_Y,
fused_reshape_X=[],
fused_reshape_Y=[],
fused_transpose_X=[],
fused_transpose_Y=[],
fused_reshape_Out=[],
fused_transpose_Out=[], )
is_scale_tensor = draw(st.booleans())
if is_scale_tensor:
scale_op = OpConfig(
"scale",
inputs={"X": ["matmul_out"],
"ScaleTensor": ["scale_tensor"]},
outputs={"Out": ["scale_out"]},
scale=scale_value,
bias=0.0,
bias_after_scale=draw(st.booleans()), )
else:
scale_op = OpConfig(
"scale",
inputs={"X": ["matmul_out"], },
outputs={"Out": ["scale_out"]},
scale=scale_value,
bias=0.0,
bias_after_scale=draw(st.booleans()), )
ops = [matmul_v2_op, scale_op]
weights = {"matmul_y": TensorConfig(shape=y_shape), }
if is_scale_tensor:
weights["scale_tensor"] = TensorConfig(shape=scale_shape)
inputs = {"matmul_x": TensorConfig(shape=x_shape), }
program_config = ProgramConfig(
ops=ops,
weights=weights,
inputs=inputs,
outputs=ops[-1].outputs["Out"], )
return program_config
def test(self):
self.run_and_statis(
quant=False,
max_examples=100,
passes=["matmul_v2_scale_fuse_pass"], )
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册