提交 ec38041b 编写于 作者: C Chunwei

Merge branch 'chunwei/clear-scale' into 'incubate/lite'

add scale eliminate pass

See merge request inference/paddlelite!18
...@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) { ...@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out)); ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out));
std::vector<std::string> tensors_with_order({ std::vector<std::string> tensors_with_order({
"a", "fc_0.w_0", "fc_0.tmp_0", "scale_0.tmp_0", "a", "fc_0.w_0", "scale_0.tmp_0",
}); });
for (const auto& tensor_name : tensors_with_order) { for (const auto& tensor_name : tensors_with_order) {
......
...@@ -13,7 +13,7 @@ ...@@ -13,7 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h" #include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono> #include <chrono> // NOLINT
#include "paddle/fluid/lite/core/mir/use_passes.h" #include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
......
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/api/lite_api_test_helper.h" #include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include <vector>
DEFINE_string(model_dir, "", ""); DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", ""); DEFINE_string(optimized_model, "", "");
......
...@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li ...@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite) lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite) lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator) lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
...@@ -5,21 +5,25 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir ...@@ -5,21 +5,25 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager) cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion) add_subdirectory(fusion)
add_subdirectory(elimination)
cc_library(mir_passes cc_library(mir_passes
SRCS fc_fuse_pass.cc SRCS
conv_elementwise_add_activation_fuse_pass.cc fusion/fc_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc fusion/conv_elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc fusion/conv_bn_fuse_pass.cc
static_kernel_pick_pass.cc fusion/elementwise_add_activation_fuse_pass.cc
variable_place_inference_pass.cc elimination/identity_scale_eliminate_pass.cc
type_target_transform_pass.cc static_kernel_pick_pass.cc
io_copy_kernel_pick_pass.cc variable_place_inference_pass.cc
graph_visualize_pass.cc type_target_transform_pass.cc
generate_program_pass.cc io_copy_kernel_pick_pass.cc
argument_type_display_pass.cc graph_visualize_pass.cc
demo_pass.cc generate_program_pass.cc
runtime_context_assign_pass.cc argument_type_display_pass.cc
DEPS mir_pass types_lite context_lite ${mir_fusers}) demo_pass.cc
runtime_context_assign_pass.cc
DEPS mir_pass types_lite context_lite ${mir_fusers})
#cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS #cc_test(test_ssa_graph SRCS ssa_graph_test.cc DEPS
#mir_ssa_graph scope_lite op_lite #mir_ssa_graph scope_lite op_lite
...@@ -73,7 +77,7 @@ message(STATUS "----> Ops lite: ${ops_lite}") ...@@ -73,7 +77,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
message(STATUS "----> Host kernels: ${host_kernels}") message(STATUS "----> Host kernels: ${host_kernels}")
message(STATUS "----> X86 kernels: ${x86_kernels}") message(STATUS "----> X86 kernels: ${x86_kernels}")
lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc lite_cc_test(test_lite_fc_fuse SRCS fusion/fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels} ${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model
...@@ -83,11 +87,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz ...@@ -83,11 +87,11 @@ lite_download_and_uncompress(${LITE_MODEL_DIR} ${LITE_URL} "lite_fc_model.tar.gz
add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz) add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_activation_fuse lite_cc_test(test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc SRCS fusion/conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels}) ${ops_lite} ${host_kernels} ${x86_kernels})
lite_cc_test(test_lite_elementwise_add_activation_fuse lite_cc_test(test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc SRCS fusion/elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels}) ${ops_lite} ${host_kernels} ${x86_kernels})
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_cc_test(test_identity_scale_eliminate_pass_lite
SRCS identity_scale_eliminate_pass_test.cc
DEPS mir_passes program_lite proto_desc cpp_op_desc_lite
${ops_lite}
)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace {
class Eliminator : public FuseBase {
public:
void BuildPattern() override {
auto* pre_op = OpNode("preop"); // the previous op's output need update
// TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("scale", "X");
auto* scale_op = OpNode("scale", "scale")
->assert_op_attr<float>("scale", 1.)
->assert_op_attr<float>("bias", 0.);
auto* out = VarNode("out")->assert_is_op_output("scale", "Out");
*pre_op >> *x >> *scale_op >> *out;
// The pre_op will be eliminated, and a new output-updated op will insert.
x->AsIntermediate(); // x is pre_op's output, need to update
}
private:
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto& pre_op = matched.at("preop")->AsStmt();
auto op_info = *pre_op.op_info();
op_info.UpdateAllOutputs(matched.at("x")->AsArg().name,
matched.at("out")->AsArg().name);
pre_op.ResetOp(op_info, graph->valid_places());
GraphSafeRemoveNodes(graph, {matched.at("scale")});
IR_NODE_LINK_TO(matched.at("preop"), matched.at("out"));
}
};
} // namespace
class IdentityScaleEliminatePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
Eliminator eliminator;
eliminator(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
// Op list:
// (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch)
// After pass
// (x)->feed->(scale_out)->fetch->(fetch)
auto* main_block = program_desc->MutableBlock(0);
auto* feed_op = main_block->AppendOp();
auto* scale_op = main_block->AppendOp();
auto* fetch_op = main_block->AppendOp();
main_block->Var("x");
main_block->Var("feed");
main_block->Var("scale_out");
main_block->Var("fetch_out");
scope->Var("x")->GetMutable<lite::Tensor>();
scope->Var("feed")->GetMutable<lite::Tensor>();
scope->Var("scale_out")->GetMutable<lite::Tensor>();
scope->Var("fetch_out")->GetMutable<lite::Tensor>();
feed_op->SetType("feed");
feed_op->SetInput("X", {"x"});
feed_op->SetAttr("col", 1);
feed_op->SetOutput("Out", {"feed"});
scale_op->SetType("scale");
scale_op->SetInput("X", {"feed"});
scale_op->SetOutput("Out", {"scale_out"});
scale_op->SetAttr("scale", 1.f);
scale_op->SetAttr("bias", 0.f);
scale_op->SetAttr("bias_after_scale", true);
fetch_op->SetType("fetch");
fetch_op->SetInput("X", {"scale_out"});
fetch_op->SetOutput("Out", {"fetch"});
fetch_op->SetAttr("col", 1);
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
LOG(INFO) << Visualize(graph.get());
return graph;
}
TEST(identity_test, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
auto pass = PassManager::Global().LookUp("identity_scale_eliminate_pass");
ASSERT_TRUE(pass);
pass->Apply(graph);
ASSERT_EQ(graph->nodes().size(), num_nodes - 2UL);
}
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(feed)
USE_LITE_OP(fetch)
USE_LITE_OP(scale)
USE_MIR_PASS(identity_scale_eliminate_pass)
...@@ -11,7 +11,7 @@ cc_library(fuse_elementwise_add_activation ...@@ -11,7 +11,7 @@ cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
fuse_conv_elementwise_add_activation fuse_conv_elementwise_add_activation
fuse_conv_bn fuse_conv_bn
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h" #include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
......
...@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() { ...@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add"); auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv = matched.at("conv2d")->stmt()->op; auto conv = matched.at("conv2d")->stmt()->op();
auto* scope = conv->scope(); auto* scope = conv->scope();
auto& valid_places = conv->valid_places(); auto& valid_places = conv->valid_places();
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h" #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
......
...@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode( ...@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph* graph, const key2nodes_t& matched) { SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_); auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op; auto conv_old = matched.at("conv2d")->stmt()->op();
auto* scope = conv_old->scope(); auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places(); auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope); conv_op->Attach(op_desc, scope);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "conv_elementwise_add_relu_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1");
main_block->Var("input_2");
main_block->Var("filter_1");
main_block->Var("filter_2");
main_block->Var("conv2d_1_out");
main_block->Var("conv2d_2_out");
main_block->Var("bias_1");
main_block->Var("add_1_out");
main_block->Var("add_2_out");
main_block->Var("relu_1_out");
main_block->Var("out");
scope->Var("input_1")->GetMutable<lite::Tensor>();
scope->Var("input_2")->GetMutable<lite::Tensor>();
scope->Var("filter_1")->GetMutable<lite::Tensor>();
scope->Var("filter_2")->GetMutable<lite::Tensor>();
scope->Var("conv2d_1_out")->GetMutable<lite::Tensor>();
scope->Var("conv2d_2_out")->GetMutable<lite::Tensor>();
scope->Var("bias_1")->GetMutable<lite::Tensor>();
scope->Var("add_1_out")->GetMutable<lite::Tensor>();
scope->Var("add_2_out")->GetMutable<lite::Tensor>();
scope->Var("relu_1_out")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
conv2d_1->SetType("conv2d");
conv2d_1->SetInput("Input", {"input_1"});
conv2d_1->SetInput("Filter", {"filter_1"});
conv2d_1->SetOutput("Output", {"conv2d_1_out"});
conv2d_1->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_1->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_1->SetAttr("groups", 1);
conv2d_1->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_1->SetAttr("fuse_relu", false);
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"conv2d_1_out"});
add_1->SetInput("Y", {"bias_1"});
add_1->SetOutput("Out", {"add_1_out"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
conv2d_2->SetInput("Input", {"input_2"});
conv2d_2->SetInput("Filter", {"filter_2"});
conv2d_2->SetOutput("Output", {"conv2d_2_out"});
conv2d_2->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_2->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_2->SetAttr("groups", 1);
conv2d_2->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_2->SetAttr("fuse_relu", false);
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"conv2d_2_out"});
add_2->SetInput("Y", {"relu_1_out"});
add_2->SetOutput("Out", {"add_2_out"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/);
Visualize(graph.get());
}
TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(relu);
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h" #include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
......
...@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph, ...@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto op = auto op =
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation"); LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
auto old_op = matched.at("add")->stmt()->op; auto old_op = matched.at("add")->stmt()->op();
auto* scope = old_op->scope(); auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places(); auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope); op->Attach(op_desc, scope);
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" #include "paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h"
#include <memory> #include <memory>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h" #include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h" #include "fc_fuse_pass.h"
#include <gflags/gflags.h> #include <gflags/gflags.h>
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <vector> #include <vector>
......
...@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() { ...@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) { void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc"); auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op; auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope(); auto* scope = mul->scope();
auto& valid_places = mul->valid_places(); auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope); fc_op->Attach(op_desc, scope);
......
...@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (item->IsStmt()) { if (item->IsStmt()) {
auto& stmt = item->AsStmt(); auto& stmt = item->AsStmt();
VLOG(4) << stmt; VLOG(4) << stmt;
insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front())); insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
} }
} }
} }
......
...@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) { ...@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if (node.IsArg()) { if (node.IsArg()) {
key = node.AsArg().name; key = node.AsArg().name;
} else { } else {
key = node.AsStmt().op_type + std::to_string(id++); key = node.AsStmt().op_type() + std::to_string(id++);
} }
if (node.IsStmt()) { if (node.IsStmt()) {
......
...@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass { ...@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
for (auto& node : graph->mutable_nodes()) { for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& inst = node.AsStmt(); auto& inst = node.AsStmt();
if (inst.op_type != "io_copy") continue; if (inst.op_type() != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel"; LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsStmt().valid_kernels; auto& kernels = node.AsStmt().kernels();
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op"; CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArg().type; const auto* inty = node.inlinks.front()->AsArg().type;
const auto* outy = node.outlinks.front()->AsArg().type; const auto* outy = node.outlinks.front()->AsArg().type;
......
...@@ -13,3 +13,62 @@ ...@@ -13,3 +13,62 @@
// limitations under the License. // limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h" #include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
const OpInfo *mir::Node::Stmt::op_info() const {
CHECK(op_);
return op_->op_info();
}
Place mir::Node::Stmt::place() const {
CHECK(!valid_kernels_.empty());
return valid_kernels_.front()->place();
}
KernelBase &mir::Node::Stmt::picked_kernel() {
CHECK(!valid_kernels_.empty()) << "no kernel for " << op_type();
return *valid_kernels_.front();
}
OpInfo *mir::Node::Stmt::mutable_op_info() {
CHECK(op_);
return op_->mutable_op_info();
}
void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc,
const std::vector<Place> &valid_places,
lite::Scope *scope) {
CHECK((op_ && op_->scope()) || scope) << "Either scope should be set";
lite::Scope *the_scope = scope ? scope : op_->scope();
op_->Attach(op_desc, the_scope);
// Recreate the kernels with the latest OpInfo.
valid_kernels_.clear();
if (!op_ || op_->op_info()->Type() != op_desc.Type()) {
op_ = LiteOpRegistry::Global().Create(op_desc.Type());
CHECK(op_) << "No op found for " << op_desc.Type();
}
valid_kernels_ = op_->CreateKernels(valid_places);
}
std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) {
os << "Statement " << other.op_type() << " " << other.place();
return os;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) {
auto &x = AsArg();
x.name = name;
x.id = id;
return x;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name) {
auto &x = AsArg();
x.name = name;
return x;
}
} // namespace lite
} // namespace paddle
...@@ -41,32 +41,40 @@ class Node { ...@@ -41,32 +41,40 @@ class Node {
kUnk, kUnk,
}; };
struct Stmt { class Stmt {
std::string op_type;
// The kernel instances this Statement contains. // The kernel instances this Statement contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels; std::vector<std::unique_ptr<KernelBase>> valid_kernels_;
// TODO(Superjomn) make this a shared_ptr for resource safety. // TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape std::shared_ptr<OpLite> op_; // we hold op to run InferShape
const OpInfo* op_info() { public:
CHECK(op); // Refresh the operator and kernels with the latest OpInfo.
return op->op_info(); void ResetOp(const cpp::OpDesc& op_desc,
} const std::vector<Place>& valid_places,
lite::Scope* scope = nullptr);
Place place() const { std::string op_type() const { return op_info()->Type(); }
CHECK(!valid_kernels.empty()); const OpInfo* op_info() const;
return valid_kernels.front()->place(); OpInfo* mutable_op_info();
}
KernelBase& picked_kernel() { void SetKernels(std::vector<std::unique_ptr<KernelBase>>&& kernels) {
CHECK(!valid_kernels.empty()) << "no kernel for " << op_type; valid_kernels_ = std::move(kernels);
return *valid_kernels.front();
} }
std::vector<std::unique_ptr<KernelBase>>& kernels() {
friend std::ostream& operator<<(std::ostream& os, const Stmt& other) { return valid_kernels_;
os << "Statement " << other.op_type << " " << other.place();
return os;
} }
void SetOp(const std::shared_ptr<OpLite>& op) { op_ = op; }
const std::shared_ptr<OpLite> op() const { return op_; }
Place place() const;
KernelBase& picked_kernel();
friend std::ostream& operator<<(std::ostream& os, const Stmt& other);
// Description.
std::string desc;
}; };
struct Arg { struct Arg {
...@@ -78,26 +86,16 @@ class Node { ...@@ -78,26 +86,16 @@ class Node {
bool is_weight{false}; bool is_weight{false};
}; };
Arg& AsArg(const std::string& name, int id) { Arg& AsArg(const std::string& name, int id);
auto& x = AsArg();
x.name = name;
x.id = id;
return x;
}
Arg& AsArg(const std::string& name) { Arg& AsArg(const std::string& name);
auto& x = AsArg();
x.name = name;
return x;
}
Stmt& AsStmt(const std::string& op_type, Stmt& AsStmt(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels, std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op) { const std::shared_ptr<OpLite>& op) {
auto& x = AsStmt(); auto& x = AsStmt();
x.op_type = op_type; x.SetOp(op);
x.op = op; x.SetKernels(std::move(kernels));
x.valid_kernels = std::move(kernels);
return x; return x;
} }
...@@ -142,7 +140,7 @@ class Node { ...@@ -142,7 +140,7 @@ class Node {
} }
if (other.IsStmt()) { if (other.IsStmt()) {
auto& arg = other.AsStmt(); auto& arg = other.AsStmt();
os << "Statement " << arg.op_type; os << "Statement " << arg.op_type();
} }
return os; return os;
} }
......
...@@ -139,14 +139,13 @@ struct PMNode { ...@@ -139,14 +139,13 @@ struct PMNode {
template <typename T> template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) { PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) { asserts_.push_back([=](const Node* x) {
if (x && x->IsStmt()) { if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info(); auto* op_info = x->stmt()->op_info();
return op_info->HasAttr(attr_name) && return op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr; op_info->GetAttr<T>(attr_name) == attr;
} else {
return false;
} }
return false;
}); });
return this; return this;
} }
......
...@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) { ...@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
} }
} }
LOG(INFO) << "keys: " << key2nodes_.size();
std::unordered_set<const Node *> nodes2rm; std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) { for (auto &matched : key2nodes_) {
for (const auto &key : keys) { for (const auto &key : keys) {
......
...@@ -49,7 +49,13 @@ class FuseBase { ...@@ -49,7 +49,13 @@ class FuseBase {
virtual void BuildPattern() = 0; virtual void BuildPattern() = 0;
// Generate an operator desc with a matched subgraph. // Generate an operator desc with a matched subgraph.
virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0; virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
return cpp::OpDesc();
}
PMNode* OpNode(const std::string& key) {
return GetOrCreateNode(key)->assert_is_op();
}
PMNode* OpNode(const std::string& key, const std::string& op_type); PMNode* OpNode(const std::string& key, const std::string& op_type);
......
...@@ -52,7 +52,7 @@ class FcFuser : public FuseBase { ...@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override { void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto op_desc = GenOpDesc(matched); auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc"); auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op; auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope(); auto* scope = mul->scope();
auto& valid_places = mul->valid_places(); auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope); fc_op->Attach(op_desc, scope);
...@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc, ...@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block->Var("w"); main_block->Var("w");
main_block->Var("out"); main_block->Var("out");
scope->Var("w")->GetMutable<lite::Tensor>(); scope->Var("x")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>(); scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>(); scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>(); scope->Var("w")->GetMutable<lite::Tensor>();
......
...@@ -23,19 +23,19 @@ namespace mir { ...@@ -23,19 +23,19 @@ namespace mir {
void BuildGraph(SSAGraph* g) { void BuildGraph(SSAGraph* g) {
g->mutable_nodes().emplace_back(); g->mutable_nodes().emplace_back();
Node& o1 = g->mutable_nodes().back(); Node& o1 = g->mutable_nodes().back();
o1.AsStmt().op_type = "op1"; o1.AsStmt().desc = "op1";
g->mutable_nodes().emplace_back(); g->mutable_nodes().emplace_back();
Node& o2 = g->mutable_nodes().back(); Node& o2 = g->mutable_nodes().back();
o2.AsStmt().op_type = "op2"; o2.AsStmt().desc = "op2";
g->mutable_nodes().emplace_back(); g->mutable_nodes().emplace_back();
Node& o3 = g->mutable_nodes().back(); Node& o3 = g->mutable_nodes().back();
o3.AsStmt().op_type = "op3"; o3.AsStmt().desc = "op3";
g->mutable_nodes().emplace_back(); g->mutable_nodes().emplace_back();
Node& o4 = g->mutable_nodes().back(); Node& o4 = g->mutable_nodes().back();
o4.AsStmt().op_type = "op4"; o4.AsStmt().desc = "op4";
g->mutable_nodes().emplace_back(); g->mutable_nodes().emplace_back();
Node& o5 = g->mutable_nodes().back(); Node& o5 = g->mutable_nodes().back();
o5.AsStmt().op_type = "op5"; o5.AsStmt().desc = "op5";
g->mutable_nodes().emplace_back(); g->mutable_nodes().emplace_back();
Node& v1 = g->mutable_nodes().back(); Node& v1 = g->mutable_nodes().back();
v1.AsArg("var1"); v1.AsArg("var1");
...@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) { ...@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
// v2 -> o3(a node named o3) // v2 -> o3(a node named o3)
auto* o2 = x.pattern_.NewNode([](const Node* node) { auto* o2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape. // The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op2"; return node && node->IsStmt() && node->stmt()->desc == "op2";
}); });
auto* o3 = x.pattern_.NewNode([](const Node* node) { auto* o3 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape. // The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op3"; return node && node->IsStmt() && node->stmt()->desc == "op3";
}); });
auto* v2 = x.pattern_.NewNode([](const Node* node) { auto* v2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape. // The teller can be any condition, such as op type, or variable's shape.
...@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) { ...@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
// op -> var // op -> var
auto* any_op = x.mutable_pattern()->NewNode( auto* any_op = x.mutable_pattern()->NewNode(
[](const Node* node) { [](const Node* node) {
return node->IsStmt() && (node->stmt()->op_type == "op2" || return node->IsStmt() &&
node->stmt()->op_type == "op3"); (node->stmt()->desc == "op2" || node->stmt()->desc == "op3");
}, },
"OP0"); "OP0");
auto* any_var = auto* any_var =
...@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) { ...@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
int count = 0; int count = 0;
PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s, PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s,
SSAGraph* g) { SSAGraph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> " LOG(INFO) << "Detect " << s.at(any_op)->stmt()->desc << " -> "
<< s.at(any_var)->arg()->name << " -> " << s.at(any_var)->arg()->name << " -> "
<< s.at(any_op1)->stmt()->op_type; << s.at(any_op1)->stmt()->desc;
count++; count++;
}; };
...@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) { ...@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
PatternMatcher matcher; PatternMatcher matcher;
auto* op2 = matcher.mutable_pattern()->NewNode( auto* op2 = matcher.mutable_pattern()->NewNode(
[](const Node* x) { [](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op2"; return x && x->IsStmt() && x->stmt()->desc == "op2";
}, },
"op2"); "op2");
auto* op3 = matcher.mutable_pattern()->NewNode( auto* op3 = matcher.mutable_pattern()->NewNode(
[](const Node* x) { [](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op3"; return x && x->IsStmt() && x->stmt()->desc == "op3";
}, },
"op3"); "op3");
auto* v2 = matcher.mutable_pattern() auto* v2 = matcher.mutable_pattern()
......
...@@ -65,6 +65,10 @@ class SSAGraph : GraphBase { ...@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
Node *GraphCreateInstructNode(const std::shared_ptr<OpLite> &op, Node *GraphCreateInstructNode(const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places); const std::vector<Place> &valid_places);
// Device related attributes
const std::vector<Place> &valid_places() const { return valid_places_; }
void SetValidPlaces(const std::vector<Place> &x) { valid_places_ = x; }
private: private:
mir::Node *Argument(const std::string &name); mir::Node *Argument(const std::string &name);
// Check the bidirectional connection. // Check the bidirectional connection.
...@@ -89,6 +93,7 @@ class SSAGraph : GraphBase { ...@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
private: private:
std::list<mir::Node> node_storage_; std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_; std::map<std::string, mir::Node *> arguments_;
std::vector<Place> valid_places_;
}; };
// Remove the link between a -> b. // Remove the link between a -> b.
......
...@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node.IsStmt()) continue; if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt(); auto& instruct = node.AsStmt();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored; std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.valid_kernels.empty()) << "No kernels found for " CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type; << instruct.op_type();
for (auto&& kernel : instruct.valid_kernels) { for (auto&& kernel : instruct.kernels()) {
size_t score = KernelGrade(*kernel); size_t score = KernelGrade(*kernel);
scored.emplace_back(score, std::move(kernel)); scored.emplace_back(score, std::move(kernel));
} }
...@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Move kernel back // Move kernel back
// Just keep a single best kernel. // Just keep a single best kernel.
// TODO(Superjomn) reconsider this. // TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear(); instruct.kernels().clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second)); instruct.kernels().emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.valid_kernels.front()->name(); VLOG(2) << "pick " << instruct.kernels().front()->name();
} }
} }
......
...@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node, ...@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK(in->AsArg().type); CHECK(in->AsArg().type);
if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) { if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op->DebugString() << " " << " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type; << *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist. // Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node, AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
...@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed"; CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op); // CHECK(io_copy_op);
// Create the new var manually. // Create the new var manually.
inst_node->AsStmt().op->scope()->Var(io_copy_output_name); inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction. // Create IoCopy Instruction.
cpp::OpDesc op_desc; cpp::OpDesc op_desc;
...@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc.SetInput("Input", {in->AsArg().name}); op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name}); op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope()); io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places); auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op); io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
...@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst( ...@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink(io_copy_output_arg, inst_node); DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information // reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name, UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name,
io_copy_output_name); io_copy_output_name);
inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(), inst_node->AsStmt().ResetOp(*inst_node->AsStmt().op_info(),
inst_node->AsStmt().op->scope()); graph->valid_places());
std::string tmp; std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) { if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp; CHECK(false) << "get old a " << tmp;
} }
for (auto& kernel : inst_node->AsStmt().valid_kernels) { for (auto& kernel : inst_node->AsStmt().kernels()) {
inst_node->AsStmt().op->AttachKernel(kernel.get()); inst_node->AsStmt().op()->AttachKernel(kernel.get());
} }
graph->CheckValid(); graph->CheckValid();
......
...@@ -24,9 +24,11 @@ USE_MIR_PASS(generate_program_pass); ...@@ -24,9 +24,11 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass); USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass); USE_MIR_PASS(argument_type_display_pass);
#endif #endif
USE_MIR_PASS(runtime_context_assign_pass); USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass); USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze); USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_fc_fuse_pass); USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
...@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
for (const auto& v : graph->inputs()) { for (const auto& v : graph->inputs()) {
// the feed op might in the inputs // the feed op might in the inputs
if (v->IsStmt()) { if (v->IsStmt()) {
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type; LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type();
continue; continue;
} }
} }
...@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
for (auto& x : graph->StmtTopologicalOrder()) { for (auto& x : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt(); auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference. // The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue; if (inst.op_type() == "io_copy") continue;
// LOG(INFO) << "- inferencing type " << // LOG(INFO) << "- inferencing type " <<
// deal with inputs // deal with inputs
VLOG(4) << "inferencing op " << inst.op_type; VLOG(4) << "Infering op " << inst.op_info()->Repr();
// TODO(zhaolong): Add check if the node's name in op's arguments. // TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&]( auto get_argname = [&](
...@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass { ...@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
} }
} }
VLOG(3) << "inst " << inst.op_info()->Repr();
for (auto* x_out : x->outlinks) { for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name; std::string node_name = x_out->AsArg().name;
std::string arg_name = std::string arg_name =
get_argname(node_name, inst.op_info()->outputs()); get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node " CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name; << node_name << " in Inst "
<< inst.op_type();
VLOG(3) << "-- output arg_name " << arg_name; VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name); auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
if (!x_out->AsArg().type) { if (!x_out->AsArg().type) {
......
...@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels( ...@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets.insert(place.target); targets.insert(place.target);
} }
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels"; VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels; return kernels;
} }
...@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) { ...@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope_ = scope; scope_ = scope;
op_info_.reset( op_info_.reset(
new OpInfo(opdesc)); // Force clean the out-of-date infomation. new OpInfo(opdesc)); // Force clean the out-of-date infomation.
return AttachImpl(opdesc, scope); return AttachImpl(*op_info(), scope);
} }
const Tensor *OpLite::GetTensor(lite::Scope *scope, const Tensor *OpLite::GetTensor(lite::Scope *scope,
......
...@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc { ...@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
} }
return false; return false;
} }
void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) {
for (auto &var : item.second) {
if (var == from) var = to;
}
}
}
void UpdateAllOutputs(const std::string &from, const std::string &to) {
for (auto &item : outputs_) {
for (auto &var : item.second) {
if (var == from) var = to;
}
}
}
}; };
} // namespace lite } // namespace lite
......
...@@ -43,6 +43,8 @@ class Optimizer { ...@@ -43,6 +43,8 @@ class Optimizer {
CHECK(!graph_) << "duplicate optimize found"; CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph); graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places); graph_->Build(program, valid_places);
graph_->SetValidPlaces(valid_places);
SpecifyKernelPickTactic(kernel_pick_factor); SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass(); InitTargetTypeTransformPass();
...@@ -51,6 +53,7 @@ class Optimizer { ...@@ -51,6 +53,7 @@ class Optimizer {
"lite_conv_bn_fuse_pass", // "lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_activation_fuse_pass", // "lite_conv_elementwise_add_activation_fuse_pass", //
"lite_fc_fuse_pass", // "lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK #ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", // "lite_elementwise_add_activation_fuse_pass", //
#endif #endif
......
...@@ -140,7 +140,7 @@ class RuntimeProgram { ...@@ -140,7 +140,7 @@ class RuntimeProgram {
void Run() { void Run() {
for (auto& inst : instructions_) { for (auto& inst : instructions_) {
VLOG(4) << ">> Running kernel: " << inst; VLOG(3) << ">> Running kernel: " << inst.op()->op_info()->Repr();
inst.Run(); inst.Run();
} }
} }
......
...@@ -191,7 +191,6 @@ class TensorBase { ...@@ -191,7 +191,6 @@ class TensorBase {
template <typename TensorT> template <typename TensorT>
bool TensorCompareWith(const TensorT &a, const TensorT &b) { bool TensorCompareWith(const TensorT &a, const TensorT &b) {
if (a.dims() != b.dims()) return false; if (a.dims() != b.dims()) return false;
LOG(INFO) << "data_size: " << a.data_size();
if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false; if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false;
return true; return true;
} }
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h" #include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
......
...@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* x = &param.x->raw_tensor(); auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor(); auto* y = &param.y->raw_tensor();
const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix( Tensor x_matrix, y_matrix;
*x, param.x_num_col_dims)
: *x; if (x->dims().size() > 2) {
const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix( x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
*y, param.y_num_col_dims) } else {
: *y; x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* z = &param.output->raw_tensor(); auto* z = &param.output->raw_tensor();
auto z_dim = z->dims(); auto z_dim = z->dims();
...@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> { ...@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* x = &param.x->raw_tensor(); auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor(); auto* y = &param.y->raw_tensor();
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, param.x_num_col_dims) Tensor x_matrix, y_matrix;
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2 if (x->dims().size() > 2) {
? framework::ReshapeToMatrix(*y, param.y_num_col_dims) x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
: static_cast<const Tensor&>(*y); } else {
x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* dout = &param.output_grad->raw_tensor(); auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat; Tensor dout_mat;
......
...@@ -15,6 +15,8 @@ ...@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/mul_compute.h" #include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h> #include <gtest/gtest.h>
#include <iostream> #include <iostream>
#include <memory>
#include <utility>
#include <vector> #include <vector>
#include "paddle/fluid/lite/core/op_registry.h" #include "paddle/fluid/lite/core/op_registry.h"
......
...@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE) ...@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
endif() endif()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite) cc_library(compatible_pb_lite SRCS compatible_pb.cc
DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite)
lite_cc_library(model_parser_lite SRCS model_parser.cc DEPS lite_cc_library(model_parser_lite SRCS model_parser.cc DEPS
variable_lite scope_lite ${tensor_lite} scope_lite variable_lite scope_lite ${tensor_lite} scope_lite
......
cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite) cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite)
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#pragma once #pragma once
#include <map> #include <map>
#include <sstream>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -79,6 +80,27 @@ class OpDescAPI { ...@@ -79,6 +80,27 @@ class OpDescAPI {
/// Get an attribute. /// Get an attribute.
template <typename T> template <typename T>
T GetAttr(const std::string& name) const; T GetAttr(const std::string& name) const;
std::string Repr() const {
std::stringstream ss;
ss << Type();
ss << "(";
for (auto& arg : InputArgumentNames()) {
ss << arg << ":";
for (auto val : Input(arg)) {
ss << val << " ";
}
}
ss << ") -> (";
for (auto& arg : OutputArgumentNames()) {
ss << arg << ":";
for (auto val : Output(arg)) {
ss << val << " ";
}
}
ss << ")";
return ss.str();
}
}; };
} // namespace lite } // namespace lite
......
...@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI { ...@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
template <typename T> template <typename T>
T GetAttr(const std::string &name) const; T GetAttr(const std::string &name) const;
std::string DebugString() const { return desc_.DebugString(); }
private: private:
std::vector<std::string> GetArguments( std::vector<std::string> GetArguments(
const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var> const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var>
......
...@@ -38,15 +38,19 @@ class MulOpLite : public OpLite { ...@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); } void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one. // TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override { bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
auto input = op_desc.Input("X").front(); auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front(); auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front(); auto out = op_desc.Output("Out").front();
auto *var = scope->FindVar(input); auto *var = scope->FindVar(input);
CHECK(var); CHECK(var);
param_.x = var->GetMutable<Tensor>(); param_.x = &var->Get<Tensor>();
var = scope->FindVar(W); var = scope->FindVar(W);
CHECK(var) << "no var called " << W; CHECK(var) << "no var called " << W;
param_.y = var->GetMutable<Tensor>(); param_.y = &var->Get<Tensor>();
var = scope->FindVar(out); var = scope->FindVar(out);
CHECK(var) << "no var called " << out; CHECK(var) << "no var called " << out;
param_.output = var->GetMutable<Tensor>(); param_.output = var->GetMutable<Tensor>();
......
...@@ -67,8 +67,8 @@ struct ReluParam { ...@@ -67,8 +67,8 @@ struct ReluParam {
// For Mul Op // For Mul Op
struct MulParam { struct MulParam {
lite::Tensor* x{}; const lite::Tensor* x{};
lite::Tensor* y{}; const lite::Tensor* y{};
lite::Tensor* output{}; lite::Tensor* output{};
int x_num_col_dims{1}; int x_num_col_dims{1};
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include <typeinfo> #include <typeinfo>
#include <utility> #include <utility>
#include "paddle/fluid/lite/utils/cp_logging.h" #include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/string.h"
// This is an equivalent implementation of boost::any. We implement this to // This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small. // avoid including the whole boost library and keep the inference library small.
...@@ -116,9 +117,9 @@ struct variant { ...@@ -116,9 +117,9 @@ struct variant {
if (type_id == typeid(T).hash_code()) if (type_id == typeid(T).hash_code())
return *reinterpret_cast<const T*>(&data); return *reinterpret_cast<const T*>(&data);
else else
throw std::invalid_argument("unmatched type"); throw std::invalid_argument(
// LOG(FATAL) << "unmatched type get, should be " << type_id << " but get " string_format("unmatched type, store as %d, but want to get %s",
// << typeid(T).name(); type_id, typeid(T).name()));
return *reinterpret_cast<const T*>(&data); return *reinterpret_cast<const T*>(&data);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册