未验证 提交 339c2e53 编写于 作者: C cc 提交者: GitHub

Weight quantization skip conv_conv_fuse_pass, test=develop (#4292)

上级 194e5a76
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/core/mir/fusion/conv_conv_fuse_pass.h"
#include <list>
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/conv_conv_fuser.h"
......@@ -27,13 +28,10 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// initialze fuser params
std::vector<bool> conv_has_bias_cases{true, false};
std::vector<std::string> conv_type_cases{"conv2d", "depthwise_conv2d"};
bool has_fp32 = false;
bool has_int8 = false;
bool has_weight_quant = false;
for (auto& place : graph->valid_places()) {
if (place.target == TARGET(kARM) || place.target == TARGET(kHost)) {
if (place.precision == PRECISION(kFloat)) {
has_fp32 = true;
}
if (place.precision == PRECISION(kInt8)) {
has_int8 = true;
}
......@@ -42,8 +40,18 @@ void ConvConvFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
return;
}
}
const std::list<mir::Node>& nodes = graph->nodes();
for (auto& node : nodes) {
if (node.IsStmt()) {
auto* op_info = (node.stmt())->op_info();
if (op_info->HasAttr("quantization_type")) {
has_weight_quant = true;
break;
}
}
}
// only support arm-fp32
if (has_int8 || (has_fp32 && has_int8)) {
if (has_int8 || has_weight_quant) {
return;
}
// only support fp32 fusion
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册