提交 0c3191f4 编写于 作者: N nhzlx

Merge branch 'incubate/lite' of https://github.com/PaddlePaddle/Paddle into xzl/incubate/lite

1. fc bias fuse
2. conv + bias + relu fuse
3. conv + bn fuse
4. graph -> ssa graph
......@@ -10,6 +10,7 @@ message(STATUS "LITE_WITH_ARM:\t${LITE_WITH_ARM}")
message(STATUS "LITE_WITH_PROFILE:\t${LITE_WITH_PROFILE}")
set(LITE_MODEL_DIR "${THIRD_PARTY_PATH}/install")
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
function(lite_download_and_uncompress INSTALL_DIR URL FILENAME)
message(STATUS "Download inference test stuff from ${URL}/${FILENAME}")
......@@ -161,13 +162,13 @@ function(lite_cc_test TARGET)
file(APPEND ${offline_test_registry_file} "${TARGET}\n")
endfunction()
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(core)
add_subdirectory(x86)
add_subdirectory(arm)
add_subdirectory(host)
add_subdirectory(cuda)
add_subdirectory(operators)
add_subdirectory(kernels)
add_subdirectory(model_parser)
add_subdirectory(utils)
add_subdirectory(api)
......
......@@ -5,7 +5,7 @@ if(LITE_WITH_CUDA)
nv_test(test_cxx_api_lite_cuda SRCS cxx_api_test.cc DEPS cxx_api_lite_cuda)
endif()
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite})
cc_library(cxx_api_lite SRCS cxx_api.cc DEPS ${cxx_api_lite_deps} ${ops_lite} program_lite)
set(light_api_deps
scope_lite target_wrapper_host model_parser_lite)
......@@ -21,15 +21,13 @@ message(STATUS "get Host kernels ${host_kernels}")
message(STATUS "get ARM kernels ${arm_kernels}")
include(ExternalProject)
set(LITE_URL "http://paddle-inference-dist.bj.bcebos.com" CACHE STRING "inference download url")
set(LITE_DEMO_INSTALL_DIR "${THIRD_PARTY_PATH}/inference_demo" CACHE STRING
"A path setting inference demo download directories.")
if((NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) AND WITH_TESTING)
lite_cc_test(test_cxx_api_lite SRCS cxx_api_test.cc
DEPS cxx_api_lite model_parser_lite target_wrapper_host
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels}
PROFILE_DEPS basic_profiler_lite
ARGS --model_dir=${LITE_MODEL_DIR}/lite_naive_model
--optimized_model=${LITE_MODEL_DIR}/lite_naive_model_opt SERIAL)
......
......@@ -30,7 +30,9 @@ cc_library(op_lite SRCS op_lite.cc DEPS scope_lite op_registry_lite target_wrapp
cc_library(types_lite SRCS types.cc)
cc_library(type_system SRCS type_system.cc DEPS ${tensor_lite} target_wrapper_lite)
lite_cc_library(program_lite SRCS program.cc DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite HVY_DEPS framework_proto
lite_cc_library(program_lite SRCS program.cc
DEPS op_lite kernel_lite compatible_pb_lite model_parser_lite
HVY_DEPS framework_proto
PROFILE_DEPS basic_profiler_lite)
cc_library(optimizer_lite SRCS optimizer.cc DEPS mir_pass_manager model_parser_lite program_lite)
......
......@@ -3,8 +3,13 @@ cc_library(mir_ssa_graph SRCS ssa_graph.cc DEPS mir_node program_lite)
cc_library(mir_pass SRCS pass.cc DEPS mir_ssa_graph)
cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir_passes)
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion)
cc_library(mir_passes
SRCS static_kernel_pick_pass.cc
SRCS fc_fuse_pass.cc
conv_elementwise_add_relu_fuse_pass.cc
conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
io_copy_kernel_pick_pass.cc
......@@ -13,13 +18,8 @@ cc_library(mir_passes
argument_type_display_pass.cc
demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite)
DEPS mir_pass types_lite context_lite ${mir_fusers})
# for mobile, unnecessary to compile the following testings.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
#cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope_lite op_lite
#fc_op_lite
......@@ -52,11 +52,37 @@ lite_cc_test(test_pattern_matcher_lite SRCS pattern_matcher_test.cc DEPS pattern
lite_cc_library(pattern_matcher_high_api SRCS pattern_matcher_high_api.cc DEPS pattern_matcher_lite)
# TODO(wz) replace framework/proto to lite proto.
# for mobile, unnecessary to compile the following testings.
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
return()
endif()
cc_test(test_mir_pass_manager SRCS pass_manager_test.cc DEPS mir_pass_manager mir_passes)
# TODO(wz) replace framework/proto to lite proto.
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# it depends on the fluid/framework/proto, that is too heavy for mobile execution.
lite_cc_test(test_pattern_matcher_high_api SRCS pattern_matcher_high_api_test.cc DEPS
pattern_matcher_high_api proto_desc mir_pass_manager fc_op_lite mul_op_lite elementwise_ops_lite
mir_passes compatible_pb_lite program_lite ${ops_lite})
endif()
message(STATUS "----> Ops lite: ${ops_lite}")
message(STATUS "----> Host kernels: ${host_kernels}")
message(STATUS "----> X86 kernels: ${x86_kernels}")
lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model
--optimized_model=${LITE_MODEL_DIR}/lite_fc_model_opt SERIAL)
lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz")
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_relu_fuse
SRCS conv_elementwise_add_relu_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvBNFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvBNFuser fuser2("depthwise_conv2d");
fuser2(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvBNFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1");
main_block->Var("input_2");
main_block->Var("filter_1");
main_block->Var("filter_2");
main_block->Var("conv2d_1_out");
main_block->Var("conv2d_2_out");
main_block->Var("bias_1");
main_block->Var("add_1_out");
main_block->Var("add_2_out");
main_block->Var("relu_1_out");
main_block->Var("out");
scope->Var("input_1")->GetMutable<lite::Tensor>();
scope->Var("input_2")->GetMutable<lite::Tensor>();
scope->Var("filter_1")->GetMutable<lite::Tensor>();
scope->Var("filter_2")->GetMutable<lite::Tensor>();
scope->Var("conv2d_1_out")->GetMutable<lite::Tensor>();
scope->Var("conv2d_2_out")->GetMutable<lite::Tensor>();
scope->Var("bias_1")->GetMutable<lite::Tensor>();
scope->Var("add_1_out")->GetMutable<lite::Tensor>();
scope->Var("add_2_out")->GetMutable<lite::Tensor>();
scope->Var("relu_1_out")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
conv2d_1->SetType("conv2d");
conv2d_1->SetInput("Input", {"input_1"});
conv2d_1->SetInput("Filter", {"filter_1"});
conv2d_1->SetOutput("Output", {"conv2d_1_out"});
conv2d_1->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_1->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_1->SetAttr("groups", 1);
conv2d_1->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_1->SetAttr("fuse_relu", false);
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"conv2d_1_out"});
add_1->SetInput("Y", {"bias_1"});
add_1->SetOutput("Out", {"add_1_out"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
conv2d_2->SetInput("Input", {"input_2"});
conv2d_2->SetInput("Filter", {"filter_2"});
conv2d_2->SetOutput("Output", {"conv2d_2_out"});
conv2d_2->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_2->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_2->SetAttr("groups", 1);
conv2d_2->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_2->SetAttr("fuse_relu", false);
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"conv2d_2_out"});
add_2->SetInput("Y", {"relu_1_out"});
add_2->SetOutput("Out", {"add_2_out"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/);
Visualize(graph.get());
}
TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(relu);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::FcFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class FcFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
TEST(fc_fuse_pass, fuse_test) {
lite::ExecutorLite predictor;
#ifndef LITE_WITH_CUDA
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
#else
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kNCHW)},
Place{TARGET(kCUDA), PRECISION(kAny), DATALAYOUT(kAny)},
Place{TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)},
});
#endif
predictor.Build(FLAGS_model_dir,
Place{TARGET(kX86), PRECISION(kFloat)}, // origin cuda
valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({100, 100})));
auto* data = input_tensor->mutable_data<float>();
for (int i = 0; i < 100 * 100; i++) {
data[i] = i;
}
predictor.Run();
auto* out = predictor.GetOutput(0);
LOG(INFO) << out << " memory size " << out->data_size();
LOG(INFO) << "out " << out->data<float>()[0];
LOG(INFO) << "out " << out->data<float>()[1];
LOG(INFO) << "dims " << out->dims();
EXPECT_NEAR(out->data<float>()[0], 38.120617f, 1e-5);
EXPECT_NEAR(out->data<float>()[1], 10.109812f, 1e-5);
CHECK_EQ(out->dims()[0], 100);
CHECK_EQ(out->dims()[1], 500);
}
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
TEST(fc_fuse_pass, save_model_test) {
lite::ExecutorLite predictor;
std::vector<Place> valid_places({Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
predictor.Build(FLAGS_model_dir, Place{TARGET(kX86), PRECISION(kFloat)},
valid_places);
LOG(INFO) << "Save optimized model to " << FLAGS_optimized_model;
predictor.SaveModel(FLAGS_optimized_model);
}
#endif // LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(mul);
USE_LITE_OP(elementwise_add);
USE_LITE_OP(elementwise_sub);
USE_LITE_OP(fc);
USE_LITE_OP(feed);
USE_LITE_OP(fetch);
USE_LITE_OP(io_copy);
USE_LITE_OP(softmax);
USE_LITE_OP(scale);
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
#ifdef LITE_WITH_X86
USE_LITE_KERNEL(mul, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(fc, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_sub, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(elementwise_add, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kX86, kFloat, kNCHW, def);
USE_LITE_KERNEL(scale, kX86, kFloat, kNCHW, def);
#endif
#ifdef LITE_WITH_CUDA
USE_LITE_KERNEL(mul, kCUDA, kFloat, kNCHW, def);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, host_to_device);
USE_LITE_KERNEL(io_copy, kCUDA, kAny, kAny, device_to_host);
#endif
cc_library(fuse_fc
SRCS fc_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_conv_elementwise_add_relu
SRCS conv_elementwise_add_relu_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_conv_bn
SRCS conv_bn_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
fuse_conv_elementwise_add_relu
fuse_conv_bn
CACHE INTERNAL "fusers")
lite_cc_test(test_lite_conv_bn_fuse SRCS conv_bn_fuse_pass_test.cc
DEPS elementwise_ops_lite batch_norm_op_lite conv_op_lite proto_desc compatible_pb_lite program_lite mir_pass mir_pass_manager pattern_matcher_high_api)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/program.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv_op = main_block->AppendOp();
auto* bn_op = main_block->AppendOp();
main_block->Var("conv_i");
main_block->Var("conv_param");
main_block->Var("conv_out");
main_block->Var("bn_scale");
main_block->Var("bn_bias");
main_block->Var("bn_mean");
main_block->Var("bn_var");
main_block->Var("bn_out");
main_block->Var("bn_mean_out");
main_block->Var("bn_var_out");
main_block->Var("bn_saved_mean");
main_block->Var("bn_saved_var");
scope->Var("conv_i")->GetMutable<lite::Tensor>();
auto conv_param_t = scope->Var("conv_param")->GetMutable<lite::Tensor>();
std::vector<int64_t> conv_param_shape = {3, 1, 2, 2};
conv_param_t->Resize(lite::DDim(conv_param_shape));
conv_param_t->mutable_data<float>();
scope->Var("conv_out")->GetMutable<lite::Tensor>();
auto bn_scale_t = scope->Var("bn_scale")->GetMutable<lite::Tensor>();
std::vector<int64_t> bn_scale_shape = {3};
bn_scale_t->Resize(lite::DDim(bn_scale_shape));
bn_scale_t->mutable_data<float>();
auto bn_bias_t = scope->Var("bn_bias")->GetMutable<lite::Tensor>();
std::vector<int64_t> bn_bias_shape = {3};
bn_bias_t->Resize(lite::DDim(bn_bias_shape));
bn_bias_t->mutable_data<float>();
auto bn_mean_t = scope->Var("bn_mean")->GetMutable<lite::Tensor>();
bn_mean_t->Resize(lite::DDim(bn_bias_shape));
bn_mean_t->mutable_data<float>();
auto bn_var_t = scope->Var("bn_var")->GetMutable<lite::Tensor>();
bn_var_t->Resize(lite::DDim(bn_bias_shape));
bn_var_t->mutable_data<float>();
scope->Var("bn_out")->GetMutable<lite::Tensor>();
scope->Var("bn_mean_out")->GetMutable<lite::Tensor>();
scope->Var("bn_var_out")->GetMutable<lite::Tensor>();
scope->Var("bn_saved_mean")->GetMutable<lite::Tensor>();
scope->Var("bn_saved_var")->GetMutable<lite::Tensor>();
conv_op->SetType("conv2d");
conv_op->SetInput("Input", {"conv_i"});
conv_op->SetInput("Filter", {"conv_param"});
conv_op->SetOutput("Output", {"conv_out"});
const std::vector<int> strides({1, 1});
const std::vector<int> paddings({1, 1});
const std::vector<int> dilations({1, 1});
const int groups = 1;
conv_op->SetAttr("strides", strides);
conv_op->SetAttr("paddings", paddings);
conv_op->SetAttr("dilations", dilations);
conv_op->SetAttr("groups", groups);
bn_op->SetType("batch_norm");
bn_op->SetInput("X", {"conv_out"});
bn_op->SetInput("Bias", {"bn_bias"});
bn_op->SetInput("Mean", {"bn_mean"});
bn_op->SetInput("Scale", {"bn_scale"});
bn_op->SetInput("Variance", {"bn_var"});
bn_op->SetOutput("Y", {"bn_out"});
bn_op->SetOutput("MeanOut", {"bn_mean_out"});
bn_op->SetOutput("VarianceOut", {"bn_var_out"});
bn_op->SetOutput("SavedMean", {"bn_saved_mean"});
bn_op->SetOutput("SavedVariance", {"bn_saved_var"});
float eps = 1e-5;
bn_op->SetAttr("epsilon", eps);
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(pattern_matcher2, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvBNFusePass;
fuser->Apply(graph);
ASSERT_EQ(graph->nodes().size(),
num_nodes - 8UL /*nodes removed */ + 1UL /* eltwise_add node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(conv2d);
USE_LITE_OP(batch_norm);
USE_LITE_OP(elementwise_add);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvBNFuser::BuildPattern() {
auto* conv_input =
VarNode("conv_input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* conv_weight = VarNode("conv_weight")
->assert_is_op_input(conv_type_, "Filter")
->AsInput();
auto* conv = OpNode("conv2d", conv_type_)->assert_is_op(conv_type_);
auto* conv_out = VarNode("conv_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("batch_norm", "X");
auto* bn_scale = VarNode("bn_scale")
->assert_is_op_input("batch_norm", "Scale")
->AsIntermediate();
auto* bn_bias =
VarNode("bn_bias")->assert_is_op_input("batch_norm", "Bias")->AsInput();
auto* bn_mean = VarNode("bn_mean")
->assert_is_op_input("batch_norm", "Mean")
->AsIntermediate();
auto* bn_var = VarNode("bn_variance")
->assert_is_op_input("batch_norm", "Variance")
->AsIntermediate();
auto* bn =
OpNode("bn", "batch_norm")->assert_is_op("batch_norm")->AsIntermediate();
auto* bn_out =
VarNode("bn_out")->assert_is_op_output("batch_norm", "Y")->AsOutput();
auto* bn_mean_out = VarNode("bn_mean_out")
->assert_is_op_output("batch_norm", "MeanOut")
->AsIntermediate();
auto* bn_var_out = VarNode("bn_var_out")
->assert_is_op_output("batch_norm", "VarianceOut")
->AsIntermediate();
auto* bn_saved_mean = VarNode("bn_saved_mean")
->assert_is_op_output("batch_norm", "SavedMean")
->AsIntermediate();
auto* bn_saved_var = VarNode("bn_saved_var")
->assert_is_op_output("batch_norm", "SavedVariance")
->AsIntermediate();
conv->LinksFrom({conv_input, conv_weight}).LinksTo({conv_out});
bn->LinksFrom({conv_out, bn_scale, bn_bias, bn_mean, bn_var})
.LinksTo({bn_out, bn_mean_out, bn_saved_mean, bn_saved_var, bn_var_out});
}
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv = matched.at("conv2d")->stmt()->op;
auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
auto conv_weight_t = scope->FindVar(matched.at("conv_weight")->arg()->name)
->GetMutable<lite::Tensor>();
auto conv_weight_d = conv_weight_t->mutable_data<float>();
auto conv_weight_dims = conv_weight_t->dims();
size_t weight_num = conv_weight_t->data_size();
auto bn_scale_t = scope->FindVar(matched.at("bn_scale")->arg()->name)
->GetMutable<lite::Tensor>();
size_t bias_size = bn_scale_t->data_size();
auto bn_scale_d = bn_scale_t->mutable_data<float>();
PADDLE_ENFORCE(bias_size == conv_weight_dims[0],
"The BN bias's size should be equal to the size of the first "
"dim size of the conv weights");
auto bn_mean_t = scope->FindVar(matched.at("bn_mean")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_mean_d = bn_mean_t->mutable_data<float>();
auto bn_var_t = scope->FindVar(matched.at("bn_variance")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_var_d = bn_var_t->mutable_data<float>();
auto bn_bias_t = scope->FindVar(matched.at("bn_bias")->arg()->name)
->GetMutable<lite::Tensor>();
auto bn_bias_d = bn_bias_t->mutable_data<float>();
auto eps = matched.at("bn")->stmt()->op_info()->GetAttr<float>("epsilon");
ComputeFusedWeight(bn_scale_d, bn_mean_d, bn_var_d, bn_bias_d, conv_weight_d,
eps, bias_size, weight_num / bias_size);
eltwise_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(eltwise_op, valid_places);
IR_NODE_LINK_TO(matched.at("conv_out"), new_op_node);
IR_NODE_LINK_TO(matched.at("bn_bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("bn_out"));
}
cpp::OpDesc ConvBNFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("elementwise_add");
op_desc.SetInput("X", {matched.at("conv_out")->arg()->name});
op_desc.SetInput("Y", {matched.at("bn_bias")->arg()->name});
op_desc.SetOutput("Out", {matched.at("bn_out")->arg()->name});
op_desc.SetAttr("axis", 1);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvBNFuser : public FuseBase {
public:
explicit ConvBNFuser(const std::string& conv_type) : conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
void ComputeFusedWeight(float* scale_d, float* mean_d, float* var_d,
float* bias_d, float* conv_weight_d, float eps, int h,
int w) {
for (int i = 0; i < h; i++) {
var_d[i] = scale_d[i] / std::sqrt(var_d[i] + eps);
}
for (int i = 0; i < h; i++) {
bias_d[i] += (-mean_d[i]) * var_d[i];
}
for (int i = 0; i < h; i++) {
for (int j = 0; j < w; j++) {
conv_weight_d[i * w + j] *= var_d[i];
}
}
}
private:
std::string conv_type_{"conv2d"};
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void ConvElementwiseAddReLUFuser::BuildPattern() {
// create input nodes.
auto* input =
VarNode("input")->assert_is_op_input(conv_type_, "Input")->AsInput();
auto* filter =
VarNode("filter")->assert_is_op_input(conv_type_, "Filter")->AsInput();
auto* bias =
VarNode("bias")->assert_is_op_input("elementwise_add", "Y")->AsInput();
// create op nodes
auto* conv2d =
OpNode("conv2d", conv_type_)->assert_is_op(conv_type_)->AsIntermediate();
auto* add = OpNode("add", "elementwise_add")
->assert_is_op("elementwise_add")
->AsIntermediate();
auto* relu = OpNode("relu", "relu")->assert_is_op("relu")->AsIntermediate();
// create intermediate nodes
auto* conv2d_out = VarNode("conv2d_out")
->assert_is_op_output(conv_type_, "Output")
->assert_is_op_input("elementwise_add", "X")
->AsIntermediate();
auto* add_out = VarNode("add_out")
->assert_is_op_output("elementwise_add", "Out")
->assert_is_op_input("relu", "X")
->AsIntermediate();
// create output node
auto* out = VarNode("output")->assert_is_op_output("relu", "Out")->AsOutput();
// create topology.
std::vector<PMNode*> conv2d_inputs{filter, input};
std::vector<PMNode*> add_inputs{conv2d_out, bias};
conv2d_inputs >> *conv2d >> *conv2d_out;
add_inputs >> *add >> *add_out;
*add_out >> *relu >> *out;
}
void ConvElementwiseAddReLUFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op;
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(conv_op, valid_places);
IR_NODE_LINK_TO(matched.at("input"), new_op_node);
IR_NODE_LINK_TO(matched.at("filter"), new_op_node);
IR_NODE_LINK_TO(matched.at("bias"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("output"));
}
cpp::OpDesc ConvElementwiseAddReLUFuser::GenOpDesc(const key2nodes_t& matched) {
auto* desc = matched.at("conv2d")->stmt()->op_info();
cpp::OpDesc op_desc;
op_desc.SetType(conv_type_);
op_desc.SetInput("Input", {matched.at("input")->arg()->name});
op_desc.SetInput("Filter", {matched.at("filter")->arg()->name});
op_desc.SetInput("Bias", {matched.at("bias")->arg()->name});
op_desc.SetOutput("Output", {matched.at("output")->arg()->name});
// Other inputs. See operators/conv_op.h
std::vector<std::string> input_arg_names = desc->InputArgumentNames();
for (auto name : input_arg_names) LOG(INFO) << name;
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
op_desc.SetInput("ResidualData", desc->Input("ResidualData"));
}
// Only consider strides, padding, groups, dilations, fuse_relu for now
op_desc.SetAttr("strides", desc->GetAttr<std::vector<int>>("strides"));
op_desc.SetAttr("paddings", desc->GetAttr<std::vector<int>>("paddings"));
op_desc.SetAttr("groups", desc->GetAttr<int>("groups"));
op_desc.SetAttr("dilations", desc->GetAttr<std::vector<int>>("dilations"));
op_desc.SetAttr("fuse_relu", true);
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class ConvElementwiseAddReLUFuser : public FuseBase {
public:
explicit ConvElementwiseAddReLUFuser(const std::string& conv_type)
: conv_type_(conv_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
std::string conv_type_;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void FcFuser::BuildPattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
auto* Out = VarNode("Out");
// create topology.
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
add_inputs >> *add >> *Out;
// Some op specialities.
mul_out->AsIntermediate();
mul->AsIntermediate();
add->AsIntermediate();
}
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fc_op, valid_places);
IR_NODE_LINK_TO(matched.at("W"), new_op_node);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("b"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Out"));
}
cpp::OpDesc FcFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
op_desc.SetType("fc");
op_desc.SetInput("Input", {matched.at("x")->arg()->name});
op_desc.SetInput("W", {matched.at("W")->arg()->name});
op_desc.SetInput("Bias", {matched.at("b")->arg()->name});
op_desc.SetOutput("Out", {matched.at("Out")->arg()->name});
op_desc.SetAttr(
"in_num_col_dims",
matched.at("mul")->stmt()->op_info()->GetAttr<int>("x_num_col_dims"));
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class FcFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -71,12 +71,20 @@ class Node {
struct Arg {
std::string name;
int id{0};
const Type* type{};
// Weight is a special kind of argument, it is marked as weight explicitly
// so that some weight related optimization can take place.
bool is_weight{false};
};
Arg& AsArg(const std::string& name, int id) {
auto& x = AsArg();
x.name = name;
x.id = id;
return x;
}
Arg& AsArg(const std::string& name) {
auto& x = AsArg();
x.name = name;
......
......@@ -22,6 +22,8 @@ namespace mir {} // namespace mir
} // namespace paddle
USE_MIR_PASS(demo);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass);
USE_MIR_PASS(static_kernel_pick_pass);
USE_MIR_PASS(variable_place_inference_pass);
USE_MIR_PASS(type_target_transform_pass);
......@@ -29,3 +31,5 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
......@@ -45,10 +45,11 @@ PMNode &PMNode::operator>>(std::vector<PMNode *> &nodes) {
return *this;
}
void operator>>(std::vector<PMNode *> &others, PMNode &me) {
PMNode &operator>>(std::vector<PMNode *> &others, PMNode &me) {
for (auto *o : others) {
*o >> me;
}
return me;
}
PMNode *PMPattern::NewNode(const std::string &name) {
......@@ -406,6 +407,67 @@ PMNode *PMNode::assert_is_op_output(const std::string &op_type) {
return this;
}
bool IsNthOutput(const Node *var, const Node *op, const std::string &argument,
size_t nth) {
PADDLE_ENFORCE(var->IsArg());
PADDLE_ENFORCE(op->IsStmt());
auto op_info = op->stmt()->op_info();
if (op_info->Output(argument).size() <= nth) return false;
return var->arg()->name == op_info->Output(argument)[nth];
}
bool IsNthInput(const Node *var, const Node *op, const std::string &argument,
size_t nth) {
PADDLE_ENFORCE(var->IsArg());
PADDLE_ENFORCE(op->IsStmt());
auto op_info = op->stmt()->op_info();
if (op_info->Input(argument).size() <= nth) return false;
return var->arg()->name == op_info->Input(argument)[nth];
}
PMNode *PMNode::assert_is_op_input(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_input(op_type, argument, 0);
return this;
}
PMNode *PMNode::assert_is_op_nth_input(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
assert_is_op_input(op_type);
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->outlinks) {
if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
IsNthInput(x, op, argument, nth))
return true;
}
return false;
});
return this;
}
PMNode *PMNode::assert_is_op_output(const std::string &op_type,
const std::string &argument) {
assert_is_var();
assert_is_op_nth_output(op_type, argument, 0);
return this;
}
PMNode *PMNode::assert_is_op_nth_output(const std::string &op_type,
const std::string &argument, int nth) {
assert_is_var();
asserts_.emplace_back([=](const Node *x) {
for (auto *op : x->inlinks) {
if (op && op->IsStmt() && op->stmt()->op_info()->Type() == op_type &&
IsNthOutput(x, op, argument, nth))
return true;
}
return false;
});
return this;
}
PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
assert_is_var();
asserts_.emplace_back([=](const Node *x) {
......@@ -422,6 +484,14 @@ PMNode *PMNode::assert_is_op_input(const std::string &op_type) {
return this;
}
bool HasInput(const Node &op, const std::string &argument) {
CHECK(op.IsStmt());
auto const &names = op.stmt()->op_info()->input_argnames();
if (std::find(names.begin(), names.end(), argument) == names.end())
return false;
return true;
}
void GraphSafeRemoveNodes(SSAGraph *graph,
const std::unordered_set<const Node *> &nodes) {
for (auto *node : nodes) {
......
......@@ -62,7 +62,7 @@ struct PMNode {
PMNode& operator>>(PMNode& right);
// Link many nodes to this node.
friend void operator>>(std::vector<PMNode*>& others, PMNode& me);
friend PMNode& operator>>(std::vector<PMNode*>& others, PMNode& me);
// Link this to many other nodes.
PMNode& operator>>(std::vector<PMNode*>& nodes);
......@@ -127,6 +127,15 @@ struct PMNode {
PMNode* assert_is_persistable_var();
PMNode* assert_is_op_output(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type);
PMNode* assert_is_op_input(const std::string& op_type,
const std::string& argument);
PMNode* assert_is_op_output(const std::string& op_type,
const std::string& argument);
PMNode* assert_is_op_nth_input(const std::string& op_type,
const std::string& argument, int nth);
PMNode* assert_is_op_nth_output(const std::string& op_type,
const std::string& argument, int nth);
template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
......@@ -297,6 +306,13 @@ class PatternMatcher {
std::unordered_map<const PMNode*, std::unordered_set<Node*>> pmnodes2nodes_;
};
// Check whether a var node is a op node's nth input.
bool IsNthInput(const Node& var, const Node& op, const std::string& argument,
int nth);
// Check whether the op node has input of given name.
bool HasInput(const Node& op, const std::string& argument);
// Graph safely remove some nodes, will automatically clean up the edges.
void GraphSafeRemoveNodes(SSAGraph* graph,
const std::unordered_set<const Node*>& nodes);
......
......@@ -64,7 +64,6 @@ class FuseBase {
// Delete nodes that are marked as Intermediate
void DeleteInterNodes(SSAGraph* graph);
private:
PMNode* GetOrCreateNode(const std::string& key);
protected:
......
......@@ -29,8 +29,8 @@ class FcFuser : public FuseBase {
public:
void BuildPattern() override {
// create nodes.
auto* x = VarNode("x");
auto* W = VarNode("W");
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b");
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
......@@ -38,12 +38,10 @@ class FcFuser : public FuseBase {
auto* Out = VarNode("Out");
// create topology.
// std::vector<PMNode*>({W, x}) >> *mul >> *mul_out;
// std::vector<PMNode*>({mul_out, b}) >> *add >> *Out;
*W >> *mul;
*x >> *mul >> *mul_out;
*b >> *add;
*mul_out >> *add >> *Out;
std::vector<PMNode*> mul_inputs{W, x};
std::vector<PMNode*> add_inputs{mul_out, b};
mul_inputs >> *mul >> *mul_out;
add_inputs >> *add >> *Out;
// Some op specialities.
mul_out->AsIntermediate();
......@@ -91,14 +89,12 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block->Var("mul_out");
main_block->Var("w");
main_block->Var("out");
main_block->Var("out1");
scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
scope->Var("out1")->GetMutable<lite::Tensor>();
mul->SetInput("X", {"x"});
mul->SetInput("Y", {"w"});
......@@ -122,18 +118,17 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
return graph;
}
TEST(pattern_matcher2, graph_test) {
TEST(pattern_matcher_high_api, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
ASSERT_EQ(graph->nodes().size(),
8UL /*real nodes*/ + 2UL /*feed op + fetch op*/);
ASSERT_EQ(graph->nodes().size(), 7UL /*real nodes*/);
Visualize(graph.get());
}
TEST(pattern_matcher2, test) {
TEST(pattern_matcher_high_api, fuse_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
......@@ -143,6 +138,7 @@ TEST(pattern_matcher2, test) {
fuser(graph.get());
ASSERT_EQ(graph->nodes().size(),
num_nodes - 3UL /*nodes removed */ + 1UL /* fused fc node*/);
Visualize(graph.get());
}
} // namespace mir
......
......@@ -16,6 +16,7 @@
#include <algorithm>
#include <memory>
#include <set>
#include <unordered_map>
#include <utility>
namespace paddle {
......@@ -25,6 +26,8 @@ namespace mir {
bool SSAGraph::CheckBidirectionalConnection() {
LOG(INFO) << "node count " << node_storage_.size();
for (auto &node : node_storage_) {
if (node.IsStmt()) LOG(INFO) << node.AsStmt().op_info()->Type();
if (node.IsArg()) LOG(INFO) << node.AsArg().name << " " << node.AsArg().id;
for (auto *in : node.inlinks) {
CHECK(in->outlinks.end() !=
std::find(in->outlinks.begin(), in->outlinks.end(), &node));
......@@ -93,31 +96,6 @@ std::vector<mir::Node *> SSAGraph::StmtTopologicalOrder() {
return res;
}
void SSAGraph::GraphCreateTmpVarNodes(const Program &program) {
for (const auto &name : program.tmp_vars()) {
CHECK(!arguments_.count(name)) << "duplicate creating temp variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
void SSAGraph::GraphCreateWeightVarNodes(const Program &program) {
// create weight nodes.
for (const auto &name : program.weights()) {
CHECK(!arguments_.count(name)) << "duplicate creating weight variable: "
<< name;
VLOG(5) << "create arg node " << name;
node_storage_.emplace_back();
auto &new_node = node_storage_.back();
new_node.AsArg(name);
arguments_[name] = &new_node;
}
}
Node *SSAGraph::GraphCreateInstructNode(
const std::shared_ptr<OpLite> &op, const std::vector<Place> &valid_places) {
node_storage_.emplace_back();
......@@ -135,29 +113,50 @@ Node *SSAGraph::GraphCreateInstructNode(
void SSAGraph::Build(const Program &program,
const std::vector<Place> &valid_places) {
CHECK(node_storage_.empty());
GraphCreateTmpVarNodes(program);
GraphCreateWeightVarNodes(program);
CHECK(CheckNodesRoleSet());
auto weights_name = program.weights();
auto is_weights = [&](const std::string &name) -> bool {
auto it = std::find(weights_name.begin(), weights_name.end(), name);
if (it == weights_name.end()) return false;
return true;
};
std::unordered_map<std::string, mir::Node *> arg_update_node_map_;
for (auto &op : program.ops()) {
LOG(INFO) << op->op_info()->Type();
auto *op_node = GraphCreateInstructNode(op, valid_places);
LOG(INFO) << "input:";
for (const std::string &name : op->op_info()->input_names()) {
auto *arg = Argument(name);
CHECK(arg->IsRoleSet());
DirectedLink(arg, op_node);
LOG(INFO) << name;
mir::Node *arg_node = nullptr;
if (arg_update_node_map_.count(name)) {
arg_node = arg_update_node_map_.at(name);
} else {
node_storage_.emplace_back();
arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
}
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
DirectedLink(arg_node, op_node);
}
LOG(INFO) << "output:";
for (const std::string &name : op->op_info()->output_names()) {
if (!arguments_.count(name)) {
NewArgumentNode(name);
}
auto *arg = arguments_.at(name);
CHECK(arg->IsRoleSet());
DirectedLink(op_node, arg);
LOG(INFO) << name;
node_storage_.emplace_back();
auto *arg_node = &node_storage_.back();
arg_node->AsArg(name, node_storage_.size() - 1);
arg_update_node_map_[name] = arg_node;
if (is_weights(name)) arg_node->AsArg().is_weight = true;
CHECK(arg_node->IsRoleSet());
DirectedLink(op_node, arg_node);
}
CHECK(CheckLinksRoleSet());
}
MarkArgumentWeights(program);
CHECK(CheckNodesRoleSet());
CheckValid();
}
......@@ -227,10 +226,9 @@ bool SSAGraph::CheckLinksRoleSet() {
Node *SSAGraph::NewArgumentNode(const std::string &name) {
node_storage_.emplace_back();
CHECK(!arguments_.count(name)) << "duplicate argument called " << name;
arguments_[name] = &node_storage_.back();
node_storage_.back().AsArg(name);
return &node_storage_.back();
auto &arg_node = node_storage_.back();
arg_node.AsArg(name, node_storage_.size() - 1);
return &arg_node;
}
Node *SSAGraph::NewInstructNode() {
......
......@@ -40,8 +40,6 @@ class SSAGraph : GraphBase {
void Build(const Program &program, const std::vector<Place> &valid_places);
void RemoveNode(const mir::Node *node);
mir::Node *Argument(const std::string &name);
std::vector<mir::Node *> StmtTopologicalOrder();
// The inputs of the graph.
......@@ -68,9 +66,7 @@ class SSAGraph : GraphBase {
const std::vector<Place> &valid_places);
private:
void GraphCreateTmpVarNodes(const Program &program);
void GraphCreateWeightVarNodes(const Program &program);
mir::Node *Argument(const std::string &name);
// Check the bidirectional connection.
bool CheckBidirectionalConnection();
bool CheckNodesRoleSet();
......
......@@ -65,20 +65,22 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
<< " for kernel " << inst.op->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in->AsArg().name, graph,
inst_node, valid_places_);
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
valid_places_);
}
}
void TypeTargetTransformPass::AddIoCopyInst(
const Type& from, const Type& to, const std::string& var, SSAGraph* graph,
const Type& from, const Type& to, Node* in, SSAGraph* graph,
Node* inst_node, const std::vector<Place>& valid_places) {
CHECK(!valid_places.empty()) << "valid_place should be set";
// var -> new_transform_op -> new_var -> inst
// So there will be a new Argument node and a new IoCopy Statement Node.
CHECK(in->IsArg());
auto node_id = [&] { return graph->nodes().size(); };
auto io_copy_output_name = var + "/trans/" + std::to_string(node_id());
auto io_copy_output_name =
in->AsArg().name + "/trans/" + std::to_string(node_id());
auto* io_copy_output_arg = graph->NewArgumentNode(io_copy_output_name);
auto* io_copy_inst = graph->NewInstructNode();
......@@ -92,7 +94,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
// Create IoCopy Instruction.
cpp::OpDesc op_desc;
op_desc.SetType("io_copy");
op_desc.SetInput("Input", {var});
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
......@@ -100,18 +102,18 @@ void TypeTargetTransformPass::AddIoCopyInst(
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
// Remove the old link
RemoveDirectedLink(graph->Argument(var), inst_node);
RemoveDirectedLink(in, inst_node);
// Update the original instruction OpDesc.
// Update its input to the io_copy_output_name
// Add new link, var -> new_inst, new_inst->newarg, newarg->inst
DirectedLink(graph->Argument(var), io_copy_inst);
DirectedLink(in, io_copy_inst);
DirectedLink(io_copy_inst, io_copy_output_arg);
DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), var,
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name,
io_copy_output_name);
inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(),
......
......@@ -45,7 +45,7 @@ class TypeTargetTransformPass : public ProgramPass {
void ComplementInputs(SSAGraph* graph, Node* inst_node, Node* in);
void AddIoCopyInst(const Type& from, const Type& to, const std::string& var,
void AddIoCopyInst(const Type& from, const Type& to, Node* in,
SSAGraph* graph, Node* inst_node,
const std::vector<Place>& valid_places);
......
......@@ -13,7 +13,10 @@
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/target_wrapper.h"
......@@ -60,40 +63,44 @@ class VariablePlaceInferencePass : public DebugPass {
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
VLOG(4) << "inferencing op " << inst.op_type;
for (auto& arg_name : inst.op_info()->input_argnames()) {
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&](
const std::string& node_name,
const std::map<std::string, std::vector<std::string>>& argname_map)
-> std::string {
for (auto& ele : argname_map) {
auto it =
std::find(ele.second.begin(), ele.second.end(), node_name);
if (it != ele.second.end()) return ele.first;
}
return "";
};
for (auto* x_in : x->inlinks) {
std::string node_name = x_in->AsArg().name;
std::string arg_name = get_argname(node_name, inst.op_info()->inputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(3) << "-- input arg_name " << arg_name;
// check if inputs's place is set, if not set, update them with the
// kernel's declaration.
auto type = inst.picked_kernel().GetInputDeclType(arg_name);
auto arg_names = inst.op_info()->inputs().at(arg_name);
for (auto& arg_name : arg_names) {
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArg();
if (!arg_node.type) {
VLOG(4) << "set type " << *type << " " << node;
arg_node.type = type;
}
if (!x_in->AsArg().type) {
VLOG(4) << "set type " << *type << " " << x_in;
x_in->AsArg().type = type;
}
}
for (auto& arg_name : inst.op_info()->output_argnames()) {
for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name;
std::string arg_name =
get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
auto arg_names = inst.op_info()->outputs().at(arg_name);
// check if outputs's place is set, if not set, update them with the
// kernel's declaration.
for (auto& arg_name : arg_names) {
VLOG(3) << "--- var " << arg_name;
auto* node = graph->RetrieveArgument(arg_name);
CHECK(node) << "argument " << arg_name << " not exists in the graph";
auto& arg_node = node->AsArg();
if (!arg_node.type) {
node->AsArg().type = type;
VLOG(3) << "set type " << *type;
}
if (!x_out->AsArg().type) {
VLOG(4) << "set type " << *type << " " << x_out;
x_out->AsArg().type = type;
}
}
}
......
......@@ -49,16 +49,19 @@ class Optimizer {
#ifndef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
if (passes.empty()) {
RunPasses(std::vector<std::string>{{
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
"runtime_context_assign_pass", //
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_act_fuse_pass", //
"lite_fc_fuse_pass", //
"static_kernel_pick_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"type_target_transform_pass", //
"argument_type_display_pass", //
"variable_place_inference_pass", //
"argument_type_display_pass", //
"io_copy_kernel_pick_pass", //
"variable_place_inference_pass", //
"runtime_context_assign_pass", //
}});
} else {
RunPasses(passes);
......
......@@ -152,8 +152,8 @@ class BasicProfiler {
}
record_t *mutable_record(int id) {
CHECK_LT(id, records_.size());
CHECK_GE(id, 0);
CHECK_LT(static_cast<size_t>(id), records_.size());
return &records_[id];
}
......
......@@ -74,6 +74,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
lite::Tensor col_matrix;
if (is_expand) {
col.Resize(col_shape);
col.mutable_data<T>();
col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape);
}
......@@ -104,7 +105,7 @@ class Conv2dCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
param.x->raw_tensor().Slice(i, i + 1).Resize(input_shape.data()));
lite::Tensor out_batch;
out_batch.ShareDataWith(param.output->raw_tensor().Slice(i, i + 1).Resize(
input_shape.data()));
output_matrix_shape.data()));
for (int g = 0; g < param.groups; g++) {
lite::Tensor in_slice;
......@@ -155,7 +156,6 @@ REGISTER_LITE_KERNEL(conv2d, kX86, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -164,6 +164,5 @@ REGISTER_LITE_KERNEL(depthwise_conv2d, kX86, kFloat, kNCHW,
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -27,8 +27,8 @@ namespace kernels {
namespace x86 {
template <typename T>
void fc_compute_eigen(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
void fc_compute_eigen(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
using matrix_t =
......@@ -36,38 +36,31 @@ void fc_compute_eigen(const T* x, int x_w, int x_h, //
Eigen::Map<const matrix_t> X(x, x_h, x_w);
Eigen::Map<const matrix_t> W(w, w_h, w_w);
Eigen::Map<matrix_t> Out(out, x_h, w_h);
Eigen::Map<matrix_t> Out(out, x_h, w_w);
Out = X * W.transpose();
Out = X * W;
if (b) {
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_h);
Eigen::Map<const Eigen::Matrix<T, Eigen::Dynamic, 1>> B(b, w_w);
Out = Out.array().rowwise() + B.transpose().array();
}
}
template <typename T>
__attribute__((optimize("unroll-loops"))) //
T dot(const T* x, const T* y, int dim) {
T out{};
for (int i = 0; i < dim; i++) {
out += x[i] * y[i];
}
return out;
}
template <typename T>
void fc_compute_naive(const T* x, int x_w, int x_h, //
const T* w, int w_w, int w_h, //
void fc_compute_naive(const T* x, int x_h, int x_w, //
const T* w, int w_h, int w_w, //
const T* b, //
T* out) {
CHECK_EQ(x_w, w_w);
CHECK_EQ(x_w, w_h);
// out shape: (x_h, w_w)
memset(out, 0, x_h * w_h * sizeof(T));
for (int r = 0; r < x_h; r++) {
for (int c = 0; c < w_h; c++) {
out[r * w_h + c] = dot(&x[r * x_w], &w[c * w_w], w_w) + b[c];
memset(out, 0, x_h * w_w * sizeof(T));
for (int i = 0; i < x_h; i++) {
for (int j = 0; j < w_w; j++) {
T tmp = static_cast<T>(0);
for (int k = 0; k < x_w; k++) {
tmp += x[i * x_w + k] * w[k * w_w + j];
}
out[i * w_w + j] = tmp + b[j];
}
}
}
......@@ -89,8 +82,8 @@ class FcCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
.Slice(param.in_num_col_dims, param.input->dims().size())
.production(),
param.w->data<T>(), // w
param.w->dims()[1], // w_w
param.w->dims()[0], // w_h
param.w->dims()[1], // w_w
param.bias->data<T>(), // b
param.output->mutable_data<T>());
}
......
......@@ -51,6 +51,6 @@ class ReluCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
REGISTER_LITE_KERNEL(relu, kX86, kFloat, kNCHW,
paddle::lite::kernels::x86::ReluCompute<float>, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
......@@ -56,4 +56,3 @@ lite_cc_test(test_softmax_op_lite SRCS softmax_op_test.cc DEPS softmax_op_lite m
lite_cc_test(test_reshape_op_lite SRCS reshape_op_test.cc DEPS reshape_op_lite memory_lite)
lite_cc_test(test_batch_norm_op_lite SRCS batch_norm_op_test.cc DEPS batch_norm_op_lite memory_lite)
lite_cc_test(test_concat_op_lite SRCS concat_op_test.cc DEPS concat_op_lite memory_lite)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/operators/batch_norm.h"
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool BatchNormOpLite::CheckShape() const { return true; }
bool BatchNormOpLite::InferShape() const { return true; }
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(batch_norm, paddle::lite::operators::BatchNormOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "paddle/fluid/lite/core/kernel.h"
#include "paddle/fluid/lite/core/op_lite.h"
#include "paddle/fluid/lite/core/scope.h"
#include "paddle/fluid/lite/operators/op_params.h"
#include "paddle/fluid/lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class BatchNormOpLite : public OpLite {
public:
BatchNormOpLite() {}
explicit BatchNormOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto bias = op_desc.Input("Bias").front();
auto mean = op_desc.Input("Mean").front();
auto scale = op_desc.Input("Scale").front();
auto variance = op_desc.Input("Variance").front();
auto out = op_desc.Output("Y").front();
auto mean_out = op_desc.Output("MeanOut").front();
auto var_out = op_desc.Output("VarianceOut").front();
auto saved_mean = op_desc.Output("SavedMean").front();
auto saved_var = op_desc.Output("SavedVariance").front();
auto *var = scope->FindVar(x);
param_.x = var->GetMutable<Tensor>();
var = scope->FindVar(bias);
param_.bias = var->GetMutable<Tensor>();
var = scope->FindVar(mean);
param_.mean = var->GetMutable<Tensor>();
var = scope->FindVar(scale);
param_.scale = var->GetMutable<Tensor>();
var = scope->FindVar(variance);
param_.var = var->GetMutable<Tensor>();
var = scope->FindVar(out);
param_.out = var->GetMutable<Tensor>();
var = scope->FindVar(mean_out);
param_.mean_out = var->GetMutable<Tensor>();
var = scope->FindVar(var_out);
param_.var_out = var->GetMutable<Tensor>();
var = scope->FindVar(saved_mean);
param_.saved_mean = var->GetMutable<Tensor>();
var = scope->FindVar(saved_var);
param_.saved_var = var->GetMutable<Tensor>();
param_.eps = op_desc.GetAttr<float>("epsilon");
return true;
}
std::string DebugString() const override { return "batch_norm"; }
private:
mutable BatchNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -30,48 +30,53 @@ class ConvOpLite : public OpLite {
public:
ConvOpLite() {}
explicit ConvOpLite(const std::string &type) : OpLite(type) {}
explicit ConvOpLite(const std::string& type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto input = op_desc.Input("Input").front();
auto filter = op_desc.Input("Filter").front();
auto out = op_desc.Output("Out").front();
param_.x = scope->FindVar(input)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(filter)->GetMutable<lite::Tensor>();
CHECK(scope->FindVar(out));
param_.output = scope->FindVar(out)->GetMutable<lite::Tensor>();
bool AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) override {
auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front();
auto Out = op_desc.Output("Output").front();
param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.strides = op_desc.GetAttr<std::vector<int>>("strides");
param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
param_.groups = op_desc.GetAttr<int>("groups");
param_.dilations = op_desc.GetAttr<std::vector<int>>("dilations");
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_var = scope->FindVar(op_desc.Input("Bias").front());
if (bias_var != nullptr) {
param_.bias =
const_cast<lite::Tensor *>(&(bias_var->Get<lite::Tensor>()));
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() != 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
param_.bias = bias_var->GetMutable<lite::Tensor>();
}
}
}
if (std::find(input_arg_names.begin(), input_arg_names.end(),
"ResidualData") != input_arg_names.end()) {
auto residual_data_var =
scope->FindVar(op_desc.Input("ResidualData").front());
if (residual_data_var != nullptr) {
param_.residualData = const_cast<lite::Tensor *>(
&(residual_data_var->Get<lite::Tensor>()));
auto res_argument = op_desc.Input("ResidualData");
if (res_argument.size() != 0) {
auto residual_data_var = scope->FindVar(res_argument.front());
if (residual_data_var != nullptr) {
param_.residualData = residual_data_var->GetMutable<lite::Tensor>();
}
}
}
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "conv2d"; }
......
......@@ -32,12 +32,11 @@ bool ReluOp::InferShape() const {
bool ReluOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.input = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("Input").front())->Get<lite::Tensor>());
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.output =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.input);
CHECK(param_.output);
kernel_->SetParam(param_);
return true;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册