From e44b87604f6eb9f7b75eda5d56c8afbbbcffb3b9 Mon Sep 17 00:00:00 2001 From: Zhen Wang Date: Thu, 13 Jun 2019 18:08:57 +0800 Subject: [PATCH] Add Fc fuse pass (#17994) --- paddle/fluid/lite/CMakeLists.txt | 5 +- paddle/fluid/lite/api/CMakeLists.txt | 6 +- paddle/fluid/lite/core/CMakeLists.txt | 5 +- paddle/fluid/lite/core/mir/CMakeLists.txt | 21 +++- paddle/fluid/lite/core/mir/fc_fuse_pass.cc | 34 ++++++ paddle/fluid/lite/core/mir/fc_fuse_pass.h | 32 +++++ .../fluid/lite/core/mir/fc_fuse_pass_test.cc | 112 ++++++++++++++++++ .../fluid/lite/core/mir/fusion/CMakeLists.txt | 3 + paddle/fluid/lite/core/mir/fusion/fc_fuser.cc | 78 ++++++++++++ paddle/fluid/lite/core/mir/fusion/fc_fuser.h | 38 ++++++ paddle/fluid/lite/core/mir/passes.h | 1 + paddle/fluid/lite/core/mir/pattern_matcher.cc | 43 ++++++- paddle/fluid/lite/core/mir/pattern_matcher.h | 13 +- .../lite/core/mir/pattern_matcher_high_api.h | 1 - .../core/mir/pattern_matcher_high_api_test.cc | 23 ++-- paddle/fluid/lite/core/optimizer.h | 1 + .../fluid/lite/core/profile/basic_profiler.h | 2 +- paddle/fluid/lite/kernels/host/CMakeLists.txt | 4 +- paddle/fluid/lite/kernels/x86/CMakeLists.txt | 8 +- paddle/fluid/lite/kernels/x86/fc_compute.cc | 41 +++---- paddle/fluid/lite/operators/CMakeLists.txt | 36 +++--- 21 files changed, 430 insertions(+), 77 deletions(-) create mode 100644 paddle/fluid/lite/core/mir/fc_fuse_pass.cc create mode 100644 paddle/fluid/lite/core/mir/fc_fuse_pass.h create mode 100644 paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc create mode 100644 paddle/fluid/lite/core/mir/fusion/CMakeLists.txt create mode 100644 paddle/fluid/lite/core/mir/fusion/fc_fuser.cc create mode 100644 paddle/fluid/lite/core/mir/fusion/fc_fuser.h diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index ac9ff84da44..269cc95b658 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") +set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") @@ -161,13 +162,13 @@ function(lite_cc_test TARGET) file(APPEND ${offline_test_registry_file} "${TARGET}\n") endfunction() +add_subdirectory(operators) +add_subdirectory(kernels) add_subdirectory(core) add_subdirectory(x86) add_subdirectory(arm) add_subdirectory(host) add_subdirectory(cuda) -add_subdirectory(operators) -add_subdirectory(kernels) add_subdirectory(model_parser) add_subdirectory(utils) add_subdirectory(api) diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index 0adaeffbb4f..78f85a8caeb 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -5,7 +5,7 @@ if(LITE_WITH_CUDA) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda) endif() -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite}) +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} program_lite) set(light_api_deps scope_lite target_wrapper_host model_parser_lite) @@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}") message(STATUS "get ARM kernels ${arm_kernels}") include(ExternalProject) -set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING "A path setting inference demo download directories.") if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc - DEPS cxx_api_lite model_parser_lite target_wrapper_host + DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels} - PROFILE_DEPS basic_profiler_lite ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index e5aef8d84fa..3edd5db08fd 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -30,7 +30,10 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) -lite_cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto) +lite_cc_library(program_lite SRCS program.cc + DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite + HVY_DEPS framework_proto + PROFILE_DEPS basic_profiler_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) add_subdirectory(mir) diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index 84cba88d11d..8e824727468 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -3,8 +3,10 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) +add_subdirectory(fusion) cc_library(mir_passes - SRCS static_kernel_pick_pass.cc + SRCS fc_fuse_pass.cc + static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc io_copy_kernel_pick_pass.cc @@ -13,7 +15,7 @@ cc_library(mir_passes argument_type_display_pass.cc demo_pass.cc runtime_context_assign_pass.cc - DEPS mir_pass types_lite context_lite) + DEPS mir_pass types_lite context_lite mir_fusers) # for mobile, unnecessary to compile the following testings. if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) @@ -53,9 +55,22 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) # TODO(wz) replace framework/proto to lite proto. -if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) # it depends on the fluid/framework/proto, that is too heavy for mobile execution. lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite mir_passes compatible_pb_lite program_lite ${ops_lite}) endif() + +message(STATUS "----> Ops lite: ${ops_lite}") +message(STATUS "----> Host kernels: ${host_kernels}") +message(STATUS "----> X86 kernels: ${x86_kernels}") + +lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model + --optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL) + +lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz") +add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc b/paddle/fluid/lite/core/mir/fc_fuse_pass.cc new file mode 100644 index 00000000000..008f05ce5cb --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void FcFusePass::Apply(const std::unique_ptr& graph) { + fusion::FcFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.h b/paddle/fluid/lite/core/mir/fc_fuse_pass.h new file mode 100644 index 00000000000..f1b548c43f9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc new file mode 100644 index 00000000000..35efedb5797 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { + +TEST(fc_fuse_pass, fuse_test) { + lite::ExecutorLite predictor; +#ifndef LITE_WITH_CUDA + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); +#else + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, + }); +#endif + + predictor.Build(FLAGS_model_dir, + Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda + valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->data_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "out " << out->data()[1]; + LOG(INFO) << "dims " << out->dims(); + EXPECT_NEAR(out->data()[0], 38.120617f, 1e-5); + EXPECT_NEAR(out->data()[1], 10.109812f, 1e-5); + CHECK_EQ(out->dims()[0], 100); + CHECK_EQ(out->dims()[1], 500); +} + +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +TEST(fc_fuse_pass, save_model_test) { + lite::ExecutorLite predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + predictor.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, + valid_places); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); +} +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); +USE_LITE_OP(elementwise_sub); +USE_LITE_OP(fc); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); +USE_LITE_OP(softmax); +USE_LITE_OP(scale); +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt new file mode 100644 index 00000000000..6a0626b9eb5 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -0,0 +1,3 @@ +cc_library(mir_fusers + SRCS fc_fuser.cc + DEPS pattern_matcher_high_api) diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc new file mode 100644 index 00000000000..a8b6336595c --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void FcFuser::BuildPattern() { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); +} + +void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op; + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); +} + +cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr( + "in_num_col_dims", + matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.h b/paddle/fluid/lite/core/mir/fusion/fc_fuser.h new file mode 100644 index 00000000000..0e2bc3bc3c3 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class FcFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 60e53257ba0..217e4d5dbdf 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -22,6 +22,7 @@ namespace mir {} // namespace mir } // namespace paddle USE_MIR_PASS(demo); +USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(type_target_transform_pass); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc index c7fa42ac5a7..8a83bd242bd 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector &nodes) { return *this; } -void operator>>(std::vector &others, PMNode &me) { +PMNode &operator>>(std::vector &others, PMNode &me) { for (auto *o : others) { *o >> me; } + return me; } PMNode *PMPattern::NewNode(const std::string &name) { @@ -422,6 +423,46 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) { return this; } +PMNode *PMNode::assert_is_op_input(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_input(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, + const std::string &argument, int nth) { + assert_is_var(); + assert_is_op_input(op_type); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthInput(*x, *op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +bool IsNthInput(const Node &var, const Node &op, const std::string &argument, + int nth) { + CHECK(var.IsArg()); + CHECK(op.IsStmt()); + if (!HasInput(op, argument) || + static_cast(op.stmt()->op_info()->Input(argument).size()) <= nth) + return false; + return var.arg()->name == op.stmt()->op_info()->Input(argument)[nth]; +} + +bool HasInput(const Node &op, const std::string &argument) { + CHECK(op.IsStmt()); + auto const &names = op.stmt()->op_info()->input_argnames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + void GraphSafeRemoveNodes(SSAGraph *graph, const std::unordered_set &nodes) { for (auto *node : nodes) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h index 2241e71af3d..f2862a229e3 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -62,7 +62,7 @@ struct PMNode { PMNode& operator>>(PMNode& right); // Link many nodes to this node. - friend void operator>>(std::vector& others, PMNode& me); + friend PMNode& operator>>(std::vector& others, PMNode& me); // Link this to many other nodes. PMNode& operator>>(std::vector& nodes); @@ -127,6 +127,10 @@ struct PMNode { PMNode* assert_is_persistable_var(); PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type, + const std::string& argument); + PMNode* assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth); template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { @@ -297,6 +301,13 @@ class PatternMatcher { std::unordered_map> pmnodes2nodes_; }; +// Check whether a var node is a op node's nth input. +bool IsNthInput(const Node& var, const Node& op, const std::string& argument, + int nth); + +// Check whether the op node has input of given name. +bool HasInput(const Node& op, const std::string& argument); + // Graph safely remove some nodes, will automatically clean up the edges. void GraphSafeRemoveNodes(SSAGraph* graph, const std::unordered_set& nodes); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h index 645e33165f4..b3a23c654bd 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h @@ -64,7 +64,6 @@ class FuseBase { // Delete nodes that are marked as Intermediate void DeleteInterNodes(SSAGraph* graph); - private: PMNode* GetOrCreateNode(const std::string& key); protected: diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc index 44f95dab754..beee4d32acb 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -29,8 +29,8 @@ class FcFuser : public FuseBase { public: void BuildPattern() override { // create nodes. - auto* x = VarNode("x"); - auto* W = VarNode("W"); + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); auto* b = VarNode("b"); auto* mul = OpNode("mul", "mul"); auto* mul_out = VarNode("mul_out"); @@ -38,12 +38,10 @@ class FcFuser : public FuseBase { auto* Out = VarNode("Out"); // create topology. - // std::vector({W, x}) >> *mul >> *mul_out; - // std::vector({mul_out, b}) >> *add >> *Out; - *W >> *mul; - *x >> *mul >> *mul_out; - *b >> *add; - *mul_out >> *add >> *Out; + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; // Some op specialities. mul_out->AsIntermediate(); @@ -91,14 +89,12 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, main_block->Var("mul_out"); main_block->Var("w"); main_block->Var("out"); - main_block->Var("out1"); scope->Var("w")->GetMutable(); scope->Var("b")->GetMutable(); scope->Var("mul_out")->GetMutable(); scope->Var("w")->GetMutable(); scope->Var("out")->GetMutable(); - scope->Var("out1")->GetMutable(); mul->SetInput("X", {"x"}); mul->SetInput("Y", {"w"}); @@ -122,18 +118,18 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, return graph; } -TEST(pattern_matcher2, graph_test) { +TEST(pattern_matcher_high_api, graph_test) { framework::ProgramDesc program_desc; std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); auto graph = BuildGraph(&program_desc, scope, places); ASSERT_EQ(graph->nodes().size(), - 8UL /*real nodes*/ + 2UL /*feed op + fetch op*/); + 7UL /*real nodes*/ + 2UL /*feed op + fetch op*/); Visualize(graph.get()); } -TEST(pattern_matcher2, test) { +TEST(pattern_matcher_high_api, fuse_test) { framework::ProgramDesc program_desc; std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); @@ -143,6 +139,7 @@ TEST(pattern_matcher2, test) { fuser(graph.get()); ASSERT_EQ(graph->nodes().size(), num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); + Visualize(graph.get()); } } // namespace mir diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 161e765a98b..b78408a6740 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -49,6 +49,7 @@ class Optimizer { #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK if (passes.empty()) { RunPasses(std::vector{{ + "lite_fc_fuse_pass", // "static_kernel_pick_pass", // "variable_place_inference_pass", // "argument_type_display_pass", // diff --git a/paddle/fluid/lite/core/profile/basic_profiler.h b/paddle/fluid/lite/core/profile/basic_profiler.h index c50aeab4af5..16a9905f1ae 100644 --- a/paddle/fluid/lite/core/profile/basic_profiler.h +++ b/paddle/fluid/lite/core/profile/basic_profiler.h @@ -152,8 +152,8 @@ class BasicProfiler { } record_t *mutable_record(int id) { - CHECK_LT(id, records_.size()); CHECK_GE(id, 0); + CHECK_LT(static_cast(id), records_.size()); return &records_[id]; } diff --git a/paddle/fluid/lite/kernels/host/CMakeLists.txt b/paddle/fluid/lite/kernels/host/CMakeLists.txt index a71a8e13ab8..d1f33477aaa 100644 --- a/paddle/fluid/lite/kernels/host/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/host/CMakeLists.txt @@ -10,6 +10,4 @@ set(host_kernels feed_compute_host fetch_compute_host reshape_compute_host - ) - -set(host_kernels "${host_kernels}" CACHE GLOBAL "host kernels") + CACHE INTERNAL "host kernels") diff --git a/paddle/fluid/lite/kernels/x86/CMakeLists.txt b/paddle/fluid/lite/kernels/x86/CMakeLists.txt index 3747351d562..6309267dd06 100644 --- a/paddle/fluid/lite/kernels/x86/CMakeLists.txt +++ b/paddle/fluid/lite/kernels/x86/CMakeLists.txt @@ -30,8 +30,6 @@ set(x86_kernels softmax_compute_x86 dropout_compute_x86 concat_compute_x86 - conv_compute_x86 - pool_compute_x86 - ) - -set(x86_kernels "${x86_kernels}" CACHE INTERNAL "x86 kernels") + conv_compute_x86 + pool_compute_x86 + CACHE INTERNAL "x86 kernels") diff --git a/paddle/fluid/lite/kernels/x86/fc_compute.cc b/paddle/fluid/lite/kernels/x86/fc_compute.cc index c89f0f19dad..dad37febc80 100644 --- a/paddle/fluid/lite/kernels/x86/fc_compute.cc +++ b/paddle/fluid/lite/kernels/x86/fc_compute.cc @@ -27,8 +27,8 @@ namespace kernels { namespace x86 { template -void fc_compute_eigen(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // +void fc_compute_eigen(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // const T* b, // T* out) { using matrix_t = @@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, // Eigen::Map X(x, x_h, x_w); Eigen::Map W(w, w_h, w_w); - Eigen::Map Out(out, x_h, w_h); + Eigen::Map Out(out, x_h, w_w); - Out = X * W.transpose(); + Out = X * W; if (b) { - Eigen::Map> B(b, w_h); + Eigen::Map> B(b, w_w); Out = Out.array().rowwise() + B.transpose().array(); } } template -__attribute__((optimize("unroll-loops"))) // -T dot(const T* x, const T* y, int dim) { - T out{}; - for (int i = 0; i < dim; i++) { - out += x[i] * y[i]; - } - return out; -} - -template -void fc_compute_naive(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // +void fc_compute_naive(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // const T* b, // T* out) { - CHECK_EQ(x_w, w_w); + CHECK_EQ(x_w, w_h); // out shape: (x_h, w_w) - memset(out, 0, x_h * w_h * sizeof(T)); - - for (int r = 0; r < x_h; r++) { - for (int c = 0; c < w_h; c++) { - out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c]; + memset(out, 0, x_h * w_w * sizeof(T)); + for (int i = 0; i < x_h; i++) { + for (int j = 0; j < w_w; j++) { + T tmp = static_cast(0); + for (int k = 0; k < x_w; k++) { + tmp += x[i * x_w + k] * w[k * w_w + j]; + } + out[i * w_w + j] = tmp + b[j]; } } } @@ -89,8 +82,8 @@ class FcCompute : public KernelLite { .Slice(param.in_num_col_dims, param.input->dims().size()) .production(), param.w->data(), // w - param.w->dims()[1], // w_w param.w->dims()[0], // w_h + param.w->dims()[1], // w_w param.bias->data(), // b param.output->mutable_data()); } diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index ed26f5fdb1f..691ff743b17 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -21,24 +21,24 @@ cc_library(conv_op_lite SRCS conv_op.cc DEPS ${op_DEPS}) cc_library(pool_op_lite SRCS pool_op.cc DEPS ${op_DEPS}) set(ops_lite - fc_op_lite - relu_op_lite - mul_op_lite - scale_op_lite - softmax_op_lite - reshape_op_lite - feed_op_lite - fetch_op_lite - io_copy_op_lite - elementwise_ops_lite - mean_op_lite - fill_constant_op_lite - activation_ops_lite - dropout_op_lite - concat_op_lite - conv_op_lite - pool_op_lite - PARENT_SCOPE) + fc_op_lite + relu_op_lite + mul_op_lite + scale_op_lite + softmax_op_lite + reshape_op_lite + feed_op_lite + fetch_op_lite + io_copy_op_lite + elementwise_ops_lite + mean_op_lite + fill_constant_op_lite + activation_ops_lite + dropout_op_lite + concat_op_lite + conv_op_lite + pool_op_lite + CACHE INTERNAL "ops lite") lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc DEPS fc_op_lite memory_lite -- GitLab