diff --git a/paddle/fluid/lite/CMakeLists.txt b/paddle/fluid/lite/CMakeLists.txt index 301dbea2b7601d43b20095685d82a11ae5dcc2f6..978fb0eec8ae5e52f7d6833233417b35a6890524 100644 --- a/paddle/fluid/lite/CMakeLists.txt +++ b/paddle/fluid/lite/CMakeLists.txt @@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") +set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") @@ -161,13 +162,13 @@ function(lite_cc_test TARGET) file(APPEND ${offline_test_registry_file} "${TARGET}\n") endfunction() +add_subdirectory(operators) +add_subdirectory(kernels) add_subdirectory(core) add_subdirectory(x86) add_subdirectory(arm) add_subdirectory(host) add_subdirectory(cuda) -add_subdirectory(operators) -add_subdirectory(kernels) add_subdirectory(model_parser) add_subdirectory(utils) add_subdirectory(api) diff --git a/paddle/fluid/lite/api/CMakeLists.txt b/paddle/fluid/lite/api/CMakeLists.txt index ec65fb69a4fc4e8c5c18c7476ca6a3d170f6447f..2fac7f3d9b2c7b59dbdbcd6c01734e67c8ffe2d2 100644 --- a/paddle/fluid/lite/api/CMakeLists.txt +++ b/paddle/fluid/lite/api/CMakeLists.txt @@ -5,7 +5,7 @@ if(LITE_WITH_CUDA) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda) endif() -cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite}) +cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} program_lite) set(light_api_deps scope_lite target_wrapper_host model_parser_lite) @@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}") message(STATUS "get ARM kernels ${arm_kernels}") include(ExternalProject) -set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url") set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING "A path setting inference demo download directories.") if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc - DEPS cxx_api_lite model_parser_lite target_wrapper_host + DEPS cxx_api_lite mir_passes ${ops_lite} ${host_kernels} ${x86_kernels} - PROFILE_DEPS basic_profiler_lite ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) diff --git a/paddle/fluid/lite/core/CMakeLists.txt b/paddle/fluid/lite/core/CMakeLists.txt index 227216990fc3af39529c40ffc14d06339ca20047..89101aa03272d98ac08d7830830de6acb9adf271 100644 --- a/paddle/fluid/lite/core/CMakeLists.txt +++ b/paddle/fluid/lite/core/CMakeLists.txt @@ -30,7 +30,9 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp cc_library(types_lite SRCS types.cc) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) -lite_cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto +lite_cc_library(program_lite SRCS program.cc + DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite + HVY_DEPS framework_proto PROFILE_DEPS basic_profiler_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) diff --git a/paddle/fluid/lite/core/mir/CMakeLists.txt b/paddle/fluid/lite/core/mir/CMakeLists.txt index c3d3df9c6778eee53bf6492f4c4bfae36ae80687..fe7defcf73e6bea6819c62ae36c87b59eb4f09b2 100644 --- a/paddle/fluid/lite/core/mir/CMakeLists.txt +++ b/paddle/fluid/lite/core/mir/CMakeLists.txt @@ -3,8 +3,13 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program_lite) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) + +add_subdirectory(fusion) cc_library(mir_passes - SRCS static_kernel_pick_pass.cc + SRCS fc_fuse_pass.cc + conv_elementwise_add_relu_fuse_pass.cc + conv_bn_fuse_pass.cc + static_kernel_pick_pass.cc variable_place_inference_pass.cc type_target_transform_pass.cc io_copy_kernel_pick_pass.cc @@ -13,13 +18,8 @@ cc_library(mir_passes argument_type_display_pass.cc demo_pass.cc runtime_context_assign_pass.cc - DEPS mir_pass types_lite context_lite) + DEPS mir_pass types_lite context_lite ${mir_fusers}) -# for mobile, unnecessary to compile the following testings. -if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) - return() -endif() -cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) #cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #mir_ssa_graph scope_lite op_lite #fc_op_lite @@ -52,11 +52,37 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) -# TODO(wz) replace framework/proto to lite proto. + +# for mobile, unnecessary to compile the following testings. if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) + return() +endif() +cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes) + + +# TODO(wz) replace framework/proto to lite proto. +if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) # it depends on the fluid/framework/proto, that is too heavy for mobile execution. lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite mir_passes compatible_pb_lite program_lite ${ops_lite}) endif() - + +message(STATUS "----> Ops lite: ${ops_lite}") +message(STATUS "----> Host kernels: ${host_kernels}") +message(STATUS "----> X86 kernels: ${x86_kernels}") + +lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels} + ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model + --optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL) + +lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz") +add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) + + +lite_cc_test(test_lite_conv_elementwise_add_relu_fuse + SRCS conv_elementwise_add_relu_fuse_pass_test.cc + DEPS cxx_api_lite mir_passes + ${ops_lite} ${host_kernels} ${x86_kernels}) diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..562ec7f45073a13f37c7f44ebcae0fb13fbb8b42 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.cc @@ -0,0 +1,37 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvBNFusePass::Apply(const std::unique_ptr& graph) { + fusion::ConvBNFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvBNFuser fuser2("depthwise_conv2d"); + fuser2(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..d5164c906525a55f04d83a7cb22f1a75b3a20c5d --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvBNFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..3110c7aa6d408d2520d982ec76a77baea7babdbc --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.cc @@ -0,0 +1,39 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void ConvElementwiseAddReLUFusePass::Apply( + const std::unique_ptr& graph) { + fusion::ConvElementwiseAddReLUFuser fuser("conv2d"); + fuser(graph.get()); + + fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d"); + depthwise_fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass, + paddle::lite::mir::ConvElementwiseAddReLUFusePass); diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..4276f1ffc8c258b0b4266abd950fa1ccf541c4a7 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class ConvElementwiseAddReLUFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..30991313ad3ed9ef39c3fb8183f4cfc43c9c49b9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass_test.cc @@ -0,0 +1,153 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" +#include "paddle/fluid/lite/core/program.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + + auto* conv2d_1 = main_block->AppendOp(); + auto* conv2d_2 = main_block->AppendOp(); + auto* add_1 = main_block->AppendOp(); + auto* relu_1 = main_block->AppendOp(); + auto* add_2 = main_block->AppendOp(); + auto* relu_2 = main_block->AppendOp(); + + main_block->Var("input_1"); + main_block->Var("input_2"); + main_block->Var("filter_1"); + main_block->Var("filter_2"); + main_block->Var("conv2d_1_out"); + main_block->Var("conv2d_2_out"); + main_block->Var("bias_1"); + main_block->Var("add_1_out"); + main_block->Var("add_2_out"); + main_block->Var("relu_1_out"); + main_block->Var("out"); + + scope->Var("input_1")->GetMutable(); + scope->Var("input_2")->GetMutable(); + scope->Var("filter_1")->GetMutable(); + scope->Var("filter_2")->GetMutable(); + scope->Var("conv2d_1_out")->GetMutable(); + scope->Var("conv2d_2_out")->GetMutable(); + scope->Var("bias_1")->GetMutable(); + scope->Var("add_1_out")->GetMutable(); + scope->Var("add_2_out")->GetMutable(); + scope->Var("relu_1_out")->GetMutable(); + scope->Var("out")->GetMutable(); + + conv2d_1->SetType("conv2d"); + conv2d_1->SetInput("Input", {"input_1"}); + conv2d_1->SetInput("Filter", {"filter_1"}); + conv2d_1->SetOutput("Output", {"conv2d_1_out"}); + conv2d_1->SetAttr("strides", std::vector({1, 1})); + conv2d_1->SetAttr("paddings", std::vector({0, 0})); + conv2d_1->SetAttr("groups", 1); + conv2d_1->SetAttr("dilations", std::vector({1, 1})); + conv2d_1->SetAttr("fuse_relu", false); + + add_1->SetType("elementwise_add"); + add_1->SetInput("X", {"conv2d_1_out"}); + add_1->SetInput("Y", {"bias_1"}); + add_1->SetOutput("Out", {"add_1_out"}); + add_1->SetAttr("axis", 1); + + relu_1->SetType("relu"); + relu_1->SetInput("X", {"add_1_out"}); + relu_1->SetOutput("Out", {"relu_1_out"}); + + conv2d_2->SetType("conv2d"); + conv2d_2->SetInput("Input", {"input_2"}); + conv2d_2->SetInput("Filter", {"filter_2"}); + conv2d_2->SetOutput("Output", {"conv2d_2_out"}); + conv2d_2->SetAttr("strides", std::vector({1, 1})); + conv2d_2->SetAttr("paddings", std::vector({0, 0})); + conv2d_2->SetAttr("groups", 1); + conv2d_2->SetAttr("dilations", std::vector({1, 1})); + conv2d_2->SetAttr("fuse_relu", false); + + add_2->SetType("elementwise_add"); + add_2->SetInput("X", {"conv2d_2_out"}); + add_2->SetInput("Y", {"relu_1_out"}); + add_2->SetOutput("Out", {"add_2_out"}); + add_2->SetAttr("axis", 1); + + relu_2->SetType("relu"); + relu_2->SetInput("X", {"add_2_out"}); + relu_2->SetOutput("Out", {"out"}); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(conv_elementwise_add_relu_fuse_pass, graph_test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/); + Visualize(graph.get()); +} + +TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + Visualize(graph.get()); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvElementwiseAddReLUFusePass; + fuser->Apply(graph); + Visualize(graph.get()); + ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ + + 1UL * 2 /* fused fc node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(elementwise_add); +USE_LITE_OP(conv2d); +USE_LITE_OP(depthwise_conv2d); +USE_LITE_OP(relu); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.cc b/paddle/fluid/lite/core/mir/fc_fuse_pass.cc new file mode 100644 index 0000000000000000000000000000000000000000..008f05ce5cbd5f6f14d67e79f732e51ab2aa3ddd --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass.cc @@ -0,0 +1,34 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include +#include +#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" +#include "paddle/fluid/lite/core/mir/pass_registry.h" + +namespace paddle { +namespace lite { +namespace mir { + +void FcFusePass::Apply(const std::unique_ptr& graph) { + fusion::FcFuser fuser; + fuser(graph.get()); +} + +} // namespace mir +} // namespace lite +} // namespace paddle + +REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass); diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass.h b/paddle/fluid/lite/core/mir/fc_fuse_pass.h new file mode 100644 index 0000000000000000000000000000000000000000..f1b548c43f99939028735e317107604bd0871945 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass.h @@ -0,0 +1,32 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pass.h" + +namespace paddle { +namespace lite { +namespace mir { + +class FcFusePass : public ProgramPass { + public: + void Apply(const std::unique_ptr& graph) override; +}; + +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..35efedb57971d19551ee144e47f87bcfd4d73ce4 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fc_fuse_pass_test.cc @@ -0,0 +1,112 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/lite/api/cxx_api.h" +#include "paddle/fluid/lite/core/mir/passes.h" +#include "paddle/fluid/lite/core/op_registry.h" + +DEFINE_string(model_dir, "", ""); +DEFINE_string(optimized_model, "", ""); + +namespace paddle { +namespace lite { +namespace mir { + +TEST(fc_fuse_pass, fuse_test) { + lite::ExecutorLite predictor; +#ifndef LITE_WITH_CUDA + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); +#else + std::vector valid_places({ + Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)}, + Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)}, + Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)}, + }); +#endif + + predictor.Build(FLAGS_model_dir, + Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda + valid_places); + + auto* input_tensor = predictor.GetInput(0); + input_tensor->Resize(DDim(std::vector({100, 100}))); + auto* data = input_tensor->mutable_data(); + for (int i = 0; i < 100 * 100; i++) { + data[i] = i; + } + + predictor.Run(); + + auto* out = predictor.GetOutput(0); + LOG(INFO) << out << " memory size " << out->data_size(); + LOG(INFO) << "out " << out->data()[0]; + LOG(INFO) << "out " << out->data()[1]; + LOG(INFO) << "dims " << out->dims(); + EXPECT_NEAR(out->data()[0], 38.120617f, 1e-5); + EXPECT_NEAR(out->data()[1], 10.109812f, 1e-5); + CHECK_EQ(out->dims()[0], 100); + CHECK_EQ(out->dims()[1], 500); +} + +#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK +TEST(fc_fuse_pass, save_model_test) { + lite::ExecutorLite predictor; + std::vector valid_places({Place{TARGET(kHost), PRECISION(kFloat)}, + Place{TARGET(kX86), PRECISION(kFloat)}}); + predictor.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)}, + valid_places); + + LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model; + predictor.SaveModel(FLAGS_optimized_model); +} +#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK + +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(mul); +USE_LITE_OP(elementwise_add); +USE_LITE_OP(elementwise_sub); +USE_LITE_OP(fc); +USE_LITE_OP(feed); +USE_LITE_OP(fetch); +USE_LITE_OP(io_copy); +USE_LITE_OP(softmax); +USE_LITE_OP(scale); +USE_LITE_KERNEL(feed, kHost, kAny, kAny, def); +USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def); + +#ifdef LITE_WITH_X86 +USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def); +USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def); +#endif + +#ifdef LITE_WITH_CUDA +USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device); +USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host); +#endif diff --git a/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt new file mode 100644 index 0000000000000000000000000000000000000000..1aecfdaed02d6f82e3829d076126adfddf686763 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/CMakeLists.txt @@ -0,0 +1,18 @@ +cc_library(fuse_fc + SRCS fc_fuser.cc + DEPS pattern_matcher_high_api) +cc_library(fuse_conv_elementwise_add_relu + SRCS conv_elementwise_add_relu_fuser.cc + DEPS pattern_matcher_high_api) +cc_library(fuse_conv_bn + SRCS conv_bn_fuser.cc + DEPS pattern_matcher_high_api) + +set(mir_fusers + fuse_fc + fuse_conv_elementwise_add_relu + fuse_conv_bn + CACHE INTERNAL "fusers") + +lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc + DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api) diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc new file mode 100644 index 0000000000000000000000000000000000000000..7ce20c4d6e28d8368397510ea912ede647224226 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass_test.cc @@ -0,0 +1,135 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" +#include +#include +#include +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/lite/core/compatible_tensor.h" +#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h" +#include "paddle/fluid/lite/core/program.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, + const std::shared_ptr& scope, + const std::vector& valid_places) { + auto* main_block = program_desc->MutableBlock(0); + auto* conv_op = main_block->AppendOp(); + auto* bn_op = main_block->AppendOp(); + main_block->Var("conv_i"); + main_block->Var("conv_param"); + main_block->Var("conv_out"); + + main_block->Var("bn_scale"); + main_block->Var("bn_bias"); + main_block->Var("bn_mean"); + main_block->Var("bn_var"); + main_block->Var("bn_out"); + main_block->Var("bn_mean_out"); + main_block->Var("bn_var_out"); + main_block->Var("bn_saved_mean"); + main_block->Var("bn_saved_var"); + + scope->Var("conv_i")->GetMutable(); + auto conv_param_t = scope->Var("conv_param")->GetMutable(); + std::vector conv_param_shape = {3, 1, 2, 2}; + conv_param_t->Resize(lite::DDim(conv_param_shape)); + conv_param_t->mutable_data(); + scope->Var("conv_out")->GetMutable(); + auto bn_scale_t = scope->Var("bn_scale")->GetMutable(); + std::vector bn_scale_shape = {3}; + bn_scale_t->Resize(lite::DDim(bn_scale_shape)); + bn_scale_t->mutable_data(); + + auto bn_bias_t = scope->Var("bn_bias")->GetMutable(); + std::vector bn_bias_shape = {3}; + bn_bias_t->Resize(lite::DDim(bn_bias_shape)); + bn_bias_t->mutable_data(); + + auto bn_mean_t = scope->Var("bn_mean")->GetMutable(); + bn_mean_t->Resize(lite::DDim(bn_bias_shape)); + bn_mean_t->mutable_data(); + + auto bn_var_t = scope->Var("bn_var")->GetMutable(); + bn_var_t->Resize(lite::DDim(bn_bias_shape)); + bn_var_t->mutable_data(); + + scope->Var("bn_out")->GetMutable(); + scope->Var("bn_mean_out")->GetMutable(); + scope->Var("bn_var_out")->GetMutable(); + scope->Var("bn_saved_mean")->GetMutable(); + scope->Var("bn_saved_var")->GetMutable(); + + conv_op->SetType("conv2d"); + conv_op->SetInput("Input", {"conv_i"}); + conv_op->SetInput("Filter", {"conv_param"}); + conv_op->SetOutput("Output", {"conv_out"}); + const std::vector strides({1, 1}); + const std::vector paddings({1, 1}); + const std::vector dilations({1, 1}); + const int groups = 1; + conv_op->SetAttr("strides", strides); + conv_op->SetAttr("paddings", paddings); + conv_op->SetAttr("dilations", dilations); + conv_op->SetAttr("groups", groups); + + bn_op->SetType("batch_norm"); + bn_op->SetInput("X", {"conv_out"}); + bn_op->SetInput("Bias", {"bn_bias"}); + bn_op->SetInput("Mean", {"bn_mean"}); + bn_op->SetInput("Scale", {"bn_scale"}); + bn_op->SetInput("Variance", {"bn_var"}); + + bn_op->SetOutput("Y", {"bn_out"}); + bn_op->SetOutput("MeanOut", {"bn_mean_out"}); + bn_op->SetOutput("VarianceOut", {"bn_var_out"}); + bn_op->SetOutput("SavedMean", {"bn_saved_mean"}); + bn_op->SetOutput("SavedVariance", {"bn_saved_var"}); + float eps = 1e-5; + bn_op->SetAttr("epsilon", eps); + + program_desc->Flush(); + + lite::Program program(*program_desc->Proto(), scope, valid_places); + auto graph = std::unique_ptr(new SSAGraph()); + graph->Build(program, valid_places); + + return graph; +} + +TEST(pattern_matcher2, test) { + framework::ProgramDesc program_desc; + std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; + auto scope = std::make_shared(); + auto graph = BuildGraph(&program_desc, scope, places); + const int num_nodes = graph->nodes().size(); + auto* fuser = new ConvBNFusePass; + fuser->Apply(graph); + ASSERT_EQ(graph->nodes().size(), + num_nodes - 8UL /*nodes removed */ + 1UL /* eltwise_add node*/); +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle + +USE_LITE_OP(conv2d); +USE_LITE_OP(batch_norm); +USE_LITE_OP(elementwise_add); diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..e753f8a858dbfe9cbe7a5f29e473524ac9196f70 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.cc @@ -0,0 +1,128 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvBNFuser::BuildPattern() { + auto* conv_input = + VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* conv_weight = VarNode("conv_weight") + ->assert_is_op_input(conv_type_, "Filter") + ->AsInput(); + auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_); + auto* conv_out = VarNode("conv_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input("batch_norm", "X"); + + auto* bn_scale = VarNode("bn_scale") + ->assert_is_op_input("batch_norm", "Scale") + ->AsIntermediate(); + auto* bn_bias = + VarNode("bn_bias")->assert_is_op_input("batch_norm", "Bias")->AsInput(); + auto* bn_mean = VarNode("bn_mean") + ->assert_is_op_input("batch_norm", "Mean") + ->AsIntermediate(); + auto* bn_var = VarNode("bn_variance") + ->assert_is_op_input("batch_norm", "Variance") + ->AsIntermediate(); + auto* bn = + OpNode("bn", "batch_norm")->assert_is_op("batch_norm")->AsIntermediate(); + + auto* bn_out = + VarNode("bn_out")->assert_is_op_output("batch_norm", "Y")->AsOutput(); + auto* bn_mean_out = VarNode("bn_mean_out") + ->assert_is_op_output("batch_norm", "MeanOut") + ->AsIntermediate(); + auto* bn_var_out = VarNode("bn_var_out") + ->assert_is_op_output("batch_norm", "VarianceOut") + ->AsIntermediate(); + auto* bn_saved_mean = VarNode("bn_saved_mean") + ->assert_is_op_output("batch_norm", "SavedMean") + ->AsIntermediate(); + auto* bn_saved_var = VarNode("bn_saved_var") + ->assert_is_op_output("batch_norm", "SavedVariance") + ->AsIntermediate(); + + conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out}); + + bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var}) + .LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out}); +} + +void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); + auto conv = matched.at("conv2d")->stmt()->op; + auto* scope = conv->scope(); + auto& valid_places = conv->valid_places(); + + auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name) + ->GetMutable(); + auto conv_weight_d = conv_weight_t->mutable_data(); + auto conv_weight_dims = conv_weight_t->dims(); + size_t weight_num = conv_weight_t->data_size(); + + auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name) + ->GetMutable(); + size_t bias_size = bn_scale_t->data_size(); + auto bn_scale_d = bn_scale_t->mutable_data(); + PADDLE_ENFORCE(bias_size == conv_weight_dims[0], + "The BN bias's size should be equal to the size of the first " + "dim size of the conv weights"); + + auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name) + ->GetMutable(); + auto bn_mean_d = bn_mean_t->mutable_data(); + + auto bn_var_t = scope->FindVar(matched.at("bn_variance")->arg()->name) + ->GetMutable(); + auto bn_var_d = bn_var_t->mutable_data(); + + auto bn_bias_t = scope->FindVar(matched.at("bn_bias")->arg()->name) + ->GetMutable(); + auto bn_bias_d = bn_bias_t->mutable_data(); + auto eps = matched.at("bn")->stmt()->op_info()->GetAttr("epsilon"); + + ComputeFusedWeight(bn_scale_d, bn_mean_d, bn_var_d, bn_bias_d, conv_weight_d, + eps, bias_size, weight_num / bias_size); + + eltwise_op->Attach(op_desc, scope); + auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places); + + IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node); + IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("bn_out")); +} + +cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("elementwise_add"); + op_desc.SetInput("X", {matched.at("conv_out")->arg()->name}); + op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name}); + op_desc.SetAttr("axis", 1); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..a591d20717e2b18771f27b709580d6a07d32bca2 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h @@ -0,0 +1,57 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvBNFuser : public FuseBase { + public: + explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + void ComputeFusedWeight(float* scale_d, float* mean_d, float* var_d, + float* bias_d, float* conv_weight_d, float eps, int h, + int w) { + for (int i = 0; i < h; i++) { + var_d[i] = scale_d[i] / std::sqrt(var_d[i] + eps); + } + for (int i = 0; i < h; i++) { + bias_d[i] += (-mean_d[i]) * var_d[i]; + } + for (int i = 0; i < h; i++) { + for (int j = 0; j < w; j++) { + conv_weight_d[i * w + j] *= var_d[i]; + } + } + } + + private: + std::string conv_type_{"conv2d"}; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..497c8f4f0d3c6ee08112794b04937bb4ec1cf0cd --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.cc @@ -0,0 +1,109 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void ConvElementwiseAddReLUFuser::BuildPattern() { + // create input nodes. + auto* input = + VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput(); + auto* filter = + VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput(); + auto* bias = + VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput(); + + // create op nodes + auto* conv2d = + OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate(); + auto* add = OpNode("add", "elementwise_add") + ->assert_is_op("elementwise_add") + ->AsIntermediate(); + auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate(); + + // create intermediate nodes + auto* conv2d_out = VarNode("conv2d_out") + ->assert_is_op_output(conv_type_, "Output") + ->assert_is_op_input("elementwise_add", "X") + ->AsIntermediate(); + auto* add_out = VarNode("add_out") + ->assert_is_op_output("elementwise_add", "Out") + ->assert_is_op_input("relu", "X") + ->AsIntermediate(); + + // create output node + auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput(); + + // create topology. + std::vector conv2d_inputs{filter, input}; + std::vector add_inputs{conv2d_out, bias}; + conv2d_inputs >> *conv2d >> *conv2d_out; + add_inputs >> *add >> *add_out; + *add_out >> *relu >> *out; +} + +void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph, + const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto conv_op = LiteOpRegistry::Global().Create(conv_type_); + auto conv_old = matched.at("conv2d")->stmt()->op; + auto* scope = conv_old->scope(); + auto& valid_places = conv_old->valid_places(); + conv_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places); + + IR_NODE_LINK_TO(matched.at("input"), new_op_node); + IR_NODE_LINK_TO(matched.at("filter"), new_op_node); + IR_NODE_LINK_TO(matched.at("bias"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("output")); +} + +cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) { + auto* desc = matched.at("conv2d")->stmt()->op_info(); + + cpp::OpDesc op_desc; + op_desc.SetType(conv_type_); + op_desc.SetInput("Input", {matched.at("input")->arg()->name}); + op_desc.SetInput("Filter", {matched.at("filter")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("bias")->arg()->name}); + op_desc.SetOutput("Output", {matched.at("output")->arg()->name}); + // Other inputs. See operators/conv_op.h + std::vector input_arg_names = desc->InputArgumentNames(); + for (auto name : input_arg_names) LOG(INFO) << name; + + if (std::find(input_arg_names.begin(), input_arg_names.end(), + "ResidualData") != input_arg_names.end()) { + op_desc.SetInput("ResidualData", desc->Input("ResidualData")); + } + + // Only consider strides, padding, groups, dilations, fuse_relu for now + op_desc.SetAttr("strides", desc->GetAttr>("strides")); + op_desc.SetAttr("paddings", desc->GetAttr>("paddings")); + op_desc.SetAttr("groups", desc->GetAttr("groups")); + op_desc.SetAttr("dilations", desc->GetAttr>("dilations")); + op_desc.SetAttr("fuse_relu", true); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..3e21368234f36a5afafb08958930943599955090 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h @@ -0,0 +1,41 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class ConvElementwiseAddReLUFuser : public FuseBase { + public: + explicit ConvElementwiseAddReLUFuser(const std::string& conv_type) + : conv_type_(conv_type) {} + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; + std::string conv_type_; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc new file mode 100644 index 0000000000000000000000000000000000000000..a8b6336595c0fe63d64d75d6434fcfd559c185c9 --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.cc @@ -0,0 +1,78 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" +#include +#include + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +void FcFuser::BuildPattern() { + // create nodes. + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); + auto* b = VarNode("b"); + auto* mul = OpNode("mul", "mul"); + auto* mul_out = VarNode("mul_out"); + auto* add = OpNode("add", "elementwise_add"); + auto* Out = VarNode("Out"); + + // create topology. + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; + + // Some op specialities. + mul_out->AsIntermediate(); + mul->AsIntermediate(); + add->AsIntermediate(); +} + +void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { + auto op_desc = GenOpDesc(matched); + auto fc_op = LiteOpRegistry::Global().Create("fc"); + auto mul = matched.at("mul")->stmt()->op; + auto* scope = mul->scope(); + auto& valid_places = mul->valid_places(); + fc_op->Attach(op_desc, scope); + + auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places); + + IR_NODE_LINK_TO(matched.at("W"), new_op_node); + IR_NODE_LINK_TO(matched.at("x"), new_op_node); + IR_NODE_LINK_TO(matched.at("b"), new_op_node); + IR_NODE_LINK_TO(new_op_node, matched.at("Out")); +} + +cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) { + cpp::OpDesc op_desc; + op_desc.SetType("fc"); + op_desc.SetInput("Input", {matched.at("x")->arg()->name}); + op_desc.SetInput("W", {matched.at("W")->arg()->name}); + op_desc.SetInput("Bias", {matched.at("b")->arg()->name}); + op_desc.SetOutput("Out", {matched.at("Out")->arg()->name}); + op_desc.SetAttr( + "in_num_col_dims", + matched.at("mul")->stmt()->op_info()->GetAttr("x_num_col_dims")); + return op_desc; +} + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/fusion/fc_fuser.h b/paddle/fluid/lite/core/mir/fusion/fc_fuser.h new file mode 100644 index 0000000000000000000000000000000000000000..0e2bc3bc3c338559a301e232e2b7bf7542d8186c --- /dev/null +++ b/paddle/fluid/lite/core/mir/fusion/fc_fuser.h @@ -0,0 +1,38 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include +#include +#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h" + +namespace paddle { +namespace lite { +namespace mir { +namespace fusion { + +class FcFuser : public FuseBase { + public: + void BuildPattern() override; + void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override; + + private: + cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override; +}; + +} // namespace fusion +} // namespace mir +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/core/mir/node.h b/paddle/fluid/lite/core/mir/node.h index 67ee47a9e12fde139a81e5b21759645a87e6b098..a5fd90dac482d434afb624216aad875e12350c36 100644 --- a/paddle/fluid/lite/core/mir/node.h +++ b/paddle/fluid/lite/core/mir/node.h @@ -71,12 +71,20 @@ class Node { struct Arg { std::string name; + int id{0}; const Type* type{}; // Weight is a special kind of argument, it is marked as weight explicitly // so that some weight related optimization can take place. bool is_weight{false}; }; + Arg& AsArg(const std::string& name, int id) { + auto& x = AsArg(); + x.name = name; + x.id = id; + return x; + } + Arg& AsArg(const std::string& name) { auto& x = AsArg(); x.name = name; diff --git a/paddle/fluid/lite/core/mir/passes.h b/paddle/fluid/lite/core/mir/passes.h index 60e53257ba01006e71095faa62b083d47e894c60..6e329a192277a9f0a76afa0ed54018cc3f12d7b7 100644 --- a/paddle/fluid/lite/core/mir/passes.h +++ b/paddle/fluid/lite/core/mir/passes.h @@ -22,6 +22,8 @@ namespace mir {} // namespace mir } // namespace paddle USE_MIR_PASS(demo); +USE_MIR_PASS(lite_fc_fuse_pass); +USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass); USE_MIR_PASS(static_kernel_pick_pass); USE_MIR_PASS(variable_place_inference_pass); USE_MIR_PASS(type_target_transform_pass); @@ -29,3 +31,5 @@ USE_MIR_PASS(generate_program_pass); USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(runtime_context_assign_pass); +USE_MIR_PASS(lite_conv_bn_fuse_pass); +USE_MIR_PASS(graph_visualze); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.cc b/paddle/fluid/lite/core/mir/pattern_matcher.cc index c7fa42ac5a786e5a8994a5fba3e2d427d752dcad..7524312db8b88055fe27344bb860906ee9c0d63f 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher.cc @@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector &nodes) { return *this; } -void operator>>(std::vector &others, PMNode &me) { +PMNode &operator>>(std::vector &others, PMNode &me) { for (auto *o : others) { *o >> me; } + return me; } PMNode *PMPattern::NewNode(const std::string &name) { @@ -406,6 +407,67 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { return this; } +bool IsNthOutput(const Node *var, const Node *op, const std::string &argument, + size_t nth) { + PADDLE_ENFORCE(var->IsArg()); + PADDLE_ENFORCE(op->IsStmt()); + auto op_info = op->stmt()->op_info(); + if (op_info->Output(argument).size() <= nth) return false; + return var->arg()->name == op_info->Output(argument)[nth]; +} + +bool IsNthInput(const Node *var, const Node *op, const std::string &argument, + size_t nth) { + PADDLE_ENFORCE(var->IsArg()); + PADDLE_ENFORCE(op->IsStmt()); + auto op_info = op->stmt()->op_info(); + if (op_info->Input(argument).size() <= nth) return false; + return var->arg()->name == op_info->Input(argument)[nth]; +} + +PMNode *PMNode::assert_is_op_input(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_input(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type, + const std::string &argument, int nth) { + assert_is_var(); + assert_is_op_input(op_type); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->outlinks) { + if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthInput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} + +PMNode *PMNode::assert_is_op_output(const std::string &op_type, + const std::string &argument) { + assert_is_var(); + assert_is_op_nth_output(op_type, argument, 0); + return this; +} + +PMNode *PMNode::assert_is_op_nth_output(const std::string &op_type, + const std::string &argument, int nth) { + assert_is_var(); + asserts_.emplace_back([=](const Node *x) { + for (auto *op : x->inlinks) { + if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type && + IsNthOutput(x, op, argument, nth)) + return true; + } + return false; + }); + return this; +} + PMNode *PMNode::assert_is_op_input(const std::string &op_type) { assert_is_var(); asserts_.emplace_back([=](const Node *x) { @@ -422,6 +484,14 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) { return this; } +bool HasInput(const Node &op, const std::string &argument) { + CHECK(op.IsStmt()); + auto const &names = op.stmt()->op_info()->input_argnames(); + if (std::find(names.begin(), names.end(), argument) == names.end()) + return false; + return true; +} + void GraphSafeRemoveNodes(SSAGraph *graph, const std::unordered_set &nodes) { for (auto *node : nodes) { diff --git a/paddle/fluid/lite/core/mir/pattern_matcher.h b/paddle/fluid/lite/core/mir/pattern_matcher.h index 2241e71af3de9e9692b2fd740c1e91ee7839fa91..ff9fbce35ddf3f601a441bb6105dc658505cbe0e 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher.h @@ -62,7 +62,7 @@ struct PMNode { PMNode& operator>>(PMNode& right); // Link many nodes to this node. - friend void operator>>(std::vector& others, PMNode& me); + friend PMNode& operator>>(std::vector& others, PMNode& me); // Link this to many other nodes. PMNode& operator>>(std::vector& nodes); @@ -127,6 +127,15 @@ struct PMNode { PMNode* assert_is_persistable_var(); PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type); + PMNode* assert_is_op_input(const std::string& op_type, + const std::string& argument); + PMNode* assert_is_op_output(const std::string& op_type, + const std::string& argument); + + PMNode* assert_is_op_nth_input(const std::string& op_type, + const std::string& argument, int nth); + PMNode* assert_is_op_nth_output(const std::string& op_type, + const std::string& argument, int nth); template PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { @@ -297,6 +306,13 @@ class PatternMatcher { std::unordered_map> pmnodes2nodes_; }; +// Check whether a var node is a op node's nth input. +bool IsNthInput(const Node& var, const Node& op, const std::string& argument, + int nth); + +// Check whether the op node has input of given name. +bool HasInput(const Node& op, const std::string& argument); + // Graph safely remove some nodes, will automatically clean up the edges. void GraphSafeRemoveNodes(SSAGraph* graph, const std::unordered_set& nodes); diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h index 645e33165f4c07c304554d1289c447c59526ea3c..b3a23c654bdb36974fd1a0419c199ba04a1d66bf 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api.h @@ -64,7 +64,6 @@ class FuseBase { // Delete nodes that are marked as Intermediate void DeleteInterNodes(SSAGraph* graph); - private: PMNode* GetOrCreateNode(const std::string& key); protected: diff --git a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc index 44f95dab754c70290470773f221255778280f0da..7a46bb9a93d95b9379c961d8044fbdfcd04e7ab4 100644 --- a/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc +++ b/paddle/fluid/lite/core/mir/pattern_matcher_high_api_test.cc @@ -29,8 +29,8 @@ class FcFuser : public FuseBase { public: void BuildPattern() override { // create nodes. - auto* x = VarNode("x"); - auto* W = VarNode("W"); + auto* x = VarNode("x")->assert_is_op_input("mul", "X"); + auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); auto* b = VarNode("b"); auto* mul = OpNode("mul", "mul"); auto* mul_out = VarNode("mul_out"); @@ -38,12 +38,10 @@ class FcFuser : public FuseBase { auto* Out = VarNode("Out"); // create topology. - // std::vector({W, x}) >> *mul >> *mul_out; - // std::vector({mul_out, b}) >> *add >> *Out; - *W >> *mul; - *x >> *mul >> *mul_out; - *b >> *add; - *mul_out >> *add >> *Out; + std::vector mul_inputs{W, x}; + std::vector add_inputs{mul_out, b}; + mul_inputs >> *mul >> *mul_out; + add_inputs >> *add >> *Out; // Some op specialities. mul_out->AsIntermediate(); @@ -91,14 +89,12 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, main_block->Var("mul_out"); main_block->Var("w"); main_block->Var("out"); - main_block->Var("out1"); scope->Var("w")->GetMutable(); scope->Var("b")->GetMutable(); scope->Var("mul_out")->GetMutable(); scope->Var("w")->GetMutable(); scope->Var("out")->GetMutable(); - scope->Var("out1")->GetMutable(); mul->SetInput("X", {"x"}); mul->SetInput("Y", {"w"}); @@ -122,18 +118,17 @@ std::unique_ptr BuildGraph(framework::ProgramDesc* program_desc, return graph; } -TEST(pattern_matcher2, graph_test) { +TEST(pattern_matcher_high_api, graph_test) { framework::ProgramDesc program_desc; std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); auto graph = BuildGraph(&program_desc, scope, places); - ASSERT_EQ(graph->nodes().size(), - 8UL /*real nodes*/ + 2UL /*feed op + fetch op*/); + ASSERT_EQ(graph->nodes().size(), 7UL /*real nodes*/); Visualize(graph.get()); } -TEST(pattern_matcher2, test) { +TEST(pattern_matcher_high_api, fuse_test) { framework::ProgramDesc program_desc; std::vector places{{TARGET(kHost), PRECISION(kFloat)}}; auto scope = std::make_shared(); @@ -143,6 +138,7 @@ TEST(pattern_matcher2, test) { fuser(graph.get()); ASSERT_EQ(graph->nodes().size(), num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); + Visualize(graph.get()); } } // namespace mir diff --git a/paddle/fluid/lite/core/mir/ssa_graph.cc b/paddle/fluid/lite/core/mir/ssa_graph.cc index 82507067c4726b271013cf4a69e95c5045b091a8..ba99a681f79db0406ce1ddd0bb53c0c4ad19a0bc 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.cc +++ b/paddle/fluid/lite/core/mir/ssa_graph.cc @@ -16,6 +16,7 @@ #include #include #include +#include #include namespace paddle { @@ -25,6 +26,8 @@ namespace mir { bool SSAGraph::CheckBidirectionalConnection() { LOG(INFO) << "node count " << node_storage_.size(); for (auto &node : node_storage_) { + if (node.IsStmt()) LOG(INFO) << node.AsStmt().op_info()->Type(); + if (node.IsArg()) LOG(INFO) << node.AsArg().name << " " << node.AsArg().id; for (auto *in : node.inlinks) { CHECK(in->outlinks.end() != std::find(in->outlinks.begin(), in->outlinks.end(), &node)); @@ -93,31 +96,6 @@ std::vector SSAGraph::StmtTopologicalOrder() { return res; } -void SSAGraph::GraphCreateTmpVarNodes(const Program &program) { - for (const auto &name : program.tmp_vars()) { - CHECK(!arguments_.count(name)) << "duplicate creating temp variable: " - << name; - VLOG(5) << "create arg node " << name; - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - new_node.AsArg(name); - arguments_[name] = &new_node; - } -} - -void SSAGraph::GraphCreateWeightVarNodes(const Program &program) { - // create weight nodes. - for (const auto &name : program.weights()) { - CHECK(!arguments_.count(name)) << "duplicate creating weight variable: " - << name; - VLOG(5) << "create arg node " << name; - node_storage_.emplace_back(); - auto &new_node = node_storage_.back(); - new_node.AsArg(name); - arguments_[name] = &new_node; - } -} - Node *SSAGraph::GraphCreateInstructNode( const std::shared_ptr &op, const std::vector &valid_places) { node_storage_.emplace_back(); @@ -135,29 +113,50 @@ Node *SSAGraph::GraphCreateInstructNode( void SSAGraph::Build(const Program &program, const std::vector &valid_places) { CHECK(node_storage_.empty()); - GraphCreateTmpVarNodes(program); - GraphCreateWeightVarNodes(program); - CHECK(CheckNodesRoleSet()); + auto weights_name = program.weights(); + auto is_weights = [&](const std::string &name) -> bool { + auto it = std::find(weights_name.begin(), weights_name.end(), name); + if (it == weights_name.end()) return false; + return true; + }; + + std::unordered_map arg_update_node_map_; for (auto &op : program.ops()) { + LOG(INFO) << op->op_info()->Type(); auto *op_node = GraphCreateInstructNode(op, valid_places); + LOG(INFO) << "input:"; for (const std::string &name : op->op_info()->input_names()) { - auto *arg = Argument(name); - CHECK(arg->IsRoleSet()); - DirectedLink(arg, op_node); + LOG(INFO) << name; + mir::Node *arg_node = nullptr; + if (arg_update_node_map_.count(name)) { + arg_node = arg_update_node_map_.at(name); + } else { + node_storage_.emplace_back(); + arg_node = &node_storage_.back(); + arg_node->AsArg(name, node_storage_.size() - 1); + arg_update_node_map_[name] = arg_node; + } + if (is_weights(name)) arg_node->AsArg().is_weight = true; + CHECK(arg_node->IsRoleSet()); + DirectedLink(arg_node, op_node); } + LOG(INFO) << "output:"; for (const std::string &name : op->op_info()->output_names()) { - if (!arguments_.count(name)) { - NewArgumentNode(name); - } - auto *arg = arguments_.at(name); - CHECK(arg->IsRoleSet()); - DirectedLink(op_node, arg); + LOG(INFO) << name; + node_storage_.emplace_back(); + auto *arg_node = &node_storage_.back(); + arg_node->AsArg(name, node_storage_.size() - 1); + arg_update_node_map_[name] = arg_node; + + if (is_weights(name)) arg_node->AsArg().is_weight = true; + CHECK(arg_node->IsRoleSet()); + DirectedLink(op_node, arg_node); } CHECK(CheckLinksRoleSet()); } - MarkArgumentWeights(program); + CHECK(CheckNodesRoleSet()); CheckValid(); } @@ -227,10 +226,9 @@ bool SSAGraph::CheckLinksRoleSet() { Node *SSAGraph::NewArgumentNode(const std::string &name) { node_storage_.emplace_back(); - CHECK(!arguments_.count(name)) << "duplicate argument called " << name; - arguments_[name] = &node_storage_.back(); - node_storage_.back().AsArg(name); - return &node_storage_.back(); + auto &arg_node = node_storage_.back(); + arg_node.AsArg(name, node_storage_.size() - 1); + return &arg_node; } Node *SSAGraph::NewInstructNode() { diff --git a/paddle/fluid/lite/core/mir/ssa_graph.h b/paddle/fluid/lite/core/mir/ssa_graph.h index 5cad1478c225a6551fcd653ca4e79b58360e3724..7c0e6cef498c5c555c1cee6ab334e6be556a9897 100644 --- a/paddle/fluid/lite/core/mir/ssa_graph.h +++ b/paddle/fluid/lite/core/mir/ssa_graph.h @@ -40,8 +40,6 @@ class SSAGraph : GraphBase { void Build(const Program &program, const std::vector &valid_places); void RemoveNode(const mir::Node *node); - mir::Node *Argument(const std::string &name); - std::vector StmtTopologicalOrder(); // The inputs of the graph. @@ -68,9 +66,7 @@ class SSAGraph : GraphBase { const std::vector &valid_places); private: - void GraphCreateTmpVarNodes(const Program &program); - void GraphCreateWeightVarNodes(const Program &program); - + mir::Node *Argument(const std::string &name); // Check the bidirectional connection. bool CheckBidirectionalConnection(); bool CheckNodesRoleSet(); diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc index 25789d34dca2fa90dbb8c7a415da651c44cc6d12..12dd2dcff0607bea46f41e7f5698ad2fb7e12404 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.cc +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.cc @@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, << " for kernel " << inst.op->DebugString() << " " << *in->AsArg().type << " -> " << *decl_arg_type; // Add an IoCopy instruction to make the input compatible with other dist. - AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph, - inst_node, valid_places_); + AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node, + valid_places_); } } void TypeTargetTransformPass::AddIoCopyInst( - const Type& from, const Type& to, const std::string& var, SSAGraph* graph, + const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, const std::vector& valid_places) { CHECK(!valid_places.empty()) << "valid_place should be set"; // var -> new_transform_op -> new_var -> inst // So there will be a new Argument node and a new IoCopy Statement Node. + CHECK(in->IsArg()); auto node_id = [&] { return graph->nodes().size(); }; - auto io_copy_output_name = var + "/trans/" + std::to_string(node_id()); + auto io_copy_output_name = + in->AsArg().name + "/trans/" + std::to_string(node_id()); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_inst = graph->NewInstructNode(); @@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst( // Create IoCopy Instruction. cpp::OpDesc op_desc; op_desc.SetType("io_copy"); - op_desc.SetInput("Input", {var}); + op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetOutput("Out", {io_copy_output_name}); io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); @@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst( io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); // Remove the old link - RemoveDirectedLink(graph->Argument(var), inst_node); + RemoveDirectedLink(in, inst_node); // Update the original instruction OpDesc. // Update its input to the io_copy_output_name // Add new link, var -> new_inst, new_inst->newarg, newarg->inst - DirectedLink(graph->Argument(var), io_copy_inst); + DirectedLink(in, io_copy_inst); DirectedLink(io_copy_inst, io_copy_output_arg); DirectedLink(io_copy_output_arg, inst_node); // reset opdesc and update kernel information - UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var, + UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name, io_copy_output_name); inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), diff --git a/paddle/fluid/lite/core/mir/type_target_transform_pass.h b/paddle/fluid/lite/core/mir/type_target_transform_pass.h index 838c0bcdabc92717d4b62bda25b77df1bad6dc5d..052e3297abbe806c24f89eb7469cb1fe69246ff3 100644 --- a/paddle/fluid/lite/core/mir/type_target_transform_pass.h +++ b/paddle/fluid/lite/core/mir/type_target_transform_pass.h @@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass { void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); - void AddIoCopyInst(const Type& from, const Type& to, const std::string& var, + void AddIoCopyInst(const Type& from, const Type& to, Node* in, SSAGraph* graph, Node* inst_node, const std::vector& valid_places); diff --git a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h index 4d555d638a91e17796a68ed3397c22d138084e5a..2128c6d2014bf8879743ebf7190b3a95a3bc4186 100644 --- a/paddle/fluid/lite/core/mir/variable_place_inference_pass.h +++ b/paddle/fluid/lite/core/mir/variable_place_inference_pass.h @@ -13,7 +13,10 @@ // limitations under the License. #pragma once +#include #include +#include +#include #include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/target_wrapper.h" @@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass { // LOG(INFO) << "- inferencing type " << // deal with inputs VLOG(4) << "inferencing op " << inst.op_type; - for (auto& arg_name : inst.op_info()->input_argnames()) { + // TODO(zhaolong): Add check if the node's name in op's arguments. + + auto get_argname = [&]( + const std::string& node_name, + const std::map>& argname_map) + -> std::string { + for (auto& ele : argname_map) { + auto it = + std::find(ele.second.begin(), ele.second.end(), node_name); + if (it != ele.second.end()) return ele.first; + } + return ""; + }; + + for (auto* x_in : x->inlinks) { + std::string node_name = x_in->AsArg().name; + std::string arg_name = get_argname(node_name, inst.op_info()->inputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name; VLOG(3) << "-- input arg_name " << arg_name; - // check if inputs's place is set, if not set, update them with the - // kernel's declaration. auto type = inst.picked_kernel().GetInputDeclType(arg_name); - auto arg_names = inst.op_info()->inputs().at(arg_name); - - for (auto& arg_name : arg_names) { - VLOG(3) << "--- var " << arg_name; - auto* node = graph->RetrieveArgument(arg_name); - CHECK(node) << "argument " << arg_name << " not exists in the graph"; - auto& arg_node = node->AsArg(); - if (!arg_node.type) { - VLOG(4) << "set type " << *type << " " << node; - arg_node.type = type; - } + if (!x_in->AsArg().type) { + VLOG(4) << "set type " << *type << " " << x_in; + x_in->AsArg().type = type; } } - for (auto& arg_name : inst.op_info()->output_argnames()) { + for (auto* x_out : x->outlinks) { + std::string node_name = x_out->AsArg().name; + std::string arg_name = + get_argname(node_name, inst.op_info()->outputs()); + CHECK(arg_name.size() > 0) << "can not found op arguments for node " + << node_name; VLOG(3) << "-- output arg_name " << arg_name; auto type = inst.picked_kernel().GetOutputDeclType(arg_name); - auto arg_names = inst.op_info()->outputs().at(arg_name); - // check if outputs's place is set, if not set, update them with the - // kernel's declaration. - for (auto& arg_name : arg_names) { - VLOG(3) << "--- var " << arg_name; - auto* node = graph->RetrieveArgument(arg_name); - CHECK(node) << "argument " << arg_name << " not exists in the graph"; - auto& arg_node = node->AsArg(); - if (!arg_node.type) { - node->AsArg().type = type; - VLOG(3) << "set type " << *type; - } + if (!x_out->AsArg().type) { + VLOG(4) << "set type " << *type << " " << x_out; + x_out->AsArg().type = type; } } } diff --git a/paddle/fluid/lite/core/optimizer.h b/paddle/fluid/lite/core/optimizer.h index 161e765a98ba54bfaee11fb7b6f3ae1b4bde23d4..651cd981c76d0d88f4c7294c0d19b1b0acbc76d4 100644 --- a/paddle/fluid/lite/core/optimizer.h +++ b/paddle/fluid/lite/core/optimizer.h @@ -49,16 +49,19 @@ class Optimizer { #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK if (passes.empty()) { RunPasses(std::vector{{ - "static_kernel_pick_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "type_target_transform_pass", // - "argument_type_display_pass", // - "variable_place_inference_pass", // - "argument_type_display_pass", // - "io_copy_kernel_pick_pass", // - "variable_place_inference_pass", // - "runtime_context_assign_pass", // + "lite_conv_bn_fuse_pass", // + "lite_conv_elementwise_add_act_fuse_pass", // + "lite_fc_fuse_pass", // + "static_kernel_pick_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "type_target_transform_pass", // + "argument_type_display_pass", // + "variable_place_inference_pass", // + "argument_type_display_pass", // + "io_copy_kernel_pick_pass", // + "variable_place_inference_pass", // + "runtime_context_assign_pass", // }}); } else { RunPasses(passes); diff --git a/paddle/fluid/lite/core/profile/basic_profiler.h b/paddle/fluid/lite/core/profile/basic_profiler.h index c50aeab4af58a84407b6d91dd7946e7abaa14ba8..16a9905f1ae6d4a69004650b07f9479869b35ebe 100644 --- a/paddle/fluid/lite/core/profile/basic_profiler.h +++ b/paddle/fluid/lite/core/profile/basic_profiler.h @@ -152,8 +152,8 @@ class BasicProfiler { } record_t *mutable_record(int id) { - CHECK_LT(id, records_.size()); CHECK_GE(id, 0); + CHECK_LT(static_cast(id), records_.size()); return &records_[id]; } diff --git a/paddle/fluid/lite/kernels/x86/conv_compute.cc b/paddle/fluid/lite/kernels/x86/conv_compute.cc index 9d2de5be452c7e4f2f66086a62283ef802157af8..b29161c1c60a3b628a97c2ad015ee3dcb1c601aa 100644 --- a/paddle/fluid/lite/kernels/x86/conv_compute.cc +++ b/paddle/fluid/lite/kernels/x86/conv_compute.cc @@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite { lite::Tensor col_matrix; if (is_expand) { col.Resize(col_shape); + col.mutable_data(); col_matrix.ShareDataWith(col); col_matrix.Resize(col_matrix_shape); } @@ -104,7 +105,7 @@ class Conv2dCompute : public KernelLite { param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); lite::Tensor out_batch; out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( - input_shape.data())); + output_matrix_shape.data())); for (int g = 0; g < param.groups; g++) { lite::Tensor in_slice; @@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); @@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/kernels/x86/fc_compute.cc b/paddle/fluid/lite/kernels/x86/fc_compute.cc index c89f0f19dad91c2ad205a92f41a5d3e66359d7ae..dad37febc80433f0cf3a6859c985e22a5425b405 100644 --- a/paddle/fluid/lite/kernels/x86/fc_compute.cc +++ b/paddle/fluid/lite/kernels/x86/fc_compute.cc @@ -27,8 +27,8 @@ namespace kernels { namespace x86 { template -void fc_compute_eigen(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // +void fc_compute_eigen(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // const T* b, // T* out) { using matrix_t = @@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, // Eigen::Map X(x, x_h, x_w); Eigen::Map W(w, w_h, w_w); - Eigen::Map Out(out, x_h, w_h); + Eigen::Map Out(out, x_h, w_w); - Out = X * W.transpose(); + Out = X * W; if (b) { - Eigen::Map> B(b, w_h); + Eigen::Map> B(b, w_w); Out = Out.array().rowwise() + B.transpose().array(); } } template -__attribute__((optimize("unroll-loops"))) // -T dot(const T* x, const T* y, int dim) { - T out{}; - for (int i = 0; i < dim; i++) { - out += x[i] * y[i]; - } - return out; -} - -template -void fc_compute_naive(const T* x, int x_w, int x_h, // - const T* w, int w_w, int w_h, // +void fc_compute_naive(const T* x, int x_h, int x_w, // + const T* w, int w_h, int w_w, // const T* b, // T* out) { - CHECK_EQ(x_w, w_w); + CHECK_EQ(x_w, w_h); // out shape: (x_h, w_w) - memset(out, 0, x_h * w_h * sizeof(T)); - - for (int r = 0; r < x_h; r++) { - for (int c = 0; c < w_h; c++) { - out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c]; + memset(out, 0, x_h * w_w * sizeof(T)); + for (int i = 0; i < x_h; i++) { + for (int j = 0; j < w_w; j++) { + T tmp = static_cast(0); + for (int k = 0; k < x_w; k++) { + tmp += x[i * x_w + k] * w[k * w_w + j]; + } + out[i * w_w + j] = tmp + b[j]; } } } @@ -89,8 +82,8 @@ class FcCompute : public KernelLite { .Slice(param.in_num_col_dims, param.input->dims().size()) .production(), param.w->data(), // w - param.w->dims()[1], // w_w param.w->dims()[0], // w_h + param.w->dims()[1], // w_w param.bias->data(), // b param.output->mutable_data()); } diff --git a/paddle/fluid/lite/kernels/x86/relu_compute.cc b/paddle/fluid/lite/kernels/x86/relu_compute.cc index 44b1f525ab05edec3f4b8d0f528704bb3d13a973..52fffb579816cd70a748d59cb3750ebaaadb10c7 100644 --- a/paddle/fluid/lite/kernels/x86/relu_compute.cc +++ b/paddle/fluid/lite/kernels/x86/relu_compute.cc @@ -51,6 +51,6 @@ class ReluCompute : public KernelLite { REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW, paddle::lite::kernels::x86::ReluCompute, def) - .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) + .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .Finalize(); diff --git a/paddle/fluid/lite/operators/CMakeLists.txt b/paddle/fluid/lite/operators/CMakeLists.txt index 009f2fb98a92401cad8bcfbf1037b3131c58457b..536fcb75ef47c33c3bb0ef1996526fca50bf5497 100644 --- a/paddle/fluid/lite/operators/CMakeLists.txt +++ b/paddle/fluid/lite/operators/CMakeLists.txt @@ -56,4 +56,3 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) - diff --git a/paddle/fluid/lite/operators/batch_norm.cc b/paddle/fluid/lite/operators/batch_norm.cc new file mode 100644 index 0000000000000000000000000000000000000000..80388e13050eaaaccf145ea3784c0e1e34886d81 --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm.cc @@ -0,0 +1,31 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/lite/operators/batch_norm.h" +#include +#include "paddle/fluid/lite/core/op_registry.h" + +namespace paddle { +namespace lite { +namespace operators { + +bool BatchNormOpLite::CheckShape() const { return true; } + +bool BatchNormOpLite::InferShape() const { return true; } + +} // namespace operators +} // namespace lite +} // namespace paddle + +REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOpLite); diff --git a/paddle/fluid/lite/operators/batch_norm.h b/paddle/fluid/lite/operators/batch_norm.h new file mode 100644 index 0000000000000000000000000000000000000000..90815768e6bd60275b6096900e6e86be080a3a42 --- /dev/null +++ b/paddle/fluid/lite/operators/batch_norm.h @@ -0,0 +1,87 @@ +// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once +#include +#include +#include "paddle/fluid/lite/core/kernel.h" +#include "paddle/fluid/lite/core/op_lite.h" +#include "paddle/fluid/lite/core/scope.h" +#include "paddle/fluid/lite/operators/op_params.h" +#include "paddle/fluid/lite/utils/all.h" + +namespace paddle { +namespace lite { +namespace operators { + +class BatchNormOpLite : public OpLite { + public: + BatchNormOpLite() {} + + explicit BatchNormOpLite(const std::string &type) : OpLite(type) {} + + bool CheckShape() const override; + + bool InferShape() const override; + + void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + // TODO(Superjomn) replace framework::OpDesc with a lite one. + bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { + auto x = op_desc.Input("X").front(); + auto bias = op_desc.Input("Bias").front(); + auto mean = op_desc.Input("Mean").front(); + auto scale = op_desc.Input("Scale").front(); + auto variance = op_desc.Input("Variance").front(); + + auto out = op_desc.Output("Y").front(); + auto mean_out = op_desc.Output("MeanOut").front(); + auto var_out = op_desc.Output("VarianceOut").front(); + auto saved_mean = op_desc.Output("SavedMean").front(); + auto saved_var = op_desc.Output("SavedVariance").front(); + + auto *var = scope->FindVar(x); + param_.x = var->GetMutable(); + var = scope->FindVar(bias); + param_.bias = var->GetMutable(); + var = scope->FindVar(mean); + param_.mean = var->GetMutable(); + var = scope->FindVar(scale); + param_.scale = var->GetMutable(); + var = scope->FindVar(variance); + param_.var = var->GetMutable(); + var = scope->FindVar(out); + param_.out = var->GetMutable(); + var = scope->FindVar(mean_out); + param_.mean_out = var->GetMutable(); + var = scope->FindVar(var_out); + param_.var_out = var->GetMutable(); + var = scope->FindVar(saved_mean); + param_.saved_mean = var->GetMutable(); + var = scope->FindVar(saved_var); + param_.saved_var = var->GetMutable(); + + param_.eps = op_desc.GetAttr("epsilon"); + + return true; + } + + std::string DebugString() const override { return "batch_norm"; } + + private: + mutable BatchNormParam param_; +}; + +} // namespace operators +} // namespace lite +} // namespace paddle diff --git a/paddle/fluid/lite/operators/conv_op.h b/paddle/fluid/lite/operators/conv_op.h index 393b5dc2a8e5e9aa8d94784bc4f5a8d041414200..3f974ea24890f3596d44fadeae5151a454dcf06d 100644 --- a/paddle/fluid/lite/operators/conv_op.h +++ b/paddle/fluid/lite/operators/conv_op.h @@ -30,48 +30,53 @@ class ConvOpLite : public OpLite { public: ConvOpLite() {} - explicit ConvOpLite(const std::string &type) : OpLite(type) {} + explicit ConvOpLite(const std::string& type) : OpLite(type) {} bool CheckShape() const override; bool InferShape() const override; // TODO(Superjomn) replace framework::OpDesc with a lite one. - bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { - auto input = op_desc.Input("Input").front(); - auto filter = op_desc.Input("Filter").front(); - auto out = op_desc.Output("Out").front(); - param_.x = scope->FindVar(input)->GetMutable(); - param_.filter = scope->FindVar(filter)->GetMutable(); - CHECK(scope->FindVar(out)); - param_.output = scope->FindVar(out)->GetMutable(); + bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override { + auto X = op_desc.Input("Input").front(); + auto Filter = op_desc.Input("Filter").front(); + auto Out = op_desc.Output("Output").front(); + + param_.x = scope->FindVar(X)->GetMutable(); + param_.filter = scope->FindVar(Filter)->GetMutable(); + param_.output = scope->FindVar(Out)->GetMutable(); + param_.strides = op_desc.GetAttr>("strides"); param_.paddings = op_desc.GetAttr>("paddings"); param_.groups = op_desc.GetAttr("groups"); param_.dilations = op_desc.GetAttr>("dilations"); + // optional params std::vector input_arg_names = op_desc.InputArgumentNames(); if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != input_arg_names.end()) { - auto bias_var = scope->FindVar(op_desc.Input("Bias").front()); - if (bias_var != nullptr) { - param_.bias = - const_cast(&(bias_var->Get())); + auto bias_arguments = op_desc.Input("Bias"); + if (bias_arguments.size() != 0) { + auto bias_var = scope->FindVar(bias_arguments.front()); + if (bias_var != nullptr) { + param_.bias = bias_var->GetMutable(); + } } } if (std::find(input_arg_names.begin(), input_arg_names.end(), "ResidualData") != input_arg_names.end()) { - auto residual_data_var = - scope->FindVar(op_desc.Input("ResidualData").front()); - if (residual_data_var != nullptr) { - param_.residualData = const_cast( - &(residual_data_var->Get())); + auto res_argument = op_desc.Input("ResidualData"); + if (res_argument.size() != 0) { + auto residual_data_var = scope->FindVar(res_argument.front()); + if (residual_data_var != nullptr) { + param_.residualData = residual_data_var->GetMutable(); + } } } return true; } - void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } + void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); } std::string DebugString() const override { return "conv2d"; } diff --git a/paddle/fluid/lite/operators/relu_op.cc b/paddle/fluid/lite/operators/relu_op.cc index b073e2db43a4891defeb95750424941969323ba0..47251c72dfa5183e19ace3e36a1d3a9dd27a6bb0 100644 --- a/paddle/fluid/lite/operators/relu_op.cc +++ b/paddle/fluid/lite/operators/relu_op.cc @@ -32,12 +32,11 @@ bool ReluOp::InferShape() const { bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) { param_.input = const_cast( - &scope->FindVar(opdesc.Input("Input").front())->Get()); + &scope->FindVar(opdesc.Input("X").front())->Get()); param_.output = scope->FindVar(opdesc.Output("Out").front())->GetMutable(); CHECK(param_.input); CHECK(param_.output); - kernel_->SetParam(param_); return true; }