提交 dff75d34 编写于 作者: C cen.li

* subgraph new framework ok

* test=develop
上级 04a36e78
......@@ -81,6 +81,7 @@ message(STATUS "get ARM kernels ${arm_kernels}")
message(STATUS "get NPU kernels ${npu_kernels}")
message(STATUS "get XPU kernels ${xpu_kernels}")
message(STATUS "get FPGA kernels ${fpga_kernels}")
message(STATUS "get BM kernels ${bm_kernels}")
# for full api
if (NOT LITE_ON_TINY_PUBLISH)
......
......@@ -34,7 +34,9 @@ void TestModel(const std::vector<Place>& valid_places) {
//DeviceInfo::Init();
//DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads);
lite::Predictor predictor;
predictor.Build(FLAGS_model_dir, "", "", valid_places);
std::vector<std::string> passes;
passes.push_back("bm_subgraph_pass");
predictor.Build(FLAGS_model_dir, "", "", valid_places, passes);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
......@@ -105,7 +107,8 @@ void TestModel(const std::vector<Place>& valid_places) {
TEST(ResNet50, test_bm) {
std::vector<Place> valid_places({
Place{TARGET(kBM), PRECISION(kFloat)}
Place{TARGET(kBM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}
});
TestModel(valid_places);
......
......@@ -6,5 +6,5 @@ endif()
lite_cc_library(arena_framework SRCS framework.cc DEPS program gtest)
if((NOT LITE_WITH_OPENCL) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_arena_framework SRCS framework_test.cc DEPS arena_framework ${npu_kernels} ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${fpga_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......@@ -90,7 +90,7 @@ class Context<TargetType::kBM> {
Context() {}
explicit Context(const BMContext& ctx);
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce();
void InitOnce() {}
void CopySharedTo(BMContext* ctx) {}
std::string name() const { return "BMContext"; }
......
......@@ -29,6 +29,7 @@ void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
VLOG(4) << stmt;
LOG(INFO) << stmt;
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/generate_bm_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/kernels/bm/bridges/paddle_use_bm_bridges.h"
#include "lite/kernels/bm/bridges/registry.h"
#include "bmcompiler_if.h"
#include "bmlog.hpp"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<void*> GenerateBMProgramPass::CvtVarNode(
lite::mir::Node* var_node, const Scope* scope) {
return nullptr;
}
void GenerateBMProgramPass::CvtAllOpNodes(
const std::vector<Node*>& nodes2cvt,
lite::kernels::bm::bridges::node_map_type* converted_vars) {
const auto& bridges = lite::kernels::bm::bridges::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions();
lite::kernels::bm::bridges::graph_ctx_type ctx;
ctx.bm_compiler_handle = create_bmcompiler("BM1684");
CHECK(ctx.bm_compiler_handle != nullptr);
//bmlog::init("paddle_bitmain");
//bmlog::set_v(3);
for (auto& node : nodes2cvt) {
lite::kernels::bm::bridges::node_map_type node_inputs;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
if (arg.is_weight) {
continue;
}
auto var_name = arg.name;
if (!converted_vars->count(var_name)) {
converted_vars->insert(std::make_pair(var_name, var_name));
}
node_inputs.insert(*converted_vars->find(var_name));
}
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), &ctx, node_inputs);
converted_vars->insert(node_outputs.begin(), node_outputs.end());
}
std::string net_name = "paddle_bitmain";
__bmcompile_opt(ctx.bm_compiler_handle, const_cast<char*>(net_name.c_str()), 2);
finish_bmcompiler(ctx.bm_compiler_handle);
}
void GenerateBMProgramPass::GenSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
std::unordered_set<Node*> out_unused_vars;
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto ordered_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::bm::bridges::node_map_type converted_vars;
CvtAllOpNodes(ordered_nodes, &converted_vars);
}
void GenerateBMProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
const auto& bridges = lite::kernels::bm::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
//LOG(INFO) << "[BM] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
//LOG(INFO) << "[BM] Converting Subgraph " << id;
GenSubgraph(graph, op_nodes.second, id);
id++;
}
}
std::unique_ptr<RuntimeProgram> GenerateBMProgramPass::GenProgram() {
std::unique_ptr<RuntimeProgram> program(
new RuntimeProgram(std::move(insts_)));
return program;
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_bm_program_pass,
paddle::lite::mir::subgraph::GenerateBMProgramPass)
.BindTargets({TARGET(kBM)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/core/context.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/bm/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class GenerateBMProgramPass : public SubgraphProgramPass {
public:
using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
std::unique_ptr<RuntimeProgram> GenProgram();
protected:
// nodes2cvt: op nodes to convert
// return cvted_vars: converted var nodes
void CvtAllOpNodes(const std::vector<Node*>& nodes2cvt,
lite::kernels::bm::bridges::node_map_type* cvted_vars);
std::shared_ptr<void*> CvtVarNode(lite::mir::Node* var_node,
const Scope* scope);
std::string BuildGraph(const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
void GenSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
private:
std::vector<Instruction> insts_;
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/generate_npu_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/paddle_use_npu_bridges.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
lite::mir::Node* var_node, const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
VLOG(4) << "[NPU] Convert var node " << arg.name;
auto* var = scope->FindVar(arg.name);
CHECK(var);
auto* tensor = var->GetMutable<lite::Tensor>();
CHECK(tensor);
auto dims = tensor->dims();
if (arg.is_weight) {
auto wgt = std::make_shared<ge::op::Const>(arg.name);
LOG(INFO) << "[NPU] Convert const var node " << arg.name;
VLOG(4) << dims;
wgt->set_attr_value(lite::npu::CvtTensor(tensor));
return wgt;
} else {
CHECK_EQ(dims.size(), 4);
LOG(INFO) << "[NPU] Convert data var node " << arg.name;
LOG(INFO) << dims;
// TODO(xxx): support more types and dims size
ge::TensorDesc desc(ge::Shape(dims.Vectorize()),
ge::Format::FORMAT_NCHW,
ge::DataType::DT_FLOAT);
// auto size = desc.GetShape().GetShapeSize();
// ge::TensorUtils::SetSize(desc, size*sizeof(float));
// ge::TensorUtils::SetRealDimCnt(desc, 4);
auto data = std::make_shared<ge::op::Data>(arg.name);
data->update_input_desc_x(desc);
return data;
}
return nullptr;
}
void GenerateNPUProgramPass::CvtAllOpNodes(
const std::vector<Node*>& nodes2cvt,
lite::kernels::npu::bridges::node_map_type* converted_vars) {
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions();
// return record all converted vars
// op node's inputs must be found in converted_vars
for (auto& node : nodes2cvt) {
lite::kernels::npu::bridges::node_map_type node_inputs;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
if (arg.is_weight) {
continue;
}
auto var_name = arg.name;
if (!converted_vars->count(var_name)) {
converted_vars->insert(
std::make_pair(var_name, CvtVarNode(var_node, stmt.op()->scope())));
}
node_inputs.insert(*converted_vars->find(var_name));
}
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs);
converted_vars->insert(node_outputs.begin(), node_outputs.end());
}
}
std::string GenerateNPUProgramPass::BuildNPUGraph(
const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id) {
auto ordered_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::npu::bridges::node_map_type converted_vars;
CvtAllOpNodes(ordered_nodes, &converted_vars);
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs;
for (auto i : in_data_vars) {
auto argname = i->AsArg().name;
in_var_names.push_back(argname);
inputs.push_back(*converted_vars.at(argname));
}
for (auto i : out_data_vars) {
auto argname = i->AsArg().name;
out_var_names.push_back(argname);
outputs.push_back(*converted_vars.at(argname));
}
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling IR graph to NPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(FATAL) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
} else {
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
}
return weight_var_name;
}
void GenerateNPUProgramPass::GenNPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
std::unordered_set<Node*> out_unused_vars;
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto weight_var_name =
BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
in_wgt_vars,
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "[NPU] Before NPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
LOG(INFO) << "[NPU] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[NPU] Converting Subgraph " << id;
GenNPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[NPU] After NPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_npu_program_pass,
paddle::lite::mir::subgraph::GenerateNPUProgramPass)
.BindTargets({TARGET(kNPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/backends/npu/builder.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class GenerateNPUProgramPass : public SubgraphProgramPass {
public:
using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
protected:
// nodes2cvt: op nodes to convert
// return cvted_vars: converted var nodes
void CvtAllOpNodes(const std::vector<Node*>& nodes2cvt,
lite::kernels::npu::bridges::node_map_type* cvted_vars);
std::shared_ptr<ge::Operator> CvtVarNode(lite::mir::Node* var_node,
const Scope* scope);
std::string BuildNPUGraph(const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
void GenNPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors");
DEFINE_int32(output_tensor_num, 1, "number of output tensors");
namespace paddle {
namespace lite {
std::vector<std::vector<int64_t>> ParseShape(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
std::string dims = txt.substr(0, idx);
std::vector<int64_t> s;
while (!dims.empty()) {
size_t idx = dims.find_first_of(",");
int d = atoi(dims.substr(0, idx).c_str());
VLOG(3) << d;
s.push_back(d);
if (idx == std::string::npos) {
break;
} else {
dims = dims.substr(idx + 1);
}
}
shape.push_back(s);
if (idx == std::string::npos) {
break;
} else {
txt = txt.substr(idx + 1);
}
}
return shape;
}
int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
void FillInputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const float value) {
for (int i = 0; i < input_tensor_shape.size(); i++) {
auto input_tensor = predictor->GetInput(i);
input_tensor->Resize(input_tensor_shape[i]);
auto input_tensor_data = input_tensor->mutable_data<float>();
auto input_tensor_size = ShapeProduction(input_tensor->shape());
for (int j = 0; j < input_tensor_size; j++) {
input_tensor_data[i] = value;
}
}
}
void CompareOutputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
for (int i = 0; i < output_tensor_num; i++) {
auto tar_output_tensor = tar_predictor->GetOutput(i);
auto ref_output_tensor = ref_predictor->GetOutput(i);
auto tar_output_tensor_data = tar_output_tensor->data<float>();
auto ref_output_tensor_data = ref_output_tensor->data<float>();
auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape());
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto abs_diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]);
auto rel_diff = abs_diff / (std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << "val: " << tar_output_tensor_data[j]
<< " ref: " << ref_output_tensor_data[j]
<< " abs_diff: " << abs_diff << " rel_diff: " << rel_diff;
EXPECT_LT(rel_diff, 0.1);
}
}
}
std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::string& model_dir,
const std::string& model_file,
const std::string& params_file,
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::string& optimized_model_dir) {
// generate optimized model
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(model_dir);
cxx_config.set_model_file(model_file);
cxx_config.set_param_file(params_file);
cxx_config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
FillInputTensor(predictor, input_tensor_shape, 1);
predictor->SaveOptimizedModel(optimized_model_dir,
lite_api::LiteModelType::kNaiveBuffer);
// load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensor(predictor, input_tensor_shape, 1);
// run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
}
for (int i = 0; i < FLAGS_repeats; i++) {
auto start = GetCurrentUS();
predictor->Run();
LOG(INFO) << i << ", " << GetCurrentUS() - start << "us";
}
return predictor;
}
TEST(NPUSubgraph, compare) {
// parsing input tensor shape, supported formats: "1,3,224,224"
// "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ParseShape(FLAGS_input_tensor_shape);
// generate and run optimized CPU model
LOG(INFO) << " ================ CPU ================== ";
auto cpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kARM), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/CPU");
// generate and run optimized NPU model
LOG(INFO) << " ================ NPU ================== ";
auto npu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kNPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kARM), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/NPU");
// verify results
CompareOutputTensor(npu_predictor, cpu_predictor, FLAGS_output_tensor_num);
}
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/generate_xpu_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/backends/xpu/builder.h"
#include "lite/kernels/xpu/bridges/paddle_use_xpu_bridges.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<xtcl::xExpr> GenerateXPUProgramPass::CvtVarNode(
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::mir::Node* var_node,
const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
auto var_name = arg.name;
VLOG(4) << "[XPU] Convert var node " << var_name;
auto* var = scope->FindVar(var_name);
CHECK(var);
auto* tensor = var->GetMutable<lite::Tensor>();
CHECK(tensor);
auto dims = tensor->dims();
auto cvted_var_node =
std::make_shared<xtcl::xExpr>(graph_ctx->builder->CreateTensor(
var_name, lite::xpu::CvtShape(dims), ::xtcl::Float(32)));
if (arg.is_weight) {
auto cvted_var_tensor = lite::xpu::CvtTensor(tensor);
graph_ctx->params->emplace(std::make_pair(var_name, *cvted_var_tensor));
}
return cvted_var_node;
}
void GenerateXPUProgramPass::CvtAllOpNodes(
const std::vector<Node*>& op_nodes,
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes) {
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
// return record all converted vars
// op node's inputs must be found in converted_vars
for (auto& node : op_nodes) {
lite::kernels::xpu::bridges::node_map_type input_nodes;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
if (arg.is_weight) {
continue;
}
auto var_name = arg.name;
if (!cvted_var_nodes->count(var_name)) {
cvted_var_nodes->insert(std::make_pair(
var_name, CvtVarNode(graph_ctx, var_node, stmt.op()->scope())));
}
input_nodes.insert(*cvted_var_nodes->find(var_name));
}
auto output_nodes =
supported_lists.at(stmt.op_type())(stmt.op(), graph_ctx, input_nodes);
cvted_var_nodes->insert(output_nodes.begin(), output_nodes.end());
}
}
std::string GenerateXPUProgramPass::BuildXPUGraph(
const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id) {
auto ordered_op_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::xpu::bridges::graph_ctx_type graph_ctx;
graph_ctx.builder = std::make_shared<xtcl::network::xNetworkBuilder>();
graph_ctx.params =
std::make_shared<xtcl::network::xTensorCompiler::ParamNDArrayMap>();
lite::kernels::xpu::bridges::node_map_type cvted_var_nodes;
CvtAllOpNodes(ordered_op_nodes, &graph_ctx, &cvted_var_nodes);
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling graph to XPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
std::vector<std::shared_ptr<xtcl::xExpr>> ordered_cvted_var_nodes;
for (auto out_data_var : out_data_vars) {
auto var_name = out_data_var->AsArg().name;
ordered_cvted_var_nodes.push_back(cvted_var_nodes[var_name]);
}
if (!lite::xpu::BuildModel(graph_ctx.builder,
graph_ctx.params,
&ordered_cvted_var_nodes,
weight)) {
LOG(FATAL) << "[XPU] Build XPU graph failed (subgraph=" << sub_id << ")";
} else {
LOG(INFO) << "[XPU] Build XPU graph success (subgraph=" << sub_id << ")";
}
return weight_var_name;
}
void GenerateXPUProgramPass::GenXPUSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
std::unordered_set<Node*> out_unused_vars;
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto weight_var_name =
BuildXPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
in_wgt_vars,
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
void GenerateXPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
LOG(INFO) << "[XPU] Before XPU Pass \n" << Visualize(graph.get());
const auto& bridges = lite::kernels::xpu::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
LOG(INFO) << "[XPU] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
LOG(INFO) << "[XPU] Converting Subgraph " << id;
GenXPUSubgraph(graph, op_nodes.second, id);
LOG(INFO) << "[XPU] After XPU Pass Subgraph " << id << "\n"
<< Visualize(graph.get());
id++;
}
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(generate_xpu_program_pass,
paddle::lite::mir::subgraph::GenerateXPUProgramPass)
.BindTargets({TARGET(kXPU)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/backends/xpu/builder.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/xpu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class GenerateXPUProgramPass : public SubgraphProgramPass {
public:
using key2nodes_t = std::map<std::string, Node*>;
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
protected:
// nodes2cvt: op nodes to convert
// return cvted_vars: converted var nodes
void CvtAllOpNodes(
const std::vector<Node*>& op_nodes,
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::kernels::xpu::bridges::node_map_type* cvted_var_nodes);
std::shared_ptr<xtcl::xExpr> CvtVarNode(
lite::kernels::xpu::bridges::graph_ctx_type* graph_ctx,
lite::mir::Node* var_node,
const Scope* scope);
std::string BuildXPUGraph(const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
void GenXPUSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <cmath>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/utils/cp_logging.h"
DEFINE_string(model_file, "", "model file path of combined protobuf model");
DEFINE_string(params_file, "", "params file path of combined protobuf model");
DEFINE_string(optimized_model_dir, "", "path of optimized naive buffer model");
DEFINE_string(input_tensor_shape, "1,3,224,224", "shapes of input tensors");
DEFINE_int32(output_tensor_num, 1, "number of output tensors");
namespace paddle {
namespace lite {
std::vector<std::vector<int64_t>> ParseShape(std::string txt) {
std::vector<std::vector<int64_t>> shape;
while (!txt.empty()) {
size_t idx = txt.find_first_of(":");
std::string dims = txt.substr(0, idx);
std::vector<int64_t> s;
while (!dims.empty()) {
size_t idx = dims.find_first_of(",");
int d = atoi(dims.substr(0, idx).c_str());
VLOG(3) << d;
s.push_back(d);
if (idx == std::string::npos) {
break;
} else {
dims = dims.substr(idx + 1);
}
}
shape.push_back(s);
if (idx == std::string::npos) {
break;
} else {
txt = txt.substr(idx + 1);
}
}
return shape;
}
int64_t ShapeProduction(std::vector<int64_t> shape) {
int64_t s = 1;
for (int64_t dim : shape) {
s *= dim;
}
return s;
}
void FillInputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& predictor,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const float value) {
for (int i = 0; i < input_tensor_shape.size(); i++) {
auto input_tensor = predictor->GetInput(i);
input_tensor->Resize(input_tensor_shape[i]);
auto input_tensor_data = input_tensor->mutable_data<float>();
auto input_tensor_size = ShapeProduction(input_tensor->shape());
for (int j = 0; j < input_tensor_size; j++) {
input_tensor_data[j] = value;
}
}
}
void CompareOutputTensor(
const std::shared_ptr<lite_api::PaddlePredictor>& tar_predictor,
const std::shared_ptr<lite_api::PaddlePredictor>& ref_predictor,
const int output_tensor_num) {
for (int i = 0; i < output_tensor_num; i++) {
auto tar_output_tensor = tar_predictor->GetOutput(i);
auto ref_output_tensor = ref_predictor->GetOutput(i);
auto tar_output_tensor_data = tar_output_tensor->data<float>();
auto ref_output_tensor_data = ref_output_tensor->data<float>();
auto tar_output_tensor_size = ShapeProduction(tar_output_tensor->shape());
auto ref_output_tensor_size = ShapeProduction(ref_output_tensor->shape());
EXPECT_EQ(tar_output_tensor_size, ref_output_tensor_size);
for (size_t j = 0; j < ref_output_tensor_size; j++) {
auto diff =
std::fabs(tar_output_tensor_data[j] - ref_output_tensor_data[j]) /
(std::fabs(ref_output_tensor_data[j]) + 1e-6);
VLOG(3) << diff;
EXPECT_LT(diff, 0.1);
}
}
}
std::shared_ptr<lite_api::PaddlePredictor> TestModel(
const std::string& model_dir,
const std::string& model_file,
const std::string& params_file,
const std::vector<lite_api::Place>& valid_places,
const std::vector<std::vector<int64_t>>& input_tensor_shape,
const std::string& optimized_model_dir) {
// generate optimized model
lite_api::CxxConfig cxx_config;
cxx_config.set_model_dir(model_dir);
cxx_config.set_model_file(model_file);
cxx_config.set_param_file(params_file);
cxx_config.set_valid_places(valid_places);
auto predictor = lite_api::CreatePaddlePredictor(cxx_config);
FillInputTensor(predictor, input_tensor_shape, -1);
predictor->SaveOptimizedModel(optimized_model_dir,
lite_api::LiteModelType::kNaiveBuffer);
#if 0 // TODO(hong19860320) supports light api for XPU
// load optimized model
lite_api::MobileConfig mobile_config;
mobile_config.set_model_dir(optimized_model_dir);
mobile_config.set_power_mode(lite_api::PowerMode::LITE_POWER_HIGH);
mobile_config.set_threads(1);
predictor = lite_api::CreatePaddlePredictor(mobile_config);
FillInputTensor(predictor, input_tensor_shape, 1);
#endif
// run optimized model
for (int i = 0; i < FLAGS_warmup; i++) {
predictor->Run();
}
for (int i = 0; i < FLAGS_repeats; i++) {
auto start = GetCurrentUS();
predictor->Run();
LOG(INFO) << i << ", " << GetCurrentUS() - start << "us";
}
return predictor;
}
TEST(XPUSubgraph, compare) {
// parsing input tensor shape, supported formats: "1,3,224,224"
// "1,3,224,224:1,80"
std::vector<std::vector<int64_t>> input_tensor_shape =
ParseShape(FLAGS_input_tensor_shape);
// generate and run optimized CPU model
LOG(INFO) << " ================ CPU ================== ";
auto cpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kX86), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/CPU");
// generate and run optimized XPU model
LOG(INFO) << " ================ XPU ================== ";
auto xpu_predictor =
TestModel(FLAGS_model_dir,
FLAGS_model_file,
FLAGS_params_file,
{lite_api::Place{TARGET(kXPU), PRECISION(kFloat)},
lite_api::Place{TARGET(kX86), PRECISION(kFloat)}},
input_tensor_shape,
FLAGS_optimized_model_dir + "/XPU");
// verify results
CompareOutputTensor(xpu_predictor, cpu_predictor, FLAGS_output_tensor_num);
}
} // namespace lite
} // namespace paddle
......@@ -94,7 +94,7 @@ std::string SubgraphVisualizer::operator()() {
}
auto res = dot.Build();
std::cout << "subgraphs: " << subgraphs_.size() << "\n" << res << std::endl;
//std::cout << "subgraphs: " << subgraphs_.size() << "\n" << res << std::endl;
return res;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include <memory>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::unordered_map<int, std::unordered_set<Node*>>
SubgraphProgramPass::ClassifySubgraph(const std::unique_ptr<SSAGraph>& graph) {
std::unordered_map<int, std::unordered_set<Node*>> op_nodes;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
if (!op_nodes.count(sub_id)) {
op_nodes[sub_id] = std::unordered_set<Node*>();
}
op_nodes.at(sub_id).insert(item);
}
return op_nodes;
}
cpp::OpDesc SubgraphProgramPass::GenGraphOpDesc(
const std::string& weight_var_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names) {
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
op_desc.SetInput("Inputs", in_var_names);
op_desc.SetInput("Weight", {weight_var_name});
op_desc.SetOutput("Outputs", out_var_names);
return op_desc;
}
void SubgraphProgramPass::InsertNewNode(
const std::unique_ptr<SSAGraph>& graph,
const std::string& weight_var_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars) {
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
for (auto i : in_data_vars) {
in_var_names.push_back(i->AsArg().name);
}
for (auto i : out_data_vars) {
out_var_names.push_back(i->AsArg().name);
}
auto op_desc = GenGraphOpDesc(weight_var_name, in_var_names, out_var_names);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
graph_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& in_var : in_data_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& in_var : in_wgt_vars) {
IR_NODE_LINK_TO(in_var, new_op_node);
}
for (auto& out_var : out_data_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
for (auto& out_var : out_unused_vars) {
IR_OP_VAR_LINK(new_op_node, out_var);
}
// add weight node to store pre-compilied NPU model
auto new_weight_node = graph->NewArgumentNode(weight_var_name);
new_weight_node->AsArg().is_weight = true;
new_weight_node->AsArg().is_persist = true;
DirectedLink(new_weight_node, new_op_node);
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
}
void SubgraphProgramPass::SortHelper(
Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret) {
for (auto& var_node : node->inlinks) {
if (var_node->inlinks.empty()) continue;
auto* op_node = var_node->inlinks.front();
if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) {
SortHelper(op_node, nodes_all, visited_nodes, ret);
}
}
ret->push_back(node);
visited_nodes->insert(node);
}
std::vector<Node*> SubgraphProgramPass::GetTopologicalOrder(
const std::unordered_set<Node*>& nodes) {
std::unordered_set<const Node*> visited;
std::vector<Node*> ret;
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
if (visited.count(node)) continue;
SortHelper(node, nodes, &visited, &ret);
}
return ret;
}
void SubgraphProgramPass::FindInputOutputVars(
const std::unordered_set<Node*>& op_nodes,
std::unordered_set<Node*>* in_data_vars,
std::unordered_set<Node*>* in_wgt_vars,
std::unordered_set<Node*>* out_data_vars,
std::unordered_set<Node*>* out_unused_vars) {
for (auto& op_node : op_nodes) {
for (auto& in_var : op_node->inlinks) {
if (in_var->AsArg().is_weight) {
in_wgt_vars->insert(in_var);
continue;
}
if (!in_var->inlinks.empty()) {
// var can only come from one op node, so use front
auto* pre_op_node = in_var->inlinks.front();
if (op_nodes.count(pre_op_node)) {
continue;
}
}
in_data_vars->insert(in_var);
}
for (auto& out_var : op_node->outlinks) {
if (out_var->outlinks.empty()) {
// the next op is empty so this var is actually unused
out_unused_vars->insert(out_var);
continue;
}
// var can have more than one next op node
// so, if any one in the op_nodes then continue
bool next_op_in_nodes = false;
for (auto& next_op_node : out_var->outlinks) {
if (op_nodes.count(next_op_node)) {
next_op_in_nodes = true;
}
}
if (next_op_in_nodes) {
continue;
}
out_data_vars->insert(out_var);
}
}
}
std::unordered_set<const Node*> SubgraphProgramPass::GetNode2rm(
const std::unordered_set<Node*>& op_nodes,
const std::vector<std::unordered_set<Node*>>& excluded_nodes) {
std::unordered_set<const Node*> nodes2rm(op_nodes.begin(), op_nodes.end());
for (auto& op_node : op_nodes) {
for (auto& in_var : op_node->inlinks) {
if (!nodes2rm.count(in_var)) {
nodes2rm.insert(in_var);
}
}
for (auto& out_var : op_node->outlinks) {
if (!nodes2rm.count(out_var)) {
nodes2rm.insert(out_var);
}
}
}
// some nodes should not be removed
for (auto& e : excluded_nodes) {
for (auto& i : e) {
if (nodes2rm.count(i)) {
nodes2rm.erase(i);
}
}
}
return nodes2rm;
}
void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) {
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
auto& op = stmt.op();
auto scope = op->scope();
std::string op_type = op->op_info()->Type();
// check the dimension of input variables in the scope, must not be empty !
if (op_type == "feed") {
auto input_var_names = op->op_info()->output_names();
CHECK_GE(input_var_names.size(), 1);
for (auto input_var_name : input_var_names) {
auto input_var = scope->FindVar(input_var_name);
CHECK(input_var) << "No input variable '" << input_var_name
<< "' found in scope " << scope;
auto input = input_var->GetMutable<lite::Tensor>();
CHECK(!input->dims().empty()) << "The dimension of input variable '"
<< input_var_name
<< "' can not be empty.";
}
continue;
}
if (op_type == "fetch") {
continue;
}
op->CheckShape();
op->InferShape();
#ifndef LITH_WITH_XPU
// TOOD(xxx): remove Launch() at last
auto& kkks = stmt.kernels();
if (!kkks.empty()) {
auto& kk = stmt.kernels().front();
if (kk) {
kk->Launch();
}
}
#endif
}
}
void SubgraphProgramPass::InitSubgraphID(
const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types) {
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
stmt.ClearSubgraphID();
if (std::find(supported_op_types.begin(),
supported_op_types.end(),
stmt.op_type()) != supported_op_types.end()) {
stmt.SetSubgraphID(0);
LOG(INFO) << "supported " << stmt.op_type();
} else {
LOG(INFO) << "======= not supported " << stmt.op_type();
}
}
}
// mark current and all output supported nodes
void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node,
int to_id,
int from_id) {
if (!node) return;
if (node->IsStmt()) {
auto& stmt = node->AsStmt();
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
for (auto& i : node->outlinks) {
ChangeAllOutConnectedID(i, to_id, from_id);
}
} else {
LOG(INFO) << "failed op type:" << stmt.op_type();
return;
}
} else {
// this it arg node
bool all_out_op_supported = true;
for (auto& i : node->outlinks) {
if (!i->IsStmt()) return;
auto& stmt = i->AsStmt();
if (stmt.subgraph_id() < from_id) {
all_out_op_supported = false;
}
}
if (!all_out_op_supported) {
return;
}
for (auto& i : node->outlinks) {
CHECK(i->IsStmt());
auto& stmt = i->AsStmt();
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
for (auto& o : i->outlinks) {
ChangeAllOutConnectedID(o, to_id, from_id);
}
}
}
}
}
int SubgraphProgramPass::FuseSubgraphID(
const std::unique_ptr<SSAGraph>& graph) {
int sub_id = 1; // id start from 1 not 0
for (auto& item : graph->StmtTopologicalOrder()) {
// bool inputvar = false;
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
/*
if (stmt.subgraph_id() == -1) {
for (auto& i : item->outlinks) {
for (auto& j : i->outlinks) {
if (j->IsStmt()) {
auto& jstmt = j->AsStmt();
if (jstmt.subgraph_id() == 0) inputvar = true;
}
}
}
}
*/
if (stmt.subgraph_id() != 0) continue;
ChangeAllOutConnectedID(item, sub_id);
sub_id++;
}
return sub_id - 1;
}
int SubgraphProgramPass::FuseSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types) {
InitSubgraphID(graph, supported_op_types);
return FuseSubgraphID(graph);
}
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(subgraph_program_pass,
paddle::lite::mir::subgraph::SubgraphProgramPass)
.BindTargets({TARGET(kAny)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
class SubgraphProgramPass : public ProgramPass {
public:
using key2nodes_t = std::map<std::string, Node*>;
// make all the linked ops in subgraph with same subgraph_id
// return the fused subgraph numbers
int FuseSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types);
void Apply(const std::unique_ptr<SSAGraph>& graph) override{};
protected:
void InferOnce(const std::unique_ptr<SSAGraph>& graph);
// clear all subgraph id and mark all ops, which could be fuse, as id zero
void InitSubgraphID(const std::unique_ptr<SSAGraph>& graph,
const std::vector<std::string>& supported_op_types);
// make all the linked ops in subgraph with same subgraph_id
// return the fused subgraph numbers
int FuseSubgraphID(const std::unique_ptr<SSAGraph>& graph);
// // GenerateFusedGraph:
// std::unique_ptr<SSAGraph> GenerateFusedGraph(const
// std::unique_ptr<SSAGraph>& graph, int sub_num);
void ChangeAllOutConnectedID(Node* node, int to_id, int from_id = 0);
// Below function cloud be useful in child classes //
// classify node by subgraph id
std::unordered_map<int, std::unordered_set<Node*>> ClassifySubgraph(
const std::unique_ptr<SSAGraph>& graph);
// generate the graph op desc
cpp::OpDesc GenGraphOpDesc(const std::string& weight_var_name,
const std::vector<std::string>& in_var_names,
const std::vector<std::string>& out_var_names);
// insert a new graph op node
void InsertNewNode(const std::unique_ptr<SSAGraph>& graph,
const std::string& weight_var_name,
Scope* scope,
const std::vector<Place>& valid_places,
std::unordered_set<Node*> in_data_vars,
std::unordered_set<Node*> in_wgt_vars,
std::unordered_set<Node*> out_data_vars,
std::unordered_set<Node*> out_unused_vars);
// Sort and return the topology order of nodes set
std::vector<Node*> GetTopologicalOrder(
const std::unordered_set<Node*>& nodes);
// find all input data vars, input weight vars,
// output data vars and output vars from the nodes
void FindInputOutputVars(const std::unordered_set<Node*>& op_nodes,
std::unordered_set<Node*>* in_data_vars,
std::unordered_set<Node*>* in_wgt_vars,
std::unordered_set<Node*>* out_data_vars,
std::unordered_set<Node*>* out_unused_vars);
// return the node to remove in the subgraph
std::unordered_set<const Node*> GetNode2rm(
const std::unordered_set<Node*>& op_nodes,
const std::vector<std::unordered_set<Node*>>& excluded_nodes);
private:
// sort nodes to operational sequence
void SortHelper(Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret);
};
} // namespace subgraph
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <vector>
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/ssa_graph.h"
#include "lite/core/program.h"
#include "lite/model_parser/cpp/program_desc.h"
#include "lite/model_parser/model_parser.h"
DEFINE_string(model_dir, "", "model_dir");
namespace paddle {
namespace lite {
TEST(SubgraphTest, models) {
cpp::ProgramDesc program_desc;
auto scope = std::make_shared<Scope>();
// LoadModelPb(FLAGS_model_dir,
// FLAGS_model_dir + "/model",
// FLAGS_model_dir + "/params",
// scope.get(),
// &program_desc,
// true);
LoadModelPb(FLAGS_model_dir, "", "", scope.get(), &program_desc);
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
#ifdef LITE_WITH_ARM
Place{TARGET(kARM), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_NPU
Place{TARGET(kNPU), PRECISION(kFloat)},
#endif
#ifdef LITE_WITH_XPU
Place{TARGET(kXPU), PRECISION(kFloat)},
#endif
});
lite::Program program(program_desc, scope, valid_places);
auto graph = std::unique_ptr<mir::SSAGraph>(new mir::SSAGraph());
graph->Build(program, valid_places);
std::vector<std::string> supported_op_types{"concat",
"conv2d",
"depthwise_conv2d",
"batch_norm",
"scale",
"pool2d",
"mul",
"elementwise_add",
"softmax",
"split",
"relu",
"reshape2",
"transpose2"};
auto* pass = new mir::subgraph::SubgraphProgramPass;
ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1);
LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
}
// return output_var_names
std::vector<std::string> AddFCDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_var_names,
const std::vector<int64_t>& wshape) {
CHECK_EQ(input_var_names.size(), 1);
CHECK_EQ(wshape.size(), 2);
static int id = 0;
std::string prefix = "fc_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* wgt = block_desc->AddVar<cpp::VarDesc>();
auto* bias = block_desc->AddVar<cpp::VarDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
wgt->SetName(prefix + "_W");
bias->SetName(prefix + "_Bias");
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
auto* wtensor = scope->Var(prefix + "_W")->GetMutable<lite::Tensor>();
wtensor->Resize(wshape);
wtensor->mutable_data<float>();
auto* btensor = scope->Var(prefix + "_Bias")->GetMutable<lite::Tensor>();
btensor->Resize({wshape[1]});
btensor->mutable_data<float>();
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("fc");
op_desc->SetInput("Input", input_var_names);
op_desc->SetInput("W", {prefix + "_W"});
op_desc->SetInput("Bias", {prefix + "_Bias"});
op_desc->SetAttr<int>("in_num_col_dims", 1);
op_desc->SetOutput("Out", out_var_names);
id++;
return out_var_names;
}
std::vector<std::string> AddElementwiseAddDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_X_names,
const std::vector<std::string>& input_Y_names) {
// CHECK_EQ(input_var_names.size(), 2);
static int id = 0;
std::string prefix = "elementwise_add_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("elementwise_add");
op_desc->SetInput("X", input_X_names);
op_desc->SetInput("Y", input_Y_names);
op_desc->SetOutput("Out", out_var_names);
op_desc->SetAttr("axis", -1);
id++;
return out_var_names;
}
std::vector<std::string> AddFeedDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_X_names) {
// CHECK_EQ(input_var_names.size(), 1);
static int id = 0;
std::string prefix = "feed_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("feed");
op_desc->SetInput("X", input_X_names);
op_desc->SetOutput("Out", out_var_names);
op_desc->SetAttr("col", 1);
id++;
return out_var_names;
}
std::vector<std::string> AddFetchDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_X_names) {
// CHECK_EQ(input_var_names.size(), 1);
static int id = 0;
std::string prefix = "fetch_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("fetch");
op_desc->SetInput("X", input_X_names);
op_desc->SetOutput("Out", out_var_names);
op_desc->SetAttr("col", 1);
id++;
return out_var_names;
}
std::unique_ptr<mir::SSAGraph> BuildSimpleNet(
cpp::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<Place>& valid_places) {
program_desc->ClearBlocks();
auto* block_desc = program_desc->AddBlock<cpp::BlockDesc>();
block_desc->ClearOps();
block_desc->ClearVars();
auto* var_desc = block_desc->AddVar<cpp::VarDesc>();
var_desc->SetName("feed_var");
auto* feed_var = scope->Var("feed_var")->GetMutable<lite::Tensor>();
feed_var->Resize({1, 4});
auto fc1_out = AddFCDesc(block_desc, scope, {"feed_var"}, {4, 5});
auto fc2_out = AddFCDesc(block_desc, scope, fc1_out, {5, 2});
lite::Program program(*program_desc, scope, valid_places);
auto graph = std::unique_ptr<mir::SSAGraph>(new mir::SSAGraph());
graph->Build(program, valid_places);
return graph;
}
TEST(SubGraphTest, SimpleNet) {
cpp::ProgramDesc program_desc;
std::vector<Place> places{{TARGET(kHost), PRECISION(kFloat)}};
auto scope = std::make_shared<Scope>();
auto graph = BuildSimpleNet(&program_desc, scope, places);
std::vector<std::string> supported_op_types{"fc"};
auto* pass = new mir::subgraph::SubgraphProgramPass;
ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1);
ASSERT_EQ(graph->nodes().size(), 9);
// LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
}
} // namespace lite
} // namespace paddle
......@@ -2,6 +2,5 @@ if(NOT LITE_WITH_BM)
return ()
endif()
add_kernel(subgraph_compute_bm BM basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} device_bm subgraph_bridge_engine ${bm_subgraph_bridges})
add_subdirectory(bridges)
add_kernel(subgraph_compute_bm BM basic SRCS subgraph_compute.cc DEPS ${lite_kernel_deps} ${bm_subgraph_bridges})
......@@ -2,11 +2,10 @@ if(NOT LITE_WITH_BM)
return()
endif()
lite_cc_library(subgraph_bridge_utility_bm SRCS utility.cc DEPS)
lite_cc_library(subgraph_bridge_graph_bm SRCS graph.cc DEPS subgraph_bridge_utility_bm)
set(bm_subgraph_bridge_deps subgraph_bridge_registry subgraph_bridge_utility_bm subgraph_bridge_graph_bm)
set(bm_subgraph_bridge_deps subgraph_bridge_registry subgraph_bridge_engine subgraph_bridge_utility_bm subgraph_bridge_graph_bm)
lite_cc_library(subgraph_bridge_act_op_bm SRCS act_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_conv_op_bm SRCS conv_op.cc DEPS ${bm_subgraph_bridge_deps})
......@@ -17,9 +16,9 @@ lite_cc_library(subgraph_bridge_mul_op_bm SRCS mul_op.cc DEPS ${bm_subgraph_brid
lite_cc_library(subgraph_bridge_batch_norm_op_bm SRCS batch_norm_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_scale_op_bm SRCS scale_op.cc DEPS ${bm_subgraph_bridge_deps})
set(bm_subgraph_bridges
subgraph_bridge_registry
subgraph_bridge_engine
subgraph_bridge_graph_bm
subgraph_bridge_act_op_bm
subgraph_bridge_conv_op_bm
......
......@@ -25,8 +25,8 @@ int ActConverter(void* ctx, OpLite* op, KernelBase* kernel){
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = act_op->scope();
auto op_info = act_op->op_info();
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto x_var_name = op_info->Input("X").front();
......
......@@ -32,7 +32,6 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto graph = static_cast<Graph*>(ctx);
// input
const int input_num = 2;
......@@ -107,7 +106,7 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} else {
const float* y_data = const_cast<const float*>(y->mutable_data<float>());
const float* x_data = const_cast<const float*>(x->mutable_data<float>());
bm_add_const_tensor(graph_ctx->bm_compiler_handle,
bm_add_const_tensor(graph->GetCompilerHandle(),
name[1],
shape[0],
dim[0],
......
......@@ -21,7 +21,7 @@ namespace subgraph {
namespace bm {
void Graph::AddNode(const std::string& name) {
nodes_.insert(std::make_pair(name, name);
nodes_.insert(std::make_pair(name, name));
}
void Graph::CreateCompilerHandle() {
......
......@@ -25,6 +25,7 @@ namespace bm {
int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
......
......@@ -22,7 +22,7 @@ namespace lite {
namespace subgraph {
namespace bm {
int MulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int PoolConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
......
......@@ -15,6 +15,7 @@
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_op_code.h"
#include "bmcompiler_if.h"
namespace paddle {
......
......@@ -25,11 +25,11 @@ namespace lite {
namespace subgraph {
namespace bm {
std::string UniqueName(const std::string& prefix) {};
std::string UniqueName(const std::string& prefix);
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname) {};
const std::string& argname);
} // namespace bm
} // namespace subgraph
......
......@@ -22,13 +22,16 @@
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/paddle_use_bridges.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace bm {
void SubgraphCompute::PrepareForRun() {
int SubgraphEngine::BuildDeviceProgram() {
int status = 0;
subgraph::bm::Graph graph;
const auto& bridges = subgraph::Registry::Instance();
graph.CreateCompilerHandle();
......@@ -40,23 +43,42 @@ void SubgraphCompute::PrepareForRun() {
op->InferShape();
std::string op_type = op->op_info()->Type();
if (!bridges.Exists("BM", op_type)) {
LOG(FATAL) << "[BM] not support op:" << op_type;
return subgraph::FAILED;
}
auto kernel = inst.kernel();
status |= bridges.Select("BM", op_type)(reinterpret_cast<void*>(&graph),
const_cast<OpLite*>(op),
const_cast<KernelBase*>(kernel));
if (subgraph::CHECK_FAILED(status)) {
LOG(FATAL) << "[BM] subgraph CHECK_FAILED";
return subgraph::FAILED;
}
}
std::string net_name = "paddle_bitmain";
__bmcompile_opt(graph.GetCompilerHandle(), const_cast<char*>(net_name.c_str()), 2);
finish_bmcompiler(graph.GetCompilerHandle());
return status;
}
int SubgraphEngine::LaunchDeviceProgram() {
return 0;
}
void SubgraphCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
engine_.reset(new SubgraphEngine(ctx_.get(),
param.sub_block_idx,
param.sub_block_desc,
param.input_data_names,
param.output_data_names,
param.scope));
CHECK(engine_);
engine_->Build();
}
void SubgraphCompute::Run() {
CHECK(engine_);
engine_->Launch();
}
} // namespace bm
......
......@@ -20,6 +20,8 @@
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
#include "lite/core/program.h"
#include "lite/kernels/npu/bridges/engine.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
......@@ -27,6 +29,22 @@ namespace lite {
namespace kernels {
namespace bm {
class SubgraphEngine : public subgraph::Engine {
public:
SubgraphEngine(KernelContext *ctx,
int block_idx,
cpp::BlockDesc *block_desc,
const std::vector<std::string> &input_names,
const std::vector<std::string> &output_names,
Scope *scope)
: subgraph::Engine(
ctx, block_idx, block_desc, input_names, output_names, scope) {}
protected:
int BuildDeviceProgram() override;
int LaunchDeviceProgram() override;
};
class SubgraphCompute : public KernelLite<TARGET(kBM), PRECISION(kFloat)> {
public:
using param_t = operators::SubgraphParam;
......@@ -39,6 +57,7 @@ class SubgraphCompute : public KernelLite<TARGET(kBM), PRECISION(kFloat)> {
private:
void* bm_compiler_ht_;
std::unique_ptr<SubgraphEngine> engine_;
};
} // namespace bm
......
if(NOT LITE_WITH_NPU AND NOT LITE_WITH_XPU)
if(NOT LITE_WITH_NPU AND NOT LITE_WITH_XPU AND NOT LITE_WITH_BM)
return()
endif()
lite_cc_library(subgraph_bridge_registry
SRCS registry.cc
DEPS op)
lite_cc_library(subgraph_bridge_engine
SRCS engine.cc
DEPS tensor op scope program)
......
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA AND NOT LITE_WITH_BM) AND (LITE_WITH_X86 OR LITE_WITH_ARM))
lite_cc_test(test_kernel_scale_compute SRCS scale_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_power_compute SRCS power_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shuffle_channel_compute SRCS shuffle_channel_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_yolo_box_compute SRCS yolo_box_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fc_compute SRCS fc_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_elementwise_compute SRCS elementwise_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_fc_compute SRCS fc_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_elementwise_compute SRCS elementwise_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_lrn_compute SRCS lrn_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_decode_bboxes_compute SRCS decode_bboxes_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_box_coder_compute SRCS box_coder_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_activation_compute SRCS activation_compute_test.cc DEPS arena_framework ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_activation_compute SRCS activation_compute_test.cc DEPS arena_framework ${bm_kernels} ${npu_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_argmax_compute SRCS argmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_axpy_compute SRCS axpy_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_conv2d_transpose_compute SRCS conv2d_transpose_compute_test.cc DEPS arena_framework ${bm_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_norm_compute SRCS norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_cast_compute SRCS cast_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_instance_norm_compute SRCS instance_norm_compute_test.cc DEPS arena_framework ${bm_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_grid_sampler_compute SRCS grid_sampler_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_sequence_softmax_compute SRCS sequence_softmax_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_im2sequence_compute SRCS im2sequence_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -25,12 +25,12 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
#lite_cc_test(test_kernel_write_to_array_compute SRCS write_to_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
#lite_cc_test(test_kernel_read_from_array_compute SRCS read_from_array_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_concat_compute SRCS concat_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_transpose_compute SRCS transpose_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reshape_compute SRCS reshape_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_layer_norm_compute SRCS layer_norm_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_dropout_compute SRCS dropout_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_softmax_compute SRCS softmax_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_mul_compute SRCS mul_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
if(LITE_BUILD_EXTRA)
lite_cc_test(test_gru_unit SRCS gru_unit_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
......@@ -53,16 +53,16 @@ if(LITE_BUILD_EXTRA)
lite_cc_test(test_kernel_lookup_table_compute SRCS lookup_table_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_gather_compute SRCS gather_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shape_compute SRCS shape_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_pad2d_compute SRCS pad2d_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_prior_box_compute SRCS prior_box_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_negative_compute SRCS negative_compute_test.cc DEPS arena_framework ${xpu_kernels} ${bm_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_bilinear_interp_compute SRCS bilinear_interp_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_nearest_interp_compute SRCS nearest_interp_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_shape_compute SRCS shape_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_crop_compute SRCS crop_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_slice_compute SRCS slice_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${bm_kernels} ${xpu_kernels} ${npu_kernels} ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
......@@ -70,9 +70,9 @@ function build_bm {
${CMAKE_COMMON_OPTIONS} \
-DWITH_GPU=OFF \
-DWITH_MKLDNN=OFF \
-DLITE_WITH_X86=OFF \
-DWITH_MKL=OFF \
-DLITE_BUILD_EXTRA=OFF \
-DLITE_WITH_X86=ON \
-DWITH_MKL=ON \
-DLITE_BUILD_EXTRA=ON \
-DLITE_WITH_XPU=OFF \
-DLITE_WITH_BM=ON \
-DWITH_TESTING=${WITH_TESTING} \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册