提交 1e160622 编写于 作者: C cen.li

* resnet50 run success

* test=develop
上级 c28971a9
......@@ -34,6 +34,7 @@ include_directories("${BM_SDK_ROOT}/include/bmruntime")
include_directories("${BM_SDK_ROOT}/include/bmlib")
include_directories("${BM_SDK_ROOT}/include/bmcompiler")
include_directories("${BM_SDK_ROOT}/include/bmcpu")
include_directories("${BM_SDK_ROOT}/include/bmlog")
find_library(BM_SDK_RT_LIB NAMES bmrt
PATHS ${BM_SDK_ROOT}/lib/bmnn/pcie)
......
......@@ -86,9 +86,9 @@ if (NOT LITE_ON_TINY_PUBLISH)
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels} ${npu_bridges} npu_pass
XPU_DEPS ${xpu_kernels} ${xpu_bridges} xpu_pass
BM_DEPS ${bm_kernels} ${bm_bridges} bm_pass
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
BM_DEPS ${bm_kernels})
endif()
# for light api
......@@ -107,7 +107,7 @@ lite_cc_library(light_api SRCS light_api.cc
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels})
FPGA_DEPS ${fpga_kernels}
BM_DEPS ${bm_kernels})
include(ExternalProject)
......@@ -162,7 +162,7 @@ if(WITH_TESTING)
add_dependencies(test_step_rnn_lite_x86 extern_lite_download_step_rnn_tar_gz)
lite_cc_test(test_resnet50_lite_bm SRCS test_resnet50_lite_bm.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${bm_kernels}
${ops} ${host_kernels} ${bm_kernels} ${bm_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
endif()
endif()
......
......@@ -141,7 +141,7 @@ std::vector<std::string> Predictor::GetOutputNames() { return output_names_; }
void Predictor::PrepareFeedFetch() {
std::vector<const cpp::OpDesc *> feeds;
std::vector<const cpp::OpDesc *> fetchs;
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU)
#if defined(LITE_WITH_NPU) || defined(LITE_WITH_XPU) || defined(LITE_WITH_BM)
// The shape of input tensors must be determined before generating NPU and XPU
// program.
auto current_block = program_desc_.GetBlock<cpp::BlockDesc>(0);
......
......@@ -55,7 +55,8 @@ const std::string& TargetToStr(TargetType target) {
"any",
"fpga",
"npu",
"xpu"};
"xpu",
"bm"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......@@ -93,7 +94,8 @@ const std::string& TargetRepr(TargetType target) {
"kAny",
"kFPGA",
"kNPU",
"kXPU"};
"kXPU",
"kBM"};
auto x = static_cast<int>(target);
CHECK_LT(x, static_cast<int>(TARGET(NUM)));
return target2string[x];
......@@ -129,6 +131,7 @@ std::set<TargetType> ExpandValidTargets(TargetType target) {
TARGET(kOpenCL),
TARGET(kNPU),
TARGET(kXPU),
TARGET(kBM),
TARGET(kFPGA)});
if (target == TARGET(kAny)) {
return valid_set;
......
......@@ -52,8 +52,9 @@ enum class TargetType : int {
kFPGA = 7,
kNPU = 8,
kXPU = 9,
kBM = 10,
kAny = 6, // any target
NUM = 10, // number of fields.
NUM = 11, // number of fields.
};
enum class PrecisionType : int {
kUnk = 0,
......
......@@ -26,6 +26,9 @@ USE_MIR_PASS(generate_npu_program_pass);
#ifdef LITE_WITH_XPU
USE_MIR_PASS(generate_xpu_program_pass);
#endif
#ifdef LITE_WITH_BM
USE_MIR_PASS(generate_bm_program_pass);
#endif
USE_MIR_PASS(io_copy_kernel_pick_pass);
USE_MIR_PASS(argument_type_display_pass);
......
......@@ -3,3 +3,4 @@ if (NOT LITE_WITH_BM)
endif()
lite_cc_library(target_wrapper_bm SRCS target_wrapper.cc bm_context.cc DEPS ${bm_runtime_libs})
lite_cc_library(bm_builder SRCS builder.cc DEPS ${bm_builder_libs})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/bm/builder.h"
#include <mutex>
#include <utility>
namespace paddle {
namespace lite {
namespace bm {
std::string UniqueName(const std::string& prefix) {
static std::mutex counter_mtx;
static std::unordered_map<std::string, int> counter_map;
std::unique_lock<std::mutex> counter_lck(counter_mtx);
int counter = 1;
auto it = counter_map.find(prefix);
if (it == counter_map.end()) {
counter_map[prefix] = counter;
} else {
counter = ++(it->second);
}
return prefix + "_" + std::to_string(counter);
}
bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname) {
auto iarg_names = op_info->input_argnames();
if (std::find(iarg_names.begin(), iarg_names.end(), argname) !=
iarg_names.end()) {
auto inputs = op_info->Input(argname);
if (inputs.empty()) {
return false;
}
auto var_name = inputs.front();
auto var = scope->FindVar(var_name);
return var != nullptr;
} else {
return false;
}
}
} // namespace bm
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace bm {
std::string UniqueName(const std::string& prefix);
bool HasInputArg(const OpInfo* op_info, const Scope* scope, const std::string& argname);
} // namespace bm
} // namespace lite
} // namespace paddle
......@@ -47,5 +47,5 @@ void ConvActivationFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_activation_fuse_pass,
paddle::lite::mir::ConvActivationFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.ExcludeTargets({TARGET(kXPU), TARGET(kBM)})
.BindKernel("conv2d");
......@@ -45,4 +45,4 @@ void ConvBNFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_bn_fuse_pass, paddle::lite::mir::ConvBNFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kX86), TARGET(kXPU)});
.ExcludeTargets({TARGET(kX86), TARGET(kXPU), TARGET(kBM)});
......@@ -47,4 +47,4 @@ void ConvElementwiseFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_conv_elementwise_fuse_pass,
paddle::lite::mir::ConvElementwiseFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)});
.ExcludeTargets({TARGET(kXPU), TARGET(kBM)});
......@@ -35,5 +35,5 @@ void ElementwiseAddActivationFusePass::Apply(
REGISTER_MIR_PASS(lite_elementwise_add_activation_fuse_pass,
paddle::lite::mir::ElementwiseAddActivationFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.ExcludeTargets({TARGET(kXPU), TARGET(kBM)})
.BindKernel("fusion_elementwise_add_activation");
......@@ -33,5 +33,5 @@ void FcFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(lite_fc_fuse_pass, paddle::lite::mir::FcFusePass)
.BindTargets({TARGET(kAny)})
.ExcludeTargets({TARGET(kXPU)})
.ExcludeTargets({TARGET(kXPU), TARGET(kBM)})
.BindKernel("fc");
......@@ -256,4 +256,4 @@ void MemoryOptimizePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
REGISTER_MIR_PASS(memory_optimize_pass, paddle::lite::mir::MemoryOptimizePass)
.BindTargets({TARGET(kARM)})
.ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU)});
.ExcludeTargets({TARGET(kOpenCL), TARGET(kNPU), TARGET(kXPU), TARGET(kBM)});
......@@ -33,7 +33,6 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
kernel_pick_factors_.ConsiderTarget();
kernel_pick_factors_.ConsiderPrecision();
kernel_pick_factors_.ConsiderDataLayout();
CHECK(kernel_pick_factors_.any_factor_considered())
<< "kernel_pick_factors should be specified first";
CHECK(graph) << "graph not valid";
......@@ -50,7 +49,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
<< instruct.op_type();
VLOG(4) << "instruct.kernels().size():" << instruct.kernels().size();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
VLOG(4) << "kernel->summary():" << kernel->summary()
<< " score:" << score;
scored.emplace_back(score, std::move(kernel));
......@@ -100,7 +99,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
instruct.ResetOp(update_desc, graph->valid_places());
scored.clear();
for (auto&& kernel : instruct.kernels()) {
float score = KernelGrade(*kernel, graph->valid_places());
float score = KernelGrade(instruct, *kernel, graph->valid_places());
scored.emplace_back(score, std::move(kernel));
}
std::sort(scored.begin(), scored.end(), KernelScoreCmp);
......@@ -115,6 +114,7 @@ void StaticKernelPickPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
bool all_output_type_match = true;
auto expect_output_type =
out_type_int8 ? PRECISION(kInt8) : PRECISION(kFloat);
for (auto& arg_name : output_arguments) {
const Type* out_arg_ty =
candidate.second->GetOutputDeclType(arg_name);
......
......@@ -46,5 +46,11 @@ if(LITE_WITH_XPU)
endif()
endif()
if(LITE_WITH_BM)
lite_cc_library(bm_pass SRCS generate_bm_program_pass.cc
DEPS mir_pass types context ${mir_fusers} ${bm_bridges} ${bm_builder_libs} graph_op subgraph_pass)
list(APPEND subgraph_passes bm_pass)
endif()
set(subgraph_passes ${subgraph_passes} CACHE INTERNAL "subgraph_passes")
message(STATUS "----> subgraph_passes: ${subgraph_passes}")
......@@ -22,57 +22,39 @@
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "lite/kernels/bm/bridges/paddle_use_bm_bridges.h"
#include "lite/kernels/bm/bridges/registry.h"
#include "bmcompiler_if.h"
#include "bmlog.hpp"
namespace paddle {
namespace lite {
namespace mir {
namespace subgraph {
std::shared_ptr<ge::Operator> GenerateBMProgramPass::CvtVarNode(
std::shared_ptr<void*> GenerateBMProgramPass::CvtVarNode(
lite::mir::Node* var_node, const Scope* scope) {
CHECK(var_node->IsArg());
const auto& arg = var_node->AsArg();
VLOG(4) << "Convert var node " << arg.name;
auto* var = scope->FindVar(arg.name);
CHECK(var);
auto* tensor = var->GetMutable<lite::Tensor>();
CHECK(tensor);
auto dims = tensor->dims();
if (arg.is_weight) {
auto wgt = std::make_shared<ge::op::Const>(arg.name);
LOG(INFO) << " Convert const var node " << arg.name;
VLOG(4) << dims;
wgt->set_attr_value(lite::npu::CvtTensor(tensor));
return wgt;
} else {
CHECK_EQ(dims.size(), 4);
LOG(INFO) << "[NPU] Convert data var node " << arg.name;
LOG(INFO) << dims;
// TODO(xxx): support more types and dims size
ge::TensorDesc desc(ge::Shape(dims.Vectorize()),
ge::Format::FORMAT_NCHW,
ge::DataType::DT_FLOAT);
// auto size = desc.GetShape().GetShapeSize();
// ge::TensorUtils::SetSize(desc, size*sizeof(float));
// ge::TensorUtils::SetRealDimCnt(desc, 4);
auto data = std::make_shared<ge::op::Data>(arg.name);
data->update_input_desc_x(desc);
return data;
}
return nullptr;
}
void GenerateNPUProgramPass::CvtAllOpNodes(
void GenerateBMProgramPass::CvtAllOpNodes(
const std::vector<Node*>& nodes2cvt,
lite::kernels::npu::bridges::node_map_type* converted_vars) {
const auto& bridges = lite::kernels::npu::bridges::Factory::Instance();
lite::kernels::bm::bridges::node_map_type* converted_vars) {
const auto& bridges = lite::kernels::bm::bridges::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions();
// return record all converted vars
// op node's inputs must be found in converted_vars
lite::kernels::bm::bridges::graph_ctx_type ctx;
ctx.bm_compiler_handle = create_bmcompiler("BM1684");
CHECK(ctx.bm_compiler_handle != nullptr);
//bmlog::init("paddle_bitmain");
//bmlog::set_v(3);
for (auto& node : nodes2cvt) {
lite::kernels::npu::bridges::node_map_type node_inputs;
lite::kernels::bm::bridges::node_map_type node_inputs;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
// weight should be handled in the converter, so skip here
......@@ -81,61 +63,25 @@ void GenerateNPUProgramPass::CvtAllOpNodes(
}
auto var_name = arg.name;
if (!converted_vars->count(var_name)) {
converted_vars->insert(
std::make_pair(var_name, CvtVarNode(var_node, stmt.op()->scope())));
converted_vars->insert(std::make_pair(var_name, var_name));
}
node_inputs.insert(*converted_vars->find(var_name));
}
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs);
converted_vars->insert(node_outputs.begin(), node_outputs.end());
}
}
std::string GenerateNPUProgramPass::BuildNPUGraph(
const std::unordered_set<Node*>& op_nodes,
const std::unordered_set<Node*>& in_data_vars,
const std::unordered_set<Node*>& out_data_vars,
int sub_id) {
auto ordered_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::npu::bridges::node_map_type converted_vars;
CvtAllOpNodes(ordered_nodes, &converted_vars);
std::vector<std::string> in_var_names;
std::vector<std::string> out_var_names;
std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs;
for (auto i : in_data_vars) {
auto argname = i->AsArg().name;
in_var_names.push_back(argname);
inputs.push_back(*converted_vars.at(argname));
}
for (auto i : out_data_vars) {
auto argname = i->AsArg().name;
out_var_names.push_back(argname);
outputs.push_back(*converted_vars.at(argname));
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), &ctx, node_inputs);
converted_vars->insert(node_outputs.begin(), node_outputs.end());
}
std::string weight_var_name = "graph" + std::to_string(sub_id) + "_weights";
auto any_op = (*op_nodes.begin())->AsStmt().op();
auto weight = any_op->scope()->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
// Compiling IR graph to NPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "[NPU] Build NPU graph failed (subgraph=" << sub_id << ")";
throw std::runtime_error("Build NPU graph failed.");
}
LOG(INFO) << "[NPU] Build NPU graph success (subgraph=" << sub_id << ")";
return weight_var_name;
std::string net_name = "paddle_bitmain";
__bmcompile_opt(ctx.bm_compiler_handle, const_cast<char*>(net_name.c_str()), 2);
finish_bmcompiler(ctx.bm_compiler_handle);
}
void GenerateBMProgramPass::GenSubgraph(
const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id) {
#if 0
std::unordered_set<Node*> in_data_vars;
std::unordered_set<Node*> in_wgt_vars;
std::unordered_set<Node*> out_data_vars;
......@@ -143,27 +89,31 @@ void GenerateBMProgramPass::GenSubgraph(
FindInputOutputVars(
op_nodes, &in_data_vars, &in_wgt_vars, &out_data_vars, &out_unused_vars);
auto weight_var_name =
BuildNPUGraph(op_nodes, in_data_vars, out_data_vars, sub_id);
auto any_op = (*op_nodes.begin())->AsStmt().op();
InsertNewNode(graph,
weight_var_name,
any_op->scope(),
any_op->valid_places(),
in_data_vars,
in_wgt_vars,
out_data_vars,
out_unused_vars);
auto nodes2rm = GetNode2rm(
op_nodes, {in_data_vars, in_wgt_vars, out_data_vars, out_unused_vars});
GraphSafeRemoveNodes(graph.get(), nodes2rm);
#endif
auto ordered_nodes = GetTopologicalOrder(op_nodes);
lite::kernels::bm::bridges::node_map_type converted_vars;
CvtAllOpNodes(ordered_nodes, &converted_vars);
}
void GenerateBMProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
const auto& bridges = lite::kernels::bm::bridges::Factory::Instance();
const auto& op_map = bridges.AllFunctions();
std::vector<std::string> supported_op_types;
for (auto& i : op_map) {
//LOG(INFO) << "[BM] Supported type: " << i.first;
supported_op_types.push_back(i.first);
}
int num_subgraph = FuseSubgraph(graph, supported_op_types);
InferOnce(graph);
auto op_nodes_all = ClassifySubgraph(graph);
CHECK_EQ(op_nodes_all.size(), num_subgraph);
int id = 1;
for (auto& op_nodes : op_nodes_all) {
//LOG(INFO) << "[BM] Converting Subgraph " << id;
GenSubgraph(graph, op_nodes.second, id);
id++;
}
}
......
......@@ -24,6 +24,8 @@
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/bm/bridges/registry.h"
namespace paddle {
namespace lite {
namespace mir {
......@@ -40,9 +42,9 @@ class GenerateBMProgramPass : public SubgraphProgramPass {
// nodes2cvt: op nodes to convert
// return cvted_vars: converted var nodes
void CvtAllOpNodes(const std::vector<Node*>& nodes2cvt,
lite::kernels::npu::bridges::node_map_type* cvted_vars);
lite::kernels::bm::bridges::node_map_type* cvted_vars);
std::shared_ptr<ge::Operator> CvtVarNode(lite::mir::Node* var_node,
std::shared_ptr<void*> CvtVarNode(lite::mir::Node* var_node,
const Scope* scope);
std::string BuildGraph(const std::unordered_set<Node*>& op_nodes,
......@@ -50,6 +52,9 @@ class GenerateBMProgramPass : public SubgraphProgramPass {
const std::unordered_set<Node*>& out_data_vars,
int sub_id);
void GenSubgraph(const std::unique_ptr<SSAGraph>& graph,
const std::unordered_set<Node*>& op_nodes,
int sub_id);
private:
std::vector<Instruction> insts_;
};
......
......@@ -33,6 +33,9 @@
#ifdef LITE_WITH_XPU
#include "lite/core/mir/subgraph/generate_xpu_program_pass.h"
#endif
#ifdef LITE_WITH_BM
#include "lite/core/mir/subgraph/generate_bm_program_pass.h"
#endif
namespace paddle {
namespace lite {
......@@ -59,7 +62,8 @@ class Optimizer {
SpecifyKernelPickTactic(kernel_pick_factor);
InitTargetTypeTransformPass();
if (passes.empty()) {
//if (passes.empty()) {
if (0) {
std::vector<std::string> passes_local{
{"lite_quant_dequant_fuse_pass", //
"lite_conv_elementwise_fuse_pass", // conv-elemwise-bn
......@@ -125,7 +129,9 @@ class Optimizer {
// of input tensors. so GenRuntimeProgram() must be called after the shapes
// of input tensors are determined.
std::vector<std::string> subgraph_passes{"generate_npu_program_pass",
"generate_xpu_program_pass"};
"generate_xpu_program_pass",
"generate_bm_program_pass"};
RunPasses(subgraph_passes);
auto pass = mir::PassManager::Global().LookUp<mir::GenerateProgramPass>(
......
lite_cc_library(bm_bridge_registry SRCS registry.cc)
set(bm_bridge_deps bm_bridge_registry op)
set(bm_bridge_deps bm_bridge_registry bm_builder op)
lite_cc_library(bm_bridge_act_op SRCS act_op.cc DEPS ${bm_bridge_deps})
lite_cc_library(bm_bridge_conv_op SRCS conv_op.cc DEPS ${bm_bridge_deps})
......@@ -9,6 +9,7 @@ lite_cc_library(bm_bridge_pool_op SRCS pool_op.cc DEPS ${bm_bridge_deps})
lite_cc_library(bm_bridge_softmax_op SRCS softmax_op.cc DEPS ${bm_bridge_deps})
lite_cc_library(bm_bridge_mul_op SRCS mul_op.cc DEPS ${bm_bridge_deps})
lite_cc_library(bm_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${bm_bridge_deps})
lite_cc_library(bm_bridge_scale_op SRCS scale_op.cc DEPS ${bm_bridge_deps})
set(bm_bridges
bm_bridge_registry
......@@ -19,5 +20,6 @@ set(bm_bridges
bm_bridge_softmax_op
bm_bridge_mul_op
bm_bridge_batch_norm_op
bm_bridge_scale_op
CACHE INTERNAL "bm_bridges")
......@@ -13,6 +13,7 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +21,49 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type ActConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = act_op->scope();
auto op_info = act_op->op_info();
auto op_type = op_info->Type();
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
CHECK(op_type == "relu");
add_relu_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
0.f,
-1.f);
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "lite/backends/bm/builder.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,97 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type BatchNormConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type BatchNormConverter(const std::shared_ptr<lite::OpLite> bn_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = bn_op->scope();
auto op_info = bn_op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::bm::UniqueName(op_type);
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
auto scale_var_name = op_info->Input("Scale").front();
auto scale = scope->FindVar(scale_var_name)->GetMutable<lite::Tensor>();
auto bias_var_name = op_info->Input("Bias").front();
auto bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto mean_var_name = op_info->Input("Mean").front();
auto mean = scope->FindVar(mean_var_name)->GetMutable<lite::Tensor>();
auto variance_var_name = op_info->Input("Variance").front();
auto variance = scope->FindVar(variance_var_name)->GetMutable<lite::Tensor>();
// output
auto output_var_name = op_info->Output("Y").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
auto epsilon = op_info->GetAttr<float>("epsilon");
auto unique_bn_out_name = lite::bm::UniqueName("batch_norm_out");
add_batchnorm_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(unique_bn_out_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
static_cast<const float*>(mean->mutable_data<float>()),
static_cast<const float*>(variance->mutable_data<float>()),
1.f,
epsilon,
0,
1);
const int input_num = 1;
int **shape = new int *[input_num];
int *dim = new int[input_num];
const char **name = new const char *[input_num];
name[0] = static_cast<const char*>(unique_bn_out_name.c_str());
dim[0] = output_dims.size();
shape[0] = i_output_shape_data;
auto unique_scale_name = lite::bm::UniqueName("scale");
add_scale_layer(graph_ctx->bm_compiler_handle,
input_num,
shape,
dim,
name,
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_scale_name.c_str()),
static_cast<const float*>(scale->mutable_data<float>()),
static_cast<const float*>(bias->mutable_data<float>()),
1,
1,
0);
delete [] shape;
delete [] name;
delete [] dim;
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "lite/backends/bm/builder.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,84 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = conv_op->scope();
auto op_info = conv_op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::bm::UniqueName(op_type);
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_dims = input->dims();
auto output_var_name = op_info->Output("Output").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_dims = filter->dims();
CHECK(input_dims.size() == 4);
CHECK(output_dims.size() == 4);
CHECK(filter_dims.size() == 4);
bool has_bias = lite::bm::HasInputArg(op_info, scope, "Bias");
float* bias_data = nullptr;
if (has_bias) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
bias_data = static_cast<float*>(bias->mutable_data<float>());
}
const long int* input_shape_data = const_cast<const long int*>(&input_dims.data()[0]);
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_input_shape_data[input_dims.size()];
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < input_dims.size(); i++) {
i_input_shape_data[i] = static_cast<int>(input_shape_data[i]);
}
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
const float* filter_data = const_cast<const float*>(filter->mutable_data<float>());
auto groups = op_info->GetAttr<int>("groups");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
add_conv_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_input_shape_data),
input_dims.size(),
static_cast<const char*>(input_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
filter_data,
bias_data,
filter_dims.data()[2],
filter_dims.data()[3],
groups,
paddings[0],
paddings[0],
paddings[1],
paddings[1],
strides[0],
strides[1],
dilations[0],
dilations[1],
static_cast<int>(has_bias));
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "bmcompiler_if.h"
#include "bmcompiler_if_lite.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,117 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type ElementwiseConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type ElementwiseConverter(const std::shared_ptr<lite::OpLite> elementwise_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = elementwise_op->scope();
auto op_info = elementwise_op->op_info();
auto op_type = op_info->Type();
// input
const int input_num = 2;
int **shape = new int *[input_num];
int *dim = new int[input_num];
const char **name = new const char *[input_num];
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
name[0] = static_cast<const char*>(x_var_name.c_str());
dim[0] = x_dims.size();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
shape[0] = i_x_shape_data;
auto y_var_name = op_info->Input("Y").front();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto y_dims = y->dims();
name[1] = static_cast<const char*>(y_var_name.c_str());
dim[1] = y_dims.size();
const long int* y_shape_data = const_cast<const long int*>(&y_dims.data()[0]);
int i_y_shape_data[y_dims.size()];
for (size_t i = 0; i < y_dims.size(); i++) {
i_y_shape_data[i] = static_cast<int>(y_shape_data[i]);
}
shape[1] = i_y_shape_data;
bool y_is_const = input_nodes.find(y_var_name) == input_nodes.end();
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
if (y_is_const) {
CHECK(op_type == "elementwise_add");
}
int op_code{-1};
float coeff[2] = {1.f, 1.f};
if (op_type == "elementwise_mul") {
op_code = 0;
} else if (op_type == "elementwise_add") {
op_code = 1;
} else if(op_type == "elementwise_sub") {
op_code = 1;
coeff[1] = -1.f;
} else {
LOG(FATAL) << "UNSUPPORTED ELTWISE OPERATION: " << op_type;
}
if (!y_is_const) {
add_eltwise_layer(graph_ctx->bm_compiler_handle,
input_num,
shape,
dim,
name,
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
op_code,
coeff);
} else {
const float* y_data = const_cast<const float*>(y->mutable_data<float>());
bm_add_const_tensor(graph_ctx->bm_compiler_handle,
name[0],
shape[0],
dim[0],
static_cast<bm_data_type_t>(0),
static_cast<const void*>(y_data));
add_binary_layer_v2(graph_ctx->bm_compiler_handle,
name[0],
shape[0],
dim[0],
0,
nullptr,
name[0],
shape[0],
dim[0],
0,
nullptr,
static_cast<const char*>(output_var_name.c_str()),
0);
}
delete [] shape;
delete [] name;
delete [] dim;
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......@@ -33,4 +142,4 @@ node_map_type ElementwiseConverter(const std::shared_ptr<lite::OpLite> op,
} // namespace lite
} // namespace paddle
REGISTER_BM_BRIDGE(elementwise, paddle::lite::kernels::bm::bridges::ElementwiseConverter);
REGISTER_BM_BRIDGE(elementwise_add, paddle::lite::kernels::bm::bridges::ElementwiseConverter);
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "lite/backends/bm/builder.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,76 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type MulConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = mul_op->scope();
auto op_info = mul_op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::bm::UniqueName(op_type);
// only support y is const
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// add reshape layer
int i_x_reshape_shape_data[2];
for (size_t i = 0; i < 2; i++) {
i_x_reshape_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
int reshape_param[] = {0, -1};
auto unique_op_reshape_name = lite::bm::UniqueName(op_type + "_reshape");
add_reshape_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_x_reshape_shape_data),
2,
static_cast<const char*>(unique_op_reshape_name.c_str()),
const_cast<const int*>(reshape_param));
auto y_var_name = op_info->Input("Y").front();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto y_dims = y->dims();
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
add_fc_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_x_reshape_shape_data),
2,
static_cast<const char*>(unique_op_reshape_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
i_x_reshape_shape_data[1],
i_output_shape_data[1],
static_cast<const float*>(y->mutable_data<float>()),
nullptr,
0,
0);
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/kernels/bm/bridges/registry.h"
USE_BM_BRIDGE(relu);
USE_BM_BRIDGE(conv2d);
USE_BM_BRIDGE(elementwise_add);
USE_BM_BRIDGE(pool2d);
USE_BM_BRIDGE(softmax);
USE_BM_BRIDGE(mul);
USE_BM_BRIDGE(batch_norm);
USE_BM_BRIDGE(scale);
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "lite/backends/bm/builder.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,80 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = pool_op->scope();
auto op_info = pool_op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::bm::UniqueName(op_type);
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
int *shape[1];
int dim[1];
const char *name[1];
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
shape[0] = i_output_shape_data;
name[0] = static_cast<const char*>(output_var_name.c_str());
dim[0] = output_dims.size();
auto pooling_type = op_info->GetAttr<std::string>("pooling_type");
CHECK(pooling_type == "max" || pooling_type == "avg");
auto ksize = op_info->GetAttr<std::vector<int>>("ksize");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto global_pooling = op_info->GetAttr<bool>("global_pooling");
auto ceil_mode = op_info->GetAttr<bool>("ceil_mode");
bool average_exclusive = false;
if (pooling_type == "avg") {
average_exclusive = op_info->GetAttr<bool>("exclusive");
}
add_pooling_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
1,
shape,
dim,
name,
ksize[0],
ksize[1],
paddings[0],
paddings[0],
paddings[1],
paddings[1],
strides[0],
strides[1],
(ksize[0] > 1 && ksize[1] > 1) && pooling_type == "max" ? 0 : 1,
static_cast<int>(average_exclusive),
static_cast<int>(global_pooling),
static_cast<int>(ceil_mode),
static_cast<const char*>(unique_op_name.c_str()),
nullptr);
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
......@@ -28,11 +28,17 @@ namespace kernels {
namespace bm {
namespace bridges {
class graph_ctx_type{
public:
void* bm_compiler_handle{nullptr};
};
// var_name, bm node point
using node_map_type =
std::unordered_map<std::string, std::shared_ptr<void*>>;
std::unordered_map<std::string, std::string>;
using func_type = std::function<node_map_type(const std::shared_ptr<OpLite>,
graph_ctx_type*,
const node_map_type&)>;
using cvt_map_type = std::unordered_map<std::string, func_type>;
class Factory {
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "lite/backends/bm/builder.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,72 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = scale_op->scope();
auto op_info = scale_op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::bm::UniqueName(op_type);
// input
const int input_num = 1;
int **shape = new int *[input_num];
int *dim = new int[input_num];
const char **name = new const char *[input_num];
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
name[0] = static_cast<const char*>(x_var_name.c_str());
dim[0] = x_dims.size();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
shape[0] = i_x_shape_data;
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
auto scale = op_info->GetAttr<float>("scale");
auto bias = op_info->GetAttr<float>("bias");
auto bias_after_scale = op_info->GetAttr<bool>("bias_after_scale");
if (bias_after_scale) {
bias *= scale;
}
add_scale_layer(graph_ctx->bm_compiler_handle,
input_num,
shape,
dim,
name,
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(unique_op_name.c_str()),
&scale,
&bias,
1,
1,
0);
delete [] shape;
delete [] dim;
delete [] name;
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
......@@ -13,6 +13,8 @@
// limitations under the License.
#include "lite/kernels/bm/bridges/registry.h"
#include "lite/backends/bm/builder.h"
#include "bmcompiler_if.h"
namespace paddle {
namespace lite {
......@@ -20,10 +22,53 @@ namespace kernels {
namespace bm {
namespace bridges {
node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> op,
node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
graph_ctx_type* graph_ctx,
const node_map_type& input_nodes) {
// output converted nodes
node_map_type output_nodes;
auto scope = softmax_op->scope();
auto op_info = softmax_op->op_info();
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const long int* x_shape_data = const_cast<const long int*>(&x_dims.data()[0]);
int i_x_shape_data[x_dims.size()];
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
const long int* output_shape_data = const_cast<const long int*>(&output_dims.data()[0]);
int i_output_shape_data[output_dims.size()];
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
auto axis = op_info->GetAttr<int>("axis");
if (axis < 0) {
axis += x_dims.size();
}
int outer_num = x_dims.Slice(0, axis).production();
int inner_num = x_dims.Slice(axis + 1, x_dims.size()).production();
add_softmax_layer(graph_ctx->bm_compiler_handle,
const_cast<const int*>(i_x_shape_data),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_output_shape_data),
output_dims.size(),
static_cast<const char*>(output_var_name.c_str()),
inner_num,
outer_num,
x_dims[axis]);
output_nodes[output_var_name] = output_var_name;
return output_nodes;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册