提交 03fd37e4 编写于 作者: J jiweibo

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle-Lite into add_matmul_op

...@@ -29,6 +29,7 @@ USE_MIR_PASS(graph_visualze); ...@@ -29,6 +29,7 @@ USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass); USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass); USE_MIR_PASS(lite_conv_activation_fuse_pass);
......
...@@ -11,6 +11,7 @@ add_subdirectory(subgraph) ...@@ -11,6 +11,7 @@ add_subdirectory(subgraph)
lite_cc_library(mir_passes lite_cc_library(mir_passes
SRCS SRCS
fusion/fc_fuse_pass.cc fusion/fc_fuse_pass.cc
fusion/shuffle_channel_fuse_pass.cc
fusion/conv_elementwise_fuse_pass.cc fusion/conv_elementwise_fuse_pass.cc
fusion/conv_activation_fuse_pass.cc fusion/conv_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc fusion/conv_bn_fuse_pass.cc
......
lite_cc_library(fuse_fc lite_cc_library(fuse_fc
SRCS fc_fuser.cc SRCS fc_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
lite_cc_library(fuse_shuffle_channel
SRCS shuffle_channel_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_conv_elementwise lite_cc_library(fuse_conv_elementwise
SRCS conv_elementwise_fuser.cc SRCS conv_elementwise_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
...@@ -19,6 +22,7 @@ lite_cc_library(fuse_quant_dequant ...@@ -19,6 +22,7 @@ lite_cc_library(fuse_quant_dequant
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
fuse_shuffle_channel
fuse_conv_elementwise fuse_conv_elementwise
fuse_conv_activation fuse_conv_activation
fuse_conv_bn fuse_conv_bn
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/shuffle_channel_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/shuffle_channel_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ShuffleChannelFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ShuffleChannelFuser fuser("reshape", "transpose");
fuser(graph.get());
fusion::ShuffleChannelFuser fuser2("reshape2", "transpose2");
fuser2(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_shuffle_channel_fuse_pass,
paddle::lite::mir::ShuffleChannelFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ShuffleChannelFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/shuffle_channel_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ShuffleChannelFuser::BuildPattern() {
// create nodes.
auto* x1 = VarNode("x1")->assert_is_op_input(reshape_type_, "X");
auto* y1 = VarNode("y1")->assert_is_op_output(reshape_type_, "Out");
auto* y2 = VarNode("y2")->assert_is_op_output(transpose_type_, "Out");
auto* out = VarNode("out")->assert_is_op_output(reshape_type_, "Out");
auto* xshape1 =
VarNode("xshape1")->assert_is_op_output(reshape_type_, "XShape");
auto* xshape2 =
VarNode("xshape2")->assert_is_op_output(transpose_type_, "XShape");
auto* xshape3 =
VarNode("xshape3")->assert_is_op_output(reshape_type_, "XShape");
auto* reshape1 = OpNode("reshape1", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] > 0;
});
auto* transpose =
OpNode("transpose_op", transpose_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"axis", [](const std::vector<int>& attr) {
return attr.size() >= 5 && attr[1] == 2 && attr[2] == 1;
});
auto* reshape2 = OpNode("reshape2", reshape_type_)
->assert_op_attr_satisfied<std::vector<int>>(
"shape", [](const std::vector<int>& attr) {
return attr.size() >= 4;
});
// create topology.
*x1 >> *reshape1 >> *y1 >> *transpose >> *y2 >> *reshape2 >> *out;
*reshape1 >> *xshape1;
*transpose >> *xshape2;
*reshape2 >> *xshape3;
// Some op specialities.
y1->AsIntermediate();
y2->AsIntermediate();
xshape1->AsIntermediate();
xshape2->AsIntermediate();
xshape3->AsIntermediate();
reshape1->AsIntermediate();
transpose->AsIntermediate();
reshape2->AsIntermediate();
}
void ShuffleChannelFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto shuffle_channel_op = LiteOpRegistry::Global().Create("shuffle_channel");
auto transpose = matched.at("transpose_op")->stmt()->op();
auto* scope = transpose->scope();
auto& valid_places = transpose->valid_places();
shuffle_channel_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(shuffle_channel_op, valid_places);
IR_NODE_LINK_TO(matched.at("x1"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc ShuffleChannelFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("shuffle_channel");
op_desc.SetInput("X", {matched.at("x1")->arg()->name});
op_desc.SetOutput("Out", {matched.at("out")->arg()->name});
op_desc.SetAttr("group",
matched.at("reshape1")
->stmt()
->op_info()
->GetAttr<std::vector<int>>("shape")[1]);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ShuffleChannelFuser : public FuseBase {
public:
explicit ShuffleChannelFuser(const std::string& reshape_type,
const std::string& transpose_type)
: reshape_type_(reshape_type), transpose_type_(transpose_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string reshape_type_;
std::string transpose_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -141,18 +141,26 @@ struct PMNode { ...@@ -141,18 +141,26 @@ struct PMNode {
int nth); int nth);
template <typename T> template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { PMNode* assert_op_attr_satisfied(
const std::string& attr_name,
const std::function<bool(const T&)>& condition) {
asserts_.push_back([=](const Node* x) { asserts_.push_back([=](const Node* x) {
if (x && x->IsStmt()) { if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info(); auto* op_info = x->stmt()->op_info();
return op_info->HasAttr(attr_name) && return op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr; condition(op_info->GetAttr<T>(attr_name));
} }
return false; return false;
}); });
return this; return this;
} }
template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
return assert_op_attr_satisfied<T>(
attr_name, [&](const T& src) { return src == attr; });
}
private: private:
PMNode(PMPattern* pattern, PMNode(PMPattern* pattern,
const std::string& name = "", const std::string& name = "",
......
...@@ -64,6 +64,7 @@ class Optimizer { ...@@ -64,6 +64,7 @@ class Optimizer {
"lite_conv_elementwise_fuse_pass", // "lite_conv_elementwise_fuse_pass", //
"lite_conv_activation_fuse_pass", // "lite_conv_activation_fuse_pass", //
"lite_fc_fuse_pass", // "lite_fc_fuse_pass", //
"lite_shuffle_channel_fuse_pass", //
"identity_scale_eliminate_pass", // "identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", // "lite_elementwise_add_activation_fuse_pass", //
......
...@@ -22,6 +22,9 @@ namespace operators { ...@@ -22,6 +22,9 @@ namespace operators {
template <typename Dtype, typename T> template <typename Dtype, typename T>
void Reshape2Op<Dtype, T>::InferShape() const { void Reshape2Op<Dtype, T>::InferShape() const {
if (this->param_.InputShape() != nullptr) {
return;
}
auto &shape = this->param_.Shape(); auto &shape = this->param_.Shape();
auto input_x_dims = this->param_.InputX()->dims(); auto input_x_dims = this->param_.InputX()->dims();
#ifdef PADDLE_MOBILE_CL #ifdef PADDLE_MOBILE_CL
......
...@@ -214,10 +214,6 @@ if (NOT FOUND_MATCH) ...@@ -214,10 +214,6 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test_yolo_combined net/test_yolo_combined.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test_yolo_combined paddle-mobile) target_link_libraries(test_yolo_combined paddle-mobile)
# gen test
ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-net paddle-mobile)
# gen test # gen test
ADD_EXECUTABLE(test-op-in-net net/test_op_in_net.cpp test_helper.h test_include.h executor_for_test.h) ADD_EXECUTABLE(test-op-in-net net/test_op_in_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-op-in-net paddle-mobile) target_link_libraries(test-op-in-net paddle-mobile)
...@@ -527,4 +523,8 @@ if (NOT FOUND_MATCH) ...@@ -527,4 +523,8 @@ if (NOT FOUND_MATCH)
ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h) ADD_EXECUTABLE(test-net-benchmark net/test_net_benchmark.cpp test_helper.h test_include.h)
target_link_libraries(test-net-benchmark paddle-mobile) target_link_libraries(test-net-benchmark paddle-mobile)
# gen test
ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-net paddle-mobile)
endif () endif ()
...@@ -93,6 +93,8 @@ void test(int argc, char *argv[]) { ...@@ -93,6 +93,8 @@ void test(int argc, char *argv[]) {
var_names.push_back(var_name); var_names.push_back(var_name);
} }
arg_index += var_count; arg_index += var_count;
bool check_shape = std::stoi(argv[arg_index]) == 1;
arg_index++;
auto time1 = time(); auto time1 = time();
if (paddle_mobile.Load("./checked_model/model", "./checked_model/params", if (paddle_mobile.Load("./checked_model/model", "./checked_model/params",
...@@ -194,6 +196,11 @@ void test(int argc, char *argv[]) { ...@@ -194,6 +196,11 @@ void test(int argc, char *argv[]) {
auto data = tensor_data; auto data = tensor_data;
std::string sample = ""; std::string sample = "";
if (check_shape) {
for (int i = 0; i < cl_image->dims().size(); i++) {
sample += " " + std::to_string(cl_image->dims()[i]);
}
}
if (!is_sample_step) { if (!is_sample_step) {
sample_step = len / sample_num; sample_step = len / sample_num;
} }
...@@ -219,6 +226,11 @@ void test(int argc, char *argv[]) { ...@@ -219,6 +226,11 @@ void test(int argc, char *argv[]) {
if (out->type() == type_id<int>()) { if (out->type() == type_id<int>()) {
auto data = out->data<int>(); auto data = out->data<int>();
std::string sample = ""; std::string sample = "";
if (check_shape) {
for (int i = 0; i < out->dims().size(); i++) {
sample += " " + std::to_string(out->dims()[i]);
}
}
if (!is_sample_step) { if (!is_sample_step) {
sample_step = len / sample_num; sample_step = len / sample_num;
} }
...@@ -233,6 +245,11 @@ void test(int argc, char *argv[]) { ...@@ -233,6 +245,11 @@ void test(int argc, char *argv[]) {
} else if (out->type() == type_id<float>()) { } else if (out->type() == type_id<float>()) {
auto data = out->data<float>(); auto data = out->data<float>();
std::string sample = ""; std::string sample = "";
if (check_shape) {
for (int i = 0; i < out->dims().size(); i++) {
sample += " " + std::to_string(out->dims()[i]);
}
}
if (!is_sample_step) { if (!is_sample_step) {
sample_step = len / sample_num; sample_step = len / sample_num;
} }
......
...@@ -19,6 +19,9 @@ sample_step = 1 ...@@ -19,6 +19,9 @@ sample_step = 1
sample_num = 20 sample_num = 20
need_encrypt = False need_encrypt = False
checked_encrypt_model_path = "checked_encrypt_model" checked_encrypt_model_path = "checked_encrypt_model"
output_var_filter = []
output_key_filter = {}
check_shape = False
np.set_printoptions(linewidth=150) np.set_printoptions(linewidth=150)
...@@ -282,6 +285,8 @@ def save_all_op_output(feed_kv=None): ...@@ -282,6 +285,8 @@ def save_all_op_output(feed_kv=None):
for fetch in fetches: for fetch in fetches:
fetch_names.append(fetch.name) fetch_names.append(fetch.name)
feed_names = feeds feed_names = feeds
for fetch_name in fetch_names:
output_var_filter.append(fetch_name)
for i in range(len(ops)): for i in range(len(ops)):
op = ops[i] op = ops[i]
var_name = None var_name = None
...@@ -297,6 +302,53 @@ def save_all_op_output(feed_kv=None): ...@@ -297,6 +302,53 @@ def save_all_op_output(feed_kv=None):
var_name = name var_name = name
if "tmp" in name: if "tmp" in name:
break break
if len(output_var_filter) > 0:
if var_name not in output_var_filter:
continue
# real_var_name = None
# if op.type == "fetch":
# for name in op.input_arg_names:
# real_var_name = name
# if "tmp" in name:
# break
# else:
# real_var_name = var_name
if fast_check:
if var_name not in fetch_names and var_name not in feed_names:
continue
try:
data = get_var_data(var_name, feed_kv=feed_kv).flatten().tolist()
sample = tensor_sample(data)
output_var_cache[var_name] = (sample)
op_cache[i] = (var_name, op)
file_name = var_name.replace("/", "_")
out_file = open(output_path + "/" + file_name, "w")
if var_name in feed_names:
for item in data:
out_file.write("{}\n".format(item))
else:
for item in sample:
out_file.write("{}\n".format(item))
out_file.close()
except:
pass
for i in range(len(ops)):
op = ops[i]
if op.type not in output_key_filter:
continue
var_name = None
var_name_index = -1
for index in range(len(op.output_names)):
if op.output_names[index] in output_key_filter[op.type]:
var_name_index = index
break
if var_name_index != -1:
var_name = op.output_arg_names[var_name_index]
else:
continue
if len(output_var_filter) > 0:
if var_name not in output_var_filter:
continue
# real_var_name = None # real_var_name = None
# if op.type == "fetch": # if op.type == "fetch":
# for name in op.input_arg_names: # for name in op.input_arg_names:
...@@ -386,12 +438,19 @@ def check_mobile_results(args, fuse, mem_opt): ...@@ -386,12 +438,19 @@ def check_mobile_results(args, fuse, mem_opt):
continue continue
values1 = output_var_cache[op_output_var_name] values1 = output_var_cache[op_output_var_name]
values2 = mobile_var_cache[op_output_var_name] values2 = mobile_var_cache[op_output_var_name]
if len(values1) != len(values2): shape = get_var_shape(op_output_var_name) if check_shape else []
if len(values1) + len(shape) != len(values2):
error_index = index error_index = index
for i in range(len(shape)):
v1 = shape[i]
v2 = values2[i]
if v1 != v2:
error_index = index
break
if error_index == None: if error_index == None:
for i in range(len(values1)): for i in range(len(values1)):
v1 = values1[i] v1 = values1[i]
v2 = values2[i] v2 = values2[len(shape) + i]
if abs(v1 - v2) > diff_threshold: if abs(v1 - v2) > diff_threshold:
error_index = index error_index = index
break break
...@@ -496,6 +555,7 @@ def main(): ...@@ -496,6 +555,7 @@ def main():
args += " " + str(sample_num) args += " " + str(sample_num)
for var_name in output_var_cache.keys(): for var_name in output_var_cache.keys():
args += " " + var_name args += " " + var_name
args += " " + str(1 if check_shape else 0)
if not fast_check: if not fast_check:
check_mobile_results(args, False, False) check_mobile_results(args, False, False)
check_mobile_results(args, False, True) check_mobile_results(args, False, True)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册