提交 505e450d 编写于 作者: X xingzhaolong

Merge branch 'xzl/incubate/lite' into 'incubate/lite'

merge github code to gitlab.

See merge request inference/paddlelite!13
...@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}") ...@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}") message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}")
set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install") set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install")
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
function(lite_download_and_uncompress INSTALL_DIR URL FILENAME) function(lite_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}") message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
...@@ -161,13 +162,13 @@ function(lite_cc_test TARGET) ...@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
file(APPEND ${offline_test_registry_file} "${TARGET}\n") file(APPEND ${offline_test_registry_file} "${TARGET}\n")
endfunction() endfunction()
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(core) add_subdirectory(core)
add_subdirectory(x86) add_subdirectory(x86)
add_subdirectory(arm) add_subdirectory(arm)
add_subdirectory(host) add_subdirectory(host)
add_subdirectory(cuda) add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(model_parser) add_subdirectory(model_parser)
add_subdirectory(utils) add_subdirectory(utils)
add_subdirectory(api) add_subdirectory(api)
......
...@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA) ...@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda) nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda)
endif() endif()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite}) cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} program_lite)
set(light_api_deps set(light_api_deps
scope_lite target_wrapper_host model_parser_lite) scope_lite target_wrapper_host model_parser_lite)
...@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}") ...@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
message(STATUS "get ARM kernels ${arm_kernels}") message(STATUS "get ARM kernels ${arm_kernels}")
include(ExternalProject) include(ExternalProject)
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.") "A path setting inference demo download directories.")
if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING) if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${ops_lite} ${host_kernels} ${x86_kernels}
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
...@@ -45,7 +43,6 @@ endif() ...@@ -45,7 +43,6 @@ endif()
# lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL) # lite_cc_test(test_light_api SRCS light_api_test.cc DEPS light_api_lite ARGS --optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
# endif() # endif()
lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc lite_cc_binary(cxx_api_lite_bin SRCS cxx_api_bin.cc
DEPS DEPS
cxx_api_lite cxx_api_lite
......
...@@ -13,13 +13,22 @@ ...@@ -13,13 +13,22 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include "paddle/fluid/lite/core/mir/passes.h" #include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
void Run(const char* model_dir) { using Time = decltype(std::chrono::high_resolution_clock::now());
Time time() { return std::chrono::high_resolution_clock::now(); }
double time_diff(Time t1, Time t2) {
typedef std::chrono::microseconds ms;
auto diff = t2 - t1;
ms counter = std::chrono::duration_cast<ms>(diff);
return counter.count() / 1000.0;
}
void Run(const char* model_dir, int repeat) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
DeviceInfo::Init(); DeviceInfo::Init();
#endif #endif
...@@ -34,10 +43,16 @@ void Run(const char* model_dir) { ...@@ -34,10 +43,16 @@ void Run(const char* model_dir) {
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224}))); input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>(); auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < input_tensor->dims().production(); i++) { for (int i = 0; i < input_tensor->dims().production(); i++) {
data[i] = i; data[i] = 1;
} }
predictor.Run(); for (int i = 0; i < 10; i++) predictor.Run();
auto time1 = time();
for (int i = 0; i < repeat; i++) predictor.Run();
auto time2 = time();
std::cout << " predict cost: " << time_diff(time1, time2) / repeat << "ms"
<< std::endl;
auto* out = predictor.GetOutput(0); auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size(); LOG(INFO) << out << " memory size " << out->data_size();
...@@ -52,7 +67,7 @@ void Run(const char* model_dir) { ...@@ -52,7 +67,7 @@ void Run(const char* model_dir) {
int main(int argc, char** argv) { int main(int argc, char** argv) {
CHECK_EQ(argc, 2) << "usage: ./cmd <model_dir>"; CHECK_EQ(argc, 2) << "usage: ./cmd <model_dir>";
paddle::lite::Run(argv[1]); paddle::lite::Run(argv[1], 1);
return 0; return 0;
} }
......
...@@ -30,7 +30,9 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp ...@@ -30,7 +30,9 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library(types_lite SRCS types.cc) cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite) cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
lite_cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto lite_cc_library(program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite) PROFILE_DEPS basic_profiler_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite) cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
......
...@@ -3,8 +3,13 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program_lite) ...@@ -3,8 +3,13 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program_lite)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph) cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes) cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion)
cc_library(mir_passes cc_library(mir_passes
SRCS static_kernel_pick_pass.cc SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc variable_place_inference_pass.cc
type_target_transform_pass.cc type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc io_copy_kernel_pick_pass.cc
...@@ -13,13 +18,8 @@ cc_library(mir_passes ...@@ -13,13 +18,8 @@ cc_library(mir_passes
argument_type_display_pass.cc argument_type_display_pass.cc
demo_pass.cc demo_pass.cc
runtime_context_assign_pass.cc runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite) DEPS mir_pass types_lite context_lite ${mir_fusers})
# for mobile, unnecessary to compile the following testings.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
#cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope_lite op_lite #mir_ssa_graph scope_lite op_lite
#fc_op_lite #fc_op_lite
...@@ -52,11 +52,37 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern ...@@ -52,11 +52,37 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite) lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite)
# TODO(wz) replace framework/proto to lite proto.
# for mobile, unnecessary to compile the following testings.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
# TODO(wz) replace framework/proto to lite proto.
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution. # it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite ${ops_lite}) mir_passes compatible_pb_lite program_lite ${ops_lite})
endif() endif()
message(STATUS "----> Ops lite: ${ops_lite}")
message(STATUS "----> Host kernels: ${host_kernels}")
message(STATUS "----> X86 kernels: ${x86_kernels}")
lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model
--optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz")
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvBNFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvBNFuser fuser2("depthwise_conv2d");
fuser2(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvBNFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1");
main_block->Var("input_2");
main_block->Var("filter_1");
main_block->Var("filter_2");
main_block->Var("conv2d_1_out");
main_block->Var("conv2d_2_out");
main_block->Var("bias_1");
main_block->Var("add_1_out");
main_block->Var("add_2_out");
main_block->Var("relu_1_out");
main_block->Var("out");
scope->Var("input_1")->GetMutable<lite::Tensor>();
scope->Var("input_2")->GetMutable<lite::Tensor>();
scope->Var("filter_1")->GetMutable<lite::Tensor>();
scope->Var("filter_2")->GetMutable<lite::Tensor>();
scope->Var("conv2d_1_out")->GetMutable<lite::Tensor>();
scope->Var("conv2d_2_out")->GetMutable<lite::Tensor>();
scope->Var("bias_1")->GetMutable<lite::Tensor>();
scope->Var("add_1_out")->GetMutable<lite::Tensor>();
scope->Var("add_2_out")->GetMutable<lite::Tensor>();
scope->Var("relu_1_out")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
conv2d_1->SetType("conv2d");
conv2d_1->SetInput("Input", {"input_1"});
conv2d_1->SetInput("Filter", {"filter_1"});
conv2d_1->SetOutput("Output", {"conv2d_1_out"});
conv2d_1->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_1->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_1->SetAttr("groups", 1);
conv2d_1->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_1->SetAttr("fuse_relu", false);
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"conv2d_1_out"});
add_1->SetInput("Y", {"bias_1"});
add_1->SetOutput("Out", {"add_1_out"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
conv2d_2->SetInput("Input", {"input_2"});
conv2d_2->SetInput("Filter", {"filter_2"});
conv2d_2->SetOutput("Output", {"conv2d_2_out"});
conv2d_2->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_2->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_2->SetAttr("groups", 1);
conv2d_2->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_2->SetAttr("fuse_relu", false);
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"conv2d_2_out"});
add_2->SetInput("Y", {"relu_1_out"});
add_2->SetOutput("Out", {"add_2_out"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/);
Visualize(graph.get());
}
TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(relu);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class FcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
TEST(fc_fuse_pass, fuse_test) {
lite::ExecutorLite predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
#else
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)},
});
#endif
predictor.Build(FLAGS_model_dir,
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({100, 100})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor.Run();
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
EXPECT_NEAR(out->data<float>()[0], 38.120617f, 1e-5);
EXPECT_NEAR(out->data<float>()[1], 10.109812f, 1e-5);
CHECK_EQ(out->dims()[0], 100);
CHECK_EQ(out->dims()[1], 500);
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(fc_fuse_pass, save_model_test) {
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(elementwise_add);
USE_LITE_OP(elementwise_sub);
USE_LITE_OP(fc);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_OP(softmax);
USE_LITE_OP(scale);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
cc_library(fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
fuse_conv_elementwise_add_relu
fuse_conv_bn
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc
DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv_op = main_block->AppendOp();
auto* bn_op = main_block->AppendOp();
main_block->Var("conv_i");
main_block->Var("conv_param");
main_block->Var("conv_out");
main_block->Var("bn_scale");
main_block->Var("bn_bias");
main_block->Var("bn_mean");
main_block->Var("bn_var");
main_block->Var("bn_out");
main_block->Var("bn_mean_out");
main_block->Var("bn_var_out");
main_block->Var("bn_saved_mean");
main_block->Var("bn_saved_var");
scope->Var("conv_i")->GetMutable<lite::Tensor>();
auto conv_param_t = scope->Var("conv_param")->GetMutable<lite::Tensor>();
std::vector<int64_t> conv_param_shape = {3, 1, 2, 2};
conv_param_t->Resize(lite::DDim(conv_param_shape));
conv_param_t->mutable_data<float>();
scope->Var("conv_out")->GetMutable<lite::Tensor>();
auto bn_scale_t = scope->Var("bn_scale")->GetMutable<lite::Tensor>();
std::vector<int64_t> bn_scale_shape = {3};
bn_scale_t->Resize(lite::DDim(bn_scale_shape));
bn_scale_t->mutable_data<float>();
auto bn_bias_t = scope->Var("bn_bias")->GetMutable<lite::Tensor>();
std::vector<int64_t> bn_bias_shape = {3};
bn_bias_t->Resize(lite::DDim(bn_bias_shape));
bn_bias_t->mutable_data<float>();
auto bn_mean_t = scope->Var("bn_mean")->GetMutable<lite::Tensor>();
bn_mean_t->Resize(lite::DDim(bn_bias_shape));
bn_mean_t->mutable_data<float>();
auto bn_var_t = scope->Var("bn_var")->GetMutable<lite::Tensor>();
bn_var_t->Resize(lite::DDim(bn_bias_shape));
bn_var_t->mutable_data<float>();
scope->Var("bn_out")->GetMutable<lite::Tensor>();
scope->Var("bn_mean_out")->GetMutable<lite::Tensor>();
scope->Var("bn_var_out")->GetMutable<lite::Tensor>();
scope->Var("bn_saved_mean")->GetMutable<lite::Tensor>();
scope->Var("bn_saved_var")->GetMutable<lite::Tensor>();
conv_op->SetType("conv2d");
conv_op->SetInput("Input", {"conv_i"});
conv_op->SetInput("Filter", {"conv_param"});
conv_op->SetOutput("Output", {"conv_out"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
const int groups = 1;
conv_op->SetAttr("strides", strides);
conv_op->SetAttr("paddings", paddings);
conv_op->SetAttr("dilations", dilations);
conv_op->SetAttr("groups", groups);
conv_op->SetAttr("fuse_relu", false);
bn_op->SetType("batch_norm");
bn_op->SetInput("X", {"conv_out"});
bn_op->SetInput("Bias", {"bn_bias"});
bn_op->SetInput("Mean", {"bn_mean"});
bn_op->SetInput("Scale", {"bn_scale"});
bn_op->SetInput("Variance", {"bn_var"});
bn_op->SetOutput("Y", {"bn_out"});
bn_op->SetOutput("MeanOut", {"bn_mean_out"});
bn_op->SetOutput("VarianceOut", {"bn_var_out"});
bn_op->SetOutput("SavedMean", {"bn_saved_mean"});
bn_op->SetOutput("SavedVariance", {"bn_saved_var"});
float eps = 1e-5;
bn_op->SetAttr("epsilon", eps);
bn_op->SetAttr("is_test", static_cast<int>(1));
bn_op->SetAttr("use_global_stats", false);
bn_op->SetAttr("momentum", 0.9f);
bn_op->SetAttr("data_layout", std::string("NCHW"));
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(pattern_matcher2, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvBNFusePass;
fuser->Apply(graph);
ASSERT_EQ(graph->nodes().size(),
num_nodes - 8UL /*nodes removed */ + 1UL /* eltwise_add node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(conv2d);
USE_LITE_OP(batch_norm);
USE_LITE_OP(elementwise_add);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvBNFuser::BuildPattern() {
auto* conv_input =
VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* conv_weight = VarNode("conv_weight")
->assert_is_op_input(conv_type_, "Filter")
->AsInput();
auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
auto* conv_out = VarNode("conv_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("batch_norm", "X");
auto* bn_scale = VarNode("bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* bn_bias =
VarNode("bn_bias")->assert_is_op_input("batch_norm", "Bias")->AsInput();
auto* bn_mean = VarNode("bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* bn_var = VarNode("bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* bn =
OpNode("bn", "batch_norm")->assert_is_op("batch_norm")->AsIntermediate();
auto* bn_out =
VarNode("bn_out")->assert_is_op_output("batch_norm", "Y")->AsOutput();
auto* bn_mean_out = VarNode("bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* bn_var_out = VarNode("bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* bn_saved_mean = VarNode("bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* bn_saved_var = VarNode("bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out});
bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var})
.LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out});
}
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv = matched.at("conv2d")->stmt()->op;
auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_weight_d = conv_weight_t->mutable_data<float>();
auto conv_weight_dims = conv_weight_t->dims();
size_t weight_num = conv_weight_t->data_size();
auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name)
->GetMutable<lite::Tensor>();
size_t bias_size = bn_scale_t->data_size();
auto bn_scale_d = bn_scale_t->mutable_data<float>();
CHECK(bias_size == conv_weight_dims[0])
<< "The BN bias's size should be equal to the size of the first "
<< "dim size of the conv weights";
auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_mean_d = bn_mean_t->mutable_data<float>();
auto bn_var_t = scope->FindVar(matched.at("bn_variance")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_var_d = bn_var_t->mutable_data<float>();
auto bn_bias_t = scope->FindVar(matched.at("bn_bias")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_bias_d = bn_bias_t->mutable_data<float>();
auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon");
ComputeFusedWeight(bn_scale_d, bn_mean_d, bn_var_d, bn_bias_d, conv_weight_d,
eps, bias_size, weight_num / bias_size);
eltwise_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places);
IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node);
IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("bn_out"));
}
cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("elementwise_add");
op_desc.SetInput("X", {matched.at("conv_out")->arg()->name});
op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name});
op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name});
op_desc.SetAttr("axis", 1);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvBNFuser : public FuseBase {
public:
explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
void ComputeFusedWeight(float* scale_d, float* mean_d, float* var_d,
float* bias_d, float* conv_weight_d, float eps, int h,
int w) {
for (int i = 0; i < h; i++) {
var_d[i] = scale_d[i] / std::sqrt(var_d[i] + eps);
}
for (int i = 0; i < h; i++) {
bias_d[i] += (-mean_d[i]) * var_d[i];
}
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
conv_weight_d[i * w + j] *= var_d[i];
}
}
}
private:
std::string conv_type_{"conv2d"};
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() {
// create input nodes.
auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* filter =
VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput();
auto* bias =
VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput();
// create op nodes
auto* conv2d =
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate();
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
// create output node
auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out;
}
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op;
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places);
IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("filter"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc;
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase {
public:
explicit ConvElementwiseAddReLUFuser(const std::string& conv_type)
: conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void FcFuser::BuildPattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
auto* Out = VarNode("Out");
// create topology.
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
add_inputs >> *add >> *Out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
}
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("b"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class FcFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -28,7 +28,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -28,7 +28,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
for (auto& item : graph->StmtTopologicalOrder()) { for (auto& item : graph->StmtTopologicalOrder()) {
if (item->IsStmt()) { if (item->IsStmt()) {
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
LOG(INFO) << stmt; VLOG(4) << stmt;
insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front())); insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front()));
} }
} }
......
...@@ -71,12 +71,20 @@ class Node { ...@@ -71,12 +71,20 @@ class Node {
struct Arg { struct Arg {
std::string name; std::string name;
int id{0};
const Type* type{}; const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly // Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place. // so that some weight related optimization can take place.
bool is_weight{false}; bool is_weight{false};
}; };
Arg& AsArg(const std::string& name, int id) {
auto& x = AsArg();
x.name = name;
x.id = id;
return x;
}
Arg& AsArg(const std::string& name) { Arg& AsArg(const std::string& name) {
auto& x = AsArg(); auto& x = AsArg();
x.name = name; x.name = name;
......
...@@ -31,3 +31,7 @@ USE_MIR_PASS(io_copy_kernel_pick_pass); ...@@ -31,3 +31,7 @@ USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
#endif #endif
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass);
...@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) { ...@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
return *this; return *this;
} }
void operator>>(std::vector<PMNode *> &others, PMNode &me) { PMNode &operator>>(std::vector<PMNode *> &others, PMNode &me) {
for (auto *o : others) { for (auto *o : others) {
*o >> me; *o >> me;
} }
return me;
} }
PMNode *PMPattern::NewNode(const std::string &name) { PMNode *PMPattern::NewNode(const std::string &name) {
...@@ -406,6 +407,67 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) { ...@@ -406,6 +407,67 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
return this; return this;
} }
bool IsNthOutput(const Node *var, const Node *op, const std::string &argument,
size_t nth) {
CHECK(var->IsArg());
CHECK(op->IsStmt());
auto op_info = op->stmt()->op_info();
if (op_info->Output(argument).size() <= nth) return false;
return var->arg()->name == op_info->Output(argument)[nth];
}
bool IsNthInput(const Node *var, const Node *op, const std::string &argument,
size_t nth) {
CHECK(var->IsArg());
CHECK(op->IsStmt());
auto op_info = op->stmt()->op_info();
if (op_info->Input(argument).size() <= nth) return false;
return var->arg()->name == op_info->Input(argument)[nth];
}
PMNode *PMNode::assert_is_op_input(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_input(op_type, argument, 0);
return this;
}
PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) {
if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
IsNthInput(x, op, argument, nth))
return true;
}
return false;
});
return this;
}
PMNode *PMNode::assert_is_op_output(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_output(op_type, argument, 0);
return this;
}
PMNode *PMNode::assert_is_op_nth_output(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->inlinks) {
if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
IsNthOutput(x, op, argument, nth))
return true;
}
return false;
});
return this;
}
PMNode *PMNode::assert_is_op_input(const std::string &op_type) { PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
assert_is_var(); assert_is_var();
asserts_.emplace_back([=](const Node *x) { asserts_.emplace_back([=](const Node *x) {
...@@ -422,6 +484,14 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) { ...@@ -422,6 +484,14 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return this; return this;
} }
bool HasInput(const Node &op, const std::string &argument) {
CHECK(op.IsStmt());
auto const &names = op.stmt()->op_info()->input_argnames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
void GraphSafeRemoveNodes(SSAGraph *graph, void GraphSafeRemoveNodes(SSAGraph *graph,
const std::unordered_set<const Node *> &nodes) { const std::unordered_set<const Node *> &nodes) {
for (auto *node : nodes) { for (auto *node : nodes) {
......
...@@ -62,7 +62,7 @@ struct PMNode { ...@@ -62,7 +62,7 @@ struct PMNode {
PMNode& operator>>(PMNode& right); PMNode& operator>>(PMNode& right);
// Link many nodes to this node. // Link many nodes to this node.
friend void operator>>(std::vector<PMNode*>& others, PMNode& me); friend PMNode& operator>>(std::vector<PMNode*>& others, PMNode& me);
// Link this to many other nodes. // Link this to many other nodes.
PMNode& operator>>(std::vector<PMNode*>& nodes); PMNode& operator>>(std::vector<PMNode*>& nodes);
...@@ -127,6 +127,15 @@ struct PMNode { ...@@ -127,6 +127,15 @@ struct PMNode {
PMNode* assert_is_persistable_var(); PMNode* assert_is_persistable_var();
PMNode* assert_is_op_output(const std::string& op_type); PMNode* assert_is_op_output(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type); PMNode* assert_is_op_input(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type,
const std::string& argument);
PMNode* assert_is_op_output(const std::string& op_type,
const std::string& argument);
PMNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
PMNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
template <typename T> template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
...@@ -297,6 +306,13 @@ class PatternMatcher { ...@@ -297,6 +306,13 @@ class PatternMatcher {
std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_; std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_;
}; };
// Check whether a var node is a op node's nth input.
bool IsNthInput(const Node& var, const Node& op, const std::string& argument,
int nth);
// Check whether the op node has input of given name.
bool HasInput(const Node& op, const std::string& argument);
// Graph safely remove some nodes, will automatically clean up the edges. // Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(SSAGraph* graph, void GraphSafeRemoveNodes(SSAGraph* graph,
const std::unordered_set<const Node*>& nodes); const std::unordered_set<const Node*>& nodes);
......
...@@ -20,7 +20,7 @@ namespace lite { ...@@ -20,7 +20,7 @@ namespace lite {
namespace mir { namespace mir {
void FuseBase::PerformPatternMatcher(SSAGraph *graph) { void FuseBase::PerformPatternMatcher(SSAGraph *graph) {
LOG(INFO) << "\n" << matcher_.pattern().DotString(); VLOG(4) << "\n" << matcher_.pattern().DotString();
// Get subgraphs and record the mir::Node pointers for each PMNode. // Get subgraphs and record the mir::Node pointers for each PMNode.
auto handler = [&](const PatternMatcher::subgraph_t &subgraph, SSAGraph *g) { auto handler = [&](const PatternMatcher::subgraph_t &subgraph, SSAGraph *g) {
// get all the reigistered nodes. // get all the reigistered nodes.
...@@ -41,17 +41,14 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) { ...@@ -41,17 +41,14 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
} }
} }
LOG(INFO) << "keys.size " << keys.size();
std::unordered_set<const Node *> nodes2rm; std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) { for (auto &matched : key2nodes_) {
LOG(INFO) << "get matched " << matched.size();
for (const auto &key : keys) { for (const auto &key : keys) {
nodes2rm.insert(matched.at(key)); nodes2rm.insert(matched.at(key));
} }
} }
LOG(INFO) << "clean nodes " << nodes2rm.size(); VLOG(3) << "clean nodes " << nodes2rm.size();
GraphSafeRemoveNodes(graph, nodes2rm); GraphSafeRemoveNodes(graph, nodes2rm);
} }
......
...@@ -64,7 +64,6 @@ class FuseBase { ...@@ -64,7 +64,6 @@ class FuseBase {
// Delete nodes that are marked as Intermediate // Delete nodes that are marked as Intermediate
void DeleteInterNodes(SSAGraph* graph); void DeleteInterNodes(SSAGraph* graph);
private:
PMNode* GetOrCreateNode(const std::string& key); PMNode* GetOrCreateNode(const std::string& key);
protected: protected:
......
...@@ -29,8 +29,8 @@ class FcFuser : public FuseBase { ...@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
public: public:
void BuildPattern() override { void BuildPattern() override {
// create nodes. // create nodes.
auto* x = VarNode("x"); auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W"); auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b"); auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul"); auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out"); auto* mul_out = VarNode("mul_out");
...@@ -38,12 +38,10 @@ class FcFuser : public FuseBase { ...@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
auto* Out = VarNode("Out"); auto* Out = VarNode("Out");
// create topology. // create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out; std::vector<PMNode*> mul_inputs{W, x};
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out; std::vector<PMNode*> add_inputs{mul_out, b};
*W >> *mul; mul_inputs >> *mul >> *mul_out;
*x >> *mul >> *mul_out; add_inputs >> *add >> *Out;
*b >> *add;
*mul_out >> *add >> *Out;
// Some op specialities. // Some op specialities.
mul_out->AsIntermediate(); mul_out->AsIntermediate();
...@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block->Var("mul_out"); main_block->Var("mul_out");
main_block->Var("w"); main_block->Var("w");
main_block->Var("out"); main_block->Var("out");
main_block->Var("out1");
scope->Var("w")->GetMutable<lite::Tensor>(); scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>(); scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>(); scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>(); scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>(); scope->Var("out")->GetMutable<lite::Tensor>();
scope->Var("out1")->GetMutable<lite::Tensor>();
mul->SetInput("X", {"x"}); mul->SetInput("X", {"x"});
mul->SetInput("Y", {"w"}); mul->SetInput("Y", {"w"});
...@@ -122,18 +118,17 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -122,18 +118,17 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
return graph; return graph;
} }
TEST(pattern_matcher2, graph_test) { TEST(pattern_matcher_high_api, graph_test) {
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places); auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(), ASSERT_EQ(graph->nodes().size(), 7UL /*real nodes*/);
8UL /*real nodes*/ + 2UL /*feed op + fetch op*/);
Visualize(graph.get()); Visualize(graph.get());
} }
TEST(pattern_matcher2, test) { TEST(pattern_matcher_high_api, fuse_test) {
framework::ProgramDesc program_desc; framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}}; std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>(); auto scope = std::make_shared<Scope>();
...@@ -143,6 +138,7 @@ TEST(pattern_matcher2, test) { ...@@ -143,6 +138,7 @@ TEST(pattern_matcher2, test) {
fuser(graph.get()); fuser(graph.get());
ASSERT_EQ(graph->nodes().size(), ASSERT_EQ(graph->nodes().size(),
num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/); num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/);
Visualize(graph.get());
} }
} // namespace mir } // namespace mir
......
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
#include <algorithm> #include <algorithm>
#include <memory> #include <memory>
#include <set> #include <set>
#include <unordered_map>
#include <utility> #include <utility>
namespace paddle { namespace paddle {
...@@ -93,31 +94,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() { ...@@ -93,31 +94,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
return res; return res;
} }
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars()) {
CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights()) {
CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
Node *SSAGraph::GraphCreateInstructNode( Node *SSAGraph::GraphCreateInstructNode(
const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) { const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) {
node_storage_.emplace_back(); node_storage_.emplace_back();
...@@ -135,29 +111,45 @@ Node *SSAGraph::GraphCreateInstructNode( ...@@ -135,29 +111,45 @@ Node *SSAGraph::GraphCreateInstructNode(
void SSAGraph::Build(const Program &program, void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) { const std::vector<Place> &valid_places) {
CHECK(node_storage_.empty()); CHECK(node_storage_.empty());
GraphCreateTmpVarNodes(program);
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
auto weights_name = program.weights();
auto is_weights = [&](const std::string &name) -> bool {
auto it = std::find(weights_name.begin(), weights_name.end(), name);
if (it == weights_name.end()) return false;
return true;
};
std::unordered_map<std::string, mir::Node *> arg_update_node_map_;
for (auto &op : program.ops()) { for (auto &op : program.ops()) {
auto *op_node = GraphCreateInstructNode(op, valid_places); auto *op_node = GraphCreateInstructNode(op, valid_places);
for (const std::string &name : op->op_info()->input_names()) { for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name); mir::Node *arg_node = nullptr;
CHECK(arg->IsRoleSet()); if (arg_update_node_map_.count(name)) {
DirectedLink(arg, op_node); arg_node = arg_update_node_map_.at(name);
} else {
node_storage_.emplace_back();
arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
} }
for (const std::string &name : op->op_info()->output_names()) { if (is_weights(name)) arg_node->AsArg().is_weight = true;
if (!arguments_.count(name)) { CHECK(arg_node->IsRoleSet());
NewArgumentNode(name); DirectedLink(arg_node, op_node);
} }
auto *arg = arguments_.at(name); for (const std::string &name : op->op_info()->output_names()) {
CHECK(arg->IsRoleSet()); node_storage_.emplace_back();
DirectedLink(op_node, arg); auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
DirectedLink(op_node, arg_node);
} }
CHECK(CheckLinksRoleSet()); CHECK(CheckLinksRoleSet());
} }
MarkArgumentWeights(program); CHECK(CheckNodesRoleSet());
CheckValid(); CheckValid();
} }
...@@ -227,10 +219,9 @@ bool SSAGraph::CheckLinksRoleSet() { ...@@ -227,10 +219,9 @@ bool SSAGraph::CheckLinksRoleSet() {
Node *SSAGraph::NewArgumentNode(const std::string &name) { Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back(); node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name; auto &arg_node = node_storage_.back();
arguments_[name] = &node_storage_.back(); arg_node.AsArg(name, node_storage_.size() - 1);
node_storage_.back().AsArg(name); return &arg_node;
return &node_storage_.back();
} }
Node *SSAGraph::NewInstructNode() { Node *SSAGraph::NewInstructNode() {
......
...@@ -40,8 +40,6 @@ class SSAGraph : GraphBase { ...@@ -40,8 +40,6 @@ class SSAGraph : GraphBase {
void Build(const Program &program, const std::vector<Place> &valid_places); void Build(const Program &program, const std::vector<Place> &valid_places);
void RemoveNode(const mir::Node *node); void RemoveNode(const mir::Node *node);
mir::Node *Argument(const std::string &name);
std::vector<mir::Node *> StmtTopologicalOrder(); std::vector<mir::Node *> StmtTopologicalOrder();
// The inputs of the graph. // The inputs of the graph.
...@@ -68,9 +66,7 @@ class SSAGraph : GraphBase { ...@@ -68,9 +66,7 @@ class SSAGraph : GraphBase {
const std::vector<Place> &valid_places); const std::vector<Place> &valid_places);
private: private:
void GraphCreateTmpVarNodes(const Program &program); mir::Node *Argument(const std::string &name);
void GraphCreateWeightVarNodes(const Program &program);
// Check the bidirectional connection. // Check the bidirectional connection.
bool CheckBidirectionalConnection(); bool CheckBidirectionalConnection();
bool CheckNodesRoleSet(); bool CheckNodesRoleSet();
......
...@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, ...@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
<< " for kernel " << inst.op->DebugString() << " " << " for kernel " << inst.op->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type; << *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist. // Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph, AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
inst_node, valid_places_); valid_places_);
} }
} }
void TypeTargetTransformPass::AddIoCopyInst( void TypeTargetTransformPass::AddIoCopyInst(
const Type& from, const Type& to, const std::string& var, SSAGraph* graph, const Type& from, const Type& to, Node* in, SSAGraph* graph,
Node* inst_node, const std::vector<Place>& valid_places) { Node* inst_node, const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set"; CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst // var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node. // So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); }; auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id()); auto io_copy_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name); auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
auto* io_copy_inst = graph->NewInstructNode(); auto* io_copy_inst = graph->NewInstructNode();
...@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Create IoCopy Instruction. // Create IoCopy Instruction.
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
op_desc.SetType("io_copy"); op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var}); op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name}); op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
...@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst(
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
// Remove the old link // Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node); RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc. // Update the original instruction OpDesc.
// Update its input to the io_copy_output_name // Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst // Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(graph->Argument(var), io_copy_inst); DirectedLink(in, io_copy_inst);
DirectedLink(io_copy_inst, io_copy_output_arg); DirectedLink(io_copy_inst, io_copy_output_arg);
DirectedLink(io_copy_output_arg, inst_node); DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var, UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name,
io_copy_output_name); io_copy_output_name);
inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(),
......
...@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass { ...@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass {
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in); void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddIoCopyInst(const Type& from, const Type& to, const std::string& var, void AddIoCopyInst(const Type& from, const Type& to, Node* in,
SSAGraph* graph, Node* inst_node, SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places); const std::vector<Place>& valid_places);
......
...@@ -13,7 +13,10 @@ ...@@ -13,7 +13,10 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <map>
#include <memory> #include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h" #include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h" #include "paddle/fluid/lite/core/target_wrapper.h"
...@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass {
// LOG(INFO) << "- inferencing type " << // LOG(INFO) << "- inferencing type " <<
// deal with inputs // deal with inputs
VLOG(4) << "inferencing op " << inst.op_type; VLOG(4) << "inferencing op " << inst.op_type;
for (auto& arg_name : inst.op_info()->input_argnames()) { // TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&](
const std::string& node_name,
const std::map<std::string, std::vector<std::string>>& argname_map)
-> std::string {
for (auto& ele : argname_map) {
auto it =
std::find(ele.second.begin(), ele.second.end(), node_name);
if (it != ele.second.end()) return ele.first;
}
return "";
};
for (auto* x_in : x->inlinks) {
std::string node_name = x_in->AsArg().name;
std::string arg_name = get_argname(node_name, inst.op_info()->inputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(3) << "-- input arg_name " << arg_name; VLOG(3) << "-- input arg_name " << arg_name;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name); auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->inputs().at(arg_name); if (!x_in->AsArg().type) {
VLOG(4) << "set type " << *type << " " << x_in;
for (auto& arg_name : arg_names) { x_in->AsArg().type = type;
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArg();
if (!arg_node.type) {
VLOG(4) << "set type " << *type << " " << node;
arg_node.type = type;
}
} }
} }
for (auto& arg_name : inst.op_info()->output_argnames()) { for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name;
std::string arg_name =
get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(3) << "-- output arg_name " << arg_name; VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
auto arg_names = inst.op_info()->outputs().at(arg_name); if (!x_out->AsArg().type) {
// check if outputs's place is set, if not set, update them with the VLOG(4) << "set type " << *type << " " << x_out;
// kernel's declaration. x_out->AsArg().type = type;
for (auto& arg_name : arg_names) {
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArg();
if (!arg_node.type) {
node->AsArg().type = type;
VLOG(3) << "set type " << *type;
}
} }
} }
} }
......
...@@ -59,7 +59,7 @@ class OpLite : public Registry { ...@@ -59,7 +59,7 @@ class OpLite : public Registry {
} }
void SetValidPlaces(const std::vector<Place> &places) { void SetValidPlaces(const std::vector<Place> &places) {
LOG(INFO) << "valid places " << valid_places_.size(); VLOG(3) << "valid places " << valid_places_.size();
valid_places_ = places; valid_places_ = places;
} }
const std::vector<Place> &valid_places() const { return valid_places_; } const std::vector<Place> &valid_places() const { return valid_places_; }
......
...@@ -48,6 +48,9 @@ class Optimizer { ...@@ -48,6 +48,9 @@ class Optimizer {
if (passes.empty()) { if (passes.empty()) {
RunPasses(std::vector<std::string>{{ RunPasses(std::vector<std::string>{{
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", //
"lite_fc_fuse_pass", //
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"static_kernel_pick_pass", // "static_kernel_pick_pass", //
"variable_place_inference_pass", // "variable_place_inference_pass", //
......
...@@ -152,8 +152,8 @@ class BasicProfiler { ...@@ -152,8 +152,8 @@ class BasicProfiler {
} }
record_t *mutable_record(int id) { record_t *mutable_record(int id) {
CHECK_LT(id, records_.size());
CHECK_GE(id, 0); CHECK_GE(id, 0);
CHECK_LT(static_cast<size_t>(id), records_.size());
return &records_[id]; return &records_[id];
} }
......
...@@ -140,7 +140,7 @@ class RuntimeProgram { ...@@ -140,7 +140,7 @@ class RuntimeProgram {
void Run() { void Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
LOG(INFO) << ">> Running kernel: " << inst; VLOG(4) << ">> Running kernel: " << inst;
inst.Run(); inst.Run();
} }
} }
......
...@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
lite::Tensor col_matrix; lite::Tensor col_matrix;
if (is_expand) { if (is_expand) {
col.Resize(col_shape); col.Resize(col_shape);
col.mutable_data<T>();
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
...@@ -104,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -104,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data())); param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch; lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize( out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
input_shape.data())); output_matrix_shape.data()));
for (int g = 0; g < param.groups; g++) { for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice; lite::Tensor in_slice;
...@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW, ...@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW, ...@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -27,8 +27,8 @@ namespace kernels { ...@@ -27,8 +27,8 @@ namespace kernels {
namespace x86 { namespace x86 {
template <typename T> template <typename T>
void fc_compute_eigen(const T* x, int x_w, int x_h, // void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_w, int w_h, // const T* w, int w_h, int w_w, //
const T* b, // const T* b, //
T* out) { T* out) {
using matrix_t = using matrix_t =
...@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, // ...@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
Eigen::Map<const matrix_t> X(x, x_h, x_w); Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w); Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_h); Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W.transpose(); Out = X * W;
if (b) { if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_h); Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_w);
Out = Out.array().rowwise() + B.transpose().array(); Out = Out.array().rowwise() + B.transpose().array();
} }
} }
template <typename T> template <typename T>
__attribute__((optimize("unroll-loops"))) // void fc_compute_naive(const T* x, int x_h, int x_w, //
T dot(const T* x, const T* y, int dim) { const T* w, int w_h, int w_w, //
T out{};
for (int i = 0; i < dim; i++) {
out += x[i] * y[i];
}
return out;
}
template <typename T>
void fc_compute_naive(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
const T* b, // const T* b, //
T* out) { T* out) {
CHECK_EQ(x_w, w_w); CHECK_EQ(x_w, w_h);
// out shape: (x_h, w_w) // out shape: (x_h, w_w)
memset(out, 0, x_h * w_h * sizeof(T)); memset(out, 0, x_h * w_w * sizeof(T));
for (int i = 0; i < x_h; i++) {
for (int r = 0; r < x_h; r++) { for (int j = 0; j < w_w; j++) {
for (int c = 0; c < w_h; c++) { T tmp = static_cast<T>(0);
out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c]; for (int k = 0; k < x_w; k++) {
tmp += x[i * x_w + k] * w[k * w_w + j];
}
out[i * w_w + j] = tmp + b[j];
} }
} }
} }
...@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
.Slice(param.in_num_col_dims, param.input->dims().size()) .Slice(param.in_num_col_dims, param.input->dims().size())
.production(), .production(),
param.w->data<T>(), // w param.w->data<T>(), // w
param.w->dims()[1], // w_w
param.w->dims()[0], // w_h param.w->dims()[0], // w_h
param.w->dims()[1], // w_w
param.bias->data<T>(), // b param.bias->data<T>(), // b
param.output->mutable_data<T>()); param.output->mutable_data<T>());
} }
......
...@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW, REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ReluCompute<float>, def) paddle::lite::kernels::x86::ReluCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize(); .Finalize();
...@@ -91,7 +91,7 @@ void LoadLoDTensor(std::istream &is, Variable *var) { ...@@ -91,7 +91,7 @@ void LoadLoDTensor(std::istream &is, Variable *var) {
auto *tensor = var->GetMutable<lite::Tensor>(); auto *tensor = var->GetMutable<lite::Tensor>();
uint32_t version{}; uint32_t version{};
is.read(reinterpret_cast<char *>(&version), sizeof(version)); is.read(reinterpret_cast<char *>(&version), sizeof(version));
LOG(INFO) << "model version " << version; VLOG(3) << "model version " << version;
// Load LoD information // Load LoD information
uint64_t lod_level{}; uint64_t lod_level{};
...@@ -154,7 +154,7 @@ void LoadModel(const std::string &model_dir, Scope *scope, ...@@ -154,7 +154,7 @@ void LoadModel(const std::string &model_dir, Scope *scope,
continue; continue;
std::string file_path = model_dir + "/" + var.name(); std::string file_path = model_dir + "/" + var.name();
LOG(INFO) << "reading weight " << var.name(); VLOG(4) << "reading weight " << var.name();
std::ifstream file(file_path); std::ifstream file(file_path);
switch (var.type().type()) { switch (var.type().type()) {
......
...@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS}) ...@@ -20,7 +20,7 @@ cc_library(fill_constant_op_lite SRCS fill_constant_op.cc DEPS ${op_DEPS})
cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite) cc_library(op_params_lite SRCS op_params.cc DEPS ${tensor_lite} any_lite framework_proto_lite)
cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS}) cc_library(dropout_op_lite SRCS dropout_op.cc DEPS ${op_DEPS})
cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS}) cc_library(concat_op_lite SRCS concat_op.cc DEPS ${op_DEPS})
cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS}) # cc_library(split_op_lite SRCS split_op.cc DEPS ${op_DEPS})
set(ops_lite set(ops_lite
conv_op_lite conv_op_lite
...@@ -41,7 +41,7 @@ set(ops_lite ...@@ -41,7 +41,7 @@ set(ops_lite
activation_ops_lite activation_ops_lite
dropout_op_lite dropout_op_lite
concat_op_lite concat_op_lite
split_op_lite #split_op_lite
PARENT_SCOPE) PARENT_SCOPE)
lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc lite_cc_test(test_fc_op_lite SRCS fc_op_test.cc
...@@ -56,4 +56,3 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m ...@@ -56,4 +56,3 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite) lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite) lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite) lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
...@@ -30,25 +30,27 @@ class ConvOpLite : public OpLite { ...@@ -30,25 +30,27 @@ class ConvOpLite : public OpLite {
public: public:
ConvOpLite() {} ConvOpLite() {}
explicit ConvOpLite(const std::string &type) : OpLite(type) {} explicit ConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override; bool CheckShape() const override;
bool InferShape() const override; bool InferShape() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
auto input = op_desc.Input("Input").front(); auto X = op_desc.Input("Input").front();
auto filter = op_desc.Input("Filter").front(); auto Filter = op_desc.Input("Filter").front();
auto output = op_desc.Output("Output").front(); auto Out = op_desc.Output("Output").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>(); param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(output)); param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(output)->GetMutable<lite::Tensor>(); param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides"); param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings"); param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups"); param_.groups = op_desc.GetAttr<int>("groups");
param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations"); param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations");
// optional params // optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames(); std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") != if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
...@@ -58,7 +60,7 @@ class ConvOpLite : public OpLite { ...@@ -58,7 +60,7 @@ class ConvOpLite : public OpLite {
auto bias_var = scope->FindVar(bias_arguments.front()); auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) { if (bias_var != nullptr) {
param_.bias = param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>())); const_cast<lite::Tensor*>(&(bias_var->Get<lite::Tensor>()));
} }
} }
} }
...@@ -68,7 +70,7 @@ class ConvOpLite : public OpLite { ...@@ -68,7 +70,7 @@ class ConvOpLite : public OpLite {
if (res_data_arguments.size() > 0) { if (res_data_arguments.size() > 0) {
auto residual_data_var = scope->FindVar(res_data_arguments.front()); auto residual_data_var = scope->FindVar(res_data_arguments.front());
if (residual_data_var != nullptr) { if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor *>( param_.residualData = const_cast<lite::Tensor*>(
&(residual_data_var->Get<lite::Tensor>())); &(residual_data_var->Get<lite::Tensor>()));
} }
} }
...@@ -77,7 +79,7 @@ class ConvOpLite : public OpLite { ...@@ -77,7 +79,7 @@ class ConvOpLite : public OpLite {
return true; return true;
} }
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conv2d"; } std::string DebugString() const override { return "conv2d"; }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册