提交 4fe5c8aa 编写于 作者: N nhzlx

Merge branch 'incubate/lite' of http://10.87.145.36/inference/paddlelite into xzl/incubate/lite

fix comments
......@@ -44,9 +44,9 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
# check arch abi
if(NOT DEFINED ARM_TARGET_LANG)
set(ARM_TARGET_LANG "clang" CACHE STRING "Choose ARM Target Language")
set(ARM_TARGET_LANG "gcc" CACHE STRING "Choose ARM Target Language")
endif()
set(ARM_TARGET_LANG_LIST "gcc" "clang")
set(ARM_TARGET_LANG_LIST "gcc" "clang" "")
set_property(CACHE ARM_TARGET_LANG PROPERTY STRINGS ${ARM_TARGET_LANG_LIST})
if (NOT ARM_TARGET_LANG IN_LIST ARM_TARGET_LANG_LIST)
message(FATAL_ERROR "ARM_TARGET_LANG must be in one of ${ARM_TARGET_LANG_LIST}")
......
......@@ -82,7 +82,7 @@ TEST(CXXApi_LightApi, save_and_load_model) {
ASSERT_TRUE(TensorCompareWith(*cxx_out, *light_out));
std::vector<std::string> tensors_with_order({
"a", "fc_0.w_0", "fc_0.tmp_0", "scale_0.tmp_0",
"a", "fc_0.w_0", "scale_0.tmp_0",
});
for (const auto& tensor_name : tensors_with_order) {
......
......@@ -13,7 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/cxx_api.h"
#include <chrono>
#include <chrono> // NOLINT
#include "paddle/fluid/lite/core/mir/use_passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
......
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "paddle/fluid/lite/api/lite_api_test_helper.h"
#include <vector>
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
......
......@@ -59,4 +59,3 @@ lite_cc_test(test_type_system SRCS type_system_test.cc DEPS type_system utils_li
lite_cc_test(test_types_lite SRCS types_test.cc DEPS types_lite)
lite_cc_test(test_memory_lite SRCS memory_test.cc DEPS memory_lite)
lite_cc_test(test_context_lite SRCS context_test.cc DEPS context_lite X86_DEPS operator)
......@@ -5,12 +5,16 @@ cc_library(mir_pass_manager SRCS pass_manager.cc DEPS mir_pass mir_ssa_graph mir
cc_library(mir_pass_registry SRCS pass_registry.cc DEPS mir_pass_manager)
add_subdirectory(fusion)
add_subdirectory(elimination)
cc_library(mir_passes
SRCS fc_fuse_pass.cc
conv_elementwise_add_activation_fuse_pass.cc
elementwise_add_activation_fuse_pass.cc
conv_bn_fuse_pass.cc
quant_dequant_fuse_pass.cc
SRCS
fusion/fc_fuse_pass.cc
fusion/conv_elementwise_add_activation_fuse_pass.cc
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
static_kernel_pick_pass.cc
variable_place_inference_pass.cc
type_target_transform_pass.cc
......@@ -74,7 +78,7 @@ message(STATUS "----> Ops lite: ${ops_lite}")
message(STATUS "----> Host kernels: ${host_kernels}")
message(STATUS "----> X86 kernels: ${x86_kernels}")
lite_cc_test(test_lite_fc_fuse SRCS fc_fuse_pass_test.cc
lite_cc_test(test_lite_fc_fuse SRCS fusion/fc_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels} ${arm_kernels}
ARGS --model_dir=${LITE_MODEL_DIR}/lite_fc_model
......@@ -85,10 +89,10 @@ add_dependencies(test_lite_fc_fuse extern_lite_download_lite_fc_model_tar_gz)
lite_cc_test(test_lite_conv_elementwise_add_activation_fuse
SRCS conv_elementwise_add_activation_fuse_pass_test.cc
SRCS fusion/conv_elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
lite_cc_test(test_lite_elementwise_add_activation_fuse
SRCS elementwise_add_activation_fuse_pass_test.cc
SRCS fusion/elementwise_add_activation_fuse_pass_test.cc
DEPS cxx_api_lite mir_passes
${ops_lite} ${host_kernels} ${x86_kernels})
if (NOT LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
lite_cc_test(test_identity_scale_eliminate_pass_lite
SRCS identity_scale_eliminate_pass_test.cc
DEPS mir_passes program_lite proto_desc cpp_op_desc_lite
${ops_lite}
)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace {
class Eliminator : public FuseBase {
public:
void BuildPattern() override {
auto* pre_op = OpNode("preop"); // the previous op's output need update
// TODO(Superjomn) check has only one output
auto* x = VarNode("x")->assert_is_op_input("scale", "X");
auto* scale_op = OpNode("scale", "scale")
->assert_op_attr<float>("scale", 1.)
->assert_op_attr<float>("bias", 0.);
auto* out = VarNode("out")->assert_is_op_output("scale", "Out");
*pre_op >> *x >> *scale_op >> *out;
// The pre_op will be eliminated, and a new output-updated op will insert.
x->AsIntermediate(); // x is pre_op's output, need to update
}
private:
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto& pre_op = matched.at("preop")->AsStmt();
auto op_info = *pre_op.op_info();
op_info.UpdateAllOutputs(matched.at("x")->AsArg().name,
matched.at("out")->AsArg().name);
pre_op.ResetOp(op_info, graph->valid_places());
GraphSafeRemoveNodes(graph, {matched.at("scale")});
IR_NODE_LINK_TO(matched.at("preop"), matched.at("out"));
}
};
} // namespace
class IdentityScaleEliminatePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override {
Eliminator eliminator;
eliminator(graph.get());
}
};
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(identity_scale_eliminate_pass,
paddle::lite::mir::IdentityScaleEliminatePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
#include "paddle/fluid/lite/core/mir/ssa_graph.h"
namespace paddle {
namespace lite {
namespace mir {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
// Op list:
// (x)->feed -> (feed) -> scale -> (scale_out) -> fetch->(fetch)
// After pass
// (x)->feed->(scale_out)->fetch->(fetch)
auto* main_block = program_desc->MutableBlock(0);
auto* feed_op = main_block->AppendOp();
auto* scale_op = main_block->AppendOp();
auto* fetch_op = main_block->AppendOp();
main_block->Var("x");
main_block->Var("feed");
main_block->Var("scale_out");
main_block->Var("fetch_out");
scope->Var("x")->GetMutable<lite::Tensor>();
scope->Var("feed")->GetMutable<lite::Tensor>();
scope->Var("scale_out")->GetMutable<lite::Tensor>();
scope->Var("fetch_out")->GetMutable<lite::Tensor>();
feed_op->SetType("feed");
feed_op->SetInput("X", {"x"});
feed_op->SetAttr("col", 1);
feed_op->SetOutput("Out", {"feed"});
scale_op->SetType("scale");
scale_op->SetInput("X", {"feed"});
scale_op->SetOutput("Out", {"scale_out"});
scale_op->SetAttr("scale", 1.f);
scale_op->SetAttr("bias", 0.f);
scale_op->SetAttr("bias_after_scale", true);
fetch_op->SetType("fetch");
fetch_op->SetInput("X", {"scale_out"});
fetch_op->SetOutput("Out", {"fetch"});
fetch_op->SetAttr("col", 1);
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
LOG(INFO) << Visualize(graph.get());
return graph;
}
TEST(identity_test, test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
const int num_nodes = graph->nodes().size();
auto pass = PassManager::Global().LookUp("identity_scale_eliminate_pass");
ASSERT_TRUE(pass);
pass->Apply(graph);
ASSERT_EQ(graph->nodes().size(), num_nodes - 2UL);
}
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(feed)
USE_LITE_OP(fetch)
USE_LITE_OP(scale)
USE_MIR_PASS(identity_scale_eliminate_pass)
......@@ -10,7 +10,6 @@ cc_library(fuse_conv_bn
cc_library(fuse_elementwise_add_activation
SRCS elementwise_add_activation_fuser.cc
DEPS pattern_matcher_high_api)
cc_library(fuse_quant_dequant
SRCS quant_dequant_op_fuser.cc
DEPS pattern_matcher_high_api)
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_bn_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_bn_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -70,7 +70,7 @@ void ConvBNFuser::BuildPattern() {
void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto eltwise_op = LiteOpRegistry::Global().Create("elementwise_add");
auto conv = matched.at("conv2d")->stmt()->op;
auto conv = matched.at("conv2d")->stmt()->op();
auto* scope = conv->scope();
auto& valid_places = conv->valid_places();
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -65,7 +65,7 @@ void ConvElementwiseAddActivationFuser::InsertNewNode(
SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto conv_op = LiteOpRegistry::Global().Create(conv_type_);
auto conv_old = matched.at("conv2d")->stmt()->op;
auto conv_old = matched.at("conv2d")->stmt()->op();
auto* scope = conv_old->scope();
auto& valid_places = conv_old->valid_places();
conv_op->Attach(op_desc, scope);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_relu_fuser.h"
#include "paddle/fluid/lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void ConvElementwiseAddReLUFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::ConvElementwiseAddReLUFuser fuser("conv2d");
fuser(graph.get());
fusion::ConvElementwiseAddReLUFuser depthwise_fuser("depthwise_conv2d");
depthwise_fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_conv_elementwise_add_act_fuse_pass,
paddle::lite::mir::ConvElementwiseAddReLUFusePass);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "paddle/fluid/lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class ConvElementwiseAddReLUFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/lite/api/cxx_api.h"
#include "paddle/fluid/lite/core/compatible_tensor.h"
#include "paddle/fluid/lite/core/mir/fusion/conv_elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/graph_visualize_pass.h"
#include "paddle/fluid/lite/core/mir/passes.h"
#include "paddle/fluid/lite/core/op_registry.h"
#include "paddle/fluid/lite/core/program.h"
DEFINE_string(model_dir, "", "");
DEFINE_string(optimized_model, "", "");
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
auto* main_block = program_desc->MutableBlock(0);
auto* conv2d_1 = main_block->AppendOp();
auto* conv2d_2 = main_block->AppendOp();
auto* add_1 = main_block->AppendOp();
auto* relu_1 = main_block->AppendOp();
auto* add_2 = main_block->AppendOp();
auto* relu_2 = main_block->AppendOp();
main_block->Var("input_1");
main_block->Var("input_2");
main_block->Var("filter_1");
main_block->Var("filter_2");
main_block->Var("conv2d_1_out");
main_block->Var("conv2d_2_out");
main_block->Var("bias_1");
main_block->Var("add_1_out");
main_block->Var("add_2_out");
main_block->Var("relu_1_out");
main_block->Var("out");
scope->Var("input_1")->GetMutable<lite::Tensor>();
scope->Var("input_2")->GetMutable<lite::Tensor>();
scope->Var("filter_1")->GetMutable<lite::Tensor>();
scope->Var("filter_2")->GetMutable<lite::Tensor>();
scope->Var("conv2d_1_out")->GetMutable<lite::Tensor>();
scope->Var("conv2d_2_out")->GetMutable<lite::Tensor>();
scope->Var("bias_1")->GetMutable<lite::Tensor>();
scope->Var("add_1_out")->GetMutable<lite::Tensor>();
scope->Var("add_2_out")->GetMutable<lite::Tensor>();
scope->Var("relu_1_out")->GetMutable<lite::Tensor>();
scope->Var("out")->GetMutable<lite::Tensor>();
conv2d_1->SetType("conv2d");
conv2d_1->SetInput("Input", {"input_1"});
conv2d_1->SetInput("Filter", {"filter_1"});
conv2d_1->SetOutput("Output", {"conv2d_1_out"});
conv2d_1->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_1->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_1->SetAttr("groups", 1);
conv2d_1->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_1->SetAttr("fuse_relu", false);
add_1->SetType("elementwise_add");
add_1->SetInput("X", {"conv2d_1_out"});
add_1->SetInput("Y", {"bias_1"});
add_1->SetOutput("Out", {"add_1_out"});
add_1->SetAttr("axis", 1);
relu_1->SetType("relu");
relu_1->SetInput("X", {"add_1_out"});
relu_1->SetOutput("Out", {"relu_1_out"});
conv2d_2->SetType("conv2d");
conv2d_2->SetInput("Input", {"input_2"});
conv2d_2->SetInput("Filter", {"filter_2"});
conv2d_2->SetOutput("Output", {"conv2d_2_out"});
conv2d_2->SetAttr("strides", std::vector<int>({1, 1}));
conv2d_2->SetAttr("paddings", std::vector<int>({0, 0}));
conv2d_2->SetAttr("groups", 1);
conv2d_2->SetAttr("dilations", std::vector<int>({1, 1}));
conv2d_2->SetAttr("fuse_relu", false);
add_2->SetType("elementwise_add");
add_2->SetInput("X", {"conv2d_2_out"});
add_2->SetInput("Y", {"relu_1_out"});
add_2->SetOutput("Out", {"add_2_out"});
add_2->SetAttr("axis", 1);
relu_2->SetType("relu");
relu_2->SetInput("X", {"add_2_out"});
relu_2->SetOutput("Out", {"out"});
program_desc->Flush();
lite::Program program(*program_desc->Proto(), scope, valid_places);
auto graph = std::unique_ptr<SSAGraph>(new SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(conv_elementwise_add_relu_fuse_pass, graph_test) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), 11UL /*vars*/ + 6UL /*ops*/);
Visualize(graph.get());
}
TEST(conv_elementwise_add_relu_fuse_pass, fuse_test_op) {
framework::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildGraph(&program_desc, scope, places);
Visualize(graph.get());
const int num_nodes = graph->nodes().size();
auto* fuser = new ConvElementwiseAddReLUFusePass;
fuser->Apply(graph);
Visualize(graph.get());
ASSERT_EQ(graph->nodes().size(), num_nodes - 5UL * 2 /*nodes removed */ +
1UL * 2 /* fused fc node*/);
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
USE_LITE_OP(elementwise_add);
USE_LITE_OP(conv2d);
USE_LITE_OP(depthwise_conv2d);
USE_LITE_OP(relu);
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/elementwise_add_activation_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/elementwise_add_activation_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -54,7 +54,7 @@ void ElementwiseAddActivationFuser::InsertNewNode(SSAGraph* graph,
auto op_desc = GenOpDesc(matched);
auto op =
LiteOpRegistry::Global().Create("fusion_elementwise_add_activation");
auto old_op = matched.at("add")->stmt()->op;
auto old_op = matched.at("add")->stmt()->op();
auto* scope = old_op->scope();
auto& valid_places = old_op->valid_places();
op->Attach(op_desc, scope);
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/fc_fuser.h"
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/fc_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/fc_fuse_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
......
......@@ -46,7 +46,7 @@ void FcFuser::BuildPattern() {
void FcFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/lite/core/mir/quant_dequant_fuse_pass.h"
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <memory>
#include <vector>
#include "paddle/fluid/lite/core/mir/fusion/quant_dequant_op_fuser.h"
......
......@@ -115,8 +115,8 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
nodes.push_back(matched.at("dequant_op_out" + std::to_string(i)));
}
int bit_length = quant_op->stmt()->op_info()->GetAttr<int>("bit_length");
auto* scope = quant_op->stmt()->op->scope();
auto& valid_places = quant_op->stmt()->op->valid_places();
auto* scope = quant_op->stmt()->op()->scope();
auto& valid_places = quant_op->stmt()->op()->valid_places();
int range = ((1 << (bit_length - 1)) - 1);
auto input_scale_t = scope->FindVar(quant_op_in_scale->arg()->name)
->GetMutable<lite::Tensor>();
......
......@@ -29,7 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
VLOG(4) << stmt;
insts_.emplace_back(stmt.op, std::move(stmt.valid_kernels.front()));
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
}
}
......
......@@ -39,7 +39,7 @@ std::string Visualize(mir::SSAGraph* graph) {
if (node.IsArg()) {
key = node.AsArg().name;
} else {
key = node.AsStmt().op_type + std::to_string(id++);
key = node.AsStmt().op_type() + std::to_string(id++);
}
if (node.IsStmt()) {
......
......@@ -25,11 +25,11 @@ class IoCopyKernelPickPass : public StmtPass {
for (auto& node : graph->mutable_nodes()) {
if (!node.IsStmt()) continue;
auto& inst = node.AsStmt();
if (inst.op_type != "io_copy") continue;
if (inst.op_type() != "io_copy") continue;
LOG(INFO) << "....> picking a IO COPY kernel";
auto& kernels = node.AsStmt().valid_kernels;
auto& kernels = node.AsStmt().kernels();
CHECK(!kernels.empty()) << "No valid kernels found for IoCopy Op";
const auto* inty = node.inlinks.front()->AsArg().type;
const auto* outy = node.outlinks.front()->AsArg().type;
......
......@@ -13,3 +13,62 @@
// limitations under the License.
#include "paddle/fluid/lite/core/mir/node.h"
#include "paddle/fluid/lite/core/op_registry.h"
namespace paddle {
namespace lite {
const OpInfo *mir::Node::Stmt::op_info() const {
CHECK(op_);
return op_->op_info();
}
Place mir::Node::Stmt::place() const {
CHECK(!valid_kernels_.empty());
return valid_kernels_.front()->place();
}
KernelBase &mir::Node::Stmt::picked_kernel() {
CHECK(!valid_kernels_.empty()) << "no kernel for " << op_type();
return *valid_kernels_.front();
}
OpInfo *mir::Node::Stmt::mutable_op_info() {
CHECK(op_);
return op_->mutable_op_info();
}
void mir::Node::Stmt::ResetOp(const cpp::OpDesc &op_desc,
const std::vector<Place> &valid_places,
lite::Scope *scope) {
CHECK((op_ && op_->scope()) || scope) << "Either scope should be set";
lite::Scope *the_scope = scope ? scope : op_->scope();
op_->Attach(op_desc, the_scope);
// Recreate the kernels with the latest OpInfo.
valid_kernels_.clear();
if (!op_ || op_->op_info()->Type() != op_desc.Type()) {
op_ = LiteOpRegistry::Global().Create(op_desc.Type());
CHECK(op_) << "No op found for " << op_desc.Type();
}
valid_kernels_ = op_->CreateKernels(valid_places);
}
std::ostream &mir::operator<<(std::ostream &os, const mir::Node::Stmt &other) {
os << "Statement " << other.op_type() << " " << other.place();
return os;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name, int id) {
auto &x = AsArg();
x.name = name;
x.id = id;
return x;
}
mir::Node::Arg &mir::Node::AsArg(const std::string &name) {
auto &x = AsArg();
x.name = name;
return x;
}
} // namespace lite
} // namespace paddle
......@@ -41,32 +41,40 @@ class Node {
kUnk,
};
struct Stmt {
std::string op_type;
class Stmt {
// The kernel instances this Statement contains.
std::vector<std::unique_ptr<KernelBase>> valid_kernels;
std::vector<std::unique_ptr<KernelBase>> valid_kernels_;
// TODO(Superjomn) make this a shared_ptr for resource safety.
std::shared_ptr<OpLite> op; // we hold op to run InferShape
std::shared_ptr<OpLite> op_; // we hold op to run InferShape
const OpInfo* op_info() {
CHECK(op);
return op->op_info();
}
public:
// Refresh the operator and kernels with the latest OpInfo.
void ResetOp(const cpp::OpDesc& op_desc,
const std::vector<Place>& valid_places,
lite::Scope* scope = nullptr);
Place place() const {
CHECK(!valid_kernels.empty());
return valid_kernels.front()->place();
}
std::string op_type() const { return op_info()->Type(); }
const OpInfo* op_info() const;
OpInfo* mutable_op_info();
KernelBase& picked_kernel() {
CHECK(!valid_kernels.empty()) << "no kernel for " << op_type;
return *valid_kernels.front();
void SetKernels(std::vector<std::unique_ptr<KernelBase>>&& kernels) {
valid_kernels_ = std::move(kernels);
}
friend std::ostream& operator<<(std::ostream& os, const Stmt& other) {
os << "Statement " << other.op_type << " " << other.place();
return os;
std::vector<std::unique_ptr<KernelBase>>& kernels() {
return valid_kernels_;
}
void SetOp(const std::shared_ptr<OpLite>& op) { op_ = op; }
const std::shared_ptr<OpLite> op() const { return op_; }
Place place() const;
KernelBase& picked_kernel();
friend std::ostream& operator<<(std::ostream& os, const Stmt& other);
// Description.
std::string desc;
};
struct Arg {
......@@ -78,26 +86,16 @@ class Node {
bool is_weight{false};
};
Arg& AsArg(const std::string& name, int id) {
auto& x = AsArg();
x.name = name;
x.id = id;
return x;
}
Arg& AsArg(const std::string& name, int id);
Arg& AsArg(const std::string& name) {
auto& x = AsArg();
x.name = name;
return x;
}
Arg& AsArg(const std::string& name);
Stmt& AsStmt(const std::string& op_type,
std::vector<std::unique_ptr<KernelBase>>&& kernels,
const std::shared_ptr<OpLite>& op) {
auto& x = AsStmt();
x.op_type = op_type;
x.op = op;
x.valid_kernels = std::move(kernels);
x.SetOp(op);
x.SetKernels(std::move(kernels));
return x;
}
......@@ -142,7 +140,7 @@ class Node {
}
if (other.IsStmt()) {
auto& arg = other.AsStmt();
os << "Statement " << arg.op_type;
os << "Statement " << arg.op_type();
}
return os;
}
......
......@@ -139,14 +139,13 @@ struct PMNode {
template <typename T>
PMNode* assert_op_attr(const std::string& attr_name, const T& attr) {
asserts_.emplace_back([=](Node* x) {
asserts_.push_back([=](const Node* x) {
if (x && x->IsStmt()) {
auto* op_info = x->stmt()->op_info();
return op_info->HasAttr(attr_name) &&
op_info->GetAttr<T>(attr_name) == attr;
} else {
return false;
}
return false;
});
return this;
}
......
......@@ -41,6 +41,7 @@ void FuseBase::DeleteInterNodes(SSAGraph *graph) {
}
}
LOG(INFO) << "keys: " << key2nodes_.size();
std::unordered_set<const Node *> nodes2rm;
for (auto &matched : key2nodes_) {
for (const auto &key : keys) {
......
......@@ -49,7 +49,13 @@ class FuseBase {
virtual void BuildPattern() = 0;
// Generate an operator desc with a matched subgraph.
virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) = 0;
virtual cpp::OpDesc GenOpDesc(const key2nodes_t& matched) {
return cpp::OpDesc();
}
PMNode* OpNode(const std::string& key) {
return GetOrCreateNode(key)->assert_is_op();
}
PMNode* OpNode(const std::string& key, const std::string& op_type);
......
......@@ -52,7 +52,7 @@ class FcFuser : public FuseBase {
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override {
auto op_desc = GenOpDesc(matched);
auto fc_op = LiteOpRegistry::Global().Create("fc");
auto mul = matched.at("mul")->stmt()->op;
auto mul = matched.at("mul")->stmt()->op();
auto* scope = mul->scope();
auto& valid_places = mul->valid_places();
fc_op->Attach(op_desc, scope);
......@@ -90,7 +90,7 @@ std::unique_ptr<SSAGraph> BuildGraph(framework::ProgramDesc* program_desc,
main_block->Var("w");
main_block->Var("out");
scope->Var("w")->GetMutable<lite::Tensor>();
scope->Var("x")->GetMutable<lite::Tensor>();
scope->Var("b")->GetMutable<lite::Tensor>();
scope->Var("mul_out")->GetMutable<lite::Tensor>();
scope->Var("w")->GetMutable<lite::Tensor>();
......
......@@ -23,19 +23,19 @@ namespace mir {
void BuildGraph(SSAGraph* g) {
g->mutable_nodes().emplace_back();
Node& o1 = g->mutable_nodes().back();
o1.AsStmt().op_type = "op1";
o1.AsStmt().desc = "op1";
g->mutable_nodes().emplace_back();
Node& o2 = g->mutable_nodes().back();
o2.AsStmt().op_type = "op2";
o2.AsStmt().desc = "op2";
g->mutable_nodes().emplace_back();
Node& o3 = g->mutable_nodes().back();
o3.AsStmt().op_type = "op3";
o3.AsStmt().desc = "op3";
g->mutable_nodes().emplace_back();
Node& o4 = g->mutable_nodes().back();
o4.AsStmt().op_type = "op4";
o4.AsStmt().desc = "op4";
g->mutable_nodes().emplace_back();
Node& o5 = g->mutable_nodes().back();
o5.AsStmt().op_type = "op5";
o5.AsStmt().desc = "op5";
g->mutable_nodes().emplace_back();
Node& v1 = g->mutable_nodes().back();
v1.AsArg("var1");
......@@ -108,11 +108,11 @@ TEST(PatternMatcher, MarkPMNodesInGraph) {
// v2 -> o3(a node named o3)
auto* o2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op2";
return node && node->IsStmt() && node->stmt()->desc == "op2";
});
auto* o3 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
return node && node->IsStmt() && node->stmt()->op_type == "op3";
return node && node->IsStmt() && node->stmt()->desc == "op3";
});
auto* v2 = x.pattern_.NewNode([](const Node* node) {
// The teller can be any condition, such as op type, or variable's shape.
......@@ -153,8 +153,8 @@ TEST(PatternMatcher, MultiSubgraph) {
// op -> var
auto* any_op = x.mutable_pattern()->NewNode(
[](const Node* node) {
return node->IsStmt() && (node->stmt()->op_type == "op2" ||
node->stmt()->op_type == "op3");
return node->IsStmt() &&
(node->stmt()->desc == "op2" || node->stmt()->desc == "op3");
},
"OP0");
auto* any_var =
......@@ -170,9 +170,9 @@ TEST(PatternMatcher, MultiSubgraph) {
int count = 0;
PatternMatcher::handle_t handle = [&](const PatternMatcher::subgraph_t& s,
SSAGraph* g) {
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->op_type << " -> "
LOG(INFO) << "Detect " << s.at(any_op)->stmt()->desc << " -> "
<< s.at(any_var)->arg()->name << " -> "
<< s.at(any_op1)->stmt()->op_type;
<< s.at(any_op1)->stmt()->desc;
count++;
};
......@@ -197,12 +197,12 @@ TEST(PatternMatcher, IntermediateCheck) {
PatternMatcher matcher;
auto* op2 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op2";
return x && x->IsStmt() && x->stmt()->desc == "op2";
},
"op2");
auto* op3 = matcher.mutable_pattern()->NewNode(
[](const Node* x) {
return x && x->IsStmt() && x->stmt()->op_type == "op3";
return x && x->IsStmt() && x->stmt()->desc == "op3";
},
"op3");
auto* v2 = matcher.mutable_pattern()
......
......@@ -65,6 +65,10 @@ class SSAGraph : GraphBase {
Node *GraphCreateInstructNode(const std::shared_ptr<OpLite> &op,
const std::vector<Place> &valid_places);
// Device related attributes
const std::vector<Place> &valid_places() const { return valid_places_; }
void SetValidPlaces(const std::vector<Place> &x) { valid_places_ = x; }
private:
mir::Node *Argument(const std::string &name);
// Check the bidirectional connection.
......@@ -89,6 +93,7 @@ class SSAGraph : GraphBase {
private:
std::list<mir::Node> node_storage_;
std::map<std::string, mir::Node *> arguments_;
std::vector<Place> valid_places_;
};
// Remove the link between a -> b.
......
......@@ -37,9 +37,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (!node.IsStmt()) continue;
auto& instruct = node.AsStmt();
std::vector<std::pair<size_t, std::unique_ptr<KernelBase>>> scored;
CHECK(!instruct.valid_kernels.empty()) << "No kernels found for "
<< instruct.op_type;
for (auto&& kernel : instruct.valid_kernels) {
CHECK(!instruct.kernels().empty()) << "No kernels found for "
<< instruct.op_type();
for (auto&& kernel : instruct.kernels()) {
size_t score = KernelGrade(*kernel);
scored.emplace_back(score, std::move(kernel));
}
......@@ -49,9 +49,9 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// Move kernel back
// Just keep a single best kernel.
// TODO(Superjomn) reconsider this.
instruct.valid_kernels.clear();
instruct.valid_kernels.emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.valid_kernels.front()->name();
instruct.kernels().clear();
instruct.kernels().emplace_back(std::move(scored.front().second));
VLOG(2) << "pick " << instruct.kernels().front()->name();
}
}
......
......@@ -62,7 +62,7 @@ void TypeTargetTransformPass::ComplementInputs(SSAGraph* graph, Node* inst_node,
CHECK(in->AsArg().type);
if (!TargetCompatibleTo(*in->AsArg().type, *decl_arg_type)) {
LOG(INFO) << "found Target unmatched tensor: " << in->AsArg().name
<< " for kernel " << inst.op->DebugString() << " "
<< " for kernel " << inst.op()->DebugString() << " "
<< *in->AsArg().type << " -> " << *decl_arg_type;
// Add an IoCopy instruction to make the input compatible with other dist.
AddIoCopyInst(*in->AsArg().type, *decl_arg_type, in, graph, inst_node,
......@@ -89,7 +89,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
CHECK(io_copy_op) << "create op [" << io_copy_op << "] failed";
// CHECK(io_copy_op);
// Create the new var manually.
inst_node->AsStmt().op->scope()->Var(io_copy_output_name);
inst_node->AsStmt().op()->scope()->Var(io_copy_output_name);
// Create IoCopy Instruction.
cpp::OpDesc op_desc;
......@@ -97,7 +97,7 @@ void TypeTargetTransformPass::AddIoCopyInst(
op_desc.SetInput("Input", {in->AsArg().name});
op_desc.SetOutput("Out", {io_copy_output_name});
io_copy_op->Attach(op_desc, inst_node->AsStmt().op->scope());
io_copy_op->Attach(op_desc, inst_node->AsStmt().op()->scope());
auto kernels = io_copy_op->CreateKernels(valid_places);
io_copy_inst->AsStmt("io_copy", std::move(kernels), io_copy_op);
......@@ -113,19 +113,19 @@ void TypeTargetTransformPass::AddIoCopyInst(
DirectedLink(io_copy_output_arg, inst_node);
// reset opdesc and update kernel information
UpdateInputTo(inst_node->AsStmt().op->mutable_op_info(), in->AsArg().name,
UpdateInputTo(inst_node->AsStmt().op()->mutable_op_info(), in->AsArg().name,
io_copy_output_name);
inst_node->AsStmt().op->Attach(*inst_node->AsStmt().op->op_info(),
inst_node->AsStmt().op->scope());
inst_node->AsStmt().ResetOp(*inst_node->AsStmt().op_info(),
graph->valid_places());
std::string tmp;
if (inst_node->AsStmt().op_info()->GetInputArgname("a", &tmp)) {
CHECK(false) << "get old a " << tmp;
}
for (auto& kernel : inst_node->AsStmt().valid_kernels) {
inst_node->AsStmt().op->AttachKernel(kernel.get());
for (auto& kernel : inst_node->AsStmt().kernels()) {
inst_node->AsStmt().op()->AttachKernel(kernel.get());
}
graph->CheckValid();
......
......@@ -23,9 +23,11 @@ USE_MIR_PASS(generate_program_pass);
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
USE_MIR_PASS(runtime_context_assign_pass);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(graph_visualze);
USE_MIR_PASS(lite_conv_bn_fuse_pass);
USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_elementwise_add_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass);
......@@ -39,7 +39,7 @@ class VariablePlaceInferencePass : public DebugPass {
for (const auto& v : graph->inputs()) {
// the feed op might in the inputs
if (v->IsStmt()) {
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type;
LOG(INFO) << "found kernel in inputs " << v->AsStmt().op_type();
continue;
}
}
......@@ -59,10 +59,10 @@ class VariablePlaceInferencePass : public DebugPass {
for (auto& x : graph->StmtTopologicalOrder()) {
auto& inst = x->AsStmt();
// The IoCopyOp is a tool operator, it won't support the type inference.
if (inst.op_type == "io_copy") continue;
if (inst.op_type() == "io_copy") continue;
// LOG(INFO) << "- inferencing type " <<
// deal with inputs
VLOG(4) << "inferencing op " << inst.op_type;
VLOG(4) << "Infering op " << inst.op_info()->Repr();
// TODO(zhaolong): Add check if the node's name in op's arguments.
auto get_argname = [&](
......@@ -90,12 +90,14 @@ class VariablePlaceInferencePass : public DebugPass {
}
}
VLOG(3) << "inst " << inst.op_info()->Repr();
for (auto* x_out : x->outlinks) {
std::string node_name = x_out->AsArg().name;
std::string arg_name =
get_argname(node_name, inst.op_info()->outputs());
CHECK(arg_name.size() > 0) << "can not found op arguments for node "
<< node_name;
<< node_name << " in Inst "
<< inst.op_type();
VLOG(3) << "-- output arg_name " << arg_name;
auto type = inst.picked_kernel().GetOutputDeclType(arg_name);
if (!x_out->AsArg().type) {
......
......@@ -61,7 +61,6 @@ std::vector<std::unique_ptr<KernelBase>> OpLite::CreateKernels(
targets.insert(place.target);
}
// CHECK(!kernels.empty()) << "No kernel found for Op " << op_type_;
VLOG(2) << "op " << op_type_ << " get " << kernels.size() << " kernels";
return kernels;
}
......@@ -83,7 +82,7 @@ bool OpLite::Attach(const cpp::OpDesc &opdesc, lite::Scope *scope) {
scope_ = scope;
op_info_.reset(
new OpInfo(opdesc)); // Force clean the out-of-date infomation.
return AttachImpl(opdesc, scope);
return AttachImpl(*op_info(), scope);
}
const Tensor *OpLite::GetTensor(lite::Scope *scope,
......
......@@ -197,6 +197,22 @@ class OpInfo : public cpp::OpDesc {
}
return false;
}
void UpdateAllInputs(const std::string &from, const std::string &to) {
for (auto &item : inputs_) {
for (auto &var : item.second) {
if (var == from) var = to;
}
}
}
void UpdateAllOutputs(const std::string &from, const std::string &to) {
for (auto &item : outputs_) {
for (auto &var : item.second) {
if (var == from) var = to;
}
}
}
};
} // namespace lite
......
......@@ -43,6 +43,8 @@ class Optimizer {
CHECK(!graph_) << "duplicate optimize found";
graph_.reset(new mir::SSAGraph);
graph_->Build(program, valid_places);
graph_->SetValidPlaces(valid_places);
SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass();
......@@ -51,6 +53,8 @@ class Optimizer {
"lite_quant_dequant_fuse_pass", //
"lite_conv_bn_fuse_pass", //
"lite_conv_elementwise_add_activation_fuse_pass", //
"lite_fc_fuse_pass", //
"identity_scale_eliminate_pass", //
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
"lite_elementwise_add_activation_fuse_pass", //
#endif
......
......@@ -140,7 +140,7 @@ class RuntimeProgram {
void Run() {
for (auto& inst : instructions_) {
VLOG(4) << ">> Running kernel: " << inst;
VLOG(3) << ">> Running kernel: " << inst.op()->op_info()->Repr();
inst.Run();
}
}
......
......@@ -191,7 +191,6 @@ class TensorBase {
template <typename TensorT>
bool TensorCompareWith(const TensorT &a, const TensorT &b) {
if (a.dims() != b.dims()) return false;
LOG(INFO) << "data_size: " << a.data_size();
if (memcmp(a.raw_data(), b.raw_data(), a.data_size()) != 0) return false;
return true;
}
......
......@@ -117,6 +117,19 @@ TEST(elementwise_add, compute) {
operators::ElementwiseParam param;
lite::Tensor x, y, output, output_ref;
#if 1
for (auto n : {1, 3, 4}) {
for (auto c : {1, 3, 4}) {
for (auto h : {1, 3, 4}) {
for (auto w : {1, 3, 4}) {
for (auto axis : {-1, 0, 1, 3}) {
for (auto yd :
{std::vector<int64_t>({n}), std::vector<int64_t>({c}),
std::vector<int64_t>({h}), std::vector<int64_t>({w}),
std::vector<int64_t>({n, c}), std::vector<int64_t>({c, h}),
std::vector<int64_t>({c, h, w}),
std::vector<int64_t>({n, c, h, w})}) {
#else
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 4, 11}) {
for (auto h : {1, 3, 4, 11}) {
......@@ -129,6 +142,7 @@ TEST(elementwise_add, compute) {
std::vector<int64_t>({h, w}), std::vector<int64_t>({n, c, h}),
std::vector<int64_t>({c, h, w}),
std::vector<int64_t>({n, c, h, w})}) {
#endif
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
......@@ -192,6 +206,20 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
operators::FusionElementwiseActivationParam param;
lite::Tensor x, y, output, output_ref;
#if 1
for (auto act_type : {"relu"}) {
for (auto n : {1, 3, 4}) {
for (auto c : {1, 3, 4}) {
for (auto h : {1, 3, 4}) {
for (auto w : {1, 3, 4}) {
for (auto axis : {-1, 0, 1, 3}) {
for (auto yd :
{std::vector<int64_t>({n}), std::vector<int64_t>({c}),
std::vector<int64_t>({h}), std::vector<int64_t>({w}),
std::vector<int64_t>({n, c}), std::vector<int64_t>({h, w}),
std::vector<int64_t>({n, c, h}),
std::vector<int64_t>({n, c, h, w})}) {
#else
for (auto act_type : {"relu"}) {
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 4, 11}) {
......@@ -206,6 +234,7 @@ TEST(fusion_elementwise_add_activation_arm, compute) {
std::vector<int64_t>({n, c, h}),
std::vector<int64_t>({c, h, w}),
std::vector<int64_t>({n, c, h, w})}) {
#endif
auto x_dim = DDim(std::vector<int64_t>({n, c, h, w}));
auto y_dim = DDim(yd);
int axis_t = axis < 0 ? x_dim.size() - y_dim.size() : axis;
......
......@@ -80,12 +80,19 @@ TEST(softmax_arm, compute) {
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_ref;
#if 1
for (auto n : {1, 3}) {
for (auto c : {1, 4}) {
for (auto h : {5, 1}) {
for (auto w : {1, 6}) {
for (auto axis : {-2, -1, 0, 1, 2}) {
#else
for (auto n : {1, 3, 4, 11}) {
for (auto c : {1, 3, 11, 4}) {
for (auto h : {3, 1, 11, 4}) {
for (auto w : {1, 3, 4, 12}) {
for (auto axis : {-4, -3, -2, -1, 0, 1, 2, 3}) {
#endif
x.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/elementwise_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
......
......@@ -40,12 +40,20 @@ class MulCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
const Tensor x_matrix = x->dims().size() > 2 ? framework::ReshapeToMatrix(
*x, param.x_num_col_dims)
: *x;
const Tensor y_matrix = y->dims().size() > 2 ? framework::ReshapeToMatrix(
*y, param.y_num_col_dims)
: *y;
Tensor x_matrix, y_matrix;
if (x->dims().size() > 2) {
x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
} else {
x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* z = &param.output->raw_tensor();
auto z_dim = z->dims();
......@@ -75,12 +83,22 @@ class MulGradCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
auto* x = &param.x->raw_tensor();
auto* y = &param.y->raw_tensor();
auto x_matrix = x->dims().size() > 2
? framework::ReshapeToMatrix(*x, param.x_num_col_dims)
: static_cast<const Tensor&>(*x);
auto y_matrix = y->dims().size() > 2
? framework::ReshapeToMatrix(*y, param.y_num_col_dims)
: static_cast<const Tensor&>(*y);
Tensor x_matrix, y_matrix;
if (x->dims().size() > 2) {
x_matrix = framework::ReshapeToMatrix(*x, param.x_num_col_dims);
} else {
x_matrix = *x;
}
if (y->dims().size() > 2) {
y_matrix = framework::ReshapeToMatrix(*y, param.y_num_col_dims);
} else {
y_matrix = *y;
}
auto* dout = &param.output_grad->raw_tensor();
Tensor dout_mat;
......
......@@ -15,6 +15,8 @@
#include "paddle/fluid/lite/kernels/x86/mul_compute.h"
#include <gtest/gtest.h>
#include <iostream>
#include <memory>
#include <utility>
#include <vector>
#include "paddle/fluid/lite/core/op_registry.h"
......
......@@ -11,7 +11,8 @@ if(NOT LITE_ON_MOBILE)
endif()
cc_library(compatible_pb_lite SRCS compatible_pb.cc DEPS op_desc_lite framework_proto_lite var_desc_lite)
cc_library(compatible_pb_lite SRCS compatible_pb.cc
DEPS op_desc_lite framework_proto_lite var_desc_lite cpp_op_desc_lite)
lite_cc_library(model_parser_lite SRCS model_parser.cc DEPS
variable_lite scope_lite ${tensor_lite} scope_lite
......
cc_library(cpp_op_desc_lite SRCS op_desc.cc DEPS any_lite)
......@@ -14,6 +14,7 @@
#pragma once
#include <map>
#include <sstream>
#include <string>
#include <vector>
......@@ -79,6 +80,27 @@ class OpDescAPI {
/// Get an attribute.
template <typename T>
T GetAttr(const std::string& name) const;
std::string Repr() const {
std::stringstream ss;
ss << Type();
ss << "(";
for (auto& arg : InputArgumentNames()) {
ss << arg << ":";
for (auto val : Input(arg)) {
ss << val << " ";
}
}
ss << ") -> (";
for (auto& arg : OutputArgumentNames()) {
ss << arg << ":";
for (auto val : Output(arg)) {
ss << val << " ";
}
}
ss << ")";
return ss.str();
}
};
} // namespace lite
......
......@@ -141,6 +141,8 @@ class OpDesc : public OpDescAPI {
template <typename T>
T GetAttr(const std::string &name) const;
std::string DebugString() const { return desc_.DebugString(); }
private:
std::vector<std::string> GetArguments(
const google::protobuf::RepeatedPtrField<framework::proto::OpDesc_Var>
......
......@@ -38,15 +38,19 @@ class MulOpLite : public OpLite {
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
auto input = op_desc.Input("X").front();
auto W = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front();
auto *var = scope->FindVar(input);
CHECK(var);
param_.x = var->GetMutable<Tensor>();
param_.x = &var->Get<Tensor>();
var = scope->FindVar(W);
CHECK(var) << "no var called " << W;
param_.y = var->GetMutable<Tensor>();
param_.y = &var->Get<Tensor>();
var = scope->FindVar(out);
CHECK(var) << "no var called " << out;
param_.output = var->GetMutable<Tensor>();
......
......@@ -67,8 +67,8 @@ struct ReluParam {
// For Mul Op
struct MulParam {
lite::Tensor* x{};
lite::Tensor* y{};
const lite::Tensor* x{};
const lite::Tensor* y{};
lite::Tensor* output{};
int x_num_col_dims{1};
......
......@@ -54,22 +54,6 @@ function check_style {
fi
}
function cmake_arm {
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake .. \
-DWITH_GPU=OFF \
-DWITH_MKL=OFF \
-DWITH_LITE=ON \
-DLITE_WITH_CUDA=OFF \
-DLITE_WITH_X86=OFF \
-DLITE_WITH_ARM=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=ON \
-DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3
}
function build_single {
#make $1 -j$(expr $(nproc) - 2)
make $1 -j$NUM_CORES_FOR_COMPILE
......@@ -153,33 +137,53 @@ function test_arm_model {
adb -s emulator-${port} shell chmod +x "${adb_work_dir}/${test_name}"
local adb_model_path="./${adb_work_dir}/`basename ${model_dir}`"
adb -s emulator-${port} shell "./${adb_work_dir}/${test_name} --eval_model_dir=$adb_model_path"
}
# Build the code and run lite arm tests. This is executed in the CI system.
function build_test_arm {
# 1. Build goes first
function cmake_arm {
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
cmake .. \
-DWITH_GPU=OFF \
-DWITH_MKL=OFF \
-DWITH_LITE=ON \
-DLITE_WITH_CUDA=OFF \
-DLITE_WITH_X86=OFF \
-DLITE_WITH_ARM=ON \
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=ON \
-DARM_TARGET_OS=$1 -DARM_TARGET_ARCH_ABI=$2 -DARM_TARGET_LANG=$3
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
function build_arm {
os=$1
abi=$2
lang=$3
cur_dir=$(pwd)
for lang in "gcc" "clang"; do
for os in "android" "armlinux" ; do
if [[ ${os} == "armlinux" && ${lang} == "clang" ]]; then
continue
if [[ ${os} == "armlinux" ]]; then
# TODO(hongming): enable compile armv7 and armv7hf on armlinux, and clang compile
if [[ ${lang} == "clang" ]]; then
echo "clang is not enabled on armlinux yet"
return 0
fi
for abi in "armv8" "armv7" "armv7hf"; do
# TODO(hongming): enable compile armv7 and armv7hf on armlinux
if [[ ${abi} == "armv7hf" ]]; then
echo "armv7hf is not supported on both android and armlinux yet"
continue
echo "armv7hf is not supported on armlinux yet"
return 0
fi
# TODO(hongming): enable armv7 on armlinux
if [[ ${os} == "armlinux" && ${abi} == "armv7" ]]; then
if [[ ${abi} == "armv7" ]]; then
echo "armv7 is not supported on armlinux yet"
continue
return 0
fi
fi
if [[ ${os} == "android" && ${abi} == "armv7hf" ]]; then
echo "android do not need armv7hf"
continue
return 0
fi
build_dir=$cur_dir/build.lite.${os}.${abi}.${lang}
......@@ -188,11 +192,47 @@ function build_test_arm {
cmake_arm ${os} ${abi} ${lang}
build $TESTS_FILE
}
# $1: ARM_TARGET_OS in "android" , "armlinux"
# $2: ARM_TARGET_ARCH_ABI in "armv8", "armv7" ,"armv7hf"
# $3: ARM_TARGET_LANG in "gcc" "clang"
# $4: android test port
# Note: test must be in build dir
function test_arm {
os=$1
abi=$2
lang=$3
port=$4
if [[ ${os} == "armlinux" ]]; then
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
echo "Skip test arm linux yet. armlinux must in another docker"
return 0
fi
if [[ ${os} == "android" && ${abi} == "armv7hf" ]]; then
echo "android do not need armv7hf"
return 0
fi
# TODO(yuanshuai): enable armv7 on android
if [[ ${abi} == "armv7" ]]; then
echo "skip android v7 test yet"
return 0
fi
echo "test file: ${TESTS_FILE}"
for _test in $(cat $TESTS_FILE); do
test_arm_android $_test $port
done
done
done
# TODO(sangoly): refine this
test_arm_model "test_cxx_api_lite" $port "./third_party/install/mobilenet_v2_relu"
}
# 2. Then test
# Build the code and run lite arm tests. This is executed in the CI system.
function build_test_arm {
########################################################################
# job 1-4 must be in one runner
port_armv8=5554
port_armv7=5556
......@@ -206,39 +246,46 @@ function build_test_arm {
echo -ne '\n' | ${ANDROID_HOME}/emulator/emulator -avd paddle-armv7 -noaudio -no-window -gpu off -verbose -port ${port_armv7} &
sleep 1m
# now can only test android.
for lang in "gcc" "clang"; do
for abi in "armv8" "armv7" ; do
# TODO(yuanshuai): enable armv7 on android
if [[ ${abi} == "armv7" ]]; then
continue
fi
# job 1
build_arm "android" "armv8" "gcc"
test_arm "android" "armv8" "gcc" ${port_armv8}
cd -
build_dir=$cur_dir/build.lite.android.${abi}.${lang}
cd $build_dir
# job 2
build_arm "android" "armv8" "clang"
test_arm "android" "armv8" "clang" ${port_armv8}
cd -
local port=
if [[ ${abi} == "armv7" ]]; then
port=${port_armv7}
fi
# job 3
build_arm "android" "armv7" "gcc"
test_arm "android" "armv7" "gcc" ${port_armv7}
cd -
if [[ ${abi} == "armv8" ]]; then
port=${port_armv8}
fi
echo "test file: ${TESTS_FILE}"
for _test in $(cat $TESTS_FILE); do
test_arm_android $_test $port
done
# TODO(sangoly): refine this
test_arm_model "test_cxx_api_lite" $port "./third_party/install/mobilenet_v2_relu"
done
done
# armlinux need in another docker
# TODO(hongming): enable test armlinux on armv8, armv7 and armv7hf
# job 4
build_arm "android" "armv7" "clang"
test_arm "android" "armv7" "clang" ${port_armv7}
cd -
adb devices | grep emulator | cut -f1 | while read line; do adb -s $line emu kill; done
echo "Done"
########################################################################
# job 5
build_arm "armlinux" "armv8"
test_arm "armlinux" "armv8"
cd -
# job 6
build_arm "armlinux" "armv7"
test_arm "armlinux" "armv7"
cd -
# job 7
build_arm "armlinux" "armv7hf"
test_arm "armlinux" "armv7hf"
cd -
echo "Done"
}
############################# MAIN #################################
......@@ -279,6 +326,10 @@ function main {
ARM_ABI="${i#*=}"
shift
;;
--arm_lang=*)
ARM_LANG="${i#*=}"
shift
;;
--arm_port=*)
ARM_PORT="${i#*=}"
shift
......@@ -301,13 +352,21 @@ function main {
shift
;;
cmake_arm)
cmake_arm $ARM_OS $ARM_ABI
cmake_arm $ARM_OS $ARM_ABI $ARM_LANG
shift
;;
build_arm)
build_arm $ARM_OS $ARM_ABI $ARM_LANG
shift
;;
test_server)
test_lite $TESTS_FILE
shift
;;
test_arm)
build_arm $ARM_OS $ARM_ABI $ARM_LANG $ARM_PORT
shift
;;
test_arm_android)
test_arm_android $TEST_NAME $ARM_PORT
shift
......
......@@ -20,6 +20,7 @@
#include <typeinfo>
#include <utility>
#include "paddle/fluid/lite/utils/cp_logging.h"
#include "paddle/fluid/lite/utils/string.h"
// This is an equivalent implementation of boost::any. We implement this to
// avoid including the whole boost library and keep the inference library small.
......@@ -116,9 +117,9 @@ struct variant {
if (type_id == typeid(T).hash_code())
return *reinterpret_cast<const T*>(&data);
else
throw std::invalid_argument("unmatched type");
// LOG(FATAL) << "unmatched type get, should be " << type_id << " but get "
// << typeid(T).name();
throw std::invalid_argument(
string_format("unmatched type, store as %d, but want to get %s",
type_id, typeid(T).name()));
return *reinterpret_cast<const T*>(&data);
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册