提交 234fb8f4 编写于 作者: Z Zhen Wang 提交者: Yan Chunwei

Add Fc fuse pass (#17994)

上级 016445c4
...@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") ...@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}")
set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install")
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) function(lite_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
...@@ -161,13 +162,13 @@ function(lite_cc_test TARGET) ...@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
file(APPEND ${offline_test_registry_file} "${TARGET}\n") file(APPEND ${offline_test_registry_file} "${TARGET}\n")
endfunction() endfunction()
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(arm) add_subdirectory(arm)
add_subdirectory(host) add_subdirectory(host)
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(model_parser) add_subdirectory(model_parser)
add_subdirectory(utils) add_subdirectory(utils)
add_subdirectory(api) add_subdirectory(api)
......
...@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA) ...@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda)
endif() endif()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite}) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} program_lite)
set(light_api_deps set(light_api_deps
scope_lite target_wrapper_host model_parser_lite) scope_lite target_wrapper_host model_parser_lite)
...@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}") ...@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
message(STATUS "get ARM kernels ${arm_kernels}") message(STATUS "get ARM kernels ${arm_kernels}")
include(ExternalProject) include(ExternalProject)
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.") "A path setting inference demo download directories.")
if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${ops_lite} ${host_kernels} ${x86_kernels}
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
......
...@@ -30,7 +30,10 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp ...@@ -30,7 +30,10 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
lite_cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto) lite_cc_library(program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
add_subdirectory(mir) add_subdirectory(mir)
......
...@@ -3,8 +3,10 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) ...@@ -3,8 +3,10 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion)
cc_library(mir_passes cc_library(mir_passes
SRCS static_kernel_pick_pass.cc SRCS fc_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_transform_pass.cc type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc io_copy_kernel_pick_pass.cc
...@@ -13,7 +15,7 @@ cc_library(mir_passes ...@@ -13,7 +15,7 @@ cc_library(mir_passes
argument_type_display_pass.cc argument_type_display_pass.cc
demo_pass.cc demo_pass.cc
runtime_context_assign_pass.cc runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite) DEPS mir_pass types_lite context_lite mir_fusers)
# for mobile, unnecessary to compile the following testings. # for mobile, unnecessary to compile the following testings.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
...@@ -53,9 +55,22 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern ...@@ -53,9 +55,22 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite)
# TODO(wz) replace framework/proto to lite proto. # TODO(wz) replace framework/proto to lite proto.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution. # it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite ${ops_lite}) mir_passes compatible_pb_lite program_lite ${ops_lite})
endif() endif()
message(STATUS "----> Ops lite: ${ops_lite}")
message(STATUS "----> Host kernels: ${host_kernels}")
message(STATUS "----> X86 kernels: ${x86_kernels}")
lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model
--optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz")
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class FcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
TEST(fc_fuse_pass, fuse_test) {
lite::ExecutorLite predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
#else
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)},
});
#endif
predictor.Build(FLAGS_model_dir,
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({100, 100})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor.Run();
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
EXPECT_NEAR(out->data<float>()[0], 38.120617f, 1e-5);
EXPECT_NEAR(out->data<float>()[1], 10.109812f, 1e-5);
CHECK_EQ(out->dims()[0], 100);
CHECK_EQ(out->dims()[1], 500);
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(fc_fuse_pass, save_model_test) {
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(elementwise_add);
USE_LITE_OP(elementwise_sub);
USE_LITE_OP(fc);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_OP(softmax);
USE_LITE_OP(scale);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
cc_library(mir_fusers
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void FcFuser::BuildPattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
auto* Out = VarNode("Out");
// create topology.
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
add_inputs >> *add >> *Out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
}
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("b"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class FcFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -22,6 +22,7 @@ namespace mir {} // namespace mir ...@@ -22,6 +22,7 @@ namespace mir {} // namespace mir
} // namespace paddle } // namespace paddle
USE_MIR_PASS(demo); USE_MIR_PASS(demo);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(type_target_transform_pass); USE_MIR_PASS(type_target_transform_pass);
......
...@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) { ...@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
return *this; return *this;
} }
void operator>>(std::vector<PMNode *> &others, PMNode &me) { PMNode &operator>>(std::vector<PMNode *> &others, PMNode &me) {
for (auto *o : others) { for (auto *o : others) {
*o >> me; *o >> me;
} }
return me;
} }
PMNode *PMPattern::NewNode(const std::string &name) { PMNode *PMPattern::NewNode(const std::string &name) {
...@@ -422,6 +423,46 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) { ...@@ -422,6 +423,46 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return this; return this;
} }
PMNode *PMNode::assert_is_op_input(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_input(op_type, argument, 0);
return this;
}
PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) {
if (op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
IsNthInput(*x, *op, argument, nth))
return true;
}
return false;
});
return this;
}
bool IsNthInput(const Node &var, const Node &op, const std::string &argument,
int nth) {
CHECK(var.IsArg());
CHECK(op.IsStmt());
if (!HasInput(op, argument) ||
static_cast<int>(op.stmt()->op_info()->Input(argument).size()) <= nth)
return false;
return var.arg()->name == op.stmt()->op_info()->Input(argument)[nth];
}
bool HasInput(const Node &op, const std::string &argument) {
CHECK(op.IsStmt());
auto const &names = op.stmt()->op_info()->input_argnames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
void GraphSafeRemoveNodes(SSAGraph *graph, void GraphSafeRemoveNodes(SSAGraph *graph,
const std::unordered_set<const Node *> &nodes) { const std::unordered_set<const Node *> &nodes) {
for (auto *node : nodes) { for (auto *node : nodes) {
......
...@@ -62,7 +62,7 @@ struct PMNode { ...@@ -62,7 +62,7 @@ struct PMNode {
PMNode& operator>>(PMNode& right); PMNode& operator>>(PMNode& right);
// Link many nodes to this node. // Link many nodes to this node.
friend void operator>>(std::vector<PMNode*>& others, PMNode& me); friend PMNode& operator>>(std::vector<PMNode*>& others, PMNode& me);
// Link this to many other nodes. // Link this to many other nodes.
PMNode& operator>>(std::vector<PMNode*>& nodes); PMNode& operator>>(std::vector<PMNode*>& nodes);
...@@ -127,6 +127,10 @@ struct PMNode { ...@@ -127,6 +127,10 @@ struct PMNode {
PMNode* assert_is_persistable_var(); PMNode* assert_is_persistable_var();
PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_output(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type,
const std::string& argument);
PMNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
template <typename T> template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
...@@ -297,6 +301,13 @@ class PatternMatcher { ...@@ -297,6 +301,13 @@ class PatternMatcher {
std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_; std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_;
}; };
// Check whether a var node is a op node's nth input.
bool IsNthInput(const Node& var, const Node& op, const std::string& argument,
int nth);
// Check whether the op node has input of given name.
bool HasInput(const Node& op, const std::string& argument);
// Graph safely remove some nodes, will automatically clean up the edges. // Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(SSAGraph* graph, void GraphSafeRemoveNodes(SSAGraph* graph,
const std::unordered_set<const Node*>& nodes); const std::unordered_set<const Node*>& nodes);
......
...@@ -64,7 +64,6 @@ class FuseBase { ...@@ -64,7 +64,6 @@ class FuseBase {
// Delete nodes that are marked as Intermediate // Delete nodes that are marked as Intermediate
void DeleteInterNodes(SSAGraph* graph); void DeleteInterNodes(SSAGraph* graph);
private:
PMNode* GetOrCreateNode(const std::string& key); PMNode* GetOrCreateNode(const std::string& key);
protected: protected:
......
...@@ -29,8 +29,8 @@ class FcFuser : public FuseBase { ...@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
// create nodes. // create nodes.
auto* x = VarNode("x"); auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W"); auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b"); auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul"); auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out"); auto* mul_out = VarNode("mul_out");
...@@ -38,12 +38,10 @@ class FcFuser : public FuseBase { ...@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
auto* Out = VarNode("Out"); auto* Out = VarNode("Out");
// create topology. // create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out; std::vector<PMNode*> mul_inputs{W, x};
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out; std::vector<PMNode*> add_inputs{mul_out, b};
*W >> *mul; mul_inputs >> *mul >> *mul_out;
*x >> *mul >> *mul_out; add_inputs >> *add >> *Out;
*b >> *add;
*mul_out >> *add >> *Out;
// Some op specialities. // Some op specialities.
mul_out->AsIntermediate(); mul_out->AsIntermediate();
...@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block->Var("mul_out"); main_block->Var("mul_out");
main_block->Var("w"); main_block->Var("w");
main_block->Var("out"); main_block->Var("out");
main_block->Var("out1");
scope->Var("w")->GetMutable<lite::Tensor>(); scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>(); scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>(); scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>(); scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>(); scope->Var("out")->GetMutable<lite::Tensor>();
scope->Var("out1")->GetMutable<lite::Tensor>();
mul->SetInput("X", {"x"}); mul->SetInput("X", {"x"});
mul->SetInput("Y", {"w"}); mul->SetInput("Y", {"w"});
...@@ -122,18 +118,18 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -122,18 +118,18 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
return graph; return graph;
} }
TEST(pattern_matcher2, graph_test) { TEST(pattern_matcher_high_api, graph_test) {
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places); auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(), ASSERT_EQ(graph->nodes().size(),
8UL /*real nodes*/ + 2UL /*feed op + fetch op*/); 7UL /*real nodes*/ + 2UL /*feed op + fetch op*/);
Visualize(graph.get()); Visualize(graph.get());
} }
TEST(pattern_matcher2, test) { TEST(pattern_matcher_high_api, fuse_test) {
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
...@@ -143,6 +139,7 @@ TEST(pattern_matcher2, test) { ...@@ -143,6 +139,7 @@ TEST(pattern_matcher2, test) {
fuser(graph.get()); fuser(graph.get());
ASSERT_EQ(graph->nodes().size(), ASSERT_EQ(graph->nodes().size(),
num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/);
Visualize(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -49,6 +49,7 @@ class Optimizer { ...@@ -49,6 +49,7 @@ class Optimizer {
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
"lite_fc_fuse_pass", //
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
"argument_type_display_pass", // "argument_type_display_pass", //
......
...@@ -152,8 +152,8 @@ class BasicProfiler { ...@@ -152,8 +152,8 @@ class BasicProfiler {
} }
record_t *mutable_record(int id) { record_t *mutable_record(int id) {
CHECK_LT(id, records_.size());
CHECK_GE(id, 0); CHECK_GE(id, 0);
CHECK_LT(static_cast<size_t>(id), records_.size());
return &records_[id]; return &records_[id];
} }
......
...@@ -10,6 +10,4 @@ set(host_kernels ...@@ -10,6 +10,4 @@ set(host_kernels
feed_compute_host feed_compute_host
fetch_compute_host fetch_compute_host
reshape_compute_host reshape_compute_host
) CACHE INTERNAL "host kernels")
set(host_kernels "${host_kernels}" CACHE GLOBAL "host kernels")
...@@ -32,6 +32,4 @@ set(x86_kernels ...@@ -32,6 +32,4 @@ set(x86_kernels
concat_compute_x86 concat_compute_x86
conv_compute_x86 conv_compute_x86
pool_compute_x86 pool_compute_x86
) CACHE INTERNAL "x86 kernels")
set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels")
...@@ -27,8 +27,8 @@ namespace kernels { ...@@ -27,8 +27,8 @@ namespace kernels {
namespace x86 { namespace x86 {
template <typename T> template <typename T>
void fc_compute_eigen(const T* x, int x_w, int x_h, // void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_w, int w_h, // const T* w, int w_h, int w_w, //
const T* b, // const T* b, //
T* out) { T* out) {
using matrix_t = using matrix_t =
...@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, // ...@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
Eigen::Map<const matrix_t> X(x, x_h, x_w); Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w); Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_h); Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W.transpose(); Out = X * W;
if (b) { if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_h); Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_w);
Out = Out.array().rowwise() + B.transpose().array(); Out = Out.array().rowwise() + B.transpose().array();
} }
} }
template <typename T> template <typename T>
__attribute__((optimize("unroll-loops"))) // void fc_compute_naive(const T* x, int x_h, int x_w, //
T dot(const T* x, const T* y, int dim) { const T* w, int w_h, int w_w, //
T out{};
for (int i = 0; i < dim; i++) {
out += x[i] * y[i];
}
return out;
}
template <typename T>
void fc_compute_naive(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
const T* b, // const T* b, //
T* out) { T* out) {
CHECK_EQ(x_w, w_w); CHECK_EQ(x_w, w_h);
// out shape: (x_h, w_w) // out shape: (x_h, w_w)
memset(out, 0, x_h * w_h * sizeof(T)); memset(out, 0, x_h * w_w * sizeof(T));
for (int i = 0; i < x_h; i++) {
for (int r = 0; r < x_h; r++) { for (int j = 0; j < w_w; j++) {
for (int c = 0; c < w_h; c++) { T tmp = static_cast<T>(0);
out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c]; for (int k = 0; k < x_w; k++) {
tmp += x[i * x_w + k] * w[k * w_w + j];
}
out[i * w_w + j] = tmp + b[j];
} }
} }
} }
...@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
.Slice(param.in_num_col_dims, param.input->dims().size()) .Slice(param.in_num_col_dims, param.input->dims().size())
.production(), .production(),
param.w->data<T>(), // w param.w->data<T>(), // w
param.w->dims()[1], // w_w
param.w->dims()[0], // w_h param.w->dims()[0], // w_h
param.w->dims()[1], // w_w
param.bias->data<T>(), // b param.bias->data<T>(), // b
param.output->mutable_data<T>()); param.output->mutable_data<T>());
} }
......
...@@ -38,7 +38,7 @@ set(ops_lite ...@@ -38,7 +38,7 @@ set(ops_lite
concat_op_lite concat_op_lite
conv_op_lite conv_op_lite
pool_op_lite pool_op_lite
PARENT_SCOPE) CACHE INTERNAL "ops lite")
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
DEPS fc_op_lite memory_lite DEPS fc_op_lite memory_lite
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册