未验证 提交 30c273de 编写于 作者: Y Yan Chunwei 提交者: GitHub

port lite code (#1819)

上级 ca334444
......@@ -173,7 +173,6 @@ include(ccache) # set ccache for compilation
include(util) # set unittest and link libs
include(version) # set PADDLE_VERSION
set(CMAKE_CXX_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
set(CMAKE_C_FLAGS_RELWITHDEBINFO "-O3 -g -DNDEBUG")
......
......@@ -211,6 +211,13 @@ if(NOT IOS)
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels})
lite_cc_binary(benchmark_bin SRCS benchmark.cc DEPS paddle_api_full paddle_api_light gflags
${ops}
ARM_DEPS ${arm_kernels}
NPU_DEPS ${npu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels})
endif()
#lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <cstdio>
#include <fstream>
#include <string>
#include <vector>
#include "lite/api/paddle_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/cpu_info.h"
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
DEFINE_string(input_shape,
"1,3,224,224",
"input shapes, separated by colon and comma");
DEFINE_string(result_filename, "", "save test result");
namespace paddle {
namespace lite_api {
void OutputOptModel(const std::string& load_model_dir,
const std::string& save_optimized_model_dir,
const std::vector<std::vector<int64_t>>& input_shapes) {
lite_api::CxxConfig config;
config.set_model_dir(load_model_dir);
config.set_preferred_place(Place{TARGET(kX86), PRECISION(kFloat)});
config.set_valid_places({
Place{TARGET(kX86), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
auto predictor = lite_api::CreatePaddlePredictor(config);
// delete old optimized model
int ret = system(
paddle::lite::string_format("rm -rf %s", save_optimized_model_dir.c_str())
.c_str());
if (ret == 0) {
LOG(INFO) << "delete old optimized model " << save_optimized_model_dir;
}
predictor->SaveOptimizedModel(save_optimized_model_dir,
LiteModelType::kNaiveBuffer);
LOG(INFO) << "Load model from " << load_model_dir;
LOG(INFO) << "Save optimized model to " << save_optimized_model_dir;
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void Run(const std::vector<std::vector<int64_t>>& input_shapes,
const std::string& model_dir,
const int repeat,
const int thread_num,
const int warmup_times,
const std::string model_name) {
#ifdef LITE_WITH_ARM
lite::DeviceInfo::Init();
if (thread_num == 1) {
lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_HIGH, thread_num);
LOG(INFO) << "LITE_POWER_HIGH";
} else {
lite::DeviceInfo::Global().SetRunMode(lite::LITE_POWER_NO_BIND, thread_num);
LOG(INFO) << "LITE_POWER_NO_BIND";
}
#endif
lite_api::MobileConfig config;
config.set_model_dir(model_dir);
auto predictor = lite_api::CreatePaddlePredictor(config);
for (int j = 0; j < input_shapes.size(); ++j) {
auto input_tensor = predictor->GetInput(j);
input_tensor->Resize(input_shapes[j]);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (int i = 0; i < input_shapes[j].size(); ++i) {
input_num *= input_shapes[j][i];
}
for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f;
}
}
for (int i = 0; i < warmup_times; ++i) {
predictor->Run();
}
auto start = lite::GetCurrentUS();
for (int i = 0; i < repeat; ++i) {
predictor->Run();
}
auto end = lite::GetCurrentUS();
std::FILE* pf = std::fopen(FLAGS_result_filename.c_str(), "a");
if (nullptr == pf) {
LOG(INFO) << "create result file error";
exit(0);
}
fprintf(pf,
"-- %-18s avg = %5.4f ms\n",
model_name.c_str(),
(end - start) / repeat / 1000.0);
std::fclose(pf);
}
#endif
} // namespace lite_api
} // namespace paddle
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "" || FLAGS_result_filename == "") {
LOG(INFO) << "usage: "
<< "--model_dir /path/to/your/model --result_filename "
"/path/to/resultfile";
exit(0);
}
std::size_t found = FLAGS_model_dir.find_last_of("/");
std::string model_name = FLAGS_model_dir.substr(found + 1);
std::string save_optimized_model_dir = FLAGS_model_dir + "opt2";
auto split_string =
[](const std::string& str_in) -> std::vector<std::string> {
std::vector<std::string> str_out;
std::string tmp_str = str_in;
while (!tmp_str.empty()) {
size_t next_offset = tmp_str.find(":");
str_out.push_back(tmp_str.substr(0, next_offset));
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return str_out;
};
auto get_shape = [](const std::string& str_shape) -> std::vector<int64_t> {
std::vector<int64_t> shape;
std::string tmp_str = str_shape;
while (!tmp_str.empty()) {
int dim = atoi(tmp_str.data());
shape.push_back(dim);
size_t next_offset = tmp_str.find(",");
if (next_offset == std::string::npos) {
break;
} else {
tmp_str = tmp_str.substr(next_offset + 1);
}
}
return shape;
};
std::vector<std::string> str_input_shapes = split_string(FLAGS_input_shape);
std::vector<std::vector<int64_t>> input_shapes;
for (int i = 0; i < str_input_shapes.size(); ++i) {
input_shapes.push_back(get_shape(str_input_shapes[i]));
}
// Output optimized model
paddle::lite_api::OutputOptModel(
FLAGS_model_dir, save_optimized_model_dir, input_shapes);
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// Run inference using optimized model
paddle::lite_api::Run(input_shapes,
save_optimized_model_dir,
FLAGS_repeats,
FLAGS_threads,
FLAGS_warmup,
model_name);
#endif
return 0;
}
......@@ -21,16 +21,19 @@
#ifndef LITE_WITH_FPGA
USE_LITE_KERNEL(feed, kHost, kAny, kAny, def);
USE_LITE_KERNEL(fetch, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
#else
USE_LITE_KERNEL(feed, kFPGA, kFP16, kNHWC, def);
USE_LITE_KERNEL(fetch, kFPGA, kFP16, kNHWC, def);
#endif
// host kernels
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
#ifdef LITE_WITH_ARM
USE_LITE_KERNEL(fc, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(mul, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(matmul, kARM, kFloat, kNCHW, def); // for x2paddle
USE_LITE_KERNEL(scale, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(softmax, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(lrn, kARM, kFloat, kNCHW, def);
......@@ -49,6 +52,7 @@ USE_LITE_KERNEL(dropout, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pool2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(relu6, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(transpose2, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(batch_norm, kARM, kFloat, kNCHW, def);
......@@ -64,6 +68,7 @@ USE_LITE_KERNEL(sigmoid, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(tanh, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(swish, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(log, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(exp, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(conv2d_transpose, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(pad2d, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(prior_box, kARM, kFloat, kNCHW, def);
......@@ -91,6 +96,9 @@ USE_LITE_KERNEL(shape, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(fill_constant, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(cast, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(slice, kARM, kFloat, kNCHW, def)
USE_LITE_KERNEL(squeeze, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(squeeze2, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(expand, kARM, kFloat, kNCHW, def) // for x2paddle
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, fp32_to_int8);
USE_LITE_KERNEL(calib, kARM, kInt8, kNCHW, int8_to_fp32);
......
......@@ -19,8 +19,10 @@
#include "paddle_lite_factory_helper.h" // NOLINT
USE_LITE_OP(mul);
USE_LITE_OP(matmul); // for x2paddle
USE_LITE_OP(fc);
USE_LITE_OP(relu);
USE_LITE_OP(relu6);
USE_LITE_OP(scale);
USE_LITE_OP(feed);
USE_LITE_OP(lrn);
......@@ -56,6 +58,7 @@ USE_LITE_OP(sigmoid)
USE_LITE_OP(tanh)
USE_LITE_OP(swish)
USE_LITE_OP(log)
USE_LITE_OP(exp)
USE_LITE_OP(conv2d_transpose)
USE_LITE_OP(negative)
USE_LITE_OP(pad2d)
......@@ -104,3 +107,6 @@ USE_LITE_OP(is_empty)
USE_LITE_OP(shape)
USE_LITE_OP(slice)
USE_LITE_OP(cast)
USE_LITE_OP(squeeze) // for x2paddle
USE_LITE_OP(squeeze2) // for x2paddle
USE_LITE_OP(expand) // for x2paddle
......@@ -632,6 +632,40 @@ void act_log(const float* din, float* dout, int size, int threads) {
}
}
template <>
void act_exp(const float* din, float* dout, int size, int threads) {
int nums_per_thread = size / threads;
int remain = size - threads * nums_per_thread;
int neon_loop_cnt_dim4 = nums_per_thread >> 2;
int neon_loop_remain_dim4 = nums_per_thread - (neon_loop_cnt_dim4 << 2);
float32x4_t vzero = vdupq_n_f32(0.f);
#pragma omp parallel for
for (int i = 0; i < threads; ++i) {
float32x4_t exp_vec = vdupq_n_f32(0.0f);
const float* ptr_in_thread = din + i * nums_per_thread;
float* ptr_out_thread = dout + i * nums_per_thread;
for (int k = 0; k < neon_loop_cnt_dim4; ++k) {
exp_vec = exp_ps(vld1q_f32(ptr_in_thread));
vst1q_f32(ptr_out_thread, exp_vec);
ptr_out_thread += 4;
ptr_in_thread += 4;
}
for (int j = 0; j < neon_loop_remain_dim4; ++j) {
ptr_out_thread[0] = expf(ptr_in_thread[0]);
ptr_in_thread++;
ptr_out_thread++;
}
}
float* ptr_out = dout + threads * nums_per_thread;
const float* ptr_in = din + threads * nums_per_thread;
for (int j = 0; j < remain; ++j) {
ptr_out[0] = expf(ptr_in[0]);
ptr_in++;
ptr_out++;
}
}
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -51,6 +51,10 @@ void act_swish(const T* din, T* dout, int size, float coef, int threads);
template <typename T>
void act_log(const T* din, T* dout, int size, int threads);
template <typename T>
void act_exp(const T* din, T* dout, int size, int threads);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -13,9 +13,11 @@
// limitations under the License.
#include "lite/core/mir/graph_visualize_pass.h"
#include <map>
#include <memory>
#include <set>
#include <string>
#include <utility>
#include "lite/core/mir/pass_registry.h"
#include "lite/utils/string.h"
......@@ -34,7 +36,15 @@ std::string Visualize(mir::SSAGraph* graph) {
int id = 0;
std::set<std::string> exists_args;
std::map<int, std::string> graph_col; // Different colors of subgraphs
graph_col.insert({{1, "red"},
{2, "green"},
{3, "cyan"},
{4, "bisque3"},
{5, "coral"},
{6, "darkseagreen1"},
{7, "goldenrod1"},
{8, "darkorchid"}});
for (auto& node : graph->mutable_nodes()) {
std::string key;
if (node.IsArg()) {
......@@ -44,7 +54,22 @@ std::string Visualize(mir::SSAGraph* graph) {
}
if (node.IsStmt()) {
dot.AddNode(key, {Dot::Attr("shape", "box")});
auto& stmt = node.AsStmt();
auto sub_id = stmt.subgraph_id();
auto it = graph_col.find(sub_id);
if (sub_id > 0 && it != graph_col.end()) {
dot.AddNode(key,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", it->second)});
} else {
dot.AddNode(key,
{Dot::Attr("shape", "box"),
Dot::Attr("style", "filled"),
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
}
for (auto& x : node.inlinks) {
auto name = x->AsArg().name;
if (!exists_args.count(name)) {
......
......@@ -7,6 +7,7 @@ lite_cc_test(test_subgraph_pass SRCS subgraph_program_pass_test.cc
ARGS --model_dir=${LITE_MODEL_DIR}/mobilenet_v1 SERIAL)
if (WITH_TESTING)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v1_tar_gz)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(test_subgraph_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
......@@ -23,6 +24,7 @@ if(LITE_WITH_NPU)
--optimized_model=${LITE_MODEL_DIR}/lite_npu_model_opt SERIAL)
if (WITH_TESTING)
add_dependencies(test_npu_pass extern_lite_download_mobilenet_v1_tar_gz)
add_dependencies(test_subgraph_pass extern_lite_download_mobilenet_v2_relu_tar_gz)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(test_npu_pass PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
......
......@@ -14,6 +14,7 @@
#include "lite/core/mir/subgraph/generate_npu_program_pass.h"
#include <memory>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
......@@ -36,182 +37,143 @@ namespace lite {
namespace mir {
namespace subgraph {
// call convert function from start node
// return if convert success and the nodes to remove
// return the output npu op
lite::npu::bridge::node_map_type GenerateNPUProgramPass::CvtOpNodes(
const lite::npu::bridge::cvt_map_type& cvtfunc_map,
const Node* op_node,
const lite::npu::bridge::node_map_type& inputs_map,
int sub_id,
std::unordered_set<const Node*>* nodes2rm,
key2nodes_t* matched) {
lite::npu::bridge::node_map_type failed;
if (!op_node->IsStmt()) {
LOG(INFO) << "stop return failed";
return failed;
}
auto* stmt = op_node->stmt();
auto op_type = stmt->op_type();
LOG(INFO) << "cvt op type: " << op_type;
if (stmt->subgraph_id() != sub_id) {
LOG(INFO) << "return as subgraph_id(" << stmt->subgraph_id()
<< ") != sub_id(" << sub_id << ")";
return failed;
} else {
CHECK(cvtfunc_map.count(op_type)) << "Should be supported " << op_type
<< ", with subgraph_id: " << sub_id;
}
auto outputs_map = cvtfunc_map.at(op_type)(stmt->op(), inputs_map);
if (outputs_map.empty()) {
return outputs_map;
}
nodes2rm->insert(op_node);
for (auto& var_node : op_node->outlinks) {
for (auto& next_op_node : var_node->outlinks) {
LOG(INFO) << "next op type: " << next_op_node->AsStmt().op_type();
if (next_op_node->AsStmt().subgraph_id() != sub_id) {
// this is the end condition
// TODO(TJ): when enable more inputs and outputs this is bugy
LOG(INFO) << "--- should return once ---";
// TODO(TJ): matched output could be vector
matched->insert(std::make_pair("Output", var_node));
return outputs_map;
} else {
// LOG(INFO) << "argnames: ";
// for (auto sss : next_op_node->AsStmt().op_info()->input_argnames()) {
// LOG(INFO) << sss;
// }
// LOG(INFO) << "input argnames: ";
// for (auto sss : next_op_node->AsStmt().op_info()->input_names()) {
// LOG(INFO) << sss;
// }
for (auto& i_node : next_op_node->inlinks) {
CHECK(i_node->IsArg());
auto& arg = i_node->AsArg();
LOG(INFO) << arg.name;
if (outputs_map.count(arg.name)) continue;
if (!arg.is_weight) {
LOG(INFO) << "Data arg name:" << arg.name;
outputs_map.insert(std::make_pair(
arg.name,
lite::npu::bridge::CvtNode(
i_node, next_op_node->AsStmt().op()->scope())));
}
}
nodes2rm->insert(var_node);
return CvtOpNodes(
cvtfunc_map, next_op_node, outputs_map, sub_id, nodes2rm, matched);
}
void GenerateNPUProgramPass::NPUSortHelper(
Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret) {
for (auto& var_node : node->inlinks) {
if (var_node->inlinks.empty()) continue;
auto* op_node = var_node->inlinks.front();
if (nodes_all.count(op_node) && !visited_nodes->count(op_node)) {
NPUSortHelper(op_node, nodes_all, visited_nodes, ret);
}
}
ret->push_back(node);
visited_nodes->insert(node);
}
void GenerateNPUProgramPass::ConvertSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) {
void GenerateNPUProgramPass::CvtOpNodes(
const std::vector<Node*>& nodes2cvt,
std::vector<std::string>* in_vars_name,
std::vector<std::string>* out_vars_name,
lite::npu::bridge::node_map_type* cvted_vars,
std::unordered_set<const Node*>* nodes2rm) {
const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& cvtfunc_map = bridges.AllFunctions();
std::unordered_set<const Node*> nodes2rm_all;
auto items = graph->StmtTopologicalOrder();
for (int id = 1; id <= sub_num; ++id) {
LOG(INFO) << "Converting subgraph_id:" << id;
for (auto& op_node : items) {
std::unordered_set<const Node*> nodes2rm;
if (!op_node->IsStmt()) continue;
auto& stmt = op_node->AsStmt();
if (stmt.subgraph_id() != id) continue;
CHECK(bridges.HasType(stmt.op_type()));
key2nodes_t matched;
matched["target_op"] = op_node;
auto& op = stmt.op();
auto* scope = op->scope();
// prepare inputs data.
std::string data_name = "data_subgraph_" + std::to_string(id);
lite::npu::bridge::node_map_type npu_inputs_map;
int name_id = 0;
LOG(INFO) << "op_type: " << stmt.op_type();
std::vector<std::string> actual_input_argnames;
for (auto& arg_node : op_node->inlinks) {
CHECK(arg_node->IsArg());
const auto& arg = arg_node->AsArg();
if (!arg_node->AsArg().is_weight) {
LOG(INFO) << "Input arg name: " << arg.name;
npu_inputs_map.insert(std::make_pair(
arg.name, lite::npu::bridge::CvtNode(arg_node, scope)));
// TODO(TJ): Here matched inputs should also be input vector
matched["Input"] = arg_node;
name_id++;
}
for (auto& node : nodes2cvt) {
lite::npu::bridge::node_map_type node_inputs;
auto& stmt = node->AsStmt();
for (auto& var_node : node->inlinks) {
auto& arg = var_node->AsArg();
auto var_name = arg.name;
if (!cvted_vars->count(var_name)) {
if (arg.is_weight) continue;
cvted_vars->insert(std::make_pair(
var_name,
lite::npu::bridge::CvtNode(var_node, stmt.op()->scope())));
in_vars_name->push_back(var_name);
}
CHECK_EQ(name_id, 1) << "mobilenetv1 only have one input data!";
auto npu_outputs_map = CvtOpNodes(
cvtfunc_map, op_node, npu_inputs_map, id, &nodes2rm, &matched);
if (!npu_outputs_map.empty()) {
LOG(INFO) << "[NPU] subgraph " << id << ": output not empty ";
std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs;
for (auto& i : npu_inputs_map) {
LOG(INFO) << "input data argname:" << i.first
<< ", ptr: " << i.second;
inputs.emplace_back(*(i.second));
}
for (auto& i : npu_outputs_map) {
LOG(INFO) << "output data argname:" << i.first
<< ", ptr: " << i.second;
outputs.emplace_back(*(i.second));
}
std::string model_name("hiai_npu_client_" + std::to_string(id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
// build failed, so this subgraph is abandoned
nodes2rm.clear();
LOG(WARNING) << "Build NPU failed subgraph " << id;
node_inputs.insert(*cvted_vars->find(var_name));
}
auto node_outputs = cvtfunc_map.at(stmt.op_type())(stmt.op(), node_inputs);
cvted_vars->insert(node_outputs.begin(), node_outputs.end());
nodes2rm->insert(node);
for (auto& var_node : node->outlinks) {
for (auto& next_op_node : var_node->outlinks) {
if (std::find(nodes2cvt.begin(), nodes2cvt.end(), next_op_node) ==
nodes2cvt.end()) {
out_vars_name->push_back(var_node->AsArg().name);
break;
}
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << id;
}
}
}
}
void GenerateNPUProgramPass::GenNPUGraphOpNode(
const std::unique_ptr<SSAGraph>& graph,
int sub_id,
const std::unordered_set<Node*>& nodes_all) {
std::unordered_set<const Node*> visited_nodes;
std::vector<Node*> ret;
for (auto& node : nodes_all) {
if (!node->IsStmt()) continue;
if (visited_nodes.count(node)) continue;
NPUSortHelper(node, nodes_all, &visited_nodes, &ret);
}
// Then InsertNewNode(graph, matched); make one function
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
// change to vectors
op_desc.SetInput("Inputs", {matched.at("Input")->arg()->name});
op_desc.SetOutput("Outputs", {matched.at("Output")->arg()->name});
op_desc.SetAttr("model_name", model_name);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
auto target_op = matched.at("target_op")->stmt()->op();
auto* scope = target_op->scope();
CHECK(scope);
CHECK(graph_op);
graph_op->Attach(op_desc, scope);
std::vector<std::string> in_vars_name;
std::vector<std::string> out_vars_name;
lite::npu::bridge::node_map_type cvted_vars;
std::unordered_set<const Node*> nodes2rm;
CvtOpNodes(ret, &in_vars_name, &out_vars_name, &cvted_vars, &nodes2rm);
// insert new graph op node
std::vector<ge::Operator> inputs;
std::vector<ge::Operator> outputs;
for (auto i : in_vars_name) {
inputs.push_back(*cvted_vars.at(i));
}
for (auto i : out_vars_name) {
outputs.push_back(*cvted_vars.at(i));
}
std::string model_name("hiai_npu_client_" + std::to_string(sub_id) + ".om");
if (!npu::BuildNPUClient(inputs, outputs, model_name)) {
LOG(FATAL) << "Build NPU failed subgraph " << sub_id;
}
LOG(INFO) << "[NPU] Build NPU Client success subgraph " << sub_id;
cpp::OpDesc op_desc;
op_desc.SetType("graph_op");
op_desc.SetInput("Inputs", in_vars_name);
op_desc.SetOutput("Outputs", out_vars_name);
op_desc.SetAttr("model_name", model_name);
auto graph_op = LiteOpRegistry::Global().Create("graph_op");
// TODO(zpy): support multi inputs op
auto start_op = ret.front()->AsStmt().op();
auto* scope = start_op->scope();
graph_op->Attach(op_desc, scope);
auto valid_places = start_op->valid_places();
auto* new_op_node = graph->GraphCreateInstructNode(graph_op, valid_places);
for (auto& var_node : ret.front()->inlinks) {
auto& arg = var_node->AsArg();
if (arg.is_weight) continue;
IR_NODE_LINK_TO(var_node, new_op_node);
}
for (auto& var_node : ret.back()->outlinks) {
auto& arg = var_node->AsArg();
if (arg.is_weight) continue;
IR_NODE_LINK_TO(var_node, new_op_node);
}
auto valid_places =
target_op->valid_places(); // TODO(TJ): add npu place?
auto* new_op_node =
graph->GraphCreateInstructNode(graph_op, valid_places);
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
IR_NODE_LINK_TO(matched.at("Input"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("Output"));
GraphSafeRemoveNodes(graph.get(), nodes2rm);
}
// assign context
auto& inst = new_op_node->AsStmt();
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target()));
void GenerateNPUProgramPass::ConvertSubgraph(
const std::unique_ptr<SSAGraph>& graph, int sub_num) {
std::unordered_map<int, std::unordered_set<Node*>> nodes_all;
for (auto& item : graph->StmtTopologicalOrder()) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
int sub_id = stmt.subgraph_id();
if (sub_id < 1) continue;
if (nodes_all.count(sub_id) == 0) {
nodes_all[sub_id] = std::unordered_set<Node*>();
}
nodes_all.at(sub_id).insert(item);
}
if (!nodes2rm.empty()) {
nodes2rm_all.insert(nodes2rm.begin(), nodes2rm.end());
}
break;
} // if npu output success
} // for op_nodes
} // for subgraph id
// remove all unused node once
GraphSafeRemoveNodes(graph.get(), nodes2rm_all);
// clear all npu ops
npu::OpList::Global().clear();
for (int id = 1; id <= sub_num; ++id) {
LOG(INFO) << "Converting subgraph_id:" << id;
GenNPUGraphOpNode(graph, id, nodes_all.at(id));
}
}
void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
......@@ -228,8 +190,6 @@ void GenerateNPUProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
InferOnce(graph);
ConvertSubgraph(graph, num_subgraph);
// auto graph1 = GenerateFusedGraph(std::move(graph));
// GraphSafeRemoveNodes(graph, nodes2rm);
LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
for (auto& item : graph->StmtTopologicalOrder()) {
......
......@@ -17,6 +17,7 @@
#include <map>
#include <memory>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/core/mir/pass.h"
......@@ -37,23 +38,27 @@ class GenerateNPUProgramPass : public SubgraphProgramPass {
std::unique_ptr<RuntimeProgram> GenProgram();
protected:
// TODO(TJ): maybe change a name
// convert all fused subgraphs to npu clients
// 1. if some subgraph failed, then skip.
// 2. add new graph nodes, kernels and context
// 3. remove unused nodes
void ConvertSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num);
void NPUSortHelper(Node* node,
const std::unordered_set<Node*>& nodes_all,
std::unordered_set<const Node*>* visited_nodes,
std::vector<Node*>* ret);
// nodes2cvt: op nodes to convert
// in_vars_name: graph op's inputs var name
// out_vars_name: graph op's outputs var name
// vcted_vars:
// nodes2rm: op nodes and var nodes that need to be removed
void CvtOpNodes(const std::vector<Node*>& nodes2cvt,
std::vector<std::string>* in_vars_name,
std::vector<std::string>* out_vars_name,
lite::npu::bridge::node_map_type* cvted_vars,
std::unordered_set<const Node*>* nodes2rm);
// call convert function from start node
// return if convert success and the nodes to remove
// return the output(arg.name, npu op)
lite::npu::bridge::node_map_type CvtOpNodes(
const lite::npu::bridge::cvt_map_type& cvtfunc_map,
const Node* op_node,
const lite::npu::bridge::node_map_type& inputs_map,
int sub_id,
std::unordered_set<const Node*>* nodes2rm,
key2nodes_t* matched);
void GenNPUGraphOpNode(const std::unique_ptr<SSAGraph>& graph,
int sub_id,
const std::unordered_set<Node*>& nodes_all);
void ConvertSubgraph(const std::unique_ptr<SSAGraph>& graph, int sub_num);
private:
std::vector<Instruction> insts_;
......
......@@ -85,21 +85,31 @@ void SubgraphProgramPass::ChangeAllOutConnectedID(Node* node,
for (auto& i : node->outlinks) {
if (!i->IsStmt()) return;
auto& stmt = i->AsStmt();
if (stmt.subgraph_id() != from_id) {
if (stmt.subgraph_id() < from_id) {
all_out_op_supported = false;
}
}
if (!all_out_op_supported) {
return;
}
nodes2rm_[to_id].insert(node);
for (auto& i : node->outlinks) {
CHECK(i->IsStmt());
auto& stmt = i->AsStmt();
CHECK_EQ(stmt.subgraph_id(), from_id);
stmt.SetSubgraphID(to_id);
nodes2rm_[to_id].insert(i);
for (auto& o : i->outlinks) {
ChangeAllOutConnectedID(o, to_id, from_id);
if (stmt.subgraph_id() == from_id) {
stmt.SetSubgraphID(to_id);
nodes2rm_[to_id].insert(i);
for (auto& o : i->outlinks) {
for (auto& j : o->outlinks) {
if (j->IsStmt()) {
auto& Nstmt = j->AsStmt();
if (Nstmt.subgraph_id() < from_id) {
o_nodes_[to_id].insert(o);
}
}
}
ChangeAllOutConnectedID(o, to_id, from_id);
}
}
}
}
......@@ -109,12 +119,62 @@ int SubgraphProgramPass::FuseSubgraphID(
const std::unique_ptr<SSAGraph>& graph) {
int sub_id = 1; // id start from 1 not 0
for (auto& item : graph->StmtTopologicalOrder()) {
bool inputvar = 0;
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
if (stmt.subgraph_id() == -1) {
for (auto& i : item->outlinks) {
for (auto& j : i->outlinks) {
if (j->IsStmt()) {
auto& jstmt = j->AsStmt();
// LOG(INFO) << "initial: "<<jstmt.op_type()<<"
// :"<<jstmt.subgraph_id();
if (jstmt.subgraph_id() == 0) inputvar = 1;
}
}
}
// LOG(INFO) << "initial: "<<stmt.op_type()<<" :"<<stmt.subgraph_id();
if (inputvar == 1) {
for (auto& i : item->outlinks) i_nodes_[sub_id].insert(i);
}
}
if (stmt.subgraph_id() != 0) continue;
ChangeAllOutConnectedID(item, sub_id);
sub_id++;
}
for (auto& i : nodes2rm_) {
for (auto& item : i.second) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << "nodes2rm_:" << stmt.op_type();
} else if (item->IsArg()) {
auto& arg = item->AsArg();
LOG(INFO) << "nodes2rm_:" << arg.name;
}
}
}
for (auto& i : i_nodes_) {
for (auto& item : i.second) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << "i_nodes_: " << i.first << " " << stmt.op_type();
} else if (item->IsArg()) {
auto& arg = item->AsArg();
LOG(INFO) << "i_nodes_: " << i.first << " " << arg.name;
}
}
}
for (auto& i : o_nodes_) {
for (auto& item : i.second) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
LOG(INFO) << "o_nodes_:" << i.first << " " << stmt.op_type();
} else if (item->IsArg()) {
auto& arg = item->AsArg();
LOG(INFO) << "o_nodes_: " << i.first << " " << arg.name;
}
}
}
return sub_id - 1;
}
......@@ -129,7 +189,6 @@ int SubgraphProgramPass::FuseSubgraph(
LOG(INFO) << "detected " << num_subgraph << " subgraph";
return num_subgraph;
}
} // namespace subgraph
} // namespace mir
} // namespace lite
......
......@@ -57,11 +57,15 @@ class SubgraphProgramPass : public ProgramPass {
private:
// {1: {nodes2rm_in_subgraph1, ...},
// 2: {nodes2rm_in_subgraph2, ...}}
std::unordered_map<int, std::unordered_set<const Node*>> nodes2rm_;
// delete nodes
std::unordered_map<int, std::unordered_set<Node*>> nodes2rm_;
// std::unordered_map<int, std::unordered_set<const Node*>> nodes2rm_;
// inputs nodes
std::unordered_map<int, std::unordered_set<const Node*>> i_nodes_;
std::unordered_map<int, std::unordered_set<Node*>> i_nodes_;
// std::unordered_map<int, std::unordered_set<const Node*>> i_nodes_;
// outputs nodes
std::unordered_map<int, std::unordered_set<const Node*>> o_nodes_;
std::unordered_map<int, std::unordered_set<Node*>> o_nodes_;
// std::unordered_map<int, std::unordered_set<const Node*>> o_nodes_;
};
} // namespace subgraph
......
......@@ -29,7 +29,7 @@ DEFINE_string(model_dir, "", "model_dir");
namespace paddle {
namespace lite {
TEST(SubgraphTest, mobilenetv1) {
TEST(SubgraphTest, mobilenetv2) {
cpp::ProgramDesc program_desc;
auto scope = std::make_shared<Scope>();
LoadModelPb(FLAGS_model_dir, scope.get(), &program_desc);
......@@ -46,7 +46,8 @@ TEST(SubgraphTest, mobilenetv1) {
auto graph = std::unique_ptr<mir::SSAGraph>(new mir::SSAGraph());
graph->Build(program, valid_places);
std::vector<std::string> supported_op_types{"conv2d",
std::vector<std::string> supported_op_types{"concat",
"conv2d",
"depthwise_conv2d",
"batch_norm",
"scale",
......@@ -54,9 +55,13 @@ TEST(SubgraphTest, mobilenetv1) {
"mul",
"elementwise_add",
"softmax",
"relu"};
"split",
"relu",
"reshape2",
"transpose2"};
auto* pass = new mir::subgraph::SubgraphProgramPass;
ASSERT_EQ(pass->FuseSubgraph(graph, supported_op_types), 1);
LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
}
// return output_var_names
......@@ -99,6 +104,77 @@ std::vector<std::string> AddFCDesc(
return out_var_names;
}
std::vector<std::string> AddElementwiseAddDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_X_names,
const std::vector<std::string>& input_Y_names) {
// CHECK_EQ(input_var_names.size(), 2);
static int id = 0;
std::string prefix = "elementwise_add_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("elementwise_add");
op_desc->SetInput("X", input_X_names);
op_desc->SetInput("Y", input_Y_names);
op_desc->SetOutput("Out", out_var_names);
op_desc->SetAttr("axis", -1);
id++;
return out_var_names;
}
std::vector<std::string> AddFeedDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_X_names) {
// CHECK_EQ(input_var_names.size(), 1);
static int id = 0;
std::string prefix = "feed_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("feed");
op_desc->SetInput("X", input_X_names);
op_desc->SetOutput("Out", out_var_names);
op_desc->SetAttr("col", 1);
id++;
return out_var_names;
}
std::vector<std::string> AddFetchDesc(
cpp::BlockDesc* block_desc,
const std::shared_ptr<Scope>& scope,
const std::vector<std::string>& input_X_names) {
// CHECK_EQ(input_var_names.size(), 1);
static int id = 0;
std::string prefix = "fetch_" + std::to_string(id);
auto* op_desc = block_desc->AddOp<cpp::OpDesc>();
auto* out = block_desc->AddVar<cpp::VarDesc>();
out->SetName(prefix + "_Out");
std::vector<std::string> out_var_names{prefix + "_Out"};
scope->Var(prefix + "_Out")->GetMutable<lite::Tensor>();
op_desc->SetType("fetch");
op_desc->SetInput("X", input_X_names);
op_desc->SetOutput("Out", out_var_names);
op_desc->SetAttr("col", 1);
id++;
return out_var_names;
}
std::unique_ptr<mir::SSAGraph> BuildSimpleNet(
cpp::ProgramDesc* program_desc,
const std::shared_ptr<Scope>& scope,
......@@ -134,6 +210,7 @@ TEST(SubGraphTest, SimpleNet) {
const int num_nodes = graph->nodes().size();
ASSERT_EQ(graph->nodes().size(), 9);
// LOG(INFO) << "After NPU Pass \n" << Visualize(graph.get());
}
} // namespace lite
......
......@@ -7,6 +7,7 @@ message(STATUS "compile with lite ARM kernels")
lite_cc_library(fc_compute_arm SRCS fc_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(activation_compute_arm SRCS activation_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(mul_compute_arm SRCS mul_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(matmul_compute_arm SRCS matmul_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(scale_compute_arm SRCS scale_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(softmax_compute_arm SRCS softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(conv_compute_arm SRCS conv_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -59,6 +60,8 @@ lite_cc_library(is_empty_compute_arm SRCS is_empty_compute.cc DEPS ${lite_kernel
lite_cc_library(shape_compute_arm SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(slice_compute_arm SRCS slice_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(cast_compute_arm SRCS cast_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(squeeze_compute_arm SRCS squeeze_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_library(expand_compute_arm SRCS expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
lite_cc_test(test_fc_compute_arm SRCS fc_compute_test.cc DEPS fc_compute_arm math_arm)
lite_cc_test(test_scale_compute_arm SRCS scale_compute_test.cc DEPS scale_compute_arm)
......@@ -84,6 +87,7 @@ set(arm_kernels
fc_compute_arm
activation_compute_arm
mul_compute_arm
matmul_compute_arm
scale_compute_arm
softmax_compute_arm
conv_compute_arm
......@@ -136,6 +140,8 @@ set(arm_kernels
shape_compute_arm
slice_compute_arm
cast_compute_arm
squeeze_compute_arm
expand_compute_arm
)
set(arm_kernels "${arm_kernels}" CACHE INTERNAL "arm kernels")
......@@ -127,6 +127,16 @@ void LogCompute::Run() {
x_data, output_data, x_dims.production(), ctx.threads());
}
void ExpCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_exp<float>(
x_data, output_data, x_dims.production(), ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -185,7 +195,7 @@ REGISTER_LITE_KERNEL(
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
relu6, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ReluCompute, def)
relu6, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::Relu6Compute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -194,3 +204,8 @@ REGISTER_LITE_KERNEL(
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
exp, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ExpCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -102,6 +102,16 @@ class LogCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~LogCompute() = default;
};
class ExpCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~ExpCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
......@@ -157,8 +157,8 @@ REGISTER_LITE_KERNEL(conv2d_transpose,
kNCHW,
paddle::lite::kernels::arm::Conv2DTransposeCompute,
def)
.BindInput("x", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("output", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Filter", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Output", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/expand_compute.h"
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void ExpandCompute::Run() {
auto& param = Param<operators::ExpandParam>();
const auto* x = param.X;
auto* out = param.Out;
std::vector<int> expand_times = param.expand_times;
const float* src = x->data<float>();
float* dst = out->mutable_data<float>();
int dims = expand_times.size();
DDim in_shape = x->dims();
int inner_num = 1;
int i = dims - 1;
int outer_num = in_shape.count(0, i);
inner_num *= in_shape[i];
for (int j = 0; j < outer_num; ++j) {
for (int k = 0; k < expand_times[i]; ++k) {
memcpy(dst + (j * expand_times[i] + k) * inner_num,
src + j * inner_num,
sizeof(float) * inner_num);
}
}
inner_num *= expand_times[i];
for (int i = dims - 2; i >= 0; --i) {
int outer_num = in_shape.count(0, i);
inner_num *= in_shape[i];
for (int j = outer_num - 1; j >= 0; --j) {
for (int k = expand_times[i] - 1; k >= 0; --k) {
memcpy(dst + (j * expand_times[i] + k) * inner_num,
dst + j * inner_num,
sizeof(float) * inner_num);
}
}
inner_num *= expand_times[i];
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
expand, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ExpandCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class ExpandCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~ExpandCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/matmul_compute.h"
#include <vector>
#include "lite/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void MatMulCompute::PrepareForRun() {
auto& ctx = this->ctx_->template As<ARMContext>();
}
void MatMulCompute::Run() {
auto& param = Param<param_t>();
const auto* x_data = param.X->data<float>();
const auto* y_data = param.Y->data<float>();
auto* o_data = param.Out->mutable_data<float>();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool x_transpose = param.transpose_X;
bool y_transpose = param.transpose_Y;
float alpha = param.alpha;
auto& ctx = this->ctx_->template As<ARMContext>();
if (x_dims.size() > 2 && y_dims.size() >= 2) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
if (x_transpose || y_transpose) {
LOG(FATAL) << "not supported transpose for x or y.";
}
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
if (y_dims.size() > 2) {
m_ = x_dims[x_dims.size() - 2];
k_ = y_dims[y_dims.size() - 2];
n_ = y_dims[y_dims.size() - 1];
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1];
int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
int out_inner = x_dims[x_dims.size() - 2] * y_dims[y_dims.size() - 1];
if (n_ == 1) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemv(x_data + i * x_inner,
y_data + i * y_inner,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
}
if (fabsf(alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= alpha;
}
}
} else {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
lite::arm::math::prepackA(packed_x,
x_data + i * x_inner,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
y_data + i * y_inner,
ldb,
0.f,
o_data + i * out_inner,
n_,
nullptr,
false,
false,
&ctx);
}
}
} else {
m_ = x_dims[x_dims.size() - 2];
k_ = y_dims[0];
n_ = y_dims[1];
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
int x_inner = x_dims[x_dims.size() - 2] * x_dims[x_dims.size() - 1];
int out_inner = x_dims[x_dims.size() - 2] * y_dims[1];
if (n_ == 1) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemv(x_data + i * x_inner,
y_data,
o_data + i * out_inner,
false,
m_,
k_,
false,
nullptr,
false);
}
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
lite::arm::math::prepackA(packed_x,
x_data + i * x_inner,
alpha,
k_,
0,
m_,
0,
k_,
false,
&ctx);
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
y_data,
ldb,
0.f,
o_data + i * out_inner,
n_,
nullptr,
false,
false,
&ctx);
}
}
}
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[1], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[0], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[0], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
}
// not supported transpose
if (x_transpose || y_transpose) {
LOG(FATAL) << "not supported transpose for x and y.";
}
m_ = x_dims[0];
k_ = x_dims[1];
n_ = y_dims[1];
int hblock = lite::arm::math::get_hblock(ctx.arch());
int m_round = 0;
m_round = hblock * ((m_ + hblock - 1) / hblock);
ctx.ExtendWorkspace(m_round * k_ * sizeof(float));
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, x_transpose, m_, k_, false, nullptr, false);
if (fabsf(param.alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= param.alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
lite::arm::math::prepackA(
packed_x, x_data, alpha, k_, 0, m_, 0, k_, x_transpose, &ctx);
int ldb = n_;
if (y_transpose) {
ldb = k_;
}
lite::arm::math::sgemm_prepack(y_transpose,
m_,
n_,
k_,
packed_x,
y_data,
ldb,
0.f,
o_data,
n_,
nullptr,
false,
false,
&ctx);
}
} else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 1); ++i) {
o_data[i] = 0;
for (size_t j = 0; j < y_dims[0]; ++j) {
o_data[i] += x_data[i * y_dims[0] + j] * y_data[j] * alpha;
}
}
} else if (x_dims.size() == 1 && y_dims.size() == 1) {
// x: [K], y: [K], out: [1]
if (x_dims[0] == y_dims[0] && x_transpose == false &&
y_transpose == false) {
o_data[0] = 0.;
for (size_t i = 0; i < x_dims[0]; ++i) {
o_data[0] += x_data[i] * y_data[i] * alpha;
}
}
// x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
if (x_transpose == true && y_transpose == true) {
m_ = x_dims[0];
k_ = 1;
n_ = y_dims[0];
if (n_ == 1) {
lite::arm::math::sgemv(
x_data, y_data, o_data, false, m_, k_, false, nullptr, false);
if (fabsf(alpha - 1.f) > 1e-8f) {
for (size_t i = 0; i < param.Out->dims().production(); ++i) {
o_data[i] *= alpha;
}
}
} else {
float* packed_x = static_cast<float*>(ctx.workspace_data<float>()) +
ctx.llc_size() / sizeof(float);
lite::arm::math::prepackA(
packed_x, x_data, alpha, k_, 0, m_, 0, k_, false, &ctx);
int ldb = n_;
lite::arm::math::sgemm_prepack(false,
m_,
n_,
k_,
packed_x,
y_data,
ldb,
0.f,
o_data,
n_,
nullptr,
false,
false,
&ctx);
}
}
} else {
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
matmul, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::MatMulCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
#include "lite/core/types.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class MatMulCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::MatMulParam;
void PrepareForRun() override;
void Run() override;
virtual ~MatMulCompute() = default;
private:
int m_, n_, k_;
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -78,12 +78,19 @@ void MulticlassNmsCompute::Run() {
}
}
lod_info.push_back(num);
(*lod).push_back(lod_info);
param.out->Resize({static_cast<int64_t>(result_corrected.size() / 6), 6});
float* out = param.out->mutable_data<float>();
std::memcpy(
out, result_corrected.data(), sizeof(float) * result_corrected.size());
if (result_corrected.empty()) {
(*lod).clear();
(*lod).push_back(std::vector<uint64_t>({0, 1}));
param.out->Resize({static_cast<int64_t>(1)});
param.out->mutable_data<float>()[0] = -1.;
} else {
param.out->Resize({static_cast<int64_t>(result_corrected.size() / 6), 6});
float* out = param.out->mutable_data<float>();
std::memcpy(
out, result_corrected.data(), sizeof(float) * result_corrected.size());
}
}
} // namespace arm
......
......@@ -235,6 +235,8 @@ void multiclass_nms_compute_ref(const operators::MulticlassNmsParam& param,
if (num_kept == 0) {
(*result).clear();
(*result).resize(1);
(*result)[0] = -1;
return;
} else {
(*result).resize(num_kept * 6);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/squeeze_compute.h"
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
void SqueezeCompute::Run() {
auto& param = Param<operators::SqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
}
void Squeeze2Compute::Run() {
auto& param = Param<operators::SqueezeParam>();
auto x = param.X;
auto output = param.Out;
auto xshape = param.XShape;
auto x_dims = x->dims();
auto* x_data = x->data<float>();
auto* out_data = output->mutable_data<float>();
auto* xshape_data = xshape->mutable_data<float>();
memcpy(out_data, x_data, x_dims.production() * sizeof(float));
memcpy(xshape_data, x_data, x_dims.production() * sizeof(float));
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(squeeze,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::host::SqueezeCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(squeeze2,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::host::Squeeze2Compute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("XShape", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
class SqueezeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~SqueezeCompute() = default;
};
class Squeeze2Compute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
void Run() override;
virtual ~Squeeze2Compute() = default;
};
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -40,6 +40,10 @@ void GraphCompute::PrepareForRun() {
npu_otensors_.resize(npu_odims_.size());
for (size_t i = 0; i < npu_idims_.size(); ++i) {
VLOG(3) << "npu_idims[" << i << "]: " << npu_idims_[i].GetNumber() << ","
<< npu_idims_[i].GetChannel() << "," << npu_idims_[i].GetHeight()
<< "," << npu_idims_[i].GetWidth();
VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i]->dims();
CHECK_EQ(param.inputs[i]->dims().production(),
npu_idims_[i].GetNumber() * npu_idims_[i].GetChannel() *
npu_idims_[i].GetHeight() * npu_idims_[i].GetWidth());
......@@ -48,6 +52,10 @@ void GraphCompute::PrepareForRun() {
}
for (size_t i = 0; i < npu_odims_.size(); ++i) {
VLOG(3) << "npu_odims[" << i << "]: " << npu_odims_[i].GetNumber() << ","
<< npu_odims_[i].GetChannel() << "," << npu_odims_[i].GetHeight()
<< "," << npu_odims_[i].GetWidth();
VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i]->dims();
auto out_size = npu_odims_[i].GetNumber() * npu_odims_[i].GetChannel() *
npu_odims_[i].GetHeight() * npu_odims_[i].GetWidth();
if (param.outputs[i]->dims().production() != out_size) {
......
......@@ -16,10 +16,12 @@
#include <algorithm>
#include <fstream>
#include <limits>
#include <set>
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/core/variable.h"
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/naive_buffer/combined_params_desc.h"
#include "lite/model_parser/naive_buffer/param_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h"
#include "lite/model_parser/naive_buffer/var_desc.h"
......@@ -316,18 +318,19 @@ void SerializeTensor(std::ostream &os,
}
/// For navie buffer
void SaveParamNaive(const std::string &path,
const lite::Scope &scope,
const std::string &var_name) {
void SetParamInfoNaive(naive_buffer::ParamDesc *param_desc,
const lite::Scope &scope,
const std::string &var_name) {
CHECK(param_desc);
auto &desc = *param_desc;
// the 1st field, uint32_t version
constexpr uint32_t version = 0;
auto *var = scope.FindVar(var_name);
const auto &tensor = var->Get<lite::Tensor>();
naive_buffer::BinaryTable table;
naive_buffer::proto::ParamDesc pt_desc(&table);
naive_buffer::ParamDesc desc(&pt_desc);
desc.SetName(var_name);
desc.SetModelVersion(version);
desc.SetTensorVersion(version);
......@@ -355,18 +358,50 @@ void SaveParamNaive(const std::string &path,
{
desc.SetData<float>(tensor.data<float>(), tensor.data_size());
}
}
void SaveParamNaive(const std::string &path,
const lite::Scope &scope,
const std::string &var_name) {
naive_buffer::BinaryTable table;
naive_buffer::proto::ParamDesc pt_desc(&table);
naive_buffer::ParamDesc desc(&pt_desc);
SetParamInfoNaive(&desc, scope, var_name);
// Save param
pt_desc.Save();
table.SaveToFile(path);
}
void SaveCombinedParamsNaive(const std::string &path,
const lite::Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
naive_buffer::BinaryTable table;
naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
naive_buffer::CombinedParamsDesc desc(&pt_desc);
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
naive_buffer::ParamDesc param_desc(desc.AddParam());
SetParamInfoNaive(&param_desc, exec_scope, var.Name());
}
pt_desc.Save();
table.SaveToFile(path);
}
void SaveModelNaive(const std::string &model_dir,
const Scope &exec_scope,
const cpp::ProgramDesc &cpp_prog) {
const cpp::ProgramDesc &cpp_prog,
bool combined) {
MkDirRecur(model_dir);
// Save program
const std::string prog_path = model_dir + "/__model__";
const std::string prog_path = model_dir + "/__model__.nb";
naive_buffer::BinaryTable table;
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
naive_buffer::ProgramDesc nb_prog(&nb_proto_prog);
......@@ -376,14 +411,19 @@ void SaveModelNaive(const std::string &model_dir,
// Save Params
// NOTE: Only main block be used now.
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
const std::string path = model_dir + "/" + var.Name();
SaveParamNaive(path, exec_scope, var.Name());
if (combined) {
const std::string combined_params_path = model_dir + "/param.nb";
SaveCombinedParamsNaive(combined_params_path, exec_scope, cpp_prog);
} else {
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
const std::string path = model_dir + "/" + var.Name() + ".nb";
SaveParamNaive(path, exec_scope, var.Name());
}
}
VLOG(4) << "Save naive buffer model in '" << model_dir << "'' successfully";
}
......@@ -398,18 +438,15 @@ void SetTensorDataNaive(T *out, size_t size, const std::vector<T> &src) {
}
}
void LoadParamNaive(const std::string &path,
lite::Scope *scope,
const std::string &name) {
void GetParamInfoNaive(const naive_buffer::ParamDesc &desc,
lite::Scope *scope,
const std::string &name) {
CHECK(scope);
auto *tensor = scope->Var(name)->GetMutable<lite::Tensor>();
CHECK_EQ(desc.Name(), name)
<< "Var name not equal: ParamDesc.name=" << desc.Name()
<< "vs filename=" << name;
// Load param
naive_buffer::BinaryTable table;
table.LoadFromFile(path);
naive_buffer::proto::ParamDesc pt_desc(&table);
pt_desc.Load();
naive_buffer::ParamDesc desc(&pt_desc);
auto *tensor = scope->Var(name)->GetMutable<lite::Tensor>();
VLOG(3) << "model version " << desc.ModelVersion();
CHECK_EQ(desc.TensorVersion(), 0U) << "Only version 0 is supported";
......@@ -442,15 +479,56 @@ void LoadParamNaive(const std::string &path,
}
}
void LoadParamNaive(const std::string &path,
lite::Scope *scope,
const std::string &name) {
// Load param
naive_buffer::BinaryTable table;
table.LoadFromFile(path);
naive_buffer::proto::ParamDesc pt_desc(&table);
pt_desc.Load();
naive_buffer::ParamDesc desc(&pt_desc);
GetParamInfoNaive(desc, scope, name);
}
void LoadCombinedParamsNaive(const std::string &path,
lite::Scope *scope,
const cpp::ProgramDesc &cpp_prog) {
naive_buffer::BinaryTable table;
table.LoadFromFile(path);
naive_buffer::proto::CombinedParamsDesc pt_desc(&table);
pt_desc.Load();
naive_buffer::CombinedParamsDesc desc(&pt_desc);
std::set<std::string> param_names;
for (size_t i = 0; i < desc.ParamsSize(); ++i) {
naive_buffer::ParamDesc param_desc(desc.GetParam(i));
GetParamInfoNaive(param_desc, scope, param_desc.Name());
param_names.insert(param_desc.Name());
}
// Check all params loaded
auto prog = cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
CHECK(param_names.count(var.Name())) << "Persistable var[" << var.Name()
<< "] not found";
}
}
void LoadModelNaive(const std::string &model_dir,
Scope *scope,
cpp::ProgramDesc *cpp_prog) {
cpp::ProgramDesc *cpp_prog,
bool combined) {
CHECK(cpp_prog);
CHECK(scope);
cpp_prog->ClearBlocks();
// Load model
const std::string prog_path = model_dir + "/__model__";
const std::string prog_path = model_dir + "/__model__.nb";
naive_buffer::BinaryTable table;
table.LoadFromFile(prog_path);
naive_buffer::proto::ProgramDesc nb_proto_prog(&table);
......@@ -462,26 +540,33 @@ void LoadModelNaive(const std::string &model_dir,
// Load Params
// NOTE: Only main block be used now.
auto &prog = *cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
std::string file_path = model_dir + "/" + var.Name();
VLOG(4) << "reading weight " << var.Name();
switch (var.GetType()) {
case VarDescAPI::Type::LOD_TENSOR:
LoadParamNaive(file_path, scope, var.Name());
break;
default:
CHECK(false) << "unknown weight type";
if (combined) {
const std::string combined_params_path = model_dir + "/param.nb";
LoadCombinedParamsNaive(combined_params_path, scope, *cpp_prog);
} else {
auto &prog = *cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.VarsSize(); ++i) {
auto &var = *main_block_desc.GetVar<cpp::VarDesc>(i);
if (var.Name() == "feed" || var.Name() == "fetch" || !var.Persistable())
continue;
std::string file_path = model_dir + "/" + var.Name() + ".nb";
VLOG(4) << "reading weight " << var.Name();
switch (var.GetType()) {
case VarDescAPI::Type::LOD_TENSOR:
LoadParamNaive(file_path, scope, var.Name());
break;
default:
CHECK(false) << "unknown weight type";
}
}
}
#ifdef LITE_WITH_NPU
auto &prog = *cpp_prog;
auto &main_block_desc = *prog.GetBlock<cpp::BlockDesc>(0);
for (size_t i = 0; i < main_block_desc.OpsSize(); ++i) {
auto &op = *main_block_desc.GetOp<cpp::OpDesc>(i);
if (op.Type() != "graph_op") {
......
......@@ -66,18 +66,28 @@ void SaveParamNaive(const std::string& path,
const lite::Scope& exec_scope,
const std::string& var_name);
void SaveCombinedParamsNaive(const std::string& path,
const lite::Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog);
void SaveModelNaive(const std::string& model_dir,
const Scope& exec_scope,
const cpp::ProgramDesc& cpp_prog);
const cpp::ProgramDesc& cpp_prog,
bool combined = true);
#endif
void LoadParamNaive(const std::string& path,
lite::Scope* scope,
const std::string& name);
void LoadCombinedParamsNaive(const std::string& path,
lite::Scope* scope,
const cpp::ProgramDesc& cpp_prog);
void LoadModelNaive(const std::string& model_dir,
lite::Scope* scope,
cpp::ProgramDesc* prog);
cpp::ProgramDesc* prog,
bool combined = true);
} // namespace lite
} // namespace paddle
......@@ -5,14 +5,15 @@ add_subdirectory(proto)
lite_cc_library(nb_op_desc SRCS op_desc.cc DEPS framework_nb)
lite_cc_library(nb_var_desc SRCS var_desc.cc DEPS framework_nb)
lite_cc_library(nb_param_desc SRCS param_desc.cc DEPS framework_nb)
lite_cc_library(nb_combined_params_desc SRCS combined_params_desc.cc DEPS nb_param_desc framework_nb)
lite_cc_library(nb_block_desc SRCS block_desc.cc DEPS framework_nb)
lite_cc_library(nb_program_desc SRCS program_desc.cc DEPS framework_nb)
set(naive_wrapper
nb_op_desc nb_var_desc nb_param_desc
nb_op_desc nb_var_desc nb_param_desc nb_combined_params_desc
nb_block_desc nb_program_desc PARENT_SCOPE)
lite_cc_test(test_naive_buffer SRCS naive_buffer_test.cc DEPS naive_buffer)
lite_cc_test(test_naive_buffer_wrapper SRCS naive_buffer_wrapper_test.cc
DEPS nb_op_desc nb_var_desc nb_param_desc nb_block_desc
nb_program_desc)
DEPS nb_op_desc nb_var_desc nb_param_desc nb_combined_params_desc
nb_block_desc nb_program_desc)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/model_parser/naive_buffer/combined_params_desc.h"
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "lite/model_parser/desc_apis.h"
#include "lite/model_parser/naive_buffer/param_desc.h"
#include "lite/model_parser/naive_buffer/proto/framework.nb.h"
namespace paddle {
namespace lite {
namespace naive_buffer {
class CombinedParamsDesc {
public:
CombinedParamsDesc() = delete;
explicit CombinedParamsDesc(proto::CombinedParamsDesc *desc) : desc_(desc) {
CHECK(desc_);
}
void CopyFrom(CombinedParamsDesc &combined_params_desc) { // NOLINT
CHECK(combined_params_desc.Proto())
<< "Source proto::CombinedParamsDesc pointer can't be null";
desc_ = combined_params_desc.Proto();
}
proto::CombinedParamsDesc *Proto() { return desc_; }
const proto::CombinedParamsDesc &ReadonlyProto() const { return *desc_; }
size_t ParamsSize() const { return desc_->size(); }
void ClearParams() { desc_->Clear(); }
proto::ParamDesc *GetParam(int32_t idx) {
CHECK_LT(idx, ParamsSize()) << "idx >= params.size()";
return desc_->GetMutable(idx);
}
proto::ParamDesc *AddParam() { return desc_->New(); }
private:
proto::CombinedParamsDesc *desc_;
};
} // namespace naive_buffer
} // namespace lite
} // namespace paddle
......@@ -14,6 +14,7 @@
#include <gtest/gtest.h>
#include "lite/model_parser/naive_buffer/block_desc.h"
#include "lite/model_parser/naive_buffer/combined_params_desc.h"
#include "lite/model_parser/naive_buffer/op_desc.h"
#include "lite/model_parser/naive_buffer/param_desc.h"
#include "lite/model_parser/naive_buffer/program_desc.h"
......@@ -97,6 +98,7 @@ TEST(NaiveBufferWrapper, ParamDesc) {
ParamDesc nb_desc0(&pt_desc0);
// Set ParamDesc
nb_desc0.SetName("fc_w.0");
nb_desc0.SetModelVersion(0);
nb_desc0.SetTensorVersion(1);
std::vector<std::vector<uint64_t>> lod({{1, 2, 3}, {4, 5}});
......@@ -122,6 +124,7 @@ TEST(NaiveBufferWrapper, ParamDesc) {
pt_desc1.Load();
ParamDesc nb_desc1(&pt_desc1);
ASSERT_EQ(nb_desc1.Name(), "fc_w.0");
ASSERT_EQ(nb_desc1.ModelVersion(), 0);
ASSERT_EQ(nb_desc1.TensorVersion(), 1);
ASSERT_EQ(nb_desc1.LoDLevel(), 2);
......@@ -134,6 +137,84 @@ TEST(NaiveBufferWrapper, ParamDesc) {
}
}
TEST(NaiveBufferWrapper, CombinedParamsDesc) {
BinaryTable table0;
proto::CombinedParamsDesc pt_desc0(&table0);
CombinedParamsDesc nb_desc0(&pt_desc0);
// Set ParamDesc
ParamDesc param_desc0_0(nb_desc0.AddParam());
param_desc0_0.SetName("fc_w.0");
param_desc0_0.SetModelVersion(0);
param_desc0_0.SetTensorVersion(1);
std::vector<std::vector<uint64_t>> param_desc0_0_lod({{1, 2, 3}, {4, 5}});
param_desc0_0.SetLoDLevel(2);
param_desc0_0.SetLoD(param_desc0_0_lod);
std::vector<int64_t> param_desc0_0_dim({1, 2, 5});
param_desc0_0.SetDim(param_desc0_0_dim);
param_desc0_0.SetDataType(VarDescAPI::VarDataType::FP32);
std::vector<float> param_desc0_0_data;
for (int i = 0; i < 10; ++i) {
param_desc0_0_data.push_back(i / 10.0);
}
param_desc0_0.SetData(param_desc0_0_data);
ParamDesc param_desc0_1(nb_desc0.AddParam());
param_desc0_1.SetName("fc_b.0");
param_desc0_1.SetModelVersion(0);
param_desc0_1.SetTensorVersion(1);
std::vector<std::vector<uint64_t>> param_desc0_1_lod({{1}, {2, 3}, {4, 5}});
param_desc0_1.SetLoDLevel(3);
param_desc0_1.SetLoD(param_desc0_1_lod);
std::vector<int64_t> param_desc0_1_dim({1, 2, 2, 5});
param_desc0_1.SetDim(param_desc0_1_dim);
param_desc0_1.SetDataType(VarDescAPI::VarDataType::FP32);
std::vector<float> param_desc0_1_data;
for (int i = 0; i < 20; ++i) {
param_desc0_1_data.push_back((i - 10) / 10.0);
}
param_desc0_1.SetData(param_desc0_1_data);
// Save model
pt_desc0.Save();
table0.SaveToFile("4.bf");
// Load model
BinaryTable table1;
table1.LoadFromFile("4.bf");
proto::CombinedParamsDesc pt_desc1(&table1);
pt_desc1.Load();
CombinedParamsDesc nb_desc1(&pt_desc1);
ASSERT_EQ(nb_desc1.ParamsSize(), 2);
ParamDesc param_desc1_0(nb_desc1.GetParam(0));
ASSERT_EQ(param_desc1_0.Name(), "fc_w.0");
ASSERT_EQ(param_desc1_0.ModelVersion(), 0);
ASSERT_EQ(param_desc1_0.TensorVersion(), 1);
ASSERT_EQ(param_desc1_0.LoDLevel(), 2);
ASSERT_EQ(param_desc1_0.LoD(), param_desc0_0_lod);
ASSERT_EQ(param_desc1_0.Dim(), param_desc0_0_dim);
auto param_desc1_0_data = param_desc1_0.Data<float>();
ASSERT_EQ(param_desc1_0_data.size(), param_desc0_0_data.size());
for (size_t i = 0; i < param_desc1_0_data.size(); ++i) {
EXPECT_NEAR(param_desc1_0_data[i], param_desc0_0_data[i], 1e-6);
}
ParamDesc param_desc1_1(nb_desc1.GetParam(1));
ASSERT_EQ(param_desc1_1.Name(), "fc_b.0");
ASSERT_EQ(param_desc1_1.ModelVersion(), 0);
ASSERT_EQ(param_desc1_1.TensorVersion(), 1);
ASSERT_EQ(param_desc1_1.LoDLevel(), 3);
ASSERT_EQ(param_desc1_1.LoD(), param_desc0_1_lod);
ASSERT_EQ(param_desc1_1.Dim(), param_desc0_1_dim);
auto param_desc1_1_data = param_desc1_1.Data<float>();
ASSERT_EQ(param_desc1_1_data.size(), param_desc0_1_data.size());
for (size_t i = 0; i < param_desc1_1_data.size(); ++i) {
EXPECT_NEAR(param_desc1_1_data[i], param_desc0_1_data[i], 1e-6);
}
}
TEST(NaiveBufferWrapper, BlockDesc) {
BinaryTable table0;
proto::BlockDesc pt_desc0(&table0);
......@@ -161,11 +242,11 @@ TEST(NaiveBufferWrapper, BlockDesc) {
// Save model
pt_desc0.Save();
table0.SaveToFile("4.bf");
table0.SaveToFile("5.bf");
// Load model
BinaryTable table1;
table1.LoadFromFile("4.bf");
table1.LoadFromFile("5.bf");
proto::BlockDesc pt_desc1(&table1);
pt_desc1.Load();
BlockDesc nb_desc1(&pt_desc1);
......@@ -217,11 +298,11 @@ TEST(NaiveBufferWrapper, ProgramDesc) {
// Save model
pt_desc0.Save();
table0.SaveToFile("5.bf");
table0.SaveToFile("6.bf");
// Load model
BinaryTable table1;
table1.LoadFromFile("5.bf");
table1.LoadFromFile("6.bf");
proto::ProgramDesc pt_desc1(&table1);
pt_desc1.Load();
ProgramDesc nb_desc1(&pt_desc1);
......
......@@ -21,6 +21,16 @@ namespace paddle {
namespace lite {
namespace naive_buffer {
std::string ParamDesc::Name() const {
return desc_->GetField<StringBuilder>("name").data();
}
void ParamDesc::SetName(const std::string& name) {
auto* build = desc_->GetMutableField<StringBuilder>("name");
CHECK(build);
build->set(name);
}
uint32_t ParamDesc::ModelVersion() const { return Version("model_version"); }
void ParamDesc::SetModelVersion(uint32_t version) {
......
......@@ -40,6 +40,10 @@ class ParamDesc {
const proto::ParamDesc &ReadonlyProto() const { return *desc_; }
std::string Name() const;
void SetName(const std::string &name);
uint32_t ModelVersion() const;
void SetModelVersion(uint32_t version);
......
......@@ -185,6 +185,7 @@ class ParamDesc : public StructBuilder {
public:
using lod_type = ListBuilder<ListBuilder<UInt64Builder>>;
explicit ParamDesc(BinaryTable* table) : StructBuilder(table) {
NewStr("name");
NewUInt32("model_version");
NewUInt64("lod_level");
New<lod_type>("lod");
......@@ -194,6 +195,8 @@ class ParamDesc : public StructBuilder {
}
};
using CombinedParamsDesc = ListBuilder<ParamDesc>;
} // namespace proto
} // namespace naive_buffer
} // namespace lite
......
......@@ -13,7 +13,13 @@ lite_cc_library(npu_bridge_softmax_op SRCS softmax_op.cc DEPS ${npu_bridge_deps}
lite_cc_library(npu_bridge_pool_op SRCS pool_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_batch_norm_op SRCS batch_norm_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_elementwise_op SRCS elementwise_ops.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_reshape_op SRCS reshape_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_conv_transpose_op SRCS conv_transpose_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_bilinear_interp_op SRCS bilinear_interp_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_transpose_op SRCS transpose_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_split_op SRCS split_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_concat_op SRCS concat_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_shuffle_channel_op SRCS shuffle_channel_op.cc DEPS ${npu_bridge_deps})
set(npu_bridges
npu_bridge_registry
......@@ -27,7 +33,13 @@ set(npu_bridges
npu_bridge_pool_op
npu_bridge_batch_norm_op
npu_bridge_elementwise_op
npu_bridge_reshape_op
npu_bridge_conv_transpose_op
npu_bridge_bilinear_interp_op
npu_bridge_transpose_op
npu_bridge_split_op
npu_bridge_concat_op
npu_bridge_shuffle_channel_op
CACHE INTERNAL "npu_bridges")
lite_cc_library(npu_test_helper SRCS test_helper.cc DEPS npu_helper ${npu_ddk_libs} ${npu_bridges} ${npu_kernels} ${ops})
......@@ -41,6 +53,12 @@ lite_cc_test(test_npu_bridge_softmax_op SRCS softmax_op_test.cc DEPS npu_test_he
lite_cc_test(test_npu_bridge_pool_op SRCS pool_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_batch_norm_op SRCS batch_norm_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_elementwise_op SRCS elementwise_ops_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_reshape_op SRCS reshape_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_conv_transpose_op SRCS conv_transpose_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_bilinear_interp_op SRCS bilinear_interp_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_transpose_op SRCS transpose_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_split_op SRCS split_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_concat_op SRCS concat_op_test.cc DEPS npu_test_helper)
lite_cc_test(test_npu_bridge_shuffle_channel_op SRCS shuffle_channel_op_test.cc DEPS npu_test_helper)
message(STATUS "+++++ npu_bridges: ${npu_bridges}")
......@@ -29,14 +29,15 @@ namespace bridge {
node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
const node_map_type& inputs_map) {
VLOG(3) << "invoking ActConverter...";
auto scope = act_op->scope();
auto op_info = act_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// create act node and set input node from inputs_map
auto x_var_name = op_info->Input("X").front();
auto act_node = std::make_shared<ge::op::Activation>(UniqueName(op_type));
auto act_node = std::make_shared<ge::op::Activation>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
act_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/utils.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
node_map_type BilinearInterpConverter(
const std::shared_ptr<lite::OpLite> interp_op,
const node_map_type& inputs_map) {
auto scope = interp_op->scope();
auto op_info = interp_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// get input, output and attributes from lite op
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
auto x_h = x_dims[2];
auto x_w = x_dims[3];
CHECK_EQ(x_dims.size(), 4);
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners");
auto interp_method = op_info->GetAttr<std::string>("interp_method");
int align_mode = op_info->GetAttr<int>("align_mode");
CHECK(!(align_mode == 0 && !align_corners))
<< "align_mode = 0 && align_corners = false isn't supported in NPU DDK";
// priority: OutSize > scale > out_h/out_w
if (scale > 0) {
out_h = static_cast<int>(x_h * scale);
out_w = static_cast<int>(x_w * scale);
out_h = out_h > 0 ? out_h : -1;
out_w = out_w > 0 ? out_w : -1;
}
// create interp node and set input node from inputs_map
auto interp_node = std::make_shared<ge::op::ResizeBilinear>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
interp_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(interp_node);
// update out_h and out_w if has OutSize
bool is_dyn_out_size = false;
if (HasInputArg(op_info, scope, "OutSize")) {
auto out_size_var_name = op_info->Input("OutSize").front();
if (!inputs_map.count(out_size_var_name)) {
auto out_size =
scope->FindVar(out_size_var_name)->GetMutable<lite::Tensor>();
auto out_size_dims = out_size->dims();
CHECK_EQ(out_size_dims.size(), 1);
CHECK_EQ(out_size_dims.production(), 2);
auto out_size_data = out_size->mutable_data<int>();
// update out_h and out_w if has OutSize
out_h = out_size_data[0];
out_w = out_size_data[1];
} else {
interp_node->set_input_w(*inputs_map.at(out_size_var_name));
OpList::Global().add(inputs_map.at(out_size_var_name));
is_dyn_out_size = true; // using dynamic output size
}
}
if (!is_dyn_out_size) {
CHECK_GT(out_h, 0);
CHECK_GT(out_w, 0);
const float largest_multiple = 7.0f;
float multiple = static_cast<float>(x_h * x_w) / (out_h * out_w);
CHECK_LT(multiple, largest_multiple)
<< "multiple=(ih*iw)/(oh*ow)=" << multiple
<< " is too large, should not exceed " << largest_multiple
<< " in NPU DDK";
auto w_const_node = std::make_shared<ge::op::Const>(unique_op_type + "/w");
w_const_node->set_attr_value(
CreateTensorAndFillData(std::vector<int>({out_h, out_w})));
interp_node->set_input_w(*w_const_node);
OpList::Global().add(w_const_node);
}
// set attributes
interp_node->set_attr_output_dim_mode(
2); // 0: zoom_factor, 1: shrink_factor, 2: height/width
interp_node->set_attr_align_corners(align_corners);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = interp_node;
return outputs_map;
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(bilinear_interp,
paddle::lite::npu::bridge::BilinearInterpConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/test_helper.h"
#include "lite/operators/interpolate_op.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
template <typename DType>
void bilinear_interp_ref(const std::shared_ptr<operators::InterpolateOp> op) {
auto scope = op->scope();
auto op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_dims = x->dims();
int batch_size = x_dims[0];
int channel_size = x_dims[1];
auto x_h = x_dims[2];
auto x_w = x_dims[3];
CHECK_EQ(x_dims.size(), 4);
auto scale = op_info->GetAttr<float>("scale");
auto out_w = op_info->GetAttr<int>("out_w");
auto out_h = op_info->GetAttr<int>("out_h");
auto align_corners = op_info->GetAttr<bool>("align_corners");
int align_mode = op_info->GetAttr<int>("align_mode");
auto interp_method = op_info->GetAttr<std::string>("interp_method");
// calc real out_h and out_w
if (scale > 0) {
out_h = static_cast<int>(x_h * scale);
out_w = static_cast<int>(x_w * scale);
}
if (op_info->HasInput("OutSize")) {
auto out_size_var_names = op_info->Input("OutSize");
if (out_size_var_names.size() > 0) {
auto out_size_var_name = out_size_var_names.front();
auto out_size =
scope->FindVar(out_size_var_name)->GetMutable<lite::Tensor>();
auto out_size_dims = out_size->dims();
CHECK_EQ(out_size_dims.size(), 1);
CHECK_EQ(out_size_dims.production(), 2);
auto out_size_data = out_size->mutable_data<int>();
out_h = out_size_data[0];
out_w = out_size_data[1];
}
}
CHECK_GT(out_h, 0);
CHECK_GT(out_w, 0);
out->Resize({batch_size, channel_size, out_h, out_w});
// copy from x if no change
if (x_h == out_h && x_w == out_w) {
out->CopyDataFrom(*x);
return;
}
float ratio_h = 0.f;
float ratio_w = 0.f;
if (out_h > 1) {
ratio_h = (align_corners) ? static_cast<float>(x_h - 1) / (out_h - 1)
: static_cast<float>(x_h) / out_h;
}
if (out_w > 1) {
ratio_w = (align_corners) ? static_cast<float>(x_w - 1) / (out_w - 1)
: static_cast<float>(x_w) / out_w;
}
// naive bilinear interpolation
auto x_data = x->mutable_data<DType>();
auto out_data = out->mutable_data<DType>();
bool align_flag = (align_mode == 0 && !align_corners);
std::vector<int> vy_n, vy_s;
std::vector<float> vd_n, vd_s;
vy_n.reserve(out_h);
vy_s.reserve(out_h);
vd_n.reserve(out_h);
vd_s.reserve(out_h);
for (int k = 0; k < out_h; k++) {
int yn = align_flag ? static_cast<int>(ratio_h * (k + 0.5) - 0.5)
: static_cast<int>(ratio_h * k);
yn = (yn > 0) ? yn : 0;
int ys = (yn + 1) < (x_h - 1) ? (yn + 1) : (x_h - 1);
float idx_src_y = ratio_h * (k + 0.5) - 0.5;
idx_src_y = (idx_src_y > 0) ? idx_src_y : 0;
float dn = align_flag ? idx_src_y - yn : ratio_h * k - yn;
float ds = 1.f - dn;
{
vy_n[k] = yn;
vy_s[k] = ys;
vd_n[k] = dn;
vd_s[k] = ds;
}
}
std::vector<int> vx_w, vx_e;
std::vector<float> vd_w, vd_e;
vx_w.reserve(out_w);
vx_e.reserve(out_w);
vd_w.reserve(out_w);
vd_e.reserve(out_w);
for (int l = 0; l < out_w; l++) {
int xw = (align_mode == 0 && !align_corners)
? static_cast<int>(ratio_w * (l + 0.5) - 0.5)
: static_cast<int>(ratio_w * l);
xw = (xw > 0) ? xw : 0;
int xe = (xw + 1) < (x_w - 1) ? (xw + 1) : (x_w - 1);
float idx_src_x = ratio_w * (l + 0.5) - 0.5;
idx_src_x = (idx_src_x > 0) ? idx_src_x : 0;
float dw = align_flag ? idx_src_x - xw : ratio_w * l - xw;
float de = 1.f - dw;
{
vx_w[l] = xw;
vx_e[l] = xe;
vd_w[l] = dw;
vd_e[l] = de;
}
}
std::vector<int64_t> x_strides(x_dims.size(), 1);
for (int idx = x_strides.size() - 2; idx >= 0; idx--) {
x_strides[idx] = x_strides[idx + 1] * x_dims[idx + 1];
}
for (int i = 0; i < batch_size; i++) {
for (int j = 0; j < channel_size; j++) {
for (int k = 0; k < out_h; k++) {
for (int l = 0; l < out_w; l++) {
DType x0 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_n[k] * x_strides[2] + vx_w[l] * x_strides[3]];
DType x1 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_s[k] * x_strides[2] + vx_w[l] * x_strides[3]];
DType x2 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_n[k] * x_strides[2] + vx_e[l] * x_strides[3]];
DType x3 = x_data[i * x_strides[0] + j * x_strides[1] +
vy_s[k] * x_strides[2] + vx_e[l] * x_strides[3]];
*out_data = x0 * vd_s[k] * vd_e[l] + x1 * vd_n[k] * vd_e[l] +
x2 * vd_s[k] * vd_w[l] + x3 * vd_n[k] * vd_w[l];
out_data++;
}
}
}
}
}
void test_bilinear_interp(int bs,
int ic,
int ih,
int iw,
int oh,
int ow,
float scale,
int out_size_h,
int out_size_w,
bool align_corners,
int align_mode) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string out_size_var_name("out_size");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
auto x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto out_size = scope.Var(out_size_var_name)->GetMutable<Tensor>();
auto out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
out_size->Resize({2});
// initialize input&output data
FillTensor<float, int>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("bilinear_interp");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("out_h", oh);
opdesc.SetAttr("out_w", ow);
opdesc.SetAttr("scale", scale);
opdesc.SetAttr("align_corners", static_cast<bool>(align_corners));
opdesc.SetAttr("align_mode", static_cast<int>(align_mode));
opdesc.SetAttr("interp_method", std::string("bilinear"));
if (out_size_h > 0 && out_size_w > 0) {
auto out_size_dims = out_size->dims();
CHECK_EQ(out_size_dims.size(), 1);
CHECK_EQ(out_size_dims.production(), 2);
auto out_size_data = out_size->mutable_data<int>();
out_size_data[0] = out_size_h;
out_size_data[1] = out_size_w;
opdesc.SetInput("OutSize", {out_size_var_name});
}
// create op and execute reference implementation
auto op = CreateOp<operators::InterpolateOp>(opdesc, &scope);
bilinear_interp_ref<float>(op);
out_ref->CopyDataFrom(*out);
// convert op to NPU model, then run it on NPU
LauchOp(op, {x_var_name}, {out_var_name});
// compare results
auto out_dims = out->dims();
auto out_ref_dims = out_ref->dims();
CHECK_EQ(out_dims.size(), out_ref_dims.size());
for (int i = 0; i < out_dims.size(); i++) {
CHECK_EQ(out_dims[i], out_ref_dims[i]);
}
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2f);
}
}
TEST(NPUBridges, bilinear_interp) {
#if 1
for (auto bs : {1, 3}) {
for (auto ic : {3, 4}) {
for (auto ih : {4, 5}) {
for (auto iw : {3, 6}) {
for (auto oh : {0, 3, 8}) {
for (auto ow : {0, 4, 9}) {
for (auto scale : {0.f, 0.5f, 0.6f, 2.0f, 2.2f}) {
for (auto out_size_h : {0, 3, 11}) {
for (auto out_size_w : {0, 2, 12}) {
for (auto align_corners : {true, false}) {
for (auto align_mode : {0, 1}) {
int act_oh = 0, act_ow = 0;
if (out_size_h > 0 && out_size_w > 0) {
act_oh = out_size_h;
act_ow = out_size_w;
} else if (scale > 1e-5) {
act_oh = static_cast<int>(ih * scale);
act_ow = static_cast<int>(iw * scale);
} else if (oh > 0 && ow > 0) {
act_oh = oh;
act_ow = ow;
}
if (act_oh <= 0 || act_ow <= 0) {
continue;
}
// TODO(hong19860320) multiple=(ih*iw)/(oh*ow) should
// not exceed 7.0 in NPU DDK, delete the following lines
// if the limination is removed.
const float largest_multiple = 7.0f;
float multiple =
static_cast<float>(ih * iw) / (act_oh * act_ow);
if (multiple > largest_multiple) {
continue;
}
if (align_mode == 0 && !align_corners) {
continue;
}
VLOG(3)
<< "bs: " << bs << " ic: " << ic << " ih: " << ih
<< " iw: " << iw << " oh: " << oh << " ow: " << ow
<< " scale: " << scale
<< " out_size: " << out_size_h << "," << out_size_w
<< " align_corners: " << align_corners
<< " align_mode: " << align_mode;
test_bilinear_interp(bs,
ic,
ih,
iw,
oh,
ow,
scale,
out_size_h,
out_size_w,
align_corners,
align_mode);
}
}
}
}
}
}
}
}
}
}
}
#else
test_bilinear_interp(3, 4, 5, 3, 8, 4, 0.6f, 3, 0, true, 0);
#endif
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
USE_LITE_OP(bilinear_interp);
USE_NPU_BRIDGE(bilinear_interp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/concat_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/utils.h"
#include "lite/npu/npu_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
const node_map_type& inputs_map) {
lite::Scope* scope = concat_op->scope();
const lite::OpInfo* op_info = concat_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "converting " << op_type << " ... ";
auto x_var_names = op_info->Input("X");
auto axis = op_info->GetAttr<int>("axis");
int num = x_var_names.size();
int index = 0;
std::shared_ptr<ge::op::Concat> output_node =
std::make_shared<ge::op::Concat>(unique_op_type);
output_node->set_attr_axis(axis);
output_node->set_attr_N(num);
output_node->create_dynamic_input_x(num);
for (auto x_var_name : x_var_names) {
if (inputs_map.find(x_var_name) != inputs_map.end()) {
output_node->set_dynamic_input_x(index + 1, *inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
} else {
auto consty = std::make_shared<ge::op::Const>(x_var_name);
auto* x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
consty->set_attr_value(CvtFromLiteTensor(x));
output_node->set_dynamic_input_x(index + 1, *consty);
OpList::Global().add(consty);
}
index++;
}
OpList::Global().add(output_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
return outputs_map;
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(concat, paddle::lite::npu::bridge::ConcatConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/concat_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/test_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
std::vector<size_t> stride_numel(const DDim& ddim) {
std::vector<size_t> strides(ddim.size());
strides[ddim.size() - 1] = ddim[ddim.size() - 1];
for (int i = ddim.size() - 2; i >= 0; --i) {
strides[i] = strides[i + 1] * ddim[i];
}
return strides;
}
void concat_ref(const std::shared_ptr<operators::ConcatOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = op_info->Input("X");
std::vector<lite::Tensor*> inputs;
for (auto var : x) {
inputs.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
int axis = op_info->GetAttr<int>("axis");
std::vector<lite::Tensor*> inputs_concat(inputs.size());
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = inputs[j];
}
size_t num = inputs.size();
int rows = 1;
auto dim_0 = inputs[0]->dims();
for (int i = 0; i < axis; ++i) {
rows *= dim_0[i];
}
int out_rows = rows, out_cols = 0;
std::vector<int64_t> inputs_cols(inputs.size());
for (int i = 0; i < num; ++i) {
int t_cols = inputs[i]->numel() / rows;
out_cols += t_cols;
inputs_cols[i] = t_cols;
}
for (int k = 0; k < out_rows; ++k) {
float* dst_ptr = out->mutable_data<float>() + k * out_cols;
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = inputs_cols[j];
const float* src_prt = inputs[j]->data<float>() + k * col_len;
std::memcpy(dst_ptr + col_idx, src_prt, sizeof(float) * col_len);
col_idx += col_len;
}
}
}
void test_concat(std::vector<vector<int64_t>> input, int axis) {
std::string x_var_name = "x";
std::string y_var_name = "y";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
// prepare input&output variables
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* y = scope.Var(y_var_name)->GetMutable<Tensor>();
x->Resize(DDim(input[0]));
y->Resize(DDim(input[1]));
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
CHECK_EQ(out->dims(), out_ref->dims());
// initialize input&output data
FillTensor<float>(x);
FillTensor<float>(y);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("concat");
opdesc.SetInput("X", {x_var_name, y_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("axis", axis);
auto op = CreateOp<operators::ConcatOpLite>(opdesc, &scope);
LauchOp(op, {x_var_name, y_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
concat_ref(op);
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 5e-4);
}
}
TEST(NPUBridges, concat) {
test_concat({{3, 3, 5, 2}, {2, 3, 5, 2}}, 0);
test_concat({{3, 5, 5, 2}, {3, 1, 5, 2}}, 1);
test_concat({{3, 3, 2, 2}, {3, 3, 4, 2}}, 2);
test_concat({{3, 3, 5, 2}, {3, 3, 5, 6}}, 3);
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
USE_LITE_OP(concat);
USE_NPU_BRIDGE(concat);
......@@ -33,20 +33,16 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto op_info = conv_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << " ... ";
LOG(INFO) << "Converting " << op_type << "... ";
// get input, output and op attributes
// get input, filter and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_dims = input->dims();
auto output_var_name = op_info->Output("Output").front();
auto output = scope->FindVar(output_var_name)->GetMutable<lite::Tensor>();
auto output_dims = output->dims();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_dims = filter->dims();
CHECK_EQ(input_dims.size(), 4);
CHECK_EQ(output_dims.size(), 4);
CHECK_EQ(filter_dims.size(), 4);
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
......@@ -89,33 +85,9 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_dims[0]);
CHECK_EQ(channel_size, output_dims[1]);
bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
if (use_depthwise_conv && is_depthwise_mode) {
// broadcast bias(1, oc, 1, 1) to (n, oc, oh, ow)
ge::TensorDesc bias_desc(
ge::Shape(output_dims.Vectorize()), ge::FORMAT_NCHW, ge::DT_FLOAT);
ge::TensorPtr bias_tensor = std::make_shared<ge::Tensor>();
bias_tensor->SetTensorDesc(bias_desc);
auto old_bias_data = bias->mutable_data<float>();
std::vector<float> new_bias_data(output_dims.production());
int batch_size = output_dims[0];
int inner_size = output_dims[2] * output_dims[3];
for (int k = 0; k < batch_size; k++) {
for (int j = 0; j < channel_size; j++) {
for (int i = 0; i < inner_size; i++) {
new_bias_data[i + j * inner_size + k * channel_size * inner_size] =
old_bias_data[j];
}
}
}
bias_tensor->SetData(reinterpret_cast<uint8_t*>(new_bias_data.data()),
new_bias_data.size() * sizeof(float));
bias_const_node->set_attr_value(bias_tensor);
} else {
bias_const_node->set_attr_value(
CvtFromLiteTensor(bias, {1, channel_size, 1, 1}));
}
bias_const_node->set_attr_value(
CvtFromLiteTensor(bias, {1, channel_size, 1, 1}));
OpList::Global().add(bias_const_node);
}
......@@ -142,13 +114,11 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
OpList::Global().add(depthwise_conv_node);
conv_node = depthwise_conv_node;
if (bias_const_node != nullptr) {
auto eltwise_add_node =
std::make_shared<ge::op::Eltwise>(unique_op_type + "/eltwise_add");
eltwise_add_node->set_input_x1(*depthwise_conv_node);
eltwise_add_node->set_input_x2(*bias_const_node);
eltwise_add_node->set_attr_mode(1); // 0:product, 1:sum, 2:max
OpList::Global().add(eltwise_add_node);
conv_node = eltwise_add_node;
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
add_node->set_input_x1(*depthwise_conv_node);
add_node->set_input_x2(*bias_const_node);
OpList::Global().add(add_node);
conv_node = add_node;
}
} else {
auto common_conv_node =
......@@ -182,9 +152,9 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(1);
OpList::Global().add(relu_node);
outputs_map[output_var_name] = relu_node;
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
outputs_map[output_var_name] = conv_node;
outputs_map[op_info->Output("Output").front()] = conv_node;
}
return outputs_map;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_transpose_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/utils.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
node_map_type ConvTransposeConverter(
const std::shared_ptr<lite::OpLite> conv_transpose_op,
const node_map_type& inputs_map) {
auto scope = conv_transpose_op->scope();
auto op_info = conv_transpose_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... ";
// get input, output and op attributes
auto input_var_name = op_info->Input("Input").front();
auto input = scope->FindVar(input_var_name)->GetMutable<lite::Tensor>();
auto input_shape = input->dims().Vectorize();
auto filter_var_name = op_info->Input("Filter").front();
auto filter = scope->FindVar(filter_var_name)->GetMutable<lite::Tensor>();
auto filter_shape = filter->dims().Vectorize();
CHECK_EQ(input_shape.size(), 4);
CHECK_EQ(filter_shape.size(), 4);
auto strides = op_info->GetAttr<std::vector<int>>("strides");
auto paddings = op_info->GetAttr<std::vector<int>>("paddings");
auto groups = op_info->GetAttr<int>("groups");
auto dilations = op_info->GetAttr<std::vector<int>>("dilations");
auto fuse_relu = op_info->GetAttr<bool>("fuse_relu");
CHECK_EQ(strides.size(), 2);
CHECK_EQ(paddings.size(), 2);
CHECK_EQ(dilations.size(), 2);
// create deconv node
auto conv_transpose_node =
std::make_shared<ge::op::Deconvolution>(unique_op_type);
// create input sizes node to describe the dimensions of input tensor
std::vector<int32_t> output_shape;
output_shape.push_back(input_shape[0]);
output_shape.push_back(filter_shape[1] * groups);
for (int i = 0; i < strides.size(); i++) {
int kernel_ext = dilations[i] * (filter_shape[i + 2] - 1) + 1;
int output_size =
(input_shape[i + 2] - 1) * strides[i] + kernel_ext - 2 * paddings[i];
output_shape.push_back(output_size);
}
auto input_sizes_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/input_size");
input_sizes_const_node->set_attr_value(CreateTensorAndFillData(output_shape));
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node);
OpList::Global().add(input_sizes_const_node);
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(CvtFromLiteTensor(filter));
conv_transpose_node->set_input_filter(*filter_const_node);
OpList::Global().add(filter_const_node);
// set input node
CHECK(inputs_map.count(input_var_name));
conv_transpose_node->set_input_x(*inputs_map.at(input_var_name));
OpList::Global().add(inputs_map.at(input_var_name));
// set attributes
conv_transpose_node->set_attr_mode(1);
conv_transpose_node->set_attr_format(0); // NCHW
conv_transpose_node->set_attr_pad_mode(0); // NOTSET
conv_transpose_node->set_attr_group(groups);
conv_transpose_node->set_attr_pad(ge::AttrValue::LIST_INT(
{paddings[0], paddings[0], paddings[1], paddings[1]}));
conv_transpose_node->set_attr_dilation(
ge::AttrValue::LIST_INT({dilations[0], dilations[1]}));
conv_transpose_node->set_attr_stride(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_transpose_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_shape[2], filter_shape[3]}));
OpList::Global().add(conv_transpose_node);
// append add node to add bias if has bias
std::shared_ptr<ge::Operator> output_node = conv_transpose_node;
if (HasInputArg(op_info, scope, "Bias")) {
// create bias node
auto bias_var_name = op_info->Input("Bias").front();
CHECK(!inputs_map.count(bias_var_name));
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto channel_size = bias->dims().production();
CHECK_EQ(channel_size, filter_shape[1] * groups);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
CvtFromLiteTensor(bias, {1, channel_size, 1, 1}));
OpList::Global().add(bias_const_node);
// append add node to add bias node
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
add_node->set_input_x1(*conv_transpose_node);
add_node->set_input_x2(*bias_const_node);
OpList::Global().add(add_node);
output_node = add_node;
}
node_map_type outputs_map;
if (fuse_relu) {
// append relu node if fuse_relu is true
auto relu_node =
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(1);
OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
outputs_map[op_info->Output("Output").front()] = output_node;
}
return outputs_map;
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(conv2d_transpose,
paddle::lite::npu::bridge::ConvTransposeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/conv_transpose_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/test_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
template <typename DType>
void add_bias_with_relu(DType* data,
const DType* bias,
int channel_size,
int inner_size,
bool has_relu) {
for (int c = 0; c < channel_size; ++c) {
DType bias_val = bias != nullptr ? bias[c] : 0;
for (int i = 0; i < inner_size; i++) {
DType data_val = data[i];
data_val += bias_val;
if (has_relu) {
data_val = data_val > 0 ? data_val : 0.f;
}
data[i] = data_val;
}
data += inner_size;
}
}
template <typename DType>
void col2im(const DType* data_col,
const int channel_size,
const int height,
const int width,
const int kernel_h,
const int kernel_w,
const int pad_h,
const int pad_w,
const int stride_h,
const int stride_w,
const int dilation_h,
const int dilation_w,
DType* data_im) {
memset(data_im, 0, height * width * channel_size * sizeof(DType));
const int output_h =
(height + 2 * pad_h - (dilation_h * (kernel_h - 1) + 1)) / stride_h + 1;
const int output_w =
(width + 2 * pad_w - (dilation_w * (kernel_w - 1) + 1)) / stride_w + 1;
const int inner_size = height * width;
for (int c = channel_size; c--; data_im += inner_size) {
for (int kernel_row = 0; kernel_row < kernel_h; kernel_row++) {
for (int kernel_col = 0; kernel_col < kernel_w; kernel_col++) {
int input_row = -pad_h + kernel_row * dilation_h;
for (int output_rows = output_h; output_rows; output_rows--) {
if (input_row < 0 || input_row >= height) {
data_col += output_w;
} else {
int input_col = -pad_w + kernel_col * dilation_w;
for (int output_col = output_w; output_col; output_col--) {
if (input_col >= 0 && input_col < width) {
data_im[input_row * width + input_col] += *data_col;
}
data_col++;
input_col += stride_w;
}
}
input_row += stride_h;
}
}
}
}
}
template <typename IType, typename OType>
void gemm(int M,
int N,
int K,
const IType* A,
const IType* B,
OType* C,
OType alpha,
OType beta,
bool is_trans_A = false,
bool is_trans_B = false) {
for (int m = 0; m < M; ++m) {
for (int n = 0; n < N; ++n) {
OType sum = static_cast<OType>(0);
for (int k = 0; k < K; ++k) {
IType a;
IType b;
if (is_trans_A) {
a = A[k * M + m];
} else {
a = A[m * K + k];
}
if (is_trans_B) {
b = B[n * K + k];
} else {
b = B[k * N + n];
}
sum += a * b;
}
C[m * N + n] = alpha * sum + beta * C[m * N + n];
}
}
}
template <typename IType, typename OType>
void conv_transpose_ref(
const std::shared_ptr<operators::ConvTransposeOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto input =
scope->FindVar(op_info->Input("Input").front())->GetMutable<Tensor>();
auto filter =
scope->FindVar(op_info->Input("Filter").front())->GetMutable<Tensor>();
auto output =
scope->FindVar(op_info->Output("Output").front())->GetMutable<Tensor>();
std::vector<int32_t> strides =
op_info->GetAttr<std::vector<int32_t>>("strides");
std::vector<int32_t> paddings =
op_info->GetAttr<std::vector<int32_t>>("paddings");
int32_t groups = op_info->GetAttr<int32_t>("groups");
std::vector<int32_t> dilations =
op_info->GetAttr<std::vector<int32_t>>("dilations");
bool fuse_relu = op_info->GetAttr<bool>("fuse_relu");
Tensor* bias = nullptr;
OType* bias_data = nullptr;
if (op_info->HasInput("Bias")) {
auto bias_var_names = op_info->Input("Bias");
if (bias_var_names.size() > 0) {
auto bias_var_name = bias_var_names.front();
bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
bias_data = bias->mutable_data<OType>();
}
}
auto input_dims = input->dims();
auto filter_dims = filter->dims();
auto output_dims = output->dims();
auto input_data = input->mutable_data<IType>();
auto filter_data = filter->mutable_data<IType>();
auto output_data = output->mutable_data<OType>();
int kernel_w = filter_dims[3];
int kernel_h = filter_dims[2];
int stride_w = strides[1];
int stride_h = strides[0];
int dila_w = dilations[1];
int dila_h = dilations[0];
int pad_w = paddings[1];
int pad_h = paddings[0];
int batch_size = input_dims[0];
int in_ch_size = input_dims[1];
int in_h = input_dims[2];
int in_w = input_dims[3];
int out_ch_size = output_dims[1];
int out_h = output_dims[2];
int out_w = output_dims[3];
int M = out_ch_size * kernel_w * kernel_h / groups;
int N = in_h * in_w;
int K = in_ch_size / groups;
if (in_ch_size != out_ch_size || groups != in_ch_size) {
CHECK_EQ(in_ch_size % groups, 0);
CHECK_EQ(out_ch_size % groups, 0);
}
auto workspace = std::vector<OType>(groups * M * N);
int group_input_size = in_w * in_h * in_ch_size / groups;
int group_output_size = out_w * out_h * out_ch_size / groups;
int group_col_size = M * N;
int group_filter_size =
in_ch_size * out_ch_size * kernel_w * kernel_h / (groups * groups);
bool flag_1x1s1p1 = (kernel_w == 1) && (kernel_h == 1) && (stride_h == 1) &&
(stride_w == 1) && (pad_w == 1) && (pad_h == 1) &&
(dila_w == 1) && (dila_h == 1);
for (int n = 0; n < batch_size; ++n) {
input_data += n * in_ch_size * in_h * in_w;
output_data += n * out_ch_size * out_h * out_w;
auto col_data = workspace.data();
if (flag_1x1s1p1) {
col_data = output_data;
}
memset(col_data, 0, sizeof(OType) * group_col_size);
for (int g = 0; g < groups; ++g) {
auto input_group_data = input_data + g * group_input_size;
auto filter_group_data = filter_data + g * group_filter_size;
auto col_group_data = col_data + g * group_col_size;
gemm<IType, OType>(M,
N,
K,
filter_group_data,
input_group_data,
col_group_data,
static_cast<OType>(1),
static_cast<OType>(0),
true,
false);
}
if (!flag_1x1s1p1) {
col2im(col_data,
out_ch_size,
out_h,
out_w,
kernel_h,
kernel_w,
pad_h,
pad_w,
stride_h,
stride_w,
dila_h,
dila_w,
output_data);
}
add_bias_with_relu(
output_data, bias_data, out_ch_size, out_w * out_h, fuse_relu);
}
}
void test_conv_transpose(int bs,
int ic,
int ih,
int iw,
bool has_bias,
bool fuse_relu,
int filters,
int groups,
int dilation,
int stride,
int padding,
int kernel) {
// prepare input&output variables
Scope scope;
std::string input_var_name("input");
std::string filter_var_name("filter");
std::string bias_var_name("bias");
std::string output_var_name("output");
std::string output_ref_var_name("output_ref");
auto* input = scope.Var(input_var_name)->GetMutable<Tensor>();
auto* filter = scope.Var(filter_var_name)->GetMutable<Tensor>();
auto* bias = scope.Var(bias_var_name)->GetMutable<Tensor>();
auto* output = scope.Var(output_var_name)->GetMutable<Tensor>();
auto* output_ref = scope.Var(output_ref_var_name)->GetMutable<Tensor>();
// get group size and input&filter shape
std::vector<int64_t> input_shape = {bs, ic, ih, iw};
std::vector<int64_t> filter_shape = {ic, filters, kernel, kernel};
input->Resize(input_shape);
filter->Resize(filter_shape);
// initialize input&output data
FillTensor<float, int>(input);
FillTensor<float, int>(filter);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("conv2d_transpose");
opdesc.SetInput("Input", {input_var_name});
opdesc.SetInput("Filter", {filter_var_name});
opdesc.SetOutput("Output", {output_var_name});
opdesc.SetAttr("dilations", std::vector<int32_t>({dilation, dilation}));
opdesc.SetAttr("strides", std::vector<int32_t>({stride, stride}));
opdesc.SetAttr("paddings", std::vector<int32_t>({padding, padding}));
opdesc.SetAttr("groups", groups);
opdesc.SetAttr("fuse_relu", static_cast<bool>(fuse_relu));
if (has_bias) {
bias->Resize({1, filters * groups, 1, 1});
FillTensor<float, int>(bias);
opdesc.SetInput("Bias", {bias_var_name});
}
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ConvTransposeOpLite>(opdesc, &scope);
LauchOp(op, {input_var_name}, {output_var_name});
output_ref->CopyDataFrom(*output);
// execute reference implementation and save to output tensor('out')
conv_transpose_ref<float, float>(op);
// compare results
auto* output_data = output->mutable_data<float>();
auto* output_ref_data = output_ref->mutable_data<float>();
for (int i = 0; i < output->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(output_data[i], output_ref_data[i], 1e-5);
}
}
TEST(NPUBridges, conv_transpose) {
#if 1
for (auto bs : {1, 2}) {
for (auto ic : {3, 6}) {
for (auto ih : {14, 28}) {
for (auto iw : {14, 28}) {
for (auto has_bias : {false, true}) {
for (auto fuse_relu : {false, true}) {
for (auto filters : {1, 2, 5}) {
for (auto groups : {1 /* , 2, 5*/}) {
for (auto dilation : {1, 2}) {
for (auto stride : {1, 2}) {
for (auto kernel : {1, 3, 5}) {
std::vector<int> paddings = {kernel / 2};
if (kernel / 2 != 0) {
paddings.push_back(0);
}
for (auto padding : paddings) {
VLOG(3) << "bs: " << bs << " ic: " << ic
<< " ih: " << ih << " iw: " << iw
<< " has_bias: " << has_bias
<< " fuse_relu: " << fuse_relu
<< " filters: " << filters
<< " groups: " << groups
<< " dilation: " << dilation
<< " stride: " << stride
<< " padding: " << padding
<< " kernel: " << kernel;
test_conv_transpose(bs,
ic,
ih,
iw,
has_bias,
fuse_relu,
filters,
groups,
dilation,
stride,
padding,
kernel);
}
}
}
}
}
}
}
}
}
}
}
}
#else
test_conv_transpose(1, 6, 8, 8, false, false, 5, 2, 1, 1, 1, 3);
#endif
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
USE_LITE_OP(conv2d_transpose);
USE_NPU_BRIDGE(conv2d_transpose);
......@@ -25,3 +25,8 @@ USE_NPU_BRIDGE(relu);
USE_NPU_BRIDGE(elementwise_add);
USE_NPU_BRIDGE(scale);
USE_NPU_BRIDGE(softmax);
USE_NPU_BRIDGE(concat);
USE_NPU_BRIDGE(split);
USE_NPU_BRIDGE(transpose);
USE_NPU_BRIDGE(transpose2);
USE_NPU_BRIDGE(shuffle_channel);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/utils.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
const node_map_type& inputs_map) {
auto scope = reshape_op->scope();
auto op_info = reshape_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// get input, output and op attributes
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
// create reshape node and set input node from inputs_map
auto reshape_node = std::make_shared<ge::op::Reshape>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
reshape_node->set_input_tensor(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
// read shape from actual shape tensor as input "w" if 'Shape' is found
if (HasInputArg(op_info, scope, "Shape")) {
auto actual_shape_var_name = op_info->Input("Shape").front();
if (!inputs_map.count(actual_shape_var_name)) {
auto actual_shape =
scope->FindVar(actual_shape_var_name)->GetMutable<lite::Tensor>();
auto actual_shape_dims = actual_shape->dims();
auto actual_shape_data = actual_shape->mutable_data<int>();
auto shape =
std::vector<int>(actual_shape_data,
actual_shape_data + actual_shape_dims.production());
auto out_dims = operators::ValidateShape(shape, x_dims);
auto out_shape = out_dims.Vectorize();
if (out_shape.size() > 4) {
LOG(WARNING)
<< "NPU DDK only supports less than 4 dimensions, but Shape has "
<< out_shape.size();
}
auto actual_shape_const_node =
std::make_shared<ge::op::Const>(actual_shape_var_name);
actual_shape_const_node->set_attr_value(CreateTensorAndFillData(
std::vector<int>(out_shape.begin(), out_shape.end())));
reshape_node->set_input_w(*actual_shape_const_node);
OpList::Global().add(actual_shape_const_node);
} else {
reshape_node->set_input_w(*inputs_map.at(actual_shape_var_name));
OpList::Global().add(inputs_map.at(actual_shape_var_name));
}
} else {
auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto out_dims = operators::ValidateShape(shape, x_dims);
auto out_shape = out_dims.Vectorize();
if (out_shape.size() > 4) {
LOG(WARNING)
<< "NPU DDK only supports less than 4 dimensions, but shape has "
<< out_shape.size();
}
reshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
}
OpList::Global().add(reshape_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = reshape_node;
if (op_type == "reshape2") {
// append an extra reshape node to calc XShape
std::vector<int64_t> xshape_dims(x_dims.size() + 1, 1);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
if (xshape_dims.size() > 4) {
LOG(WARNING)
<< "NPU DDK only supports less than 4 dimensions, but XShape has "
<< xshape_dims.size();
}
auto xshape_node =
std::make_shared<ge::op::Reshape>(unique_op_type + "/xshape");
xshape_node->set_input_tensor(*inputs_map.at(x_var_name));
xshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
OpList::Global().add(xshape_node);
outputs_map[op_info->Output("XShape").front()] = xshape_node;
}
return outputs_map;
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(reshape, paddle::lite::npu::bridge::ReshapeConverter);
REGISTER_NPU_BRIDGE(reshape2, paddle::lite::npu::bridge::ReshapeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include <gtest/gtest.h>
#include <random>
#include "lite/core/op_registry.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/test_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
void reshape_ref(const std::shared_ptr<lite::OpLite> op) {
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_dims = x->dims();
auto shape = op_info->GetAttr<std::vector<int>>("shape");
auto inplace = op_info->GetAttr<bool>("inplace");
if (op_info->HasInput("Shape")) {
auto actual_shape_var_names = op_info->Input("Shape");
if (actual_shape_var_names.size() > 0) {
auto actual_shape = scope->FindVar(actual_shape_var_names.front())
->GetMutable<lite::Tensor>();
auto actual_shape_dims = actual_shape->dims();
auto* actual_shape_data = actual_shape->data<int>();
shape =
std::vector<int>(actual_shape_data,
actual_shape_data + actual_shape_dims.production());
}
}
if (inplace) {
out->ShareDataWith(*x);
} else {
out->CopyDataFrom(*x);
}
auto out_dims = operators::ValidateShape(shape, x_dims);
out->Resize(out_dims);
}
void test_reshape(const std::vector<int64_t>& x_shape,
const std::vector<int>& shape,
const std::vector<int>& act_shape,
bool inplace,
bool reshape2) {
// prepare input&output variables
Scope scope;
std::string x_var_name("x");
std::string actual_shape_var_name("actual_shape");
std::string out_var_name("out");
std::string out_ref_var_name("out_ref");
std::string xshape_var_name("xshape");
std::string xshape_ref_var_name("xshape_ref");
auto x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto actual_shape = scope.Var(actual_shape_var_name)->GetMutable<Tensor>();
auto out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
auto xshape = scope.Var(xshape_var_name)->GetMutable<Tensor>();
auto xshape_ref = scope.Var(xshape_ref_var_name)->GetMutable<Tensor>();
x->Resize(x_shape);
// initialize input&output data
FillTensor<float, int>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType(reshape2 ? "reshape2" : "reshape");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("shape", shape);
opdesc.SetAttr("inplace", inplace);
if (!act_shape.empty()) {
int64_t act_shape_size = act_shape.size();
actual_shape->Resize({act_shape_size});
memcpy(actual_shape->mutable_data<int>(),
act_shape.data(),
act_shape_size * sizeof(int));
opdesc.SetInput("Shape", {actual_shape_var_name});
}
if (reshape2) {
opdesc.SetOutput("XShape", {xshape_var_name});
}
// create op and execute reference implementation
auto op = reshape2 ? CreateOp<operators::Reshape2Op>(opdesc, &scope)
: CreateOp<operators::ReshapeOp>(opdesc, &scope);
reshape_ref(op);
out_ref->CopyDataFrom(*out);
if (reshape2) {
xshape_ref->CopyDataFrom(*xshape);
}
// convert op to NPU model, then run it on NPU
LauchOp(op,
{x_var_name},
{out_var_name}); // TODO(hong19860320) support XShape for reshape2
// compare results
auto out_dims = out->dims();
auto out_ref_dims = out_ref->dims();
CHECK_EQ(out_dims.size(), out_ref_dims.size());
for (int i = 0; i < out_dims.size(); i++) {
CHECK_EQ(out_dims[i], out_ref_dims[i]);
}
auto out_data = out->mutable_data<float>();
auto out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-5);
}
// if (reshape2) {
// auto xshape_dims = xshape->dims();
// auto xshape_ref_dims = xshape_ref->dims();
// CHECK_EQ(xshape_dims.size(), xshape_ref_dims.size());
// for (size_t i = 0; i < xshape_dims.size(); i++) {
// CHECK_EQ(xshape_dims[i], xshape_ref_dims[i]);
// }
// }
}
TEST(NPUBridges, reshape) {
#if 1
std::map<std::vector<int64_t>, std::vector<std::vector<int>>> tests = {
{{1, 2, 4, 6},
{{},
{-1},
{48},
{-1, 48},
{1, 48},
{0, 48},
{48, -1},
{48, 1},
{-1, 24},
{2, 24},
{24, 0},
{-1, 0, 3, 2},
{4, 2, 3, 2},
{0, -1, 3, 2},
{1, 8, 3, 2}}}};
for (auto& i : tests) {
for (auto& shape : i.second) {
if (shape.empty()) {
continue;
}
for (auto& act_shape : i.second) {
for (auto& inplace : {true, false}) {
for (auto& reshape2 : {true, false}) {
std::stringstream ss;
ss << "x:{ ";
for (auto s : i.first) {
ss << s << " ";
}
ss << "} shape:{ ";
for (auto s : shape) {
ss << s << " ";
}
ss << "} act_shape:{ ";
for (auto s : act_shape) {
ss << s << " ";
}
VLOG(3) << ss.str() << "} inplace:" << inplace
<< " reshape2:" << reshape2;
test_reshape(i.first, shape, act_shape, inplace, reshape2);
}
}
}
}
}
#else
test_reshape({2, 4, 6}, {-1, 0, 4, 3}, {}, true, true);
test_reshape({1, 232, 14, 14}, {-1, 2, 116, 14, 14}, {}, true, true);
#endif
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
USE_LITE_OP(reshape);
USE_NPU_BRIDGE(reshape);
USE_LITE_OP(reshape2);
USE_NPU_BRIDGE(reshape2);
......@@ -29,10 +29,11 @@ namespace bridge {
node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
const node_map_type& inputs_map) {
VLOG(3) << "invoking ScaleConverter...";
auto scope = scale_op->scope();
auto op_info = scale_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// get input, output and op attributes
auto x_var_name = op_info->Input("X").front();
......@@ -48,7 +49,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
}
// create scale node and set input node from inputs_map
auto scale_node = std::make_shared<ge::op::Scale>(UniqueName(op_type));
auto scale_node = std::make_shared<ge::op::Scale>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
scale_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
......@@ -56,7 +57,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
// add filter node(fill with scale)
auto filter_const_node =
std::make_shared<ge::op::Const>(UniqueName(op_type + "/filter"));
std::make_shared<ge::op::Const>(unique_op_type + "/filter");
filter_const_node->set_attr_value(
CreateTensorAndFillData(scale, scale_bias_shape));
scale_node->set_input_filter(*filter_const_node);
......@@ -65,7 +66,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
// add bias node(fill with bias)
if (fabs(bias) > 1e-6f) {
auto bias_const_node =
std::make_shared<ge::op::Const>(UniqueName(op_type + "/bias"));
std::make_shared<ge::op::Const>(unique_op_type + "/bias");
bias_const_node->set_attr_value(
CreateTensorAndFillData(bias, scale_bias_shape));
scale_node->set_input_bias(*bias_const_node);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/shuffle_channel_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/utils.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
node_map_type ShuffleChannelConverter(
const std::shared_ptr<lite::OpLite> shuffle_channel_op,
const node_map_type& inputs_map) {
LOG(INFO) << "converting shuffle_channel...";
lite::Scope* scope = shuffle_channel_op->scope();
const lite::OpInfo* op_info = shuffle_channel_op->op_info();
std::shared_ptr<ge::op::ShuffleChannel> output_node =
std::make_shared<ge::op::ShuffleChannel>(UniqueName("shuffle_channel"));
auto x_var_name = op_info->Input("X").front();
output_node->set_input_x(*inputs_map.at(x_var_name));
output_node->set_attr_group(op_info->GetAttr<int>("group"));
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(output_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
return outputs_map;
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(shuffle_channel,
paddle::lite::npu::bridge::ShuffleChannelConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/shuffle_channel_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/test_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
void shuffle_channel_ref(
const std::shared_ptr<operators::ShuffleChannelOpLite> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
auto out =
scope->FindVar(op_info->Output("Out").front())->GetMutable<Tensor>();
auto x_data = x->mutable_data<float>();
auto out_data = out->mutable_data<float>();
int group = op_info->GetAttr<int>("group");
auto x_dims = x->dims();
int n_size = x_dims.production() / x_dims[0];
int c_size = n_size / x_dims[1];
for (int n = 0; n < x_dims[0]; n++) {
int g_num = x_dims[1] / group;
auto tmp_out_data = out_data;
for (int g = 0; g < g_num; g++) {
auto tmp_x_data = x_data + g * c_size;
for (int i = 0; i < group; i++) {
std::memcpy(tmp_out_data,
tmp_x_data + i * g_num * c_size,
c_size * sizeof(float));
tmp_out_data += c_size;
}
}
x_data += n_size;
out_data += n_size;
}
}
void test_shuffle_channel(int bs, int ic, int ih, int iw, int group) {
// prepare input&output variables
Scope scope;
std::string x_var_name = "x";
std::string out_var_name = "out";
std::string out_ref_var_name = "out_ref";
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out = scope.Var(out_var_name)->GetMutable<Tensor>();
auto* out_ref = scope.Var(out_ref_var_name)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("shuffle_channel");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name});
opdesc.SetAttr("group", group);
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::ShuffleChannelOpLite>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name});
out_ref->CopyDataFrom(*out);
// execute reference implementation and save to output tensor
shuffle_channel_ref(op);
// compare results
auto* out_data = out->mutable_data<float>();
auto* out_ref_data = out_ref->mutable_data<float>();
for (int i = 0; i < out->dims().production(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-2);
}
}
TEST(NPUBridges, softmax) {
for (auto bs : {1, 4}) {
for (auto ic : {1, 24, 35}) {
for (auto ih : {1, 4}) {
for (auto iw : {1, 4}) {
for (auto group : {1, 3, 7, 24, 35}) {
if (ic % group != 0) continue;
test_shuffle_channel(bs, ic, ih, iw, group);
}
}
}
}
}
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
USE_LITE_OP(shuffle_channel);
USE_NPU_BRIDGE(shuffle_channel);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/split_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/utils.h"
#include "lite/npu/npu_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op,
const node_map_type& inputs_map) {
lite::Scope* scope = split_op->scope();
const lite::OpInfo* op_info = split_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << " ... ";
auto x_var_name = op_info->Input("X").front();
auto axis = op_info->GetAttr<int>("axis");
auto num = op_info->GetAttr<int>("num");
auto sections = op_info->GetAttr<std::vector<int>>("sections");
int64_t sections_num = static_cast<int64_t>(sections.size());
std::shared_ptr<ge::op::Split> output_node =
std::make_shared<ge::op::Split>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
output_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
output_node->set_attr_axis(static_cast<int64_t>(axis));
if (num > 0) {
output_node->set_attr_output_num(static_cast<int64_t>(num));
} else {
output_node->set_attr_output_num(sections_num);
auto size_split = ge::AttrValue::LIST_INT(sections.begin(), sections.end());
output_node->set_attr_size_split(size_split);
}
node_map_type outputs_map;
auto out_var_names = op_info->Output("Out");
output_node->create_dynamic_output_y(out_var_names.size());
int index = 1;
for (auto out_var_name : out_var_names) {
auto const_node = std::make_shared<ge::op::Const>(
unique_op_type + "/const_zero" + std::to_string(index));
const_node->set_attr_value(CreateTensorAndFillData(0));
OpList::Global().add(const_node);
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add" +
std::to_string(index));
add_node->set_input_x1(*output_node, "y" + std::to_string(index));
add_node->set_input_x2(*const_node);
outputs_map[out_var_name] = add_node;
OpList::Global().add(add_node);
index++;
}
OpList::Global().add(output_node);
return outputs_map;
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
REGISTER_NPU_BRIDGE(split, paddle::lite::npu::bridge::SplitConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/split_op.h"
#include <gtest/gtest.h>
#include "lite/core/op_registry.h"
#include "lite/npu/bridge/registry.h"
#include "lite/npu/bridge/test_helper.h"
namespace paddle {
namespace lite {
namespace npu {
namespace bridge {
template <typename dtype>
void split_ref(const std::shared_ptr<operators::SplitOp> op) {
Scope* scope = op->scope();
const OpInfo* op_info = op->op_info();
auto x = scope->FindVar(op_info->Input("X").front())->GetMutable<Tensor>();
int num = op_info->GetAttr<int>("num");
int axis = op_info->GetAttr<int>("axis");
std::vector<int> sections = op_info->GetAttr<std::vector<int>>("sections");
std::vector<lite::Tensor*> output_vec;
auto output = op_info->Output("Out");
for (auto out_var : output) {
output_vec.push_back(scope->Var(out_var)->GetMutable<Tensor>());
}
auto in_dims = x->dims();
auto rank = in_dims.size();
int outs_number = output_vec.size();
std::vector<lite::DDimLite> outs_dims;
outs_dims.reserve(outs_number);
if (axis < 0) {
axis += rank;
}
if (num > 0) {
int out_axis_dim = in_dims[axis] / num;
for (int i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = out_axis_dim;
outs_dims.push_back(dim);
}
} else if (sections.size() > 0) {
for (size_t i = 0; i < outs_number; ++i) {
auto dim = in_dims;
dim[axis] = sections[i];
outs_dims.push_back(dim);
}
}
for (int j = 0; j < outs_dims.size(); ++j) {
output_vec[j]->Resize(outs_dims[j]);
}
const dtype* din = x->mutable_data<const dtype>();
std::vector<int> in_strides(in_dims.size());
in_strides[in_dims.size() - 1] = in_dims[in_dims.size() - 1];
for (int i = in_dims.size() - 2; i >= 0; --i) {
in_strides[i] = in_strides[i + 1] * in_dims[i];
}
int input_offset = 0;
for (auto out : output_vec) {
auto out_dim = out->dims();
std::vector<int> out_strides(out_dim.size());
out_strides[out_dim.size() - 1] = out_dim[out_dim.size() - 1];
for (int i = out_dim.size() - 2; i >= 0; --i) {
out_strides[i] = out_strides[i + 1] * out_dim[i];
}
dtype* out_data = out->mutable_data<dtype>();
int before = out_strides[0] / out_strides[axis];
int in_after = in_strides[axis];
int out_after = out_strides[axis];
for (int i = 0; i < before; ++i) {
std::memcpy(out_data + i * out_after,
din + input_offset + i * in_after,
sizeof(dtype) * out_after);
}
input_offset += out_strides[axis];
}
}
void test_split(int bs,
int ic,
int ih,
int iw,
int axis,
int num,
std::vector<int> sections) {
const auto& bridges = lite::npu::bridge::Factory::Instance();
const auto& supported_lists = bridges.AllFunctions();
CHECK(bridges.HasType("split"));
// prepare input&output variables
std::string x_var_name = "x";
std::string out_var_name_1 = "out_1";
std::string out_var_name_2 = "out_2";
std::string out_ref_var_name_1 = "out_ref_1";
std::string out_ref_var_name_2 = "out_ref_2";
Scope scope;
auto* x = scope.Var(x_var_name)->GetMutable<Tensor>();
auto* out_1 = scope.Var(out_var_name_1)->GetMutable<Tensor>();
auto* out_2 = scope.Var(out_var_name_2)->GetMutable<Tensor>();
auto* out_ref_1 = scope.Var(out_ref_var_name_1)->GetMutable<Tensor>();
auto* out_ref_2 = scope.Var(out_ref_var_name_2)->GetMutable<Tensor>();
x->Resize({bs, ic, ih, iw});
// initialize input&output data
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
opdesc.SetType("split");
opdesc.SetInput("X", {x_var_name});
opdesc.SetOutput("Out", {out_var_name_1, out_var_name_2});
opdesc.SetAttr("axis", axis);
opdesc.SetAttr("sections", sections);
opdesc.SetAttr("num", num);
// create and convert op to NPU model, then run it on NPU
auto op = CreateOp<operators::SplitOp>(opdesc, &scope);
LauchOp(op, {x_var_name}, {out_var_name_1, out_var_name_2});
out_ref_1->CopyDataFrom(*out_1);
out_ref_2->CopyDataFrom(*out_2);
// execute reference implementation and save to output tensor
split_ref<float>(op);
// compare results
auto* out_data_1 = out_1->mutable_data<float>();
auto* out_data_2 = out_2->mutable_data<float>();
auto* out_ref_data_1 = out_ref_1->mutable_data<float>();
auto* out_ref_data_2 = out_ref_2->mutable_data<float>();
for (int i = 0; i < out_1->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data_1[i], out_ref_data_1[i], 5e-4);
}
for (int i = 0; i < out_2->dims().production(); i++) {
VLOG(5) << i;
EXPECT_NEAR(out_data_2[i], out_ref_data_2[i], 5e-4);
}
}
TEST(NPUBridges, split) {
test_split(4, 2, 3, 1, 0, 2, {});
test_split(4, 2, 3, 1, 0, 0, {3, 1});
test_split(4, 6, 3, 1, 1, 2, {});
test_split(4, 6, 3, 1, 1, 0, {2, 4});
test_split(4, 2, 2, 1, 2, 2, {});
test_split(4, 2, 6, 1, 2, 0, {3, 3});
test_split(4, 2, 3, 4, 3, 2, {});
test_split(4, 2, 3, 6, 3, 0, {5, 1});
}
} // namespace bridge
} // namespace npu
} // namespace lite
} // namespace paddle
USE_LITE_OP(split);
USE_NPU_BRIDGE(split);
......@@ -33,7 +33,7 @@ int data_index(std::vector<int> pos, DDimLite dims) {
std::vector<int> pos_trans(std::vector<int> in_pos, std::vector<int> axis) {
std::vector<int> out_pos(in_pos.size());
for (int i = 0; i < axis.size(); i++) {
out_pos[axis[i]] = in_pos[1];
out_pos[axis[i]] = in_pos[i];
}
return out_pos;
}
......@@ -88,11 +88,7 @@ void test_transpose(int bs, int ic, int ih, int iw, std::vector<int> axis) {
x->Resize({bs, ic, ih, iw});
// initialize input&output data
// FillTensor<float>(x);
auto* x_data = x->mutable_data<float>();
for (int i = 0; i < x->numel(); i++) {
x_data[i] = i;
}
FillTensor<float>(x);
// initialize op desc
cpp::OpDesc opdesc;
......@@ -123,7 +119,12 @@ TEST(NPUBridges, transpose) {
for (auto ic : {1, 4, 7}) {
for (auto ih : {1, 4, 7}) {
for (auto iw : {1, 4, 7}) {
for (auto axis : {std::vector<int>{0, 1, 2, 3}}) {
for (auto axis : {std::vector<int>{0, 1, 2, 3},
std::vector<int>{0, 1, 3, 2},
std::vector<int>{0, 3, 1, 2},
std::vector<int>{1, 2, 3, 0},
std::vector<int>{3, 2, 1, 0},
std::vector<int>{2, 3, 1, 0}}) {
test_transpose(bs, ic, ih, iw, axis);
}
}
......@@ -131,8 +132,8 @@ TEST(NPUBridges, transpose) {
}
}
#endif
// test_transpose(2, 3, 4, 5, std::vector<int>{0,1,3,2});
test_transpose(2, 3, 4, 5, std::vector<int>{0, 1, 2, 3});
test_transpose(2, 3, 4, 5, std::vector<int>{0, 1, 3, 2});
// test_transpose(2, 3, 4, 5, std::vector<int>{0, 1, 2, 3});
// test_transpose(2, 2, 2, 2, std::vector<int>{0,1,3,2});
// test_transpose(1, 1, 2, 2, std::vector<int>{0,1,3,2});
// test_transpose(1, 1, 1, 2, std::vector<int>{0,1,2,3});
......
......@@ -41,8 +41,8 @@ ge::TensorPtr CvtFromLiteTensor(Tensor* in_tensor,
DataLayoutType in_ltype = DATALAYOUT(kNCHW));
template <typename T>
ge::TensorPtr CreateTensorAndFillData(T value,
std::vector<int64_t> shape = {1},
ge::TensorPtr CreateTensorAndFillData(std::vector<T> data,
std::vector<int64_t> shape = {},
ge::Format format = ge::FORMAT_NCHW) {
const std::type_info& info = typeid(T);
ge::DataType type = ge::DT_FLOAT;
......@@ -55,17 +55,33 @@ ge::TensorPtr CreateTensorAndFillData(T value,
} else {
LOG(FATAL) << "Unknow value type " << info.name();
}
if (shape.empty()) {
shape = {static_cast<int64_t>(data.size())};
} else {
int size = 1;
for (auto i : shape) {
size *= i;
}
CHECK_EQ(data.size(), size);
}
ge::TensorDesc desc(ge::Shape(shape), format, type);
ge::TensorPtr tensor = std::make_shared<ge::Tensor>();
tensor->SetTensorDesc(desc);
int64_t data_num = 1;
tensor->SetData(reinterpret_cast<uint8_t*>(data.data()),
data.size() * sizeof(T));
return tensor;
}
template <typename T>
ge::TensorPtr CreateTensorAndFillData(T value,
std::vector<int64_t> shape = {1},
ge::Format format = ge::FORMAT_NCHW) {
int64_t size = 1;
for (auto i : shape) {
data_num *= i;
size *= i;
}
std::vector<T> data_value(data_num, value);
tensor->SetData(reinterpret_cast<uint8_t*>(data_value.data()),
data_num * sizeof(T));
return tensor;
std::vector<T> data(size, value);
return CreateTensorAndFillData(data, shape, format);
}
std::shared_ptr<ge::Operator> CvtNode2Tensor(const lite::mir::Node* arg_node);
......
......@@ -5,6 +5,7 @@ lite_cc_library(pool_op SRCS pool_op.cc DEPS ${op_DEPS})
lite_cc_library(fc_op SRCS fc_op.cc DEPS ${op_DEPS})
lite_cc_library(relu_op SRCS relu_op.cc DEPS ${op_DEPS})
lite_cc_library(mul_op SRCS mul_op.cc DEPS ${op_DEPS})
lite_cc_library(matmul_op SRCS matmul_op.cc DEPS ${op_DEPS})
lite_cc_library(scale_op SRCS scale_op.cc DEPS ${op_DEPS})
lite_cc_library(softmax_op SRCS softmax_op.cc DEPS ${op_DEPS})
lite_cc_library(reshape_op SRCS reshape_op.cc DEPS ${op_DEPS} )
......@@ -81,6 +82,8 @@ lite_cc_library(is_empty SRCS is_empty_op.cc DEPS ${op_DEPS})
lite_cc_library(shape_op_lite SRCS shape_op.cc DEPS ${op_DEPS})
lite_cc_library(cast_op_lite SRCS cast_op.cc DEPS ${op_DEPS})
lite_cc_library(slice_op_lite SRCS slice_op.cc DEPS ${op_DEPS})
lite_cc_library(squeeze_op_lite SRCS squeeze_op.cc DEPS ${op_DEPS})
lite_cc_library(expand_op_lite SRCS expand_op.cc DEPS ${op_DEPS})
set(ops
......@@ -89,6 +92,7 @@ set(ops
fc_op
relu_op
mul_op
matmul_op
scale_op
softmax_op
reshape_op
......@@ -164,6 +168,8 @@ set(ops
shape_op_lite
cast_op_lite
slice_op_lite
squeeze_op_lite
expand_op_lite
CACHE INTERNAL "ops lite")
if (NOT LITE_WITH_X86)
......
......@@ -109,6 +109,7 @@ REGISTER_LITE_OP(tanh, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(swish, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(relu6, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(log, paddle::lite::operators::ActivationOp);
REGISTER_LITE_OP(exp, paddle::lite::operators::ActivationOp);
#ifdef LITE_WITH_TRAIN
REGISTER_LITE_OP(square_grad, paddle::lite::operators::ActivationGradOp);
......
......@@ -44,7 +44,7 @@ bool ConvTransposeOpLite::InferShape() const {
std::vector<int64_t> output_shape;
output_shape.push_back(in_dims[0]);
output_shape.push_back(filter_dims[0] * param_.groups);
output_shape.push_back(filter_dims[1] * param_.groups);
for (int i = 0; i < param_.strides.size(); i++) {
int kernel_extent = param_.dilations[i] * (filter_dims[i + 2] - 1) + 1;
int output_len = (in_dims[i + 2] - 1) * param_.strides[i] + kernel_extent -
......@@ -60,10 +60,9 @@ bool ConvTransposeOpLite::InferShape() const {
// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto X = op_desc.Input("x").front();
auto Filter = op_desc.Input("filter").front();
auto Out = op_desc.Output("output").front();
auto X = op_desc.Input("Input").front();
auto Filter = op_desc.Input("Filter").front();
auto Out = op_desc.Output("Output").front();
param_.x = scope->FindVar(X)->GetMutable<lite::Tensor>();
param_.filter = scope->FindVar(Filter)->GetMutable<lite::Tensor>();
param_.output = scope->FindVar(Out)->GetMutable<lite::Tensor>();
......@@ -75,9 +74,9 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
// optional params
std::vector<std::string> input_arg_names = op_desc.InputArgumentNames();
if (std::find(input_arg_names.begin(), input_arg_names.end(), "bias") !=
if (std::find(input_arg_names.begin(), input_arg_names.end(), "Bias") !=
input_arg_names.end()) {
auto bias_arguments = op_desc.Input("bias");
auto bias_arguments = op_desc.Input("Bias");
if (bias_arguments.size() > 0) {
auto bias_var = scope->FindVar(bias_arguments.front());
if (bias_var != nullptr) {
......@@ -87,6 +86,7 @@ bool ConvTransposeOpLite::AttachImpl(const cpp::OpDesc &op_desc,
}
}
param_.fuse_relu = op_desc.GetAttr<bool>("fuse_relu");
return true;
}
} // namespace operators
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/expand_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool ExpandOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
int expand_size = param_.expand_times.size();
int x_dims_size = param_.X->dims().size();
CHECK_EQ(expand_size, x_dims_size)
<< "The number of expand_times size must be qual to the rank of "
"Input(X).";
CHECK_LE(param_.X->dims().size(), 6)
<< "The rank of Input(X) must not be greater than 6.";
return true;
}
bool ExpandOpLite::InferShape() const {
DDim out_dims(param_.X->dims());
for (size_t i = 0; i < param_.expand_times.size(); ++i) {
out_dims[i] *= param_.expand_times[i];
}
param_.Out->Resize(out_dims);
return true;
}
bool ExpandOpLite::AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) {
auto X_name = opdesc.Input("X").front();
auto Out_name = opdesc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X_name);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.expand_times = opdesc.GetAttr<std::vector<int>>("expand_times");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(expand, paddle::lite::operators::ExpandOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
namespace paddle {
namespace lite {
namespace operators {
class ExpandOpLite : public OpLite {
public:
ExpandOpLite() {}
explicit ExpandOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "expand"; }
private:
mutable ExpandParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -67,9 +67,12 @@ bool InterpolateOp::InferShape() const {
bool InterpolateOp::AttachImpl(const cpp::OpDesc& op_desc, lite::Scope* scope) {
auto X = op_desc.Input("X").front();
if (op_desc.Input("OutSize").size() > 0) {
auto OutSize = op_desc.Input("OutSize").front();
param_.OutSize = scope->FindVar(OutSize)->GetMutable<lite::Tensor>();
if (op_desc.HasInput("OutSize")) {
auto out_size_var_names = op_desc.Input("OutSize");
if (out_size_var_names.size() > 0) {
param_.OutSize = scope->FindVar(out_size_var_names.front())
->GetMutable<lite::Tensor>();
}
} else {
param_.OutSize = nullptr;
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/matmul_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool MatMulOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool MatMulOpLite::InferShape() const {
const auto x_dims = param_.X->dims();
const auto y_dims = param_.Y->dims();
bool x_transpose = param_.transpose_X;
bool y_transpose = param_.transpose_Y;
std::vector<int64_t> dim_out_vec;
if (x_dims.size() > 2 && y_dims.size() >= 2) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
dim_out_vec.resize(x_dims.size());
for (size_t i = 0; i < x_dims.size() - 1; ++i) {
dim_out_vec[i] = x_dims[i];
}
dim_out_vec[x_dims.size() - 1] = y_dims[y_dims.size() - 1];
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
// x: [M, K], y: [K, N], out: [M, N]
if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[1], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[0], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[0], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
}
dim_out_vec.resize(x_dims.size());
if (x_transpose) {
dim_out_vec[0] = x_dims[1];
} else {
dim_out_vec[0] = x_dims[0];
}
if (y_transpose) {
dim_out_vec[1] = y_dims[0];
} else {
dim_out_vec[1] = y_dims[1];
}
} else if (x_dims.size() > 2 && y_dims.size() == 1) {
// x: [B, M, K], y: [K], out: [B, M]
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
dim_out_vec.resize(x_dims.size() - 1);
for (size_t i = 0; i < dim_out_vec.size(); ++i) {
dim_out_vec[i] = x_dims[i];
}
} else if (x_dims.size() == 1 && y_dims.size() == 1) { // todo
// x: [K], y: [K], out: [1]
if (x_dims[0] == y_dims[0] && x_transpose == false &&
y_transpose == false) {
dim_out_vec.resize(1);
dim_out_vec[0] = 1;
}
// x: [M], y: [N], x_transpose: true, y_transpose: true, out: [M, N]
if (x_transpose == true && y_transpose == true) {
dim_out_vec.resize(2);
dim_out_vec[0] = x_dims[0];
dim_out_vec[1] = y_dims[0];
}
} else {
LOG(FATAL) << "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ")";
}
DDim dim_out(dim_out_vec);
param_.Out->Resize(dim_out);
return true;
}
bool MatMulOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
CHECK(!op_desc.Input("X").empty());
CHECK(!op_desc.Input("Y").empty());
CHECK(!op_desc.Output("Out").empty());
auto X = op_desc.Input("X").front();
auto Y = op_desc.Input("Y").front();
auto Out = op_desc.Output("Out").front();
param_.X = GetVar<lite::Tensor>(scope, X);
param_.Y = GetVar<lite::Tensor>(scope, Y);
param_.Out = GetMutableVar<lite::Tensor>(scope, Out);
param_.transpose_X = op_desc.GetAttr<bool>("transpose_X");
param_.transpose_Y = op_desc.GetAttr<bool>("transpose_Y");
param_.alpha = op_desc.GetAttr<float>("alpha");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(matmul, paddle::lite::operators::MatMulOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class MatMulOpLite : public OpLite {
public:
MatMulOpLite() {}
explicit MatMulOpLite(const std::string &type) : OpLite(type) {}
bool CheckShape() const override;
bool InferShape() const override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override;
std::string DebugString() const override { return "matmul"; }
private:
mutable MatMulParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -694,6 +694,32 @@ struct SliceParam {
std::vector<int> ends{};
std::vector<int> decrease_axis{};
};
/// ----------------------- shape operators ----------------------
struct SqueezeParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
lite::Tensor* XShape{};
std::vector<int> axes{};
};
/// ----------------------- expand operators ----------------------
struct ExpandParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
std::vector<int> expand_times{};
};
/// ----------------------- matmul operators ----------------------
struct MatMulParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
lite::Tensor* Out{};
bool transpose_X{false};
bool transpose_Y{false};
float alpha{1.0f};
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/squeeze_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
static DDim GetOutputShape(const std::vector<int> &squeeze_dims,
const DDim &in_dims,
bool is_runtime) {
size_t num_squeeze_dims = squeeze_dims.size();
int cnt_squeezed_dims = 0;
bool should_squeeze[9] = {false};
// Determines number of dimensions of output tensor after squeeze.
// Mark and count the dimensions need to be squeezed
if (num_squeeze_dims == 0) {
for (int idx = 0; idx < in_dims.size(); ++idx) {
if (in_dims[idx] == 1) {
should_squeeze[idx] = true;
++cnt_squeezed_dims;
}
}
} else {
for (size_t idx = 0; idx < num_squeeze_dims; ++idx) {
int current = squeeze_dims[idx] < 0 ? squeeze_dims[idx] + in_dims.size()
: squeeze_dims[idx];
// Check current index, the upper limit has been checked.
CHECK_GE(current, 0)
<< "Invalid axis, the negative axis is out of range.";
if (is_runtime) {
CHECK_EQ(in_dims[current], 1) << "Invalid axis index, the axis that "
"will be squeezed should be equal "
"to 1.";
}
if (!(should_squeeze[current])) {
++cnt_squeezed_dims;
}
should_squeeze[current] = true;
}
}
// Make output dimensions
std::vector<int64_t> output_shape(in_dims.size() - cnt_squeezed_dims, 0);
for (int in_idx = 0, out_idx = 0; in_idx < in_dims.size(); ++in_idx) {
if (!should_squeeze[in_idx]) {
output_shape[out_idx++] = in_dims[in_idx];
}
}
return DDim(output_shape);
}
bool SqueezeOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
for (int a : param_.axes) {
CHECK_LT(a, static_cast<int>(param_.X->dims().size()))
<< "The squeeze axis should be less than input tensor's rank.";
}
return true;
}
bool SqueezeOp::InferShape() const {
std::vector<int> squeeze_dims = param_.axes;
DDim in_dims = param_.X->dims();
DDim out_dim = GetOutputShape(squeeze_dims, in_dims, true);
param_.Out->Resize(out_dim);
return true;
}
bool SqueezeOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
auto x_var = scope->FindVar(opdesc.Input("X").front());
auto output_var = scope->FindVar(opdesc.Output("Out").front());
CHECK(x_var);
CHECK(output_var);
param_.X = const_cast<lite::Tensor *>(&(x_var->Get<lite::Tensor>()));
param_.Out = output_var->GetMutable<lite::Tensor>();
if (opdesc.HasAttr("axes")) {
param_.axes = opdesc.GetAttr<std::vector<int>>("axes");
}
CHECK(param_.X) << "Input(X) of SqueezeOp should not be null.";
CHECK(param_.Out) << "Output(Out) of SqueezeOp should not be null.";
return true;
}
bool Squeeze2Op::CheckShape() const {
SqueezeOp::CheckShape();
CHECK_OR_FALSE(param_.XShape);
return true;
}
bool Squeeze2Op::InferShape() const {
SqueezeOp::InferShape();
auto x_dims = param_.X->dims();
std::vector<DDim::value_type> xshape_dims(x_dims.size() + 1, 1);
for (size_t i = 0; i < x_dims.size(); i++) {
xshape_dims[i + 1] = x_dims[i];
}
param_.XShape->Resize(DDim(xshape_dims));
return true;
}
bool Squeeze2Op::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
SqueezeOp::AttachImpl(opdesc, scope);
auto xshape_var = scope->FindVar(opdesc.Output("XShape").front());
CHECK(xshape_var);
param_.XShape = xshape_var->GetMutable<lite::Tensor>();
CHECK(param_.XShape) << "Output(XShape) of ReshapeOp should not be null.";
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(squeeze, paddle::lite::operators::SqueezeOp);
REGISTER_LITE_OP(squeeze2, paddle::lite::operators::Squeeze2Op);
此差异已折叠。
......@@ -34,4 +34,7 @@ if((NOT LITE_WITH_OPENCL AND NOT LITE_WITH_FPGA) AND (LITE_WITH_X86 OR LITE_WITH
lite_cc_test(test_kernel_sequence_pool_compute SRCS sequence_pool_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_sequence_expand_compute SRCS sequence_expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_reduce_max_compute SRCS reduce_max_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_squeeze_compute SRCS squeeze_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_expand_compute SRCS expand_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
lite_cc_test(test_kernel_matmul_compute SRCS matmul_compute_test.cc DEPS arena_framework ${x86_kernels} ${arm_kernels} ${lite_ops} ${host_kernels})
endif()
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册