未验证 提交 339c2e53 编写于 作者: C cc 提交者: GitHub

Weight quantization skip conv_conv_fuse_pass, test=develop (#4292)

上级 194e5a76
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h" #include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <list>
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h" #include "lite/core/mir/fusion/conv_conv_fuser.h"
...@@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params // initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false}; std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"}; std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
bool has_fp32 = false;
bool has_int8 = false; bool has_int8 = false;
bool has_weight_quant = false;
for (auto& place : graph->valid_places()) { for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) { if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) {
if (place.precision == PRECISION(kFloat)) {
has_fp32 = true;
}
if (place.precision == PRECISION(kInt8)) { if (place.precision == PRECISION(kInt8)) {
has_int8 = true; has_int8 = true;
} }
...@@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
return; return;
} }
} }
const std::list<mir::Node>& nodes = graph->nodes();
for (auto& node : nodes) {
if (node.IsStmt()) {
auto* op_info = (node.stmt())->op_info();
if (op_info->HasAttr("quantization_type")) {
has_weight_quant = true;
break;
}
}
}
// only support arm-fp32 // only support arm-fp32
if (has_int8 || (has_fp32 && has_int8)) { if (has_int8 || has_weight_quant) {
return; return;
} }
// only support fp32 fusion // only support fp32 fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册