提交 92f0b3c5 编写于 作者: H hong19860320 提交者: GitHub

[LITE][NPU] Add supporting for Huawei offical DDK (#2262)

* Add supporting for Huawei offical DDK
* Fix the param of graph op in NPU graph computing kernel
上级 f2b9ca34
......@@ -50,9 +50,6 @@ find_library(NPU_DDK_IR_FILE NAMES hiai_ir
find_library(NPU_DDK_IR_BUILD_FILE NAMES hiai_ir_build
PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH})
find_library(NPU_DDK_PROTO_FILE NAMES protobuf-lite
PATHS ${NPU_DDK_ROOT}/${NPU_SUB_LIB_PATH})
if(NOT NPU_DDK_HIAI_FILE)
message(FATAL_ERROR "Can not find NPU_DDK_HIAI_FILE in ${NPU_DDK_ROOT}")
else()
......@@ -77,14 +74,8 @@ else()
set_property(TARGET npu_ddk_ir_build PROPERTY IMPORTED_LOCATION ${NPU_DDK_IR_BUILD_FILE})
endif()
if(NOT NPU_DDK_PROTO_FILE)
message(FATAL_ERROR "Can not find NPU_DDK_PROTO_FILE in ${NPU_DDK_ROOT}")
else()
message(STATUS "Found NPU_DDK Protobuf Library: ${NPU_DDK_PROTO_FILE}")
add_library(npu_ddk_proto SHARED IMPORTED GLOBAL)
set_property(TARGET npu_ddk_proto PROPERTY IMPORTED_LOCATION ${NPU_DDK_PROTO_FILE})
endif()
set(npu_runtime_libs npu_ddk_hiai CACHE INTERNAL "npu ddk runtime libs")
set(npu_builder_libs npu_ddk_ir npu_ddk_ir_build CACHE INTERNAL "npu ddk builder libs")
set(npu_ddk_libs npu_ddk_hiai npu_ddk_ir npu_ddk_ir_build npu_ddk_proto CACHE INTERNAL "npu ddk libs")
......@@ -27,11 +27,21 @@ if ((NOT LITE_ON_TINY_PUBLISH) AND (LITE_WITH_X86 OR ARM_TARGET_OS STREQUAL "and
DEPS ${light_lib_DEPS}
ARM_DEPS ${arm_kernels} NPU_DEPS ${npu_kernels})
target_link_libraries(paddle_light_api_shared ${light_lib_DEPS} ${arm_kernels} ${npu_kernels})
if (LITE_WITH_NPU)
# Strips the symbols of our protobuf functions to fix the conflicts during
# loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(paddle_light_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
else()
if ((ARM_TARGET_OS STREQUAL "android") OR (ARM_TARGET_OS STREQUAL "armlinux"))
add_library(paddle_light_api_shared SHARED "")
target_sources(paddle_light_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc light_api_impl.cc)
add_dependencies(paddle_light_api_shared op_list_h kernel_list_h)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_light_api_shared ${npu_runtime_libs})
endif()
endif()
endif()
......
......@@ -17,10 +17,20 @@ if (NOT LITE_ON_TINY_PUBLISH)
# Unlike static library, module library has to link target to be able to work
# as a single .so lib.
target_link_libraries(paddle_lite_jni ${lib_DEPS} ${arm_kernels} ${npu_kernels})
if (LITE_WITH_NPU)
# Strips the symbols of our protobuf functions to fix the conflicts during
# loading HIAI builder libs (libhiai_ir.so and libhiai_ir_build.so)
set(LINK_FLAGS "-Wl,--version-script ${PADDLE_SOURCE_DIR}/lite/core/lite.map")
set_target_properties(paddle_lite_jni PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
else()
add_library(paddle_lite_jni SHARED "")
target_sources(paddle_lite_jni PUBLIC ${__lite_cc_files} paddle_lite_jni.cc tensor_jni.cc)
add_dependencies(paddle_lite_jni op_list_h kernel_list_h)
if (LITE_WITH_NPU)
# Need to add HIAI runtime libs (libhiai.so) dependency
target_link_libraries(paddle_lite_jni ${npu_runtime_libs})
endif()
endif()
if (APPLE)
......
......@@ -2,4 +2,5 @@ if(NOT LITE_WITH_NPU)
return()
endif()
lite_cc_library(npu_runtime SRCS runtime.cc DEPS npu_ddk_hiai)
lite_cc_library(npu_runtime SRCS runtime.cc DEPS ${npu_runtime_libs})
lite_cc_library(npu_builder SRCS builder.cc DEPS ${npu_builder_libs} npu_runtime tensor op scope)
......@@ -12,21 +12,14 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/utils.h"
#include "lite/backends/npu/builder.h"
#include <mutex> // NOLINT
#include <utility>
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h" // for ge::op::Data
#include "ai_ddk_lib/include/graph/tensor.h" // for ge::TensorUtils
#include "ai_ddk_lib/include/hiai_ir_build.h"
#include "lite/backends/npu/runtime.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
// Build HIAI IR graph to om model, and store om model data into lite tensor
bool BuildModel(std::vector<ge::Operator>& inputs, // NOLINT
......@@ -165,8 +158,6 @@ bool HasInputArg(const OpInfo* op_info,
}
}
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -18,16 +18,147 @@
#include <string>
#include <unordered_map>
#include <vector>
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "ai_ddk_lib/include/hiai_ir_build.h"
#include "lite/core/op_lite.h"
#include "lite/core/target_wrapper.h"
#include "lite/core/tensor.h"
// Extended Ops of HIAI DDK
namespace ge {
REG_OP(MatMul)
.INPUT(x, TensorType({DT_FLOAT}))
.INPUT(w, TensorType({DT_FLOAT}))
.OPTIONAL_INPUT(bias, TensorType({DT_FLOAT})) // bias must be const input
.OUTPUT(y, TensorType({DT_FLOAT}))
.ATTR(has_bias, AttrValue::BOOL{false}) // when has input::bias,set true
.OP_END()
/**
* Computes the gradients of convolution with respect to the input.
* <Input>
* input_sizes : An integer vector representing the shape of input,
* where input is a 4-D [batch, height, width, channels] tensor.
* filter : the filter tensor, with shape [H , W, filter_channel,
* filter_number], filter_channel must be same as x channel.
* x : The input tensor.
* <Output>
* y : The output tensor.
* <Attr>
* format: 0: NCHW. 1: NHWC
* group : 1: default
* num_output : 0: default, num_output must be equal to
* (filter_channel * group)
* pad : Padding for the beginning and ending along each axis
* stride : Stride along each axis.
* dilation : dilation value along each axis of the filter.
* pad_mode : 0:NOTSET, 5:VALID 6:SAME. defaul value is 0:NOTSET
* bias_term : 0: default
* kernel : The shape of the convolution kernel
*/
REG_OP(Deconvolution)
.INPUT(input_sizes, TensorType({DT_UINT8}))
.INPUT(filter, TensorType({DT_FLOAT}))
.INPUT(x, TensorType({DT_FLOAT}))
.OPTIONAL_INPUT(b, TensorType({DT_FLOAT}))
.OUTPUT(y, TensorType({DT_FLOAT}))
.ATTR(mode, AttrValue::INT{1})
.ATTR(format, AttrValue::INT{1})
.ATTR(group, AttrValue::INT{1})
.ATTR(num_output, AttrValue::INT{0})
.ATTR(pad, AttrValue::LIST_INT({0, 0, 0, 0}))
.ATTR(stride, AttrValue::LIST_INT({1, 1}))
.ATTR(dilation, AttrValue::LIST_INT({1, 1}))
.ATTR(pad_mode, AttrValue::INT{0})
.ATTR(bias_term, AttrValue::INT{0})
.ATTR(kernel, AttrValue::LIST_INT({0, 0}))
.OP_END()
/**
* Resize images to size using bilinear interpolation.
* <Input>
* x : The tensor of 4-D
* w : A int32 Tensor of 2 elements: [height, width].
* <Output>
* y : the output tensor
* <Attr>
* align_corners : If true, the centers of the 4 corner pixels of the
* input and output tensors are aligned, preserving the values at the corner
* pixels.
* output_dim_mode : Defaults 2, including 0: zoom_factor , 1:
* shrink_factor, 2: height/width. when output_dim_mode=2, the output-dim is
* controled by the [height, width] of w.
* shrink_factor : shrink factor.
* zoom_factor : zoom factor.
* pad_begin : begin of pad.
* pad_end : end of pad.
*/
REG_OP(ResizeBilinear)
.INPUT(x, TensorType({DT_FLOAT, DT_INT32}))
.INPUT(w, TensorType({DT_FLOAT, DT_INT32}))
.OUTPUT(y, TensorType({DT_FLOAT, DT_INT32}))
.ATTR(align_corners, AttrValue::BOOL{false})
.ATTR(output_dim_mode, AttrValue::INT{2})
.ATTR(shrink_factor, AttrValue::INT{1})
.ATTR(zoom_factor, AttrValue::INT{1})
.ATTR(pad_begin, AttrValue::INT{0})
.ATTR(pad_end, AttrValue::INT{0})
.OP_END()
/**
* Resize images to size using nearest neighbor interpolation.
* <Input>
* image : Resize images to size using nearest neighbor interpolation.
* size : Must be one dimension and two elements
* <Output>
* output : the output tensor
* <Attr>
* align_corners : If true, the centers of the 4 corner pixels of the
* input and output tensors are aligned, preserving the values at the corner
* pixels. Defaults to false
*/
REG_OP(ResizeNearestNeighbor)
.INPUT(image, TensorType({DT_FLOAT, DT_INT32, DT_UINT8, DT_BOOL}))
.INPUT(size, TensorType({DT_INT32}))
.OUTPUT(output, TensorType({DT_FLOAT, DT_INT32, DT_UINT8, DT_BOOL}))
.ATTR(align_corners, AttrValue::BOOL{false})
.OP_END()
/**
* Pads a tensor.
* <Input>
* x : the input tensor
* padding : the input tensor must be 2-D
* constant_values : constant values must be a scalar
* <Output>
* output : the output tensor
* <Attr>
* t_paddings : Default DT_INT32 , t_paddings must be the same with
* datatype of the padding
* mode : 0: CONSTANT, 1: REFLECT, 2: SYMMETRIC
* T : datatype of constant_values DT_INT32:3 DT_FLOAT:0
*/
REG_OP(Pad)
.INPUT(x, TensorType({DT_FLOAT, DT_INT32}))
.INPUT(padding, TensorType({DT_INT32}))
.OPTIONAL_INPUT(constant_values, TensorType({DT_INT32, DT_FLOAT}))
.OUTPUT(output, TensorType({DT_FLOAT, DT_INT32}))
.ATTR(t_paddings, AttrValue::INT{3})
.ATTR(mode, AttrValue::INT{0})
.REQUIRED_ATTR(T, AttrValue::INT)
.OP_END()
} // namespace ge
namespace paddle {
namespace lite {
namespace kernels {
namespace npu {
namespace bridges {
class OpList {
public:
......@@ -106,8 +237,6 @@ bool HasInputArg(const OpInfo* op_info,
const Scope* scope,
const std::string& argname);
} // namespace bridges
} // namespace npu
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -16,7 +16,7 @@ set(subgraph_passes subgraph_pass)
if(LITE_WITH_NPU)
lite_cc_library(npu_pass SRCS generate_npu_program_pass.cc
DEPS mir_pass types context ${mir_fusers} ${npu_bridges} ${npu_ddk_libs} graph_op subgraph_pass)
DEPS mir_pass types context ${mir_fusers} ${npu_bridges} graph_op subgraph_pass)
list(APPEND subgraph_passes npu_pass)
lite_cc_test(test_npu_pass SRCS generate_npu_program_pass_test.cc
DEPS npu_pass mir_passes paddle_api_full paddle_api_light gflags
......
......@@ -22,14 +22,9 @@
#include "lite/core/mir/pass_registry.h"
#include "lite/core/mir/pattern_matcher.h"
#include "ai_ddk_lib/include/HiAiModelManagerService.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h" // for ge::op::Data
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/paddle_use_npu_bridges.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -51,7 +46,7 @@ std::shared_ptr<ge::Operator> GenerateNPUProgramPass::CvtVarNode(
auto wgt = std::make_shared<ge::op::Const>(arg.name);
LOG(INFO) << "in convert const:" << arg.name;
VLOG(4) << dims;
wgt->set_attr_value(lite::kernels::npu::bridges::CvtFromLiteTensor(tensor));
wgt->set_attr_value(lite::npu::CvtFromLiteTensor(tensor));
return wgt;
} else {
CHECK_EQ(dims.size(), 4);
......@@ -132,7 +127,7 @@ std::string GenerateNPUProgramPass::BuildNPUGraph(
// Compiling IR graph to NPU model and store mode data into weight tensor with
// persistable=true, Sothat the model parser can recognize it and save it to
// param files
if (!lite::kernels::npu::bridges::BuildModel(inputs, outputs, weight)) {
if (!lite::npu::BuildModel(inputs, outputs, weight)) {
LOG(WARNING) << "Build NPU failed subgraph " << sub_id;
throw std::runtime_error("Build NPU failed subgraph.");
}
......
......@@ -20,10 +20,10 @@
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "lite/backends/npu/builder.h"
#include "lite/core/mir/pass.h"
#include "lite/core/mir/subgraph/subgraph_program_pass.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......
lite_cc_library(npu_bridge_registry SRCS registry.cc DEPS ${npu_ddk_libs})
lite_cc_library(npu_bridge_utils SRCS utils.cc DEPS ${npu_ddk_libs} npu_runtime tensor op scope)
lite_cc_library(npu_bridge_registry SRCS registry.cc)
set(npu_bridge_deps npu_bridge_registry npu_bridge_utils op)
set(npu_bridge_deps npu_bridge_registry npu_builder op)
lite_cc_library(npu_bridge_fc_op SRCS fc_op.cc DEPS ${npu_bridge_deps})
lite_cc_library(npu_bridge_conv_op SRCS conv_op.cc DEPS ${npu_bridge_deps})
......@@ -23,7 +22,6 @@ lite_cc_library(npu_bridge_pad2d_op SRCS pad2d_op.cc DEPS ${npu_bridge_deps})
set(npu_bridges
npu_bridge_registry
npu_bridge_utils
npu_bridge_fc_op
npu_bridge_conv_op
npu_bridge_mul_op
......@@ -43,7 +41,7 @@ set(npu_bridges
npu_bridge_pad2d_op
CACHE INTERNAL "npu_bridges")
set(npu_bridge_test_deps ${npu_ddk_libs} ${npu_bridges} ${npu_kernels} ${ops})
set(npu_bridge_test_deps ${npu_bridges} ${npu_kernels} ${ops})
lite_cc_test(test_npu_bridge_fc_op SRCS fc_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
lite_cc_test(test_npu_bridge_conv_op SRCS conv_op_test.cc test_helper.cc DEPS ${npu_bridge_test_deps})
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
auto scope = act_op->scope();
auto op_info = act_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// create act node and set input node from inputs_map
......@@ -40,8 +34,8 @@ node_map_type ActConverter(const std::shared_ptr<lite::OpLite> act_op,
auto act_node = std::make_shared<ge::op::Activation>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
act_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(act_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(act_node);
// parse and set activation type
int act_mode = 1;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,7 +27,7 @@ node_map_type BatchNormConverter(
auto scope = batch_norm_op->scope();
auto op_info = batch_norm_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
std::shared_ptr<ge::op::BatchNorm> batch_norm_node =
......@@ -43,27 +37,27 @@ node_map_type BatchNormConverter(
auto scale_var_name = op_info->Input("Scale").front();
lite::Tensor* scale = scope->FindVar(scale_var_name)->GetMutable<Tensor>();
auto npu_scale = std::make_shared<ge::op::Const>(scale_var_name);
npu_scale->set_attr_value(CvtFromLiteTensor(scale));
OpList::Global().add(npu_scale);
npu_scale->set_attr_value(lite::npu::CvtFromLiteTensor(scale));
lite::npu::OpList::Global().add(npu_scale);
auto bias_var_name = op_info->Input("Bias").front();
lite::Tensor* bias = scope->FindVar(bias_var_name)->GetMutable<Tensor>();
auto npu_bias = std::make_shared<ge::op::Const>(bias_var_name);
npu_bias->set_attr_value(CvtFromLiteTensor(bias));
OpList::Global().add(npu_bias);
npu_bias->set_attr_value(lite::npu::CvtFromLiteTensor(bias));
lite::npu::OpList::Global().add(npu_bias);
auto mean_var_name = op_info->Input("Mean").front();
lite::Tensor* mean = scope->FindVar(mean_var_name)->GetMutable<Tensor>();
auto npu_mean = std::make_shared<ge::op::Const>(mean_var_name);
npu_mean->set_attr_value(CvtFromLiteTensor(mean));
OpList::Global().add(npu_mean);
npu_mean->set_attr_value(lite::npu::CvtFromLiteTensor(mean));
lite::npu::OpList::Global().add(npu_mean);
auto variance_var_name = op_info->Input("Variance").front();
lite::Tensor* variance =
scope->FindVar(variance_var_name)->GetMutable<Tensor>();
auto npu_variance = std::make_shared<ge::op::Const>(variance_var_name);
npu_variance->set_attr_value(CvtFromLiteTensor(variance));
OpList::Global().add(npu_variance);
npu_variance->set_attr_value(lite::npu::CvtFromLiteTensor(variance));
lite::npu::OpList::Global().add(npu_variance);
float npu_momentum = op_info->GetAttr<float>("momentum");
float npu_epsilon = op_info->GetAttr<float>("epsilon");
......@@ -80,8 +74,8 @@ node_map_type BatchNormConverter(
batch_norm_node->set_attr_mode(npu_mode);
batch_norm_node->set_attr_use_global_stats(npu_use_global_stats);
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(batch_norm_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(batch_norm_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Y").front()] = batch_norm_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
lite::Scope* scope = concat_op->scope();
const lite::OpInfo* op_info = concat_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "converting " << op_type << " ... ";
auto x_var_names = op_info->Input("X");
......@@ -48,17 +42,17 @@ node_map_type ConcatConverter(const std::shared_ptr<lite::OpLite> concat_op,
for (auto x_var_name : x_var_names) {
if (inputs_map.find(x_var_name) != inputs_map.end()) {
output_node->set_dynamic_input_x(index + 1, *inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
} else {
auto consty = std::make_shared<ge::op::Const>(x_var_name);
auto* x = scope->FindVar(x_var_name)->GetMutable<Tensor>();
consty->set_attr_value(CvtFromLiteTensor(x));
consty->set_attr_value(lite::npu::CvtFromLiteTensor(x));
output_node->set_dynamic_input_x(index + 1, *consty);
OpList::Global().add(consty);
lite::npu::OpList::Global().add(consty);
}
index++;
}
OpList::Global().add(output_node);
lite::npu::OpList::Global().add(output_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto scope = conv_op->scope();
auto op_info = conv_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... ";
// get input, filter and op attributes
......@@ -78,13 +72,13 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
// check input
CHECK(inputs_map.count(input_var_name));
OpList::Global().add(inputs_map.at(input_var_name));
lite::npu::OpList::Global().add(inputs_map.at(input_var_name));
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(CvtFromLiteTensor(filter));
OpList::Global().add(filter_const_node);
filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter));
lite::npu::OpList::Global().add(filter_const_node);
// create bias node if has bias
// supports the bias nodes with the following dimensions
......@@ -93,7 +87,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
// 2: {n, oc, oh, ow}
std::shared_ptr<ge::Operator> bias_node = nullptr;
bool is_channel_bias = false;
if (HasInputArg(op_info, scope, "Bias")) {
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
auto bias_var_name = op_info->Input("Bias").front();
auto* bias = scope->FindVar(bias_var_name)->GetMutable<lite::Tensor>();
auto bias_dims = bias->dims();
......@@ -121,10 +115,11 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
} else {
// bias node with const data
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(CvtFromLiteTensor(bias, bias_shape));
bias_const_node->set_attr_value(
lite::npu::CvtFromLiteTensor(bias, bias_shape));
bias_node = bias_const_node;
}
OpList::Global().add(bias_node);
lite::npu::OpList::Global().add(bias_node);
}
// create conv node and set input, filter, bias nodes and attributes
......@@ -147,7 +142,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
depthwise_conv_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
OpList::Global().add(depthwise_conv_node);
lite::npu::OpList::Global().add(depthwise_conv_node);
conv_node = depthwise_conv_node;
// ConvolutionDepthwise Op doesn't support bias, so append Add node to
// support bias
......@@ -155,7 +150,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
add_node->set_input_x1(*depthwise_conv_node);
add_node->set_input_x2(*bias_node);
OpList::Global().add(add_node);
lite::npu::OpList::Global().add(add_node);
conv_node = add_node;
}
} else {
......@@ -174,7 +169,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
common_conv_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_dims[2], filter_dims[3]}));
OpList::Global().add(common_conv_node);
lite::npu::OpList::Global().add(common_conv_node);
conv_node = common_conv_node;
// Convolution Op only support bias with dimension {1, oc, 1, 1},
// so append Add node if dimension is {1, oc, oh, ow} or (n, oc, oh, ow)
......@@ -185,7 +180,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
add_node->set_input_x1(*common_conv_node);
add_node->set_input_x2(*bias_node);
OpList::Global().add(add_node);
lite::npu::OpList::Global().add(add_node);
conv_node = add_node;
}
}
......@@ -199,7 +194,7 @@ node_map_type ConvConverter(const std::shared_ptr<lite::OpLite> conv_op,
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*conv_node);
relu_node->set_attr_mode(1);
OpList::Global().add(relu_node);
lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
outputs_map[op_info->Output("Output").front()] = conv_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,7 +27,7 @@ node_map_type ConvTransposeConverter(
auto scope = conv_transpose_op->scope();
auto op_info = conv_transpose_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << "... ";
// get input, output and op attributes
......@@ -70,21 +64,22 @@ node_map_type ConvTransposeConverter(
}
auto input_sizes_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/input_size");
input_sizes_const_node->set_attr_value(CreateTensorAndFillData(output_shape));
input_sizes_const_node->set_attr_value(
lite::npu::CreateTensorAndFillData(output_shape));
conv_transpose_node->set_input_input_sizes(*input_sizes_const_node);
OpList::Global().add(input_sizes_const_node);
lite::npu::OpList::Global().add(input_sizes_const_node);
// create filter node
CHECK(!inputs_map.count(filter_var_name));
auto filter_const_node = std::make_shared<ge::op::Const>(filter_var_name);
filter_const_node->set_attr_value(CvtFromLiteTensor(filter));
filter_const_node->set_attr_value(lite::npu::CvtFromLiteTensor(filter));
conv_transpose_node->set_input_filter(*filter_const_node);
OpList::Global().add(filter_const_node);
lite::npu::OpList::Global().add(filter_const_node);
// set input node
CHECK(inputs_map.count(input_var_name));
conv_transpose_node->set_input_x(*inputs_map.at(input_var_name));
OpList::Global().add(inputs_map.at(input_var_name));
lite::npu::OpList::Global().add(inputs_map.at(input_var_name));
// set attributes
conv_transpose_node->set_attr_mode(1);
......@@ -99,11 +94,11 @@ node_map_type ConvTransposeConverter(
ge::AttrValue::LIST_INT({strides[0], strides[1]}));
conv_transpose_node->set_attr_kernel(
ge::AttrValue::LIST_INT({filter_shape[2], filter_shape[3]}));
OpList::Global().add(conv_transpose_node);
lite::npu::OpList::Global().add(conv_transpose_node);
// append add node to add bias if has bias
std::shared_ptr<ge::Operator> output_node = conv_transpose_node;
if (HasInputArg(op_info, scope, "Bias")) {
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
// create bias node
auto bias_var_name = op_info->Input("Bias").front();
CHECK(!inputs_map.count(bias_var_name));
......@@ -112,13 +107,13 @@ node_map_type ConvTransposeConverter(
CHECK_EQ(channel_size, filter_shape[1] * groups);
auto bias_const_node = std::make_shared<ge::op::Const>(bias_var_name);
bias_const_node->set_attr_value(
CvtFromLiteTensor(bias, {1, channel_size, 1, 1}));
OpList::Global().add(bias_const_node);
lite::npu::CvtFromLiteTensor(bias, {1, channel_size, 1, 1}));
lite::npu::OpList::Global().add(bias_const_node);
// append add node to add bias node
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add");
add_node->set_input_x1(*conv_transpose_node);
add_node->set_input_x2(*bias_const_node);
OpList::Global().add(add_node);
lite::npu::OpList::Global().add(add_node);
output_node = add_node;
}
......@@ -129,7 +124,7 @@ node_map_type ConvTransposeConverter(
std::make_shared<ge::op::Activation>(unique_op_type + "/relu");
relu_node->set_input_x(*output_node);
relu_node->set_attr_mode(1);
OpList::Global().add(relu_node);
lite::npu::OpList::Global().add(relu_node);
outputs_map[op_info->Output("Output").front()] = relu_node;
} else {
outputs_map[op_info->Output("Output").front()] = output_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,7 +27,7 @@ node_map_type ElementwiseConverter(
auto scope = elementwise_op->scope();
auto op_info = elementwise_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "converting elementwise...";
std::shared_ptr<ge::op::Eltwise> elementwise_node =
......@@ -47,20 +41,20 @@ node_map_type ElementwiseConverter(
CHECK(inputs_map.find(x_var_name) != inputs_map.end());
elementwise_node->set_input_x1(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
if (inputs_map.find(y_var_name) != inputs_map.end()) {
elementwise_node->set_input_x2(*inputs_map.at(y_var_name));
OpList::Global().add(inputs_map.at(y_var_name));
lite::npu::OpList::Global().add(inputs_map.at(y_var_name));
} else {
auto consty = std::make_shared<ge::op::Const>(y_var_name);
auto* y = scope->FindVar(y_var_name)->GetMutable<Tensor>();
consty->set_attr_value(CvtFromLiteTensor(y));
consty->set_attr_value(lite::npu::CvtFromLiteTensor(y));
elementwise_node->set_input_x2(*consty);
OpList::Global().add(consty);
lite::npu::OpList::Global().add(consty);
}
OpList::Global().add(elementwise_node);
lite::npu::OpList::Global().add(elementwise_node);
// paddlelite has sum only
elementwise_node->set_attr_mode(1);
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,8 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
LOG(INFO) << "Converting fc...";
lite::Scope* scope = fc_op->scope();
const lite::OpInfo* op_info = fc_op->op_info();
auto output_node = std::make_shared<ge::op::MatMul>(UniqueName("fc"));
auto output_node =
std::make_shared<ge::op::MatMul>(lite::npu::UniqueName("fc"));
auto x_var_name = op_info->Input("Input").front();
auto w_var_name = op_info->Input("W").front();
......@@ -64,8 +59,8 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
reshapex->set_input_tensor(*xsrc);
reshapex->set_attr_shape({m, k});
reshapex->set_attr_axis(0);
OpList::Global().add(xsrc);
OpList::Global().add(reshapex);
lite::npu::OpList::Global().add(xsrc);
lite::npu::OpList::Global().add(reshapex);
output_node->set_input_x(*reshapex);
auto wconst = std::make_shared<ge::op::Const>(w_var_name);
......@@ -77,10 +72,10 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
auto* pdata = reinterpret_cast<uint8_t*>(wtensor->mutable_data<float>());
ptensor->SetData(pdata, size * sizeof(float));
wconst->set_attr_value(ptensor);
OpList::Global().add(wconst);
lite::npu::OpList::Global().add(wconst);
output_node->set_input_w(*wconst);
if (HasInputArg(op_info, scope, "Bias")) {
if (lite::npu::HasInputArg(op_info, scope, "Bias")) {
auto b_var_name = op_info->Input("Bias").front();
auto* btensor = scope->FindVar(b_var_name)->GetMutable<lite::Tensor>();
......@@ -99,12 +94,12 @@ node_map_type FCConverter(const std::shared_ptr<lite::OpLite> fc_op,
auto* pdata = reinterpret_cast<uint8_t*>(btensor->mutable_data<float>());
ptensor->SetData(pdata, size * sizeof(float));
bconst->set_attr_value(ptensor);
OpList::Global().add(bconst);
lite::npu::OpList::Global().add(bconst);
output_node->set_input_bias(*bconst);
output_node->set_attr_has_bias(ge::AttrValue::BOOL{true});
}
OpList::Global().add(output_node);
lite::npu::OpList::Global().add(output_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,13 +27,13 @@ node_map_type InterpolateConverter(
auto scope = interpolate_op->scope();
auto op_info = interpolate_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// get input, output and attributes from lite op
auto x_var_name = op_info->Input("X").front();
CHECK(inputs_map.count(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
......@@ -64,7 +58,7 @@ node_map_type InterpolateConverter(
// update out_h and out_w if has OutSize
bool inputs_map_has_w = false;
if (HasInputArg(op_info, scope, "OutSize")) {
if (lite::npu::HasInputArg(op_info, scope, "OutSize")) {
auto out_size_var_name = op_info->Input("OutSize").front();
if (inputs_map.count(out_size_var_name)) {
inputs_map_has_w = true;
......@@ -83,12 +77,12 @@ node_map_type InterpolateConverter(
auto interp_method = op_info->GetAttr<std::string>("interp_method");
if (interp_method == "bilinear") {
auto interp_node = std::make_shared<ge::op::ResizeBilinear>(unique_op_type);
OpList::Global().add(interp_node);
lite::npu::OpList::Global().add(interp_node);
interp_node->set_input_x(*inputs_map.at(x_var_name));
if (inputs_map_has_w) {
auto out_size_var_name = op_info->Input("OutSize").front();
interp_node->set_input_w(*inputs_map.at(out_size_var_name));
OpList::Global().add(inputs_map.at(out_size_var_name));
lite::npu::OpList::Global().add(inputs_map.at(out_size_var_name));
} else {
const float largest_multiple = 7.0f;
float multiple = static_cast<float>(x_h * x_w) / (out_h * out_w);
......@@ -99,9 +93,9 @@ node_map_type InterpolateConverter(
auto w_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/w");
w_const_node->set_attr_value(
CreateTensorAndFillData(std::vector<int>({out_h, out_w})));
lite::npu::CreateTensorAndFillData(std::vector<int>({out_h, out_w})));
interp_node->set_input_w(*w_const_node);
OpList::Global().add(w_const_node);
lite::npu::OpList::Global().add(w_const_node);
}
interp_node->set_attr_output_dim_mode(
2); // 0: zoom_factor, 1: shrink_factor, 2: height/width
......@@ -110,19 +104,19 @@ node_map_type InterpolateConverter(
} else if (interp_method == "nearest") {
auto interp_node =
std::make_shared<ge::op::ResizeNearestNeighbor>(unique_op_type);
OpList::Global().add(interp_node);
lite::npu::OpList::Global().add(interp_node);
interp_node->set_input_image(*inputs_map.at(x_var_name));
if (inputs_map_has_w) {
auto out_size_var_name = op_info->Input("OutSize").front();
interp_node->set_input_size(*inputs_map.at(out_size_var_name));
OpList::Global().add(inputs_map.at(out_size_var_name));
lite::npu::OpList::Global().add(inputs_map.at(out_size_var_name));
} else {
auto w_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/w");
w_const_node->set_attr_value(
CreateTensorAndFillData(std::vector<int>({out_h, out_w})));
lite::npu::CreateTensorAndFillData(std::vector<int>({out_h, out_w})));
interp_node->set_input_size(*w_const_node);
OpList::Global().add(w_const_node);
lite::npu::OpList::Global().add(w_const_node);
}
interp_node->set_attr_align_corners(align_corners);
outputs_map[op_info->Output("Out").front()] = interp_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -34,7 +28,8 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
LOG(INFO) << "converting mul...";
lite::Scope* scope = mul_op->scope();
const lite::OpInfo* op_info = mul_op->op_info();
auto output_node = std::make_shared<ge::op::MatMul>(UniqueName("mul"));
auto output_node =
std::make_shared<ge::op::MatMul>(lite::npu::UniqueName("mul"));
auto x_var_name = op_info->Input("X").front();
auto y_var_name = op_info->Input("Y").front();
......@@ -66,8 +61,8 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
reshapex->set_input_tensor(*xsrc);
reshapex->set_attr_shape({m, k});
reshapex->set_attr_axis(0);
OpList::Global().add(xsrc);
OpList::Global().add(reshapex);
lite::npu::OpList::Global().add(xsrc);
lite::npu::OpList::Global().add(reshapex);
output_node->set_input_x(*reshapex);
} else {
auto constx = std::make_shared<ge::op::Const>(x_var_name);
......@@ -79,7 +74,7 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
auto* pdata = reinterpret_cast<uint8_t*>(xtensor->mutable_data<float>());
ptensor->SetData(pdata, size * sizeof(float));
constx->set_attr_value(ptensor);
OpList::Global().add(constx);
lite::npu::OpList::Global().add(constx);
output_node->set_input_x(*constx);
}
......@@ -89,8 +84,8 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
reshapey->set_input_tensor(*ysrc);
reshapey->set_attr_shape({k, n});
reshapey->set_attr_axis(0);
OpList::Global().add(ysrc);
OpList::Global().add(reshapey);
lite::npu::OpList::Global().add(ysrc);
lite::npu::OpList::Global().add(reshapey);
output_node->set_input_w(*reshapey);
} else {
auto consty = std::make_shared<ge::op::Const>(y_var_name);
......@@ -102,11 +97,11 @@ node_map_type MulConverter(const std::shared_ptr<lite::OpLite> mul_op,
auto* pdata = reinterpret_cast<uint8_t*>(ytensor->mutable_data<float>());
ptensor->SetData(pdata, size * sizeof(float));
consty->set_attr_value(ptensor);
OpList::Global().add(consty);
lite::npu::OpList::Global().add(consty);
output_node->set_input_w(*consty);
}
OpList::Global().add(output_node);
lite::npu::OpList::Global().add(output_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = output_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,15 +26,15 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
auto scope = pad2d_op->scope();
auto op_info = pad2d_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
std::shared_ptr<ge::op::Pad> pad2d_node =
std::make_shared<ge::op::Pad>(unique_op_type);
auto x_var_name = op_info->Input("X").front();
pad2d_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(pad2d_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(pad2d_node);
auto mode = op_info->GetAttr<std::string>("mode");
if (mode == "constant") {
......@@ -59,17 +53,19 @@ node_map_type Pad2dConverter(const std::shared_ptr<lite::OpLite> pad2d_op,
padding.insert(padding.begin(), xds * 2 - 4, 0);
auto npu_padding =
std::make_shared<ge::op::Const>(unique_op_type + "/padding");
npu_padding->set_attr_value(CreateTensorAndFillData<int>(padding, {xds, 2}));
npu_padding->set_attr_value(
lite::npu::CreateTensorAndFillData<int>(padding, {xds, 2}));
pad2d_node->set_input_padding(*npu_padding);
OpList::Global().add(npu_padding);
lite::npu::OpList::Global().add(npu_padding);
if (mode == "constant") {
auto pad_value = op_info->GetAttr<float>("pad_value");
auto npu_pad_value =
std::make_shared<ge::op::Const>(unique_op_type + "/pad_value");
npu_pad_value->set_attr_value(CreateTensorAndFillData<float>({pad_value}));
npu_pad_value->set_attr_value(
lite::npu::CreateTensorAndFillData<float>({pad_value}));
pad2d_node->set_input_constant_values(*npu_pad_value);
OpList::Global().add(npu_pad_value);
lite::npu::OpList::Global().add(npu_pad_value);
pad2d_node->set_attr_T(0); // type of pad_value: 0:float 3:int32
}
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
auto scope = pool_op->scope();
auto op_info = pool_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
std::shared_ptr<ge::op::Pooling> pool_node =
......@@ -73,8 +67,8 @@ node_map_type PoolConverter(const std::shared_ptr<lite::OpLite> pool_op,
pool_node->set_attr_ceil_mode(npu_ceil_mode);
// output_node->set_attr_data_mode(npu_data_mode);
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(pool_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(pool_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = pool_node;
......
......@@ -13,14 +13,8 @@
// limitations under the License.
#include "lite/operators/reshape_op.h"
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,7 +27,7 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto scope = reshape_op->scope();
auto op_info = reshape_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// get input, output and op attributes
......@@ -45,10 +39,10 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
auto reshape_node = std::make_shared<ge::op::Reshape>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
reshape_node->set_input_tensor(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
// read shape from actual shape tensor as input "w" if 'Shape' is found
if (HasInputArg(op_info, scope, "Shape")) {
if (lite::npu::HasInputArg(op_info, scope, "Shape")) {
auto actual_shape_var_name = op_info->Input("Shape").front();
if (!inputs_map.count(actual_shape_var_name)) {
auto actual_shape =
......@@ -67,13 +61,14 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
}
auto actual_shape_const_node =
std::make_shared<ge::op::Const>(actual_shape_var_name);
actual_shape_const_node->set_attr_value(CreateTensorAndFillData(
std::vector<int>(out_shape.begin(), out_shape.end())));
actual_shape_const_node->set_attr_value(
lite::npu::CreateTensorAndFillData(
std::vector<int>(out_shape.begin(), out_shape.end())));
reshape_node->set_input_w(*actual_shape_const_node);
OpList::Global().add(actual_shape_const_node);
lite::npu::OpList::Global().add(actual_shape_const_node);
} else {
reshape_node->set_input_w(*inputs_map.at(actual_shape_var_name));
OpList::Global().add(inputs_map.at(actual_shape_var_name));
lite::npu::OpList::Global().add(inputs_map.at(actual_shape_var_name));
}
} else {
auto shape = op_info->GetAttr<std::vector<int>>("shape");
......@@ -87,7 +82,7 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
reshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
}
OpList::Global().add(reshape_node);
lite::npu::OpList::Global().add(reshape_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = reshape_node;
......@@ -107,7 +102,7 @@ node_map_type ReshapeConverter(const std::shared_ptr<lite::OpLite> reshape_op,
xshape_node->set_input_tensor(*inputs_map.at(x_var_name));
xshape_node->set_attr_shape(
ge::AttrValue::LIST_INT(xshape_dims.begin(), xshape_dims.end()));
OpList::Global().add(xshape_node);
lite::npu::OpList::Global().add(xshape_node);
outputs_map[op_info->Output("XShape").front()] = xshape_node;
}
return outputs_map;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
auto scope = scale_op->scope();
auto op_info = scale_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
// get input, output and op attributes
......@@ -52,26 +46,26 @@ node_map_type ScaleConverter(const std::shared_ptr<lite::OpLite> scale_op,
auto scale_node = std::make_shared<ge::op::Scale>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
scale_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(scale_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(scale_node);
// add filter node(fill with scale)
auto filter_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/filter");
filter_const_node->set_attr_value(
CreateTensorAndFillData(scale, scale_bias_shape));
lite::npu::CreateTensorAndFillData(scale, scale_bias_shape));
scale_node->set_input_filter(*filter_const_node);
OpList::Global().add(filter_const_node);
lite::npu::OpList::Global().add(filter_const_node);
// add bias node(fill with bias)
if (fabs(bias) > 1e-6f) {
auto bias_const_node =
std::make_shared<ge::op::Const>(unique_op_type + "/bias");
bias_const_node->set_attr_value(
CreateTensorAndFillData(bias, scale_bias_shape));
lite::npu::CreateTensorAndFillData(bias, scale_bias_shape));
scale_node->set_input_bias(*bias_const_node);
scale_node->set_attr_has_bias_value(true);
OpList::Global().add(bias_const_node);
lite::npu::OpList::Global().add(bias_const_node);
}
scale_node->set_attr_axis(1);
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,7 +27,7 @@ node_map_type ShuffleChannelConverter(
auto scope = shuffle_channel_op->scope();
auto op_info = shuffle_channel_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
std::shared_ptr<ge::op::ShuffleChannel> shuffle_channel_node =
......@@ -43,8 +37,8 @@ node_map_type ShuffleChannelConverter(
shuffle_channel_node->set_input_x(*inputs_map.at(x_var_name));
shuffle_channel_node->set_attr_group(op_info->GetAttr<int>("group"));
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(shuffle_channel_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(shuffle_channel_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = shuffle_channel_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
auto scope = softmax_op->scope();
auto op_info = softmax_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
std::shared_ptr<ge::op::Softmax> softmax_node =
......@@ -51,8 +45,8 @@ node_map_type SoftmaxConverter(const std::shared_ptr<lite::OpLite> softmax_op,
softmax_node->set_input_x(*inputs_map.at(x_var_name));
softmax_node->set_attr_axis(axis);
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(softmax_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(softmax_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = softmax_node;
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -32,7 +26,7 @@ node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op,
lite::Scope* scope = split_op->scope();
const lite::OpInfo* op_info = split_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " << op_type << " ... ";
auto x_var_name = op_info->Input("X").front();
......@@ -45,7 +39,7 @@ node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op,
std::make_shared<ge::op::Split>(unique_op_type);
CHECK(inputs_map.count(x_var_name));
output_node->set_input_x(*inputs_map.at(x_var_name));
OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
output_node->set_attr_axis(static_cast<int64_t>(axis));
if (num > 0) {
......@@ -63,18 +57,18 @@ node_map_type SplitConverter(const std::shared_ptr<lite::OpLite> split_op,
for (auto out_var_name : out_var_names) {
auto const_node = std::make_shared<ge::op::Const>(
unique_op_type + "/const_zero" + std::to_string(index));
const_node->set_attr_value(CreateTensorAndFillData(0));
OpList::Global().add(const_node);
const_node->set_attr_value(lite::npu::CreateTensorAndFillData(0));
lite::npu::OpList::Global().add(const_node);
auto add_node = std::make_shared<ge::op::Add>(unique_op_type + "/add" +
std::to_string(index));
add_node->set_input_x1(*output_node, "y" + std::to_string(index));
add_node->set_input_x2(*const_node);
outputs_map[out_var_name] = add_node;
OpList::Global().add(add_node);
lite::npu::OpList::Global().add(add_node);
index++;
}
OpList::Global().add(output_node);
lite::npu::OpList::Global().add(output_node);
return outputs_map;
}
......
......@@ -14,10 +14,9 @@
#include "lite/kernels/npu/bridges/test_helper.h"
#include <utility>
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "lite/backends/npu/builder.h"
#include "lite/core/op_registry.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
#include "lite/operators/graph_op.h"
namespace paddle {
......@@ -63,7 +62,7 @@ void LauchOp(const std::shared_ptr<lite::OpLite> op,
auto weight = scope->Var(weight_var_name)->GetMutable<Tensor>();
weight->set_persistable(true);
weight->set_precision(PRECISION(kInt8));
CHECK(BuildModel(graph_inputs, graph_outputs, weight));
CHECK(lite::npu::BuildModel(graph_inputs, graph_outputs, weight));
CHECK_GT(weight->numel(), 0);
CHECK_NE(weight->data<uint8_t>(), 0);
......@@ -94,7 +93,7 @@ void LauchOp(const std::shared_ptr<lite::OpLite> op,
graph_kernel->Launch();
// release all of resources of generated model
OpList::Global().clear();
lite::npu::OpList::Global().clear();
}
} // namespace bridges
......
......@@ -12,14 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ai_ddk_lib/include/graph/buffer.h"
#include "ai_ddk_lib/include/graph/graph.h"
#include "ai_ddk_lib/include/graph/model.h"
#include "ai_ddk_lib/include/graph/op/all_ops.h"
#include "ai_ddk_lib/include/graph/operator.h"
#include "ai_ddk_lib/include/graph/operator_reg.h"
#include "lite/backends/npu/builder.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utils.h"
namespace paddle {
namespace lite {
......@@ -33,7 +27,7 @@ node_map_type TransposeConverter(
auto scope = transpose_op->scope();
auto op_info = transpose_op->op_info();
auto op_type = op_info->Type();
auto unique_op_type = UniqueName(op_type);
auto unique_op_type = lite::npu::UniqueName(op_type);
LOG(INFO) << "Converting " + op_type + "...";
std::shared_ptr<ge::op::Permute> transpose_node =
......@@ -50,8 +44,8 @@ node_map_type TransposeConverter(
w_data[i] = 1.f;
}
auto npu_w = std::make_shared<ge::op::Const>(w_var_name);
npu_w->set_attr_value(CvtFromLiteTensor(w));
OpList::Global().add(npu_w);
npu_w->set_attr_value(lite::npu::CvtFromLiteTensor(w));
lite::npu::OpList::Global().add(npu_w);
auto axis = op_info->GetAttr<std::vector<int>>("axis");
auto npu_axis = ge::AttrValue::LIST_INT(axis.begin(), axis.end());
......@@ -61,8 +55,8 @@ node_map_type TransposeConverter(
transpose_node->set_input_w(*npu_w);
transpose_node->set_attr_order(npu_axis);
OpList::Global().add(inputs_map.at(x_var_name));
OpList::Global().add(transpose_node);
lite::npu::OpList::Global().add(inputs_map.at(x_var_name));
lite::npu::OpList::Global().add(transpose_node);
node_map_type outputs_map;
outputs_map[op_info->Output("Out").front()] = transpose_node;
......
......@@ -49,8 +49,8 @@ void GraphCompute::PrepareForRun() {
VLOG(3) << "npu_idims[" << i << "]: " << npu_idims_[i].GetNumber() << ","
<< npu_idims_[i].GetChannel() << "," << npu_idims_[i].GetHeight()
<< "," << npu_idims_[i].GetWidth();
VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i]->dims();
CHECK_EQ(param.inputs[i]->dims().production(),
VLOG(3) << "lite_idims[" << i << "]: " << param.inputs[i].second->dims();
CHECK_EQ(param.inputs[i].second->dims().production(),
npu_idims_[i].GetNumber() * npu_idims_[i].GetChannel() *
npu_idims_[i].GetHeight() * npu_idims_[i].GetWidth());
npu_itensors_[i].reset(new hiai::AiTensor);
......@@ -61,16 +61,16 @@ void GraphCompute::PrepareForRun() {
VLOG(3) << "npu_odims[" << i << "]: " << npu_odims_[i].GetNumber() << ","
<< npu_odims_[i].GetChannel() << "," << npu_odims_[i].GetHeight()
<< "," << npu_odims_[i].GetWidth();
VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i]->dims();
VLOG(3) << "lite_odims[" << i << "]: " << param.outputs[i].second->dims();
auto out_size = npu_odims_[i].GetNumber() * npu_odims_[i].GetChannel() *
npu_odims_[i].GetHeight() * npu_odims_[i].GetWidth();
if (param.outputs[i]->dims().production() != out_size) {
param.outputs[i]->Resize({npu_odims_[i].GetNumber(),
npu_odims_[i].GetChannel(),
npu_odims_[i].GetHeight(),
npu_odims_[i].GetWidth()});
if (param.outputs[i].second->dims().production() != out_size) {
param.outputs[i].second->Resize({npu_odims_[i].GetNumber(),
npu_odims_[i].GetChannel(),
npu_odims_[i].GetHeight(),
npu_odims_[i].GetWidth()});
}
LOG(INFO) << param.outputs[i]->dims();
LOG(INFO) << param.outputs[i].second->dims();
npu_otensors_[i].reset(new hiai::AiTensor);
npu_otensors_[i]->Init(&(npu_odims_[i]));
}
......@@ -80,7 +80,7 @@ bool GraphCompute::input_dims_changed() const {
auto& param = this->Param<param_t>();
CHECK_EQ(param.inputs.size(), npu_idims_.size());
for (size_t i = 0; i < param.inputs.size(); ++i) {
auto param_idims = param.inputs[i]->dims();
auto param_idims = param.inputs[i].second->dims();
CHECK(!param_idims.empty());
CHECK_EQ(param_idims.size(), 4);
std::vector<int> idims{static_cast<int>(npu_idims_[i].GetNumber()),
......@@ -105,7 +105,7 @@ void GraphCompute::Run() {
CHECK_EQ(param.outputs.size(), npu_otensors_.size());
for (size_t i = 0; i < param.inputs.size(); ++i) {
auto* itensor = param.inputs[i];
auto* itensor = param.inputs[i].second;
CHECK(itensor);
const auto* i_data = itensor->data<float>();
std::memcpy(
......@@ -126,10 +126,10 @@ void GraphCompute::Run() {
CHECK_EQ(hiai::AI_SUCCESS,
model_client_->Process(
model_context_, npu_itensors_, npu_otensors_, 1000, istamp));
LOG(INFO) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
VLOG(3) << "[NPU] Process cost " << GetCurrentUS() - start_time << " us";
for (size_t i = 0; i < param.outputs.size(); ++i) {
auto* otensor = param.outputs[i];
auto* otensor = param.outputs[i].second;
CHECK(otensor);
auto* o_data = otensor->mutable_data<float>();
auto* npu_obuffer = static_cast<float*>(npu_otensors_[i]->GetBuffer());
......
......@@ -5,12 +5,13 @@ set -ex
ARM_OS="android" # android only yet
ARM_ABI="armv8" # armv8, armv7
ARM_LANG="gcc" # gcc only yet
ANDROID_STL="c++_shared" # c++_shared, c++_static
ANDROID_STL="c++_static" # c++_shared, c++_static
DDK_ROOT="$(pwd)/ai_ddk_lib/" # HIAI SDK from https://developer.huawei.com/consumer/cn/hiai/
TARGET_NAME="test_npu_pass" # default target
BUILD_EXTRA=OFF # ON(with sequence ops)/OFF
WITH_JAVA=ON # ON(build jar and jni so)/OFF
WITH_TESTING=ON # ON/OFF
SHUTDOWN_LOG=OFF # ON(disable logging)/OFF
ON_TINY_PUBLISH=OFF # ON(tiny publish)/OFF(full publish)
function print_usage {
......@@ -75,6 +76,7 @@ function build_npu {
fi
if [[ "${ON_TINY_PUBLISH}" == "ON" ]]; then
WITH_TESTING=OFF
SHUTDOWN_LOG=ON
publish_dir="tiny_publish"
else
publish_dir="full_publish"
......@@ -97,6 +99,7 @@ function build_npu {
-DLITE_WITH_LIGHT_WEIGHT_FRAMEWORK=ON \
-DWITH_TESTING=${WITH_TESTING} \
-DLITE_WITH_JAVA=${WITH_JAVA} \
-DLITE_SHUTDOWN_LOG=${SHUTDOWN_LOG} \
-DLITE_WITH_NPU=ON \
-DLITE_ON_TINY_PUBLISH=${ON_TINY_PUBLISH} \
-DANDROID_API_LEVEL=24 \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册