未验证 提交 f92ccf5b 编写于 作者: H HappyAngel 提交者: GitHub

Merge pull request #96 from PaddlePaddle/develop

pull code
......@@ -181,3 +181,6 @@ if (LITE_ON_MODEL_OPTIMIZE_TOOL)
add_definitions("-DLITE_ON_MODEL_OPTIMIZE_TOOL")
endif(LITE_ON_MODEL_OPTIMIZE_TOOL)
if (LITE_WITH_PYTHON)
add_definitions("-DLITE_WITH_PYTHON")
endif(LITE_WITH_PYTHON)
......@@ -36,7 +36,16 @@ else()
# eigen on cuda9.1 missing header of math_funtions.hpp
# https://stackoverflow.com/questions/43113508/math-functions-hpp-not-found-when-using-cuda-with-eigen
GIT_TAG
URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2Feigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip
######################################################################################################
# url address of eigen before v2.3.0
# URL http://paddle-inference-dist.bj.bcebos.com/PaddleLite_ThirdParty%2Feigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip
######################################################################################################
# url address of eigen since v2.6.0
# github address: https://github.com/eigenteam/eigen-git-mirror
# we changed the source code to adapt for windows compiling
# git diffs : (1) unsupported/Eigen/CXX11/src/Tensor/TensorBlockV2.h
######################################################################################################
URL https://paddlelite-data.bj.bcebos.com/third_party_libs/eigen-git-mirror-master-9ab917e9db99f5907d086aa73d5f9103.zip
DOWNLOAD_DIR ${EIGEN_SOURCECODE_DIR}
DOWNLOAD_NO_PROGRESS 1
PREFIX ${EIGEN_SOURCE_DIR}
......
......@@ -381,6 +381,9 @@ function(add_kernel TARGET device level)
endif()
if ("${device}" STREQUAL "MLU")
if (NOT LITE_WITH_MLU)
foreach(src ${args_SRCS})
file(APPEND ${fake_kernels_src_list} "${CMAKE_CURRENT_SOURCE_DIR}/${src}\n")
endforeach()
return()
endif()
set(mlu_kernels "${mlu_kernels};${TARGET}" CACHE INTERNAL "")
......
......@@ -92,7 +92,7 @@ if (LITE_WITH_PYTHON)
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/python/pybind/liblite_pybind.so" "${INFER_LITE_PUBLISH_ROOT}/python/install/lite/lite.so"
COMMAND cp "${CMAKE_BINARY_DIR}/lite/api/python/pybind/liblite_pybind.so" "${INFER_LITE_PUBLISH_ROOT}/python/lib/lite.so")
add_custom_target(publish_inference_python_installer ${TARGET}
COMMAND python setup.py bdist_wheel
COMMAND ${PYTHON_EXECUTABLE} setup.py bdist_wheel
WORKING_DIRECTORY ${INFER_LITE_PUBLISH_ROOT}/python/install/
DEPENDS publish_inference_python_lib)
add_custom_target(publish_inference_python_light_demo ${TARGET}
......
......@@ -190,7 +190,11 @@ if(WITH_TESTING)
lite_cc_test(test_classify_lite_bm SRCS test_classify_lite_bm.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${bm_kernels} ${bm_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/resnet50)
ARGS --model_dir=${LITE_MODEL_DIR}/classify)
lite_cc_test(test_yolov3_lite_bm SRCS test_yolov3_lite_bm.cc
DEPS mir_passes lite_api_test_helper paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels} ${bm_kernels} ${bm_bridges}
ARGS --model_dir=${LITE_MODEL_DIR}/yolov3)
endif()
endif()
endif()
......
......@@ -63,6 +63,7 @@ USE_LITE_OP(swish)
USE_LITE_OP(log)
USE_LITE_OP(exp)
USE_LITE_OP(conv2d_transpose)
USE_LITE_OP(depthwise_conv2d_transpose)
USE_LITE_OP(negative)
USE_LITE_OP(pad2d)
USE_LITE_OP(power)
......
......@@ -27,6 +27,9 @@
#include "lite/utils/cp_logging.h"
#include "lite/utils/string.h"
DEFINE_string(optimized_model_path,
"",
"the path of the model that is optimized by opt.");
DEFINE_string(model_dir,
"",
"the path of the model, the model and param files is under "
......@@ -61,10 +64,7 @@ DEFINE_int32(threads, 1, "threads num");
DEFINE_string(result_filename,
"result.txt",
"save the inference time to the file.");
DEFINE_bool(run_model_optimize,
false,
"if set true, apply model_optimize_tool to "
"model and use optimized model to test. ");
DEFINE_bool(show_output, false, "Wether to show the output in shell.");
namespace paddle {
namespace lite_api {
......@@ -100,15 +100,23 @@ void OutputOptModel(const std::string& save_optimized_model_dir) {
LOG(INFO) << "Save optimized model to " << save_optimized_model_dir;
}
int64_t ShapeProduction(const std::vector<int64_t>& shape) {
int64_t num = 1;
for (auto i : shape) {
num *= i;
}
return num;
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
void Run(const std::vector<int64_t>& input_shape,
const std::string& model_dir,
const std::string& model_path,
const std::string model_name) {
// set config and create predictor
lite_api::MobileConfig config;
config.set_threads(FLAGS_threads);
config.set_power_mode(static_cast<PowerMode>(FLAGS_power_mode));
config.set_model_from_file(model_dir + ".nb");
config.set_model_from_file(model_path);
auto predictor = lite_api::CreatePaddlePredictor(config);
......@@ -116,10 +124,7 @@ void Run(const std::vector<int64_t>& input_shape,
auto input_tensor = predictor->GetInput(0);
input_tensor->Resize(input_shape);
auto input_data = input_tensor->mutable_data<float>();
int input_num = 1;
for (size_t i = 0; i < input_shape.size(); ++i) {
input_num *= input_shape[i];
}
int64_t input_num = ShapeProduction(input_shape);
if (FLAGS_input_img_path.empty()) {
for (int i = 0; i < input_num; ++i) {
input_data[i] = 1.f;
......@@ -167,26 +172,78 @@ void Run(const std::vector<int64_t>& input_shape,
ofs << "average = " << std::setw(12) << avg_res;
ofs << std::endl;
ofs.close();
if (FLAGS_show_output) {
auto out_tensor = predictor->GetOutput(0);
auto* out_data = out_tensor->data<float>();
int64_t output_num = ShapeProduction(out_tensor->shape());
float max_value = out_data[0];
int max_index = 0;
for (int i = 0; i < output_num; i++) {
if (max_value < out_data[i]) {
max_value = out_data[i];
max_index = i;
}
}
LOG(INFO) << "max_value:" << max_value;
LOG(INFO) << "max_index:" << max_index;
LOG(INFO) << "output data[0:10]:";
for (int i = 0; i < 10; i++) {
LOG(INFO) << out_data[i];
}
}
}
#endif
} // namespace lite_api
} // namespace paddle
void print_usage() {
std::string help_info =
"Usage: \n"
"./benchmark_bin \n"
" --optimized_model_path (The path of the model that is optimized\n"
" by opt. If the model is optimized, please set the param.) \n"
" type: string \n"
" --model_dir (The path of the model that is not optimized by opt,\n"
" the model and param files is under model_dir.) type: string \n"
" --model_filename (The filename of model file. When the model is\n "
" combined formate, please set model_file. Otherwise, it is not\n"
" necessary to set it.) type: string \n"
" --param_filename (The filename of param file, set param_file when\n"
" the model is combined formate. Otherwise, it is not necessary\n"
" to set it.) type: string \n"
" --input_shape (Set input shapes according to the model, separated by\n"
" colon and comma, such as 1,3,244,244) type: string\n"
" default: 1,3,224,224 \n"
" --input_img_path (The path of input image, if not set\n"
" input_img_path, the input will be 1.0.) type: string \n "
" --power_mode (Arm power mode: 0 for big cluster, 1 for little\n"
" cluster, 2 for all cores, 3 for no bind) type: int32 default: 3\n"
" --repeats (Repeats times) type: int32 default: 1 \n"
" --result_filename (Save the inference time to the file.) type: \n"
" string default: result.txt \n"
" --threads (Threads num) type: int32 default: 1 \n"
" --warmup (Warmup times) type: int32 default: 0 \n"
"Note that: \n"
" If load the optimized model, set optimized_model_path. Otherwise, \n"
" set model_dir, model_filename and param_filename according to \n"
" the model. \n";
LOG(INFO) << help_info;
}
int main(int argc, char** argv) {
// Check inputs
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (FLAGS_model_dir == "") {
LOG(INFO) << "Please run ./benchmark_bin --help to obtain usage.";
bool is_opt_model = (FLAGS_optimized_model_path != "");
bool is_origin_model = (FLAGS_model_dir != "");
if (!is_origin_model && !is_opt_model) {
LOG(INFO) << "Input error, the model path should not be empty.\n";
print_usage();
exit(0);
}
if (FLAGS_model_dir.back() == '/') {
FLAGS_model_dir.pop_back();
}
std::size_t found = FLAGS_model_dir.find_last_of("/");
std::string model_name = FLAGS_model_dir.substr(found + 1);
std::string save_optimized_model_dir = FLAGS_model_dir + "_opt2";
// Get input shape
auto get_shape = [](const std::string& str_shape) -> std::vector<int64_t> {
std::vector<int64_t> shape;
std::string tmp_str = str_shape;
......@@ -202,19 +259,31 @@ int main(int argc, char** argv) {
}
return shape;
};
std::vector<int64_t> input_shape = get_shape(FLAGS_input_shape);
// Output optimized model if needed
if (FLAGS_run_model_optimize) {
paddle::lite_api::OutputOptModel(save_optimized_model_dir);
// Get model_name and run_model_path
std::string model_name;
std::string run_model_path;
if (is_origin_model) {
if (FLAGS_model_dir.back() == '/') {
FLAGS_model_dir.pop_back();
}
std::size_t found = FLAGS_model_dir.find_last_of("/");
model_name = FLAGS_model_dir.substr(found + 1);
std::string optimized_model_path = FLAGS_model_dir + "_opt2";
paddle::lite_api::OutputOptModel(optimized_model_path);
run_model_path = optimized_model_path + ".nb";
} else {
size_t found1 = FLAGS_optimized_model_path.find_last_of("/");
size_t found2 = FLAGS_optimized_model_path.find_last_of(".");
size_t len = found2 - found1 - 1;
model_name = FLAGS_optimized_model_path.substr(found1 + 1, len);
run_model_path = FLAGS_optimized_model_path;
}
#ifdef LITE_WITH_LIGHT_WEIGHT_FRAMEWORK
// Run inference using optimized model
std::string run_model_dir =
FLAGS_run_model_optimize ? save_optimized_model_dir : FLAGS_model_dir;
paddle::lite_api::Run(input_shape, run_model_dir, model_name);
// Run test
paddle::lite_api::Run(input_shape, run_model_path, model_name);
#endif
return 0;
}
......@@ -292,9 +292,10 @@ void Predictor::Build(const cpp::ProgramDesc &desc,
program_desc_ = desc;
// `inner_places` is used to optimize passes
std::vector<Place> inner_places = valid_places;
inner_places.emplace_back(TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny));
inner_places.emplace_back(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kNCHW));
for (auto &valid_place : valid_places) {
inner_places.emplace_back(
Place(TARGET(kHost), valid_place.precision, valid_place.layout));
}
// Analysis whether the modle is quantized.
// For quantized model, add place(arm, int8) to inner_places
......
......@@ -32,12 +32,17 @@ namespace lite {
void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config_ = config;
auto places = config.valid_places();
std::vector<std::string> passes{};
#ifdef LITE_WITH_CUDA
// if kCUDA is included in valid places, it should be initialized first,
// otherwise skip this step.
for (auto &p : places) {
if (p.target == TARGET(kCUDA)) {
Env<TARGET(kCUDA)>::Init();
if (config_.multi_stream()) {
passes = {"multi_stream_analysis_pass"};
VLOG(3) << "add pass: " << passes[0];
}
break;
}
}
......@@ -51,7 +56,6 @@ void CxxPaddleApiImpl::Init(const lite_api::CxxConfig &config) {
config.mlu_first_conv_std(),
config.mlu_input_layout());
#endif // LITE_WITH_MLU
std::vector<std::string> passes{};
auto use_layout_preprocess_pass =
config.model_dir().find("OPENCL_PRE_PRECESS");
VLOG(1) << "use_layout_preprocess_pass:" << use_layout_preprocess_pass;
......
......@@ -29,7 +29,10 @@ void LightPredictor::Build(const std::string& lite_model_file,
LoadModelNaiveFromFile(lite_model_file, scope_.get(), &cpp_program_desc_);
}
// For weight quantization of post training, load the int8/16 weights
// for optimized model, and dequant it to fp32.
DequantizeWeight();
BuildRuntimeProgram(cpp_program_desc_);
PrepareFeedFetch();
}
......@@ -138,9 +141,6 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
// 2. Create Instructs
#ifdef LITE_WITH_OPENCL
using WaitListType =
std::unordered_map<decltype(static_cast<const void*>(nullptr)),
std::shared_ptr<cl::Event>>;
using OpenCLContext = Context<TargetType::kOpenCL>;
std::unique_ptr<KernelContext> local_ctx(new KernelContext());
local_ctx->As<OpenCLContext>().InitOnce();
......@@ -182,58 +182,76 @@ void LightPredictor::BuildRuntimeProgram(const cpp::ProgramDesc& prog) {
}
void LightPredictor::DequantizeWeight() {
#define PROCESS_CONV2D_DATA() \
for (int64_t i = 0; i < h; ++i) { \
for (int64_t j = 0; j < w; ++j) { \
fp_data[i * w + j] = scale_list[i] * int_data[i * w + j]; \
} \
#define PROCESS_CONV2D_DATA() \
for (int64_t i = 0; i < ch; ++i) { \
for (int64_t j = 0; j < offset; ++j) { \
fp_data[i * offset + j] = scale_list[i] * int_data[i * offset + j]; \
} \
}
#define PROCESS_FC_DATA() \
for (int i = 0; i < input_tensor->numel(); i++) { \
*fp_data = scale_list[0] * (*int_data); \
++fp_data; \
++int_data; \
#define PROCESS_FC_DATA() \
for (int64_t i = 0; i < chin; i++) { \
for (int64_t j = 0; j < chout; j++) { \
fp_data[i * chout + j] = scale_list[j] * int_data[i * chout + j]; \
} \
}
auto is_weight_quantized_op = [](const cpp::OpDesc* op_desc) {
bool result = false;
if (op_desc->HasAttr("quantization_type")) {
std::string type = op_desc->GetAttr<std::string>("quantization_type");
result = (type == "post_weight_abs_max") ||
(type == "post_weight_channel_wise_abs_max");
} else {
result = op_desc->HasAttr("quantize_weight_bits");
}
return result;
};
Tensor tmp_tensor;
CHECK(cpp_program_desc_.BlocksSize());
auto* main_block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(0);
for (size_t k = 0; k < main_block->OpsSize(); ++k) {
auto* op_desc = main_block->GetOp<cpp::OpDesc>(k);
if (op_desc->HasAttr("quantize_weight_bits")) { // weight quantized op
auto input_names = op_desc->input_vars();
for (auto& input_name : input_names) {
std::string input_scale_name = input_name + "_quant_scale";
if (op_desc->HasAttr(input_scale_name)) { // the input is quantized
auto input_tensor =
scope_->FindVar(input_name)->GetMutable<lite::Tensor>();
tmp_tensor.CopyDataFrom(*input_tensor);
auto scale_list =
op_desc->GetAttr<std::vector<float>>(input_scale_name);
int quantize_weight_bits =
op_desc->GetAttr<int>("quantize_weight_bits");
float* fp_data = input_tensor->mutable_data<float>();
std::string op_type = op_desc->Type();
if (op_type == "conv2d" || op_type == "depthwise_conv2d") {
int64_t h = input_tensor->dims()[0];
int64_t w = input_tensor->numel() / h;
CHECK_EQ(scale_list.size(), h);
if (quantize_weight_bits == 8) {
const int8_t* int_data = tmp_tensor.data<int8_t>();
PROCESS_CONV2D_DATA()
} else {
const int16_t* int_data = tmp_tensor.data<int16_t>();
PROCESS_CONV2D_DATA()
}
} else if (op_type == "fc" || op_type == "mul") {
if (quantize_weight_bits == 8) {
const int8_t* int_data = tmp_tensor.data<int8_t>();
PROCESS_FC_DATA()
} else {
const int16_t* int_data = tmp_tensor.data<int16_t>();
PROCESS_FC_DATA()
for (size_t i = 0; i < cpp_program_desc_.BlocksSize(); i++) {
auto* block = cpp_program_desc_.GetBlock<cpp::BlockDesc>(i);
for (size_t k = 0; k < block->OpsSize(); ++k) {
auto* op_desc = block->GetOp<cpp::OpDesc>(k);
if (is_weight_quantized_op(op_desc)) {
auto input_names = op_desc->input_vars();
for (auto& input_name : input_names) {
std::string input_scale_name = input_name + "_quant_scale";
if (op_desc->HasAttr(input_scale_name)) { // the input is quantized
auto input_tensor =
scope_->FindVar(input_name)->GetMutable<lite::Tensor>();
tmp_tensor.CopyDataFrom(*input_tensor);
auto scale_list =
op_desc->GetAttr<std::vector<float>>(input_scale_name);
int quantize_weight_bits =
op_desc->GetAttr<int>("quantize_weight_bits");
CHECK(quantize_weight_bits == 8 || quantize_weight_bits == 16);
float* fp_data = input_tensor->mutable_data<float>();
std::string op_type = op_desc->Type();
if (op_type == "conv2d" || op_type == "depthwise_conv2d") {
int64_t ch = input_tensor->dims()[0];
int64_t offset = input_tensor->numel() / ch;
CHECK_EQ(scale_list.size(), ch);
if (quantize_weight_bits == 8) {
const int8_t* int_data = tmp_tensor.data<int8_t>();
PROCESS_CONV2D_DATA()
} else {
const int16_t* int_data = tmp_tensor.data<int16_t>();
PROCESS_CONV2D_DATA()
}
} else if (op_type == "fc" || op_type == "mul") {
int64_t chin = input_tensor->dims()[0];
int64_t chout = input_tensor->dims()[1];
CHECK_EQ(scale_list.size(), chout);
if (quantize_weight_bits == 8) {
const int8_t* int_data = tmp_tensor.data<int8_t>();
PROCESS_FC_DATA()
} else {
const int16_t* int_data = tmp_tensor.data<int16_t>();
PROCESS_FC_DATA()
}
}
}
}
......
......@@ -136,6 +136,9 @@ class LITE_API CxxConfig : public ConfigBase {
#ifdef LITE_WITH_X86
int x86_math_library_math_threads_ = 1;
#endif
#ifdef LITE_WITH_CUDA
bool multi_stream_{false};
#endif
#ifdef LITE_WITH_MLU
lite_api::MLUCoreVersion mlu_core_version_{lite_api::MLUCoreVersion::MLU_270};
int mlu_core_number_{1};
......@@ -171,6 +174,10 @@ class LITE_API CxxConfig : public ConfigBase {
return x86_math_library_math_threads_;
}
#endif
#ifdef LITE_WITH_CUDA
void set_multi_stream(bool multi_stream) { multi_stream_ = multi_stream; }
int multi_stream() const { return multi_stream_; }
#endif
#ifdef LITE_WITH_MLU
// set MLU core version, which is used when compiling MLU kernels
......
......@@ -42,6 +42,7 @@ USE_MIR_PASS(type_precision_cast_pass);
USE_MIR_PASS(type_layout_cast_pass);
USE_MIR_PASS(type_layout_cast_preprocess_pass);
USE_MIR_PASS(memory_optimize_pass);
USE_MIR_PASS(multi_stream_analysis_pass);
USE_MIR_PASS(elementwise_mul_constant_eliminate_pass)
USE_MIR_PASS(npu_subgraph_pass);
USE_MIR_PASS(xpu_subgraph_pass);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <fstream>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input_img_txt_path,
"",
"if set input_img_txt_path, read the img filename as input.");
namespace paddle {
namespace lite {
void TestModel(const std::vector<Place>& valid_places) {
lite::Predictor predictor;
std::vector<std::string> passes;
predictor.Build(FLAGS_model_dir,
FLAGS_model_dir + "/model",
FLAGS_model_dir + "/params",
valid_places,
passes);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(
std::vector<DDim::value_type>({1, 3, FLAGS_im_height, FLAGS_im_width})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
if (FLAGS_input_img_txt_path.empty()) {
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
} else {
std::fstream fs(FLAGS_input_img_txt_path, std::ios::in);
if (!fs.is_open()) {
LOG(FATAL) << "open input_img_txt error.";
}
for (int i = 0; i < item_size; i++) {
fs >> data[i];
}
}
auto* image_tensor = predictor.GetInput(1);
image_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 2})));
data = image_tensor->mutable_data<float>();
data[0] = FLAGS_im_height;
data[1] = FLAGS_im_width;
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
predictor.Run();
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
auto out = predictor.GetOutputs();
FILE* fp = fopen("result.txt", "wb");
for (int i = 0; i < out.size(); i++) {
auto* out_data = out[i]->data<float>();
for (int j = 0; j < out[i]->numel(); j++) {
fprintf(fp, "%f\n", out_data[j]);
}
}
fclose(fp);
}
TEST(Yolov3, test_bm) {
std::vector<Place> valid_places({Place{TARGET(kBM), PRECISION(kFloat)},
Place{TARGET(kX86), PRECISION(kFloat)}});
TestModel(valid_places);
}
} // namespace lite
} // namespace paddle
......@@ -744,6 +744,15 @@ void act_reciprocal<float>(const float* din,
}
}
template <>
void act_abs<float>(const float* din, float* dout, int size, int threads) {
for (int i = 0; i < size; ++i) {
dout[0] = (din[0] > 0 ? din[0] : -din[0]);
din++;
dout++;
}
}
#ifdef LITE_WITH_TRAIN
template <>
void act_square_grad(const float* din,
......
......@@ -83,6 +83,9 @@ void act_hard_swish(const T* din,
template <typename T>
void act_reciprocal(const T* din, T* dout, int size, int threads);
template <typename T>
void act_abs(const T* din, T* dout, int size, int threads);
#ifdef LITE_WITH_TRAIN
template <typename T>
void act_square_grad(
......
......@@ -16,46 +16,3 @@
#include <algorithm>
#include <limits>
#include <memory>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace arm {
namespace math {
void concat_func(const std::vector<lite::Tensor *> &input,
const int axis,
lite::Tensor *output) {
int64_t concat_input_size = 1;
int64_t num_cancats = 1;
auto dim_0 = input[0]->dims();
size_t num = input.size();
for (int i = axis + 1; i < dim_0.size(); i++) {
concat_input_size *= dim_0[i];
}
for (int i = 0; i < axis; i++) {
num_cancats *= dim_0[i];
}
float *dst_ptr = output->mutable_data<float>();
const int out_concat_axis = output->dims()[axis];
int64_t offset_concat_axis = 0;
int64_t out_sum = out_concat_axis * concat_input_size;
for (int n = 0; n < num; n++) {
auto dims = input[n]->dims();
const float *src_ptr = input[n]->data<float>();
int64_t in_concat_axis = dims[axis];
float *dout_ptr = dst_ptr + offset_concat_axis * concat_input_size;
int64_t in_sum = in_concat_axis * concat_input_size;
for (int i = 0; i < num_cancats; i++) {
std::memcpy(dout_ptr, src_ptr, sizeof(float) * in_sum);
dout_ptr += out_sum;
src_ptr += in_sum;
}
offset_concat_axis += in_concat_axis;
}
}
} // namespace math
} // namespace arm
} // namespace lite
} // namespace paddle
......@@ -25,9 +25,39 @@ namespace lite {
namespace arm {
namespace math {
void concat_func(const std::vector<lite::Tensor *> &input,
template <typename T>
void concat_func(const std::vector<lite::Tensor*>& input,
const int axis,
lite::Tensor *output);
lite::Tensor* output) {
size_t num = input.size();
auto dim_0 = input[0]->dims();
int64_t concat_input_size = 1;
int64_t num_cancats = 1;
for (int i = axis + 1; i < dim_0.size(); i++) {
concat_input_size *= dim_0[i];
}
for (int i = 0; i < axis; i++) {
num_cancats *= dim_0[i];
}
auto* dst_ptr = output->mutable_data<T>();
const int out_concat_axis = output->dims()[axis];
int64_t offset_concat_axis = 0;
int64_t out_sum = out_concat_axis * concat_input_size;
for (int n = 0; n < num; n++) {
auto dims = input[n]->dims();
auto* src_ptr = input[n]->data<T>();
int64_t in_concat_axis = dims[axis];
auto* dout_ptr = dst_ptr + offset_concat_axis * concat_input_size;
int64_t in_sum = in_concat_axis * concat_input_size;
for (int i = 0; i < num_cancats; i++) {
std::memcpy(dout_ptr, src_ptr, sizeof(T) * in_sum);
dout_ptr += out_sum;
src_ptr += in_sum;
}
offset_concat_axis += in_concat_axis;
}
}
} // namespace math
} // namespace arm
......
......@@ -5,5 +5,7 @@ get_property(cuda_deps GLOBAL PROPERTY CUDA_MODULES)
nv_library(target_wrapper_cuda SRCS target_wrapper.cc DEPS ${cuda_deps})
nv_library(cuda_blas SRCS blas.cc DEPS ${cuda_deps})
lite_cc_library(cuda_context SRCS context.cc DEPS device_info)
add_subdirectory(math)
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/context.h"
namespace paddle {
namespace lite {} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/target_wrapper.h"
#include "lite/core/device_info.h"
namespace paddle {
namespace lite {
template <TargetType Type>
class Context;
using CUDAContext = Context<TargetType::kCUDA>;
// Only works with CUDA kernels.
template <>
class Context<TargetType::kCUDA> {
public:
typename Env<TargetType::kCUDA>::Devs& devs =
Env<TargetType::kCUDA>::Global();
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {
if (devs.size() > 0) {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
} else {
LOG(INFO) << "No cuda device(s) found, CUDAContext init failed.";
}
}
void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) {
CHECK_GT(devs.size(), 0UL)
<< "Env is not initialized or current target is not exit!";
if (dev_id >= static_cast<int>(devs.size())) {
LOG(WARNING) << "device index exceeds the number of devices, set to "
"default device(0)!";
device_id_ = 0;
} else {
device_id_ = dev_id;
}
if (io_stream_id >= devs[dev_id].max_stream()) {
LOG(WARNING) << "data stream index exceeds the maximum stream number, "
"set to default stream(0)!";
io_stream_id = 0;
}
if (exec_stream_id >= devs[dev_id].max_stream()) {
LOG(WARNING) << "exec stream index exceeds the maximum stream number, "
"set to default stream(0)!";
exec_stream_id = 0;
}
exec_stream_ = devs[dev_id].exec_streams()[exec_stream_id];
io_stream_ = devs[dev_id].io_streams()[io_stream_id];
exec_stream_id_ = exec_stream_id;
io_stream_id_ = io_stream_id;
need_sync_ = false;
}
void CopySharedTo(CUDAContext* ctx) {
CHECK(ctx);
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
ctx->cublas_fp32_ = cublas_fp32_;
}
const cudaStream_t& exec_stream() const { return exec_stream_; }
void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; }
const cudaStream_t& io_stream() const { return io_stream_; }
void SetIoStream(cudaStream_t stream) { io_stream_ = stream; }
std::shared_ptr<cuda::Blas<float>> cublas_fp32() { return cublas_fp32_; }
void SetCuBlasFP32(std::shared_ptr<cuda::Blas<float>> cublas_fp32) {
cublas_fp32_ = cublas_fp32;
}
const std::vector<cudaEvent_t>& input_events() { return input_events_; }
void SetInputEvents(const std::vector<cudaEvent_t>& input_events) {
input_events_.clear();
input_events_.assign(input_events.begin(), input_events.end());
}
const std::vector<cudaEvent_t>& output_events() { return output_events_; }
void SetOutputEvents(const std::vector<cudaEvent_t>& output_events) {
output_events_.clear();
output_events_.assign(output_events.begin(), output_events.end());
}
std::vector<cudaStream_t> all_exec_streams() {
int dev_id = TargetWrapper<TargetType::kCUDA>::GetCurDevice();
return devs[dev_id].exec_streams();
}
void SetSyncStreams(const std::vector<int>& nums) {
sync_streams_.clear();
std::vector<cudaStream_t> exec_streams = all_exec_streams();
for (size_t i = 0; i < nums.size(); ++i) {
CHECK(nums[i] >= 0 && nums[i] < static_cast<int>(exec_streams.size()))
<< "streams id is not valid";
sync_streams_.push_back(exec_streams[nums[i]]);
}
InitSyncEvents(nums.size());
}
void InitSyncEvents(const int num) {
sync_events_.clear();
for (int i = 0; i < num; ++i) {
cudaEvent_t eve;
TargetWrapperCuda::CreateEventWithFlags(&eve);
sync_events_.push_back(eve);
}
}
void SetNeedSync(bool sync) { need_sync_ = sync; }
bool need_sync() const { return need_sync_; }
void Sync() {
CHECK_EQ(sync_streams_.size(), sync_events_.size());
for (size_t i = 0; i < sync_events_.size(); ++i) {
TargetWrapperCuda::RecordEvent(sync_events_[i], sync_streams_[i]);
TargetWrapperCuda::StreamSync(exec_stream_, sync_events_[i]);
}
}
std::string name() const { return "CUDAContext"; }
CUDAContext& operator=(const CUDAContext& context) {
this->Init(
context.device_id_, context.exec_stream_id_, context.io_stream_id_);
cublas_fp32_ = const_cast<CUDAContext&>(context).cublas_fp32();
return *this;
}
private:
int device_id_;
// overall information
int exec_stream_id_;
int io_stream_id_;
cudaStream_t exec_stream_;
cudaStream_t io_stream_;
// not thread-safe, should allocate for each thread.
std::shared_ptr<cuda::Blas<float>> cublas_fp32_;
// kernel information
std::vector<cudaEvent_t> input_events_;
std::vector<cudaEvent_t> output_events_;
// multi stream sync.
std::vector<cudaStream_t> sync_streams_;
std::vector<cudaEvent_t> sync_events_;
bool need_sync_;
};
} // namespace lite
} // namespace paddle
......@@ -58,7 +58,7 @@ void CLContext::AddKernel(const std::string &kernel_name,
auto program = GetProgram(file_name, options);
VLOG(3) << " --- end get program --- ";
VLOG(3) << " --- to create kernel: " << kernel_name << " --- ";
std::unique_ptr<cl::Kernel> kernel(
std::shared_ptr<cl::Kernel> kernel(
new cl::Kernel(program, kernel_name.c_str(), &status));
CL_CHECK_FATAL(status);
VLOG(3) << " --- end create kernel --- ";
......
......@@ -29,13 +29,14 @@ class CLContext {
public:
~CLContext() {
for (size_t kidx = 0; kidx < kernels_.size(); ++kidx) {
clReleaseKernel(kernels_[kidx]->get());
// Note(ysh329): Don't need `clReleaseKernel`
kernels_[kidx].reset();
}
kernels_.clear();
kernel_offset_.clear();
for (auto &p : programs_) {
clReleaseProgram(p.second->get());
// Note(ysh329): Dont't need `clReleaseProgram`
p.second.reset();
}
programs_.clear();
LOG(INFO) << "release cl::Program, cl::Kernel finished.";
......@@ -66,9 +67,10 @@ class CLContext {
int divitor = 2);
// cl::NDRange LocalWorkSizeConv1x1(cl::NDRange global_work_size,
// size_t max_work_size);
private:
std::unordered_map<std::string, std::unique_ptr<cl::Program>> programs_;
std::vector<std::unique_ptr<cl::Kernel>> kernels_;
std::vector<std::shared_ptr<cl::Kernel>> kernels_;
std::map<std::string, int> kernel_offset_;
};
......
......@@ -29,12 +29,12 @@ CLRuntime::~CLRuntime() {
command_queue_->flush();
command_queue_->finish();
}
// For controlling the destruction order:
// For controlling the destruction order
command_queue_.reset();
context_.reset();
device_.reset();
platform_.reset();
LOG(INFO) << "release ~CLRuntime() ";
device_info_.clear();
}
bool CLRuntime::Init() {
......
......@@ -55,7 +55,7 @@ class CLRuntime {
std::map<std::string, size_t>& GetDeviceInfo();
private:
CLRuntime() = default;
CLRuntime() { Init(); }
~CLRuntime();
......
......@@ -38,7 +38,7 @@ lite_cc_library(device_info SRCS device_info.cc DEPS tensor)
if (LITE_WITH_ARM)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info CL_DEPS cl_context)
else()
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context)
lite_cc_library(context SRCS context.cc DEPS tensor any device_info eigen3 CL_DEPS cl_context CUDA_DEPS cuda_context)
endif()
#-------------------------------------------- GET CODE META INFO ------------------------------------------
......
......@@ -16,8 +16,7 @@
#include "lite/utils/any.h"
#ifdef LITE_WITH_CUDA
#include "lite/backends/cuda/blas.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/backends/cuda/context.h"
#endif
#ifdef LITE_WITH_OPENCL
#include <unordered_map>
......@@ -53,7 +52,6 @@ class Context;
using HostContext = Context<TargetType::kHost>;
using X86Context = Context<TargetType::kX86>;
using CUDAContext = Context<TargetType::kCUDA>;
using ARMContext = Context<TargetType::kARM>;
using NPUContext = Context<TargetType::kNPU>;
using XPUContext = Context<TargetType::kXPU>;
......@@ -286,103 +284,6 @@ class Context<TargetType::kMLU> {
};
#endif // LITE_WITH_MLU
#ifdef LITE_WITH_CUDA
// Only works with CUDA kernels.
template <>
class Context<TargetType::kCUDA> {
public:
typename Env<TargetType::kCUDA>::Devs& devs =
Env<TargetType::kCUDA>::Global();
// NOTE: InitOnce should only be used by ContextScheduler
void InitOnce() {
if (devs.size() > 0) {
cublas_fp32_ = std::make_shared<lite::cuda::Blas<float>>();
} else {
LOG(INFO) << "No cuda device(s) found, CUDAContext init failed.";
}
}
void Init(int dev_id, int exec_stream_id = 0, int io_stream_id = 0) {
CHECK_GT(devs.size(), 0UL)
<< "Env is not initialized or current target is not exit!";
if (dev_id >= static_cast<int>(devs.size())) {
LOG(WARNING) << "device index exceeds the number of devices, set to "
"default device(0)!";
device_id_ = 0;
} else {
device_id_ = dev_id;
}
if (io_stream_id >= devs[dev_id].max_stream()) {
LOG(WARNING) << "data stream index exceeds the maximum stream number, "
"set to default stream(0)!";
io_stream_id = 0;
}
if (exec_stream_id >= devs[dev_id].max_stream()) {
LOG(WARNING) << "exec stream index exceeds the maximum stream number, "
"set to default stream(0)!";
exec_stream_id = 0;
}
exec_stream_ = devs[dev_id].exec_streams()[exec_stream_id];
io_stream_ = devs[dev_id].io_streams()[io_stream_id];
exec_stream_id_ = exec_stream_id;
io_stream_id_ = io_stream_id;
}
void CopySharedTo(CUDAContext* ctx) {
CHECK(ctx);
CHECK(cublas_fp32_) << "cublas_fp32 should be set first";
ctx->cublas_fp32_ = cublas_fp32_;
}
const cudaStream_t& exec_stream() const { return exec_stream_; }
void SetExecStream(cudaStream_t stream) { exec_stream_ = stream; }
const cudaStream_t& io_stream() const { return io_stream_; }
void SetIoStream(cudaStream_t stream) { io_stream_ = stream; }
std::shared_ptr<cuda::Blas<float>> cublas_fp32() { return cublas_fp32_; }
void SetCuBlasFP32(std::shared_ptr<cuda::Blas<float>> cublas_fp32) {
cublas_fp32_ = cublas_fp32;
}
const std::vector<cudaEvent_t>& input_events() { return input_events_; }
void SetInputEvents(const std::vector<cudaEvent_t>& input_events) {
input_events_.clear();
input_events_.assign(input_events.begin(), input_events.end());
}
const std::vector<cudaEvent_t>& output_events() { return output_events_; }
void SetOutputEvents(const std::vector<cudaEvent_t>& output_events) {
output_events_.clear();
output_events_.assign(output_events.begin(), output_events.end());
}
std::string name() const { return "CUDAContext"; }
CUDAContext& operator=(const CUDAContext& context) {
this->Init(
context.device_id_, context.exec_stream_id_, context.io_stream_id_);
cublas_fp32_ = const_cast<CUDAContext&>(context).cublas_fp32();
return *this;
}
private:
int device_id_;
// overall information
int exec_stream_id_;
int io_stream_id_;
cudaStream_t exec_stream_;
cudaStream_t io_stream_;
// not thread-safe, should allocate for each thread.
std::shared_ptr<cuda::Blas<float>> cublas_fp32_;
// kernel information
std::vector<cudaEvent_t> input_events_;
std::vector<cudaEvent_t> output_events_;
};
#endif
#ifdef LITE_WITH_X86
template <>
class Context<TargetType::kX86> {
......@@ -455,7 +356,9 @@ class ContextScheduler {
return *x;
}
std::unique_ptr<KernelContext> NewContext(TargetType target) {
std::unique_ptr<KernelContext> NewContext(
TargetType target,
/*only used for cuda context*/ int exec_stream_id = 0) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
switch (target) {
case TARGET(kHost):
......@@ -472,7 +375,7 @@ class ContextScheduler {
case TARGET(kCUDA): {
int dev_id = TargetWrapper<TargetType::kCUDA>::GetCurDevice();
auto& context = ctx->As<CUDAContext>();
context.Init(dev_id);
context.Init(dev_id, exec_stream_id);
kernel_contexts_[TargetType::kCUDA].As<CUDAContext>().CopySharedTo(
&context);
} break;
......
......@@ -159,7 +159,7 @@ class Env {
static Devs* devs = new Devs();
return *devs;
}
static void Init(int max_stream = 4) {
static void Init(int max_stream = 6) {
#ifdef LITE_WITH_MLU
CNRT_CALL(cnrtInit(0));
#endif
......@@ -175,6 +175,7 @@ class Env {
} else {
LOG(INFO) << "Found " << count << " device(s)";
}
CHECK_GT(max_stream, 0) << "max_stream must be greater than 0.";
// create all device
for (int i = 0; i < count; i++) {
auto dev = Device<Type>(i, max_stream);
......@@ -234,8 +235,8 @@ class Device<TARGET(kCUDA)> {
std::string name() { return device_prop_.name; }
int core_num() { return device_prop_.multiProcessorCount; }
float max_memory() { return device_prop_.totalGlobalMem / 1048576.; }
std::vector<cudaStream_t> exec_streams() { return exec_stream_; }
std::vector<cudaStream_t> io_streams() { return io_stream_; }
const std::vector<cudaStream_t>& exec_streams() { return exec_stream_; }
const std::vector<cudaStream_t>& io_streams() { return io_stream_; }
int sm_version() { return sm_version_; }
bool has_fp16() { return has_fp16_; }
......
......@@ -37,6 +37,7 @@ lite_cc_library(mir_passes
demo_pass.cc
runtime_context_assign_pass.cc
memory_optimize_pass.cc
multi_stream_analysis_pass.cc
mlu_postprocess_pass.cc
weight_quantization_preprocess_pass.cc
quantized_op_attributes_inference_pass.cc
......
......@@ -116,8 +116,7 @@ void ConvBNFuser::InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) {
}
size_t weight_num = conv_weight_t->data_size();
bool enable_int8 = conv_op_desc->HasAttr("enable_int8") ? true : false;
bool is_weight_quantization =
conv_op_desc->HasAttr("quantize_weight_bits") ? true : false;
bool is_weight_quantization = conv_op_desc->HasAttr("quantize_weight_bits");
// comupte BN alpha and beta
Tensor alpha_tensor, beta_tensor;
......
......@@ -14,6 +14,7 @@
#include "lite/core/mir/generate_program_pass.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/mir/graph_visualize_pass.h"
......@@ -25,10 +26,37 @@ namespace mir {
void GenerateProgramPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
VLOG(4) << "final program \n" << Visualize(graph.get());
for (auto& item : graph->StmtTopologicalOrder()) {
std::vector<Node*> nodes_in_order;
#ifdef LITE_WITH_CUDA
const std::string depend_pass = "multi_stream_analysis_pass";
const std::string attr_name = "nodes_in_order";
mir::Pass* pass = mir::PassManager::Global().LookUp(depend_pass);
if (pass->HasAttr(attr_name)) {
nodes_in_order = pass->GetAttr<std::vector<Node*>>(attr_name);
}
#endif
if (nodes_in_order.empty()) {
nodes_in_order = graph->StmtTopologicalOrder();
}
for (auto& item : nodes_in_order) {
if (item->IsStmt()) {
auto& stmt = item->AsStmt();
VLOG(4) << stmt;
#ifdef LITE_WITH_CUDA
if (stmt.kernels().front()->target() == TargetType::kCUDA) {
stmt.kernels()
.front()
->mutable_context()
->As<CUDAContext>()
.SetNeedSync(stmt.need_sync_);
stmt.kernels()
.front()
->mutable_context()
->As<CUDAContext>()
.SetSyncStreams(stmt.sync_streams_);
}
#endif
insts_.emplace_back(stmt.op(), std::move(stmt.kernels().front()));
}
}
......
......@@ -85,7 +85,23 @@ std::string Visualize(mir::SSAGraph* graph) {
if (!node->IsStmt()) continue;
auto op_info = node->AsStmt().op_info();
auto op_type = op_info->Type();
std::string op_name = string_format("%s%d", op_type.c_str(), op_idx++);
std::string op_name;
if (node->AsStmt().need_sync_) {
std::ostringstream oss;
for (size_t i = 0; i < node->AsStmt().sync_streams_.size(); ++i) {
oss << std::to_string(node->AsStmt().sync_streams_[i]);
if (i != node->AsStmt().sync_streams_.size() - 1) {
oss << ",";
}
}
op_name = string_format("%s%d, stream=%d, sync_streams={%s}",
op_type.c_str(),
op_idx++,
node->AsStmt().stream_id_,
oss.str().c_str());
} else {
op_name = string_format("%s%d", op_type.c_str(), op_idx++);
}
// Add its input&output variables as the Dot nodes
dot.AddNode(op_name,
{Dot::Attr("shape", "box"),
......@@ -93,7 +109,13 @@ std::string Visualize(mir::SSAGraph* graph) {
Dot::Attr("color", "black"),
Dot::Attr("fillcolor", "yellow")});
for (auto& x : node->inlinks) {
auto var_name = x->AsArg().name;
std::string var_name;
if (x->AsArg().lane != -1) {
var_name = string_format(
"%s, lane=%d", x->AsArg().name.c_str(), x->AsArg().lane);
} else {
var_name = x->AsArg().name;
}
if (!exists_var_names.count(var_name)) {
dot.AddNode(var_name, {});
exists_var_names.insert(var_name);
......@@ -101,7 +123,13 @@ std::string Visualize(mir::SSAGraph* graph) {
dot.AddEdge(var_name, op_name, {});
}
for (auto& x : node->outlinks) {
auto var_name = x->AsArg().name;
std::string var_name;
if (x->AsArg().lane != -1) {
var_name = string_format(
"%s, lane=%d", x->AsArg().name.c_str(), x->AsArg().lane);
} else {
var_name = x->AsArg().name;
}
if (!exists_var_names.count(var_name)) {
dot.AddNode(var_name, {});
exists_var_names.insert(var_name);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/multi_stream_analysis_pass.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include "lite/core/device_info.h"
#include "lite/core/mir/graph_visualize_pass.h"
#include "lite/core/mir/pass_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace mir {
void MultiStreamAnalysisPass::CleanUp() {
exec_ops_.clear();
wait_que_.clear();
wait_que_cpu_.clear();
std::queue<int> empty_queue;
while (!exec_que_.empty()) {
exec_que_.pop();
}
ops_in_streams_.clear();
resources_.clear();
map_arg_to_lane_.clear();
op_types_set_.clear();
io_copy_once_num_ = 0;
}
void MultiStreamAnalysisPass::Init(SSAGraph* graph) {
// If not cleaned, the clone will overlay the previous state
CleanUp();
for (auto& op_node : graph->StmtTopologicalOrder()) {
if (op_node->IsStmt()) {
// Set all outputs of op to inaccessible state.
auto outputs = op_node->outlinks;
for (Node* node : outputs) {
CHECK(node->IsArg());
auto& arg = node->AsArg();
if (!resources_.count(arg.name)) {
resources_[arg.name] = false;
}
}
// Set the weight input of op to be accessible.
auto inputs = op_node->inlinks;
for (Node* node : inputs) {
CHECK(node->IsArg());
auto& arg = node->AsArg();
if (arg.is_weight || arg.is_persist) {
resources_[arg.name] = true;
}
}
// feed and io_copy_once op has no dependencies and can be launched
// directly. Other ops are put into the waiting queue.
if (op_node->AsStmt().op_type() == "feed" ||
op_node->AsStmt().op_type() == "io_copy_once") {
exec_que_.push(op_node);
} else {
auto tgt = op_node->AsStmt().kernels().front()->target();
if (tgt == TargetType::kCUDA) {
wait_que_.push_back(op_node);
} else {
wait_que_cpu_.push_back(op_node);
}
}
op_types_set_.insert(op_node->AsStmt().op_type());
}
}
// Set the stream id according to the number of feed ops, and set the output
// of the feed op to be accessible.
int lane = 0;
auto nodes = graph->inputs();
ops_in_streams_.resize(max_stream_);
for (auto& node : nodes) {
std::string::size_type idx = node->AsArg().name.find("feed");
if (idx != std::string::npos) {
for (auto& feed_ops : node->outlinks) {
if (feed_ops->AsStmt().op_type() == "feed") {
// feed op doesn't need to wait sync.
feed_ops->AsStmt().need_sync_ = false;
CHECK_EQ(static_cast<int>(feed_ops->outlinks.size()), 1)
<< "feed op must have one output.";
for (auto& var : feed_ops->outlinks) {
var->AsArg().lane = lane;
map_arg_to_lane_[var->AsArg().name] = lane;
resources_[var->AsArg().name] = true;
}
feed_ops->AsStmt().stream_id_ = lane;
ops_in_streams_[lane].push_back(feed_ops);
++lane;
if (lane >= max_stream_) {
lane = 0;
}
}
}
}
// set all io_copy_once op in the first stream
for (auto& io_copy_once_ops : node->outlinks) {
if (io_copy_once_ops->AsStmt().op_type() == "io_copy_once") {
ops_in_streams_[0].push_back(io_copy_once_ops);
io_copy_once_ops->AsStmt().stream_id_ = 0;
io_copy_once_ops->AsStmt().need_sync_ = false;
++io_copy_once_num_;
}
}
}
}
bool MultiStreamAnalysisPass::CheckOpSupport() {
std::unordered_set<std::string> invalid_op = {
"while", "conditional_block", "conditional_block_infer", "graph_op"};
for (auto& op_type : op_types_set_) {
if (invalid_op.count(op_type)) {
LOG(INFO) << "multi_stream_analysis_pass don't support " << op_type
<< ", just return.";
return false;
}
}
return true;
}
bool MultiStreamAnalysisPass::IsPrepared(Node* stmt_node) {
// feed op are prepared when init.
std::string op_name = stmt_node->AsStmt().op_type();
if (op_name == "feed") {
return true;
}
// Check is op's input are all accessible.
std::vector<std::string> args;
for (auto* ins : stmt_node->inlinks) {
args.push_back(ins->AsArg().name);
}
return CheckAccess(args);
}
bool MultiStreamAnalysisPass::CheckAccess(
const std::vector<std::string>& args) {
if (args.size() == 0) {
return true;
}
for (auto& name : args) {
if (resources_[name]) {
continue;
} else {
return false;
}
}
return true;
}
int MultiStreamAnalysisPass::SelectStreamId(const std::vector<int>& lanes) {
if (lanes.size() == 0) {
return 0;
}
int res = lanes[0];
int exclude_io_copy_once_num = ops_in_streams_[0].size() - io_copy_once_num_;
int min_num = lanes[0] == 0 ? exclude_io_copy_once_num
: ops_in_streams_[lanes[0]].size();
for (size_t i = 1; i < lanes.size(); ++i) {
int ith_num = lanes[i] == 0 ? exclude_io_copy_once_num
: ops_in_streams_[lanes[i]].size();
if (ith_num < min_num) {
res = lanes[i];
min_num = ith_num;
}
}
return res;
}
void MultiStreamAnalysisPass::Launch(Node* stmt_node) {
// record ops launch order.
exec_que_.push(stmt_node);
std::vector<int> lanes;
for (auto& in_arg : stmt_node->inlinks) {
// Weight parameter does not involve stream id, so just skip it.
if (in_arg->AsArg().is_weight || in_arg->AsArg().is_persist) {
continue;
}
if (std::find(lanes.begin(), lanes.end(), in_arg->AsArg().lane) ==
lanes.end()) {
lanes.push_back(in_arg->AsArg().lane);
}
}
int stream_id = SelectStreamId(lanes);
// If all inputs of the op are on multiple streams, they need to be
// synchronized
if (lanes.size() > 1) {
for (size_t i = 0; i < lanes.size(); ++i) {
if (lanes[i] != stream_id) {
stmt_node->AsStmt().sync_streams_.push_back(lanes[i]);
}
}
stmt_node->AsStmt().need_sync_ = true;
}
// io_copy are nodes inserted across devices and need to be synced.
if (stmt_node->AsStmt().op_type() == "io_copy") {
stmt_node->AsStmt().need_sync_ = true;
}
stmt_node->AsStmt().stream_id_ = stream_id;
// set output lane and set the output of op to be accessible.
for (auto& out_arg : stmt_node->outlinks) {
out_arg->AsArg().lane = stream_id;
resources_[out_arg->AsArg().name] = true;
}
ops_in_streams_[stream_id].push_back(stmt_node);
}
void MultiStreamAnalysisPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
#ifdef LITE_WITH_CUDA
typename Env<TargetType::kCUDA>::Devs& devs =
Env<TargetType::kCUDA>::Global();
int dev_id = TargetWrapper<TargetType::kCUDA>::GetCurDevice();
max_stream_ = devs[dev_id].max_stream();
#else
LOG(FATAL) << "Please re-compile by setting the cmake flag LITE_WITH_CUDA=ON";
#endif
// Find the correct startup sequence for op.
Init(graph.get());
bool is_valid = CheckOpSupport();
if (!is_valid) {
return;
}
size_t prev_size;
while (!(this->wait_que_.empty() && this->wait_que_cpu_.empty())) {
prev_size = this->wait_que_.size() + this->wait_que_cpu_.size();
// launch the acessible cuda kernel and remove it from wait que.
for (auto it = this->wait_que_.begin(); it != this->wait_que_.end();) {
if (IsPrepared(*it)) {
Launch(*it);
it = wait_que_.erase(it);
} else {
++it;
}
}
// launch the accessible cpu kernel and remove it from wait que.
for (auto cpu_it = this->wait_que_cpu_.begin();
cpu_it != this->wait_que_cpu_.end();) {
if (IsPrepared(*cpu_it)) {
Launch(*cpu_it);
cpu_it = wait_que_cpu_.erase(cpu_it);
} else {
++cpu_it;
}
}
if (this->wait_que_.size() + this->wait_que_cpu_.size() == prev_size) {
LOG(FATAL) << "network topo error!";
}
}
// Get exec ops order.
while (!exec_que_.empty()) {
auto* node = exec_que_.front();
exec_ops_.push_back(node);
VLOG(4) << node->AsStmt().op_type()
<< " stream: " << node->AsStmt().stream_id_
<< ", sync: " << node->AsStmt().need_sync_;
if (node->AsStmt().need_sync_) {
for (size_t i = 0; i < node->AsStmt().sync_streams_.size(); ++i) {
VLOG(4) << " " << node->AsStmt().sync_streams_[i];
}
}
exec_que_.pop();
}
// Set attribute parameters, for passing parameters between passes
const std::string attr_name{"nodes_in_order"};
SetAttr<std::vector<Node*>>(attr_name, &exec_ops_);
LOG(INFO) << "stream " << 0 << " has "
<< ops_in_streams_[0].size() - io_copy_once_num_
<< " ops. (exclude io_copy_once).";
for (size_t i = 1; i < ops_in_streams_.size(); ++i) {
LOG(INFO) << "stream " << i << " has " << ops_in_streams_[i].size()
<< " ops.";
}
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(multi_stream_analysis_pass,
paddle::lite::mir::MultiStreamAnalysisPass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <list>
#include <memory>
#include <queue>
#include <string>
#include <unordered_map>
#include <unordered_set>
#include <utility>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
/*
* MultiStreamAnalysisPass will find the correct launch sequence for all ops.
* Ideally, the order should be multiple asynchronous ops and a small number of
* synchronous ops.
*/
class MultiStreamAnalysisPass : public StmtPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
private:
// Init resource list. Set all ops except feed to inaccessible state and set
// stream id according to the numer of inputs.
void Init(SSAGraph* graph);
// Clean state information of all member variables.
void CleanUp();
// After launching, unlock the output resources of op.
void Launch(Node* stmt_node);
// If all inputs of an op are accessible, the op is considered to be in the
// prepared state
bool IsPrepared(Node* stmt_node);
// Determine if all inputs of op are accessible.
bool CheckAccess(const std::vector<std::string>& args);
// The logic of selecting a stream:
// 1. Make the number of ops on each stream as close as possible.
// 2. The selected stream must be one of the streams contained in the input
// arg
int SelectStreamId(const std::vector<int>& lanes);
// Check if the model's ops are all supported. If you encounter unsupported
// ops, exit
bool CheckOpSupport();
private:
std::list<Node*> wait_que_;
std::list<Node*> wait_que_cpu_;
std::queue<Node*> exec_que_;
std::vector<Node*> exec_ops_;
std::vector<std::vector<Node*>> ops_in_streams_;
std::unordered_map<std::string, bool> resources_;
std::unordered_map<std::string, int> map_arg_to_lane_;
int max_stream_;
int io_copy_once_num_;
std::unordered_set<std::string> op_types_set_;
};
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -80,6 +80,12 @@ class Node {
// Description.
std::string desc;
// for cuda multi stream
bool need_sync_{false};
int stream_id_{0};
// streams which need to be sync. exclude stream_id_
std::vector<int> sync_streams_{};
};
struct Arg {
......@@ -93,6 +99,7 @@ class Node {
// if the need more than one tool operator(eg. io_copy layout calib), the
// argument between them should be persist to make sure it's only run once
bool is_persist{false};
int lane{-1};
};
Arg& AsArg(const std::string& name, int id);
......
......@@ -17,9 +17,11 @@
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "lite/core/mir/node.h"
#include "lite/core/mir/ssa_graph.h"
#include "lite/utils/varient.h"
namespace paddle {
namespace lite {
......@@ -121,6 +123,27 @@ class Pass {
virtual ~Pass() = default;
bool HasAttr(const std::string& attr_name) const {
return pass_attrs_.count(attr_name) > 0;
}
// Set a pointer to the attribute. Specific pass itself takes ownership of the
// attribute.
template <typename AttrType>
void SetAttr(const std::string& attr_name, const AttrType* attr) {
VLOG(4) << "Setting the attribute " << attr_name << " for the pass "
<< name_;
pass_attrs_[attr_name].set<const AttrType>(*attr);
}
// Get a reference to the attribute previously set.
template <typename AttrType>
const AttrType& GetAttr(const std::string& attr_name) const {
CHECK(pass_attrs_.count(attr_name))
<< attr_name << " attr not register for pass " << name_;
return pass_attrs_.at(attr_name).get<const AttrType>();
}
private:
const Kind kind_;
std::string name_;
......@@ -128,6 +151,8 @@ class Pass {
std::set<TargetType> bound_targets_;
std::set<TargetType> excluded_targets_;
std::unordered_map<std::string, std::set<lite_api::Place>> bound_kernels_;
std::unordered_map<std::string, variant<Node, std::vector<Node*>>>
pass_attrs_;
};
// Different kinds.
......
......@@ -45,9 +45,10 @@ class RuntimeContextAssignPass : public StmtPass {
inst.picked_kernel().target()));
}
#else
inst.picked_kernel().SetContext(
ContextScheduler::Global().NewContext(inst.picked_kernel().target()));
int stream_id = inst.stream_id_;
inst.picked_kernel().SetContext(ContextScheduler::Global().NewContext(
inst.picked_kernel().target(), stream_id));
#endif
}
}
......
......@@ -22,9 +22,29 @@ namespace paddle {
namespace lite {
namespace mir {
bool IsAbsMaxQuantizedOp(const OpInfo& op_info) {
bool result = false;
if (op_info.HasAttr("quantization_type") &&
op_info.GetAttr<std::string>("quantization_type") ==
"post_weight_abs_max") {
result = true;
} else if (!op_info.HasAttr("quantization_type") &&
op_info.HasAttr("quantize_weight_bits")) { // Support older model,
// save this for now
result = true;
}
return result;
}
/*
* For abs_max method in WeightQuantization, this pass obtains the scale value
* of conv2d, depthwise_conv2d and mul, expands the scale list, and save the
* list in the quantized ops.
*/
void WeightQuantizationPreprocessPass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
std::vector<std::string> weight_quantized_op = {"conv2d", "depthwise_conv2d"};
std::vector<std::string> weight_quantized_op = {
"conv2d", "depthwise_conv2d", "mul"};
for (auto& node : graph->StmtTopologicalOrder()) {
if (node->IsStmt() &&
std::find(weight_quantized_op.begin(),
......@@ -32,14 +52,20 @@ void WeightQuantizationPreprocessPass::Apply(
node->AsStmt().op_type()) != weight_quantized_op.end()) {
auto* scope = node->stmt()->op()->scope();
auto* op_desc = node->stmt()->mutable_op_info();
if (op_desc->HasAttr("quantize_weight_bits")) {
if (IsAbsMaxQuantizedOp(*op_desc)) {
for (auto& input_name : op_desc->input_vars()) {
std::string scale_name = input_name + "_quant_scale";
if (op_desc->HasAttr(scale_name)) {
VLOG(5) << "op:" << op_desc->Type() << " input_name:" << input_name;
VLOG(0) << " WeightQuantizationPreprocessPass op:"
<< op_desc->Type() << " input_name:" << input_name;
auto input_tensor =
scope->FindVar(input_name)->GetMutable<lite::Tensor>();
int weight_out_channel = static_cast<int>(input_tensor->dims()[0]);
int weight_out_channel;
if (op_desc->Type() == "mul") {
weight_out_channel = static_cast<int>(input_tensor->dims()[1]);
} else {
weight_out_channel = static_cast<int>(input_tensor->dims()[0]);
}
auto input_scale = op_desc->GetAttr<std::vector<float>>(scale_name);
// scale length is equal to weight out channel
std::vector<float> scale_list(weight_out_channel, input_scale[0]);
......
......@@ -25,8 +25,9 @@ namespace mir {
* If the model is quantized by WeightQuantization in PostTrainingQuantization,
* the data type of the weight in quantized ops (conv2d, depthwise_conv2d) is
* int, and the scale is save in the quantized ops.
* WeightQuantizationPreprocessPass obtains the scale value, expands the
* scale value to a list, and save the list in the quantized ops.
* For abs_max method in WeightQuantization, WeightQuantizationPreprocessPass
* obtains the scale value of conv2d, depthwise_conv2d and mul, expands the
* scale list, and save the list in the quantized ops.
*/
class WeightQuantizationPreprocessPass : public ProgramPass {
public:
......
......@@ -151,16 +151,30 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kMLU, kInt16, kNHWC);
INIT_FOR(kMLU, kInt16, kNCHW);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kInt32, kNCHW);
INIT_FOR(kHost, kInt64, kNCHW);
INIT_FOR(kHost, kAny, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
INIT_FOR(kHost, kFloat, kAny);
INIT_FOR(kHost, kAny, kNHWC);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kHost, kAny, kNHWC);
INIT_FOR(kHost, kAny, kAny);
INIT_FOR(kHost, kBool, kNCHW);
INIT_FOR(kHost, kBool, kNHWC);
INIT_FOR(kHost, kBool, kAny);
INIT_FOR(kHost, kFloat, kNCHW);
INIT_FOR(kHost, kFloat, kNHWC);
INIT_FOR(kHost, kFloat, kAny);
INIT_FOR(kHost, kFP16, kNCHW);
INIT_FOR(kHost, kFP16, kNHWC);
INIT_FOR(kHost, kFP16, kAny);
INIT_FOR(kHost, kInt8, kNCHW);
INIT_FOR(kHost, kInt8, kNHWC);
INIT_FOR(kHost, kInt8, kAny);
INIT_FOR(kHost, kInt16, kNCHW);
INIT_FOR(kHost, kInt16, kNHWC);
INIT_FOR(kHost, kInt16, kAny);
INIT_FOR(kHost, kInt32, kNCHW);
INIT_FOR(kHost, kInt32, kNHWC);
INIT_FOR(kHost, kInt32, kAny);
INIT_FOR(kHost, kInt64, kNCHW);
INIT_FOR(kHost, kInt64, kNHWC);
INIT_FOR(kHost, kInt64, kAny);
INIT_FOR(kX86, kFloat, kNCHW);
INIT_FOR(kX86, kAny, kNCHW);
......
......@@ -127,7 +127,21 @@ class Optimizer {
"memory_optimize_pass"}};
if (passes.size() == 1) {
passes_local.push_back(passes[0]);
// multi_stream_analysis_pass must be in the front of
// runtime_context_assign_pass
const std::string msa_pass{"multi_stream_analysis_pass"};
const std::string depend_pass{"runtime_context_assign_pass"};
if (passes[0] == msa_pass) {
auto iter =
std::find(passes_local.begin(), passes_local.end(), depend_pass);
if (iter != passes_local.end()) {
passes_local.insert(iter, msa_pass);
} else {
CHECK(false) << "Not find " << depend_pass;
}
} else {
passes_local.push_back(passes[0]);
}
}
RunPasses(passes_local);
} else {
......
......@@ -178,6 +178,13 @@ class PrecisionProfiler {
write_result_to_file&& write_tensorfile<int32_t>(in, name);
return;
}
case PRECISION(kInt64): {
auto ptr = in->data<int64_t>();
*mean = compute_mean<int64_t>(ptr, in->numel());
*std_dev = compute_standard_deviation<int64_t>(
ptr, in->numel(), true, *mean);
return;
}
default:
*mean = -333333333333;
*std_dev = -33333333333;
......
......@@ -145,6 +145,11 @@ void RuntimeProgram::Run() {
for (auto& inst : instructions_) {
#ifndef LITE_WITH_FPGA
if (inst.is_feed_fetch_op()) continue;
#endif
#ifdef LITE_WITH_CUDA
if (inst.need_sync()) {
inst.Sync();
}
#endif
inst.Run();
#ifdef LITE_WITH_PRECISION_PROFILE
......
......@@ -108,6 +108,18 @@ struct Instruction {
bool is_feed_fetch_op() const { return is_feed_fetch_op_; }
#ifdef LITE_WITH_CUDA
bool need_sync() const {
if (kernel_->target() == TargetType::kCUDA) {
return kernel_->mutable_context()->As<CUDAContext>().need_sync();
} else {
// the io_copy kernel has synced, so cpu kernels don't need sync..
return false;
}
}
void Sync() const { kernel_->mutable_context()->As<CUDAContext>().Sync(); }
#endif
#ifdef LITE_WITH_PROFILE
void set_profiler(profile::Profiler* profiler) {
profiler_ = profiler;
......
......@@ -67,31 +67,31 @@ STL::ostream& operator<<(STL::ostream& os, const KernelPickFactor& k) {
template <>
Type StdTypeToRepr<int32_t>() {
return Type::_int32;
return Type::INT32;
}
template <>
Type StdTypeToRepr<int64_t>() {
return Type::_int64;
return Type::INT64;
}
template <>
Type StdTypeToRepr<float>() {
return Type::_float32;
return Type::FLOAT32;
}
template <>
Type StdTypeToRepr<double>() {
return Type::_float64;
return Type::Float64;
}
template <>
Type StdTypeToRepr<std::vector<char>>() {
return Type::_char_list;
return Type::CHARLIST;
}
template <>
Type StdTypeToRepr<std::string>() {
return Type::_string;
return Type::STRING;
}
template <>
Type StdTypeToRepr<bool>() {
return Type::_bool;
return Type::BOOL;
}
} // namespace core
......
......@@ -29,23 +29,23 @@ namespace core {
*/
// TODO(Superjomn) unify all the type representation across the lite framework.
enum class Type {
_unk = -1,
UNK = -1,
// primary types
_int32,
_int64,
_float32,
_float64,
_bool,
_string,
INT32,
INT64,
FLOAT32,
Float64,
BOOL,
STRING,
// primary list type
_char_list,
CHARLIST,
// list types
_list,
LIST,
// enum type
_enum,
_float16,
ENUM,
FLOAT16,
// number of types
__num__,
NUM,
};
enum class FluidType {
......@@ -81,7 +81,7 @@ enum class FluidType {
template <typename T>
Type StdTypeToRepr() {
return Type::_unk;
return Type::UNK;
}
template <>
Type StdTypeToRepr<int32_t>();
......
......@@ -63,7 +63,6 @@ add_kernel(lrn_compute_arm ARM extra SRCS lrn_compute.cc DEPS ${lite_kernel_deps
add_kernel(decode_bboxes_compute_arm ARM extra SRCS decode_bboxes_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(density_prior_box_compute_arm ARM basic SRCS density_prior_box_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(axpy_compute_arm ARM extra SRCS axpy_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(shape_compute_arm ARM extra SRCS shape_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_max_compute_arm ARM extra SRCS reduce_max_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_expand_compute_arm ARM extra SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM extra SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -92,7 +91,6 @@ add_kernel(lookup_table_dequant_compute_arm ARM extra SRCS lookup_table_dequant_
add_kernel(logical_compute_arm ARM extra SRCS logical_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_softmax_compute_arm ARM extra SRCS sequence_softmax_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(while_compute_arm ARM extra SRCS while_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(compare_compute_arm ARM extra SRCS compare_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(topk_compute_arm ARM extra SRCS topk_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(increment_compute_arm ARM extra SRCS increment_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(write_to_array_compute_arm ARM extra SRCS write_to_array_compute.cc DEPS ${lite_kernel_deps} math_arm)
......
......@@ -207,6 +207,16 @@ void ReciprocalCompute::Run() {
x_data, output_data, x_dims.production(), ctx.threads());
}
void AbsCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<ARMContext>();
auto x_dims = param.X->dims();
auto x_data = param.X->data<float>();
auto output_data = param.Out->mutable_data<float>();
lite::arm::math::act_abs<float>(
x_data, output_data, x_dims.production(), ctx.threads());
}
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -321,3 +331,8 @@ REGISTER_LITE_KERNEL(reciprocal,
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
abs, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::AbsCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
......@@ -166,6 +166,15 @@ class ReciprocalCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~ReciprocalCompute() = default;
};
class AbsCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::ActivationParam;
void Run() override;
virtual ~AbsCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/compare_compute.h"
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/core/op_registry.h"
#include "lite/core/type_system.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
#define COMPARE_FUNCTOR(name, op) \
template <typename T> \
struct _##name##Functor { \
inline bool operator()(const T &a, const T &b) const { return a op b; } \
};
COMPARE_FUNCTOR(Equal, ==);
COMPARE_FUNCTOR(NotEqual, !=);
COMPARE_FUNCTOR(LessThan, <);
COMPARE_FUNCTOR(LessEqual, <=);
COMPARE_FUNCTOR(GreaterThan, >);
COMPARE_FUNCTOR(GreaterEqual, >=);
template <>
struct _EqualFunctor<float> {
inline bool operator()(const float &a, const float &b) const {
// It is safe to cast a and b to double.
return fabs(static_cast<double>(a - b)) < 1e-8;
}
};
template <>
struct _NotEqualFunctor<float> {
inline bool operator()(const float &a, const float &b) const {
return !_EqualFunctor<float>()(a, b);
}
};
inline void get_mid_dims(const lite::DDim &x_dims,
const lite::DDim &y_dims,
const int axis,
int *pre,
int *n,
int *post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
template <template <typename T> class Functor>
void CompareCompute<Functor>::Run() {
auto &param = this->Param<operators::CompareParam>();
using CompareFunctor = Functor<float>;
const size_t x_size = param.X->numel();
const size_t y_size = param.Y->numel();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<float>();
const auto *y = param.Y->template data<float>();
auto axis = param.axis;
bool force_cpu = param.force_cpu;
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
int outer_num, mid_num, inner_num;
get_mid_dims(x_dims, y_dims, axis, &outer_num, &mid_num, &inner_num);
for (int outer_id = 0; outer_id < outer_num; ++outer_id) {
for (int mid_id = 0; mid_id < mid_num; ++mid_id) {
auto y_data = y[mid_id];
for (int inner_id = 0; inner_id < inner_num; ++inner_id) {
int index = (outer_id * mid_num + mid_id) * inner_num + inner_id;
z[index] = CompareFunctor()(x[index], y_data);
// z[index] = x[index] < y_data;
}
}
}
}
}
template <template <typename T> class Functor>
void CompareCompute_int32<Functor>::Run() {
auto &param = this->Param<operators::CompareParam>();
using CompareFunctor = Functor<int>;
const size_t x_size = param.X->numel();
const size_t y_size = param.Y->numel();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<int>();
const auto *y = param.Y->template data<int>();
auto axis = param.axis;
bool force_cpu = param.force_cpu;
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
int outer_num, mid_num, inner_num;
get_mid_dims(x_dims, y_dims, axis, &outer_num, &mid_num, &inner_num);
for (int outer_id = 0; outer_id < outer_num; ++outer_id) {
for (int mid_id = 0; mid_id < mid_num; ++mid_id) {
auto y_data = y[mid_id];
for (int inner_id = 0; inner_id < inner_num; ++inner_id) {
int index = (outer_id * mid_num + mid_id) * inner_num + inner_id;
z[index] = CompareFunctor()(x[index], y_data);
// z[index] = x[index] < y_data;
}
}
}
}
}
template <template <typename T> class Functor>
void CompareCompute_int64<Functor>::Run() {
auto &param = this->Param<operators::CompareParam>();
using CompareFunctor = Functor<int64_t>;
const size_t x_size = param.X->numel();
const size_t y_size = param.Y->numel();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<int64_t>();
const auto *y = param.Y->template data<int64_t>();
auto axis = param.axis;
bool force_cpu = param.force_cpu;
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
int outer_num, mid_num, inner_num;
get_mid_dims(x_dims, y_dims, axis, &outer_num, &mid_num, &inner_num);
for (int outer_id = 0; outer_id < outer_num; ++outer_id) {
for (int mid_id = 0; mid_id < mid_num; ++mid_id) {
auto y_data = y[mid_id];
for (int inner_id = 0; inner_id < inner_num; ++inner_id) {
int index = (outer_id * mid_num + mid_id) * inner_num + inner_id;
z[index] = CompareFunctor()(x[index], y_data);
}
}
}
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(equal,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CompareCompute<
paddle::lite::kernels::arm::_EqualFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(equal,
kARM,
kInt32,
kNCHW,
paddle::lite::kernels::arm::CompareCompute_int32<
paddle::lite::kernels::arm::_EqualFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(not_equal,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CompareCompute<
paddle::lite::kernels::arm::_NotEqualFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(less_than,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CompareCompute<
paddle::lite::kernels::arm::_LessThanFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(less_than,
kARM,
kInt32,
kNCHW,
paddle::lite::kernels::arm::CompareCompute_int32<
paddle::lite::kernels::arm::_LessThanFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(less_than,
kARM,
kInt64,
kNCHW,
paddle::lite::kernels::arm::CompareCompute_int64<
paddle::lite::kernels::arm::_LessThanFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(less_equal,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CompareCompute<
paddle::lite::kernels::arm::_LessEqualFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(greater_than,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CompareCompute<
paddle::lite::kernels::arm::_GreaterThanFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(greater_equal,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::CompareCompute<
paddle::lite::kernels::arm::_GreaterEqualFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
......@@ -34,40 +34,21 @@ std::vector<size_t> stride_numel(const DDim& ddim) {
return strides;
}
void ConcatCompute::Run() {
auto& param = Param<operators::ConcatParam>();
std::vector<lite::Tensor*> inputs = param.x;
auto* out = param.output;
int axis = param.axis;
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0];
}
out->mutable_data<float>();
/// Sometimes direct copies will be faster, this maybe need deeply analysis.
template <typename T>
void ConcatFunc(const std::vector<lite::Tensor*> inputs,
int axis,
lite::Tensor* out) {
// Sometimes direct copies will be faster, this maybe need deeply analysis.
if (axis == 0 && inputs.size() < 10) {
size_t output_offset = 0;
for (auto* in : inputs) {
auto in_stride = stride_numel(in->dims());
auto out_stride = stride_numel(out->dims());
void* dst = out->mutable_data<float>() + output_offset;
const void* src = in->data<float>();
#if 0
LOG(INFO) << "out_stride.size():" << out_stride.size();
LOG(INFO) << "out_stride[0]" << out_stride[0];
for (int i=0; i < out_stride.size(); ++i) {
LOG(INFO) << "out_stride[" << i << "]:" << out_stride[i];
}
LOG(INFO) << "in_stride.size():" << in_stride.size();
for (int i=0; i < in_stride.size(); ++i) {
LOG(INFO) << "in_stride[" << i << "]:" << in_stride[i];
}
#endif
void* dst = out->mutable_data<T>() + output_offset;
const void* src = in->data<T>();
// src and dst tensor should have the same dims size.
CHECK(in_stride.size() == out_stride.size());
std::memcpy(dst, src, sizeof(float) * in_stride[0]);
std::memcpy(dst, src, sizeof(T) * in_stride[0]);
output_offset += in_stride[0];
}
} else {
......@@ -75,9 +56,37 @@ void ConcatCompute::Run() {
for (int j = 0; j < inputs.size(); ++j) {
inputs_concat[j] = inputs[j];
}
lite::arm::math::concat_func(inputs_concat, axis, out);
lite::arm::math::concat_func<T>(inputs_concat, axis, out);
}
}
void ConcatCompute::Run() {
auto& param = Param<operators::ConcatParam>();
std::vector<lite::Tensor*> inputs = param.x;
CHECK_GE(inputs.size(), 1);
auto* out = param.output;
int axis = param.axis;
auto* axis_tensor = param.axis_tensor;
if (axis_tensor != nullptr) {
auto* axis_tensor_data = axis_tensor->data<int>();
axis = axis_tensor_data[0];
}
switch (inputs.front()->precision()) {
case PRECISION(kFloat):
ConcatFunc<float>(inputs, axis, out);
break;
case PRECISION(kInt32):
ConcatFunc<int32_t>(inputs, axis, out);
break;
case PRECISION(kInt64):
ConcatFunc<int64_t>(inputs, axis, out);
break;
default:
LOG(FATAL) << "Concat does not implement for the "
<< "input type:"
<< static_cast<int>(inputs.front()->precision());
}
return;
}
} // namespace arm
......@@ -86,9 +95,9 @@ void ConcatCompute::Run() {
} // namespace paddle
REGISTER_LITE_KERNEL(
concat, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ConcatCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
concat, kARM, kAny, kNCHW, paddle::lite::kernels::arm::ConcatCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("AxisTensor",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -22,7 +22,7 @@ namespace lite {
namespace kernels {
namespace arm {
class ConcatCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class ConcatCompute : public KernelLite<TARGET(kARM), PRECISION(kAny)> {
public:
using param_t = operators::ConcatParam;
......
......@@ -95,7 +95,7 @@ void concat_compute_ref(const operators::ConcatParam& param) {
TEST(concat_arm, init) {
ConcatCompute concat;
ASSERT_EQ(concat.precision(), PRECISION(kFloat));
ASSERT_EQ(concat.precision(), PRECISION(kAny));
ASSERT_EQ(concat.target(), TARGET(kARM));
}
......@@ -222,8 +222,7 @@ TEST(concat_arm, compute_input_multi) {
TEST(concat, retrive_op) {
auto concat =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"concat");
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kAny)>("concat");
ASSERT_FALSE(concat.empty());
ASSERT_TRUE(concat.front());
}
......@@ -233,4 +232,4 @@ TEST(concat, retrive_op) {
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(concat, kARM, kFloat, kNCHW, def);
USE_LITE_KERNEL(concat, kARM, kAny, kNCHW, def);
......@@ -20,24 +20,48 @@ namespace lite {
namespace kernels {
namespace arm {
void GatherCompute::Run() {
auto& param = this->Param<operators::GatherParam>();
auto* p_output = param.Out->mutable_data<float>();
auto index_size = param.Index->dims()[0];
template <typename T>
void GatherFunc(const operators::GatherParam& param) {
auto src_dims = param.X->dims();
const float* p_src = param.X->data<float>();
auto index_size = param.Index->dims()[0];
auto* p_src = param.X->data<T>();
const int* p_index = param.Index->data<int>();
auto* p_output = param.Out->mutable_data<T>();
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
for (size_t i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size,
p_src + index_ * slice_size,
slice_size * sizeof(float));
slice_size * sizeof(T));
}
}
void GatherCompute::Run() {
auto& param = this->Param<operators::GatherParam>();
switch (param.X->precision()) {
case PRECISION(kFloat):
GatherFunc<float>(param);
break;
case PRECISION(kInt8):
GatherFunc<int8_t>(param);
break;
case PRECISION(kInt16):
GatherFunc<int16_t>(param);
break;
case PRECISION(kInt32):
GatherFunc<int32_t>(param);
break;
case PRECISION(kInt64):
GatherFunc<int64_t>(param);
break;
default:
LOG(FATAL) << "Gather does not implement for the "
<< "input type:" << static_cast<int>(param.X->precision());
}
}
......@@ -48,8 +72,8 @@ void GatherCompute::Run() {
REGISTER_LITE_KERNEL(
gather, kARM, kAny, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.BindInput("Index",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))})
.Finalize();
......@@ -45,32 +45,13 @@ void MatMulCompute::Run() {
operators::ActivationParam act_param;
act_param.has_active = false;
if (x_dims.size() > 2 && y_dims.size() >= 2) {
if ((x_dims.size() >= 2 && y_dims.size() >= 2) &&
(x_dims.size() != 2 || y_dims.size() != 2)) {
// x: [B, ..., M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [B, M, K], y: [K, N], out: [B, M, N]
if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 2])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[x_dims.size() - 1], y_dims[y_dims.size() - 1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 2])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[x_dims.size() - 2], y_dims[y_dims.size() - 1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< ") x_transpose is " << x_transpose << "y_transpose is "
<< y_transpose;
}
// or
// x: [M, K], y: [B, ..., K, N], out: [B, ..., M, N]
// x: [M, K], y: [B, K, N], out: [B, M, N]
int lda, ldb, ldc;
if (!x_transpose) {
m_ = x_dims[x_dims.size() - 2];
......@@ -96,11 +77,7 @@ void MatMulCompute::Run() {
int y_inner = y_dims[y_dims.size() - 2] * y_dims[y_dims.size() - 1];
int out_inner = o_dims[o_dims.size() - 2] * o_dims[o_dims.size() - 1];
float* x_data_trans = nullptr;
if (x_transpose) {
x_data_trans = static_cast<float*>(malloc(sizeof(float) * x_inner));
}
if (y_dims.size() > 2) {
if (x_dims.size() > 2 && y_dims.size() > 2) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemm(x_transpose,
y_transpose,
......@@ -120,7 +97,7 @@ void MatMulCompute::Run() {
act_param,
&ctx);
}
} else {
} else if (x_dims.size() > 2 && y_dims.size() == 2) {
for (size_t i = 0; i < x_dims.count(0, x_dims.size() - 2); ++i) {
lite::arm::math::sgemm(x_transpose,
y_transpose,
......@@ -140,34 +117,29 @@ void MatMulCompute::Run() {
act_param,
&ctx);
}
}
if (x_data_trans) {
free(x_data_trans);
} else if (x_dims.size() == 2 && y_dims.size() > 2) {
for (size_t i = 0; i < y_dims.count(0, y_dims.size() - 2); ++i) {
lite::arm::math::sgemm(x_transpose,
y_transpose,
m_,
n_,
k_,
alpha,
x_data,
lda,
y_data + i * y_inner,
ldb,
0.f,
o_data + i * out_inner,
ldc,
nullptr,
false,
act_param,
&ctx);
}
}
} else if (x_dims.size() == 2 && y_dims.size() == 2) {
// x: [M, K], y: [K, N], out: [M, N]
if (!x_transpose && !y_transpose) {
CHECK_EQ(x_dims[1], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (!x_transpose && y_transpose) {
CHECK_EQ(x_dims[1], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else if (x_transpose && !y_transpose) {
CHECK_EQ(x_dims[0], y_dims[0])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
} else {
CHECK_EQ(x_dims[0], y_dims[1])
<< "not supported x_dims(" << x_dims << ") and y_dims(" << y_dims
<< "), x_transpose is " << x_transpose << ", y_transpose is "
<< y_transpose;
}
int lda, ldb, ldc;
if (!x_transpose) {
m_ = x_dims[0];
......
......@@ -32,6 +32,9 @@ lite_cc_library(subgraph_bridge_squeeze_op_bm SRCS squeeze_op.cc DEPS ${bm_subgr
lite_cc_library(subgraph_bridge_cast_op_bm SRCS cast_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_fill_constant_op_bm SRCS fill_constant_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_assign_value_op_bm SRCS assign_value_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_shape_op_bm SRCS shape_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_split_op_bm SRCS split_op.cc DEPS ${bm_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_matmul_op_bm SRCS matmul_op.cc DEPS ${bm_subgraph_bridge_deps})
set(bm_subgraph_bridges
subgraph_bridge_registry
......@@ -62,4 +65,7 @@ set(bm_subgraph_bridges
subgraph_bridge_cast_op_bm
subgraph_bridge_fill_constant_op_bm
subgraph_bridge_assign_value_op_bm
subgraph_bridge_shape_op_bm
subgraph_bridge_split_op_bm
subgraph_bridge_matmul_op_bm
CACHE INTERNAL "bm_subgraph_bridges")
......@@ -40,17 +40,31 @@ int AssignValueConverter(void* ctx, OpLite* op, KernelBase* kernel) {
i_output_shape_data[i] = static_cast<int>(output_dims[i]);
buffer_size *= i_output_shape_data[i];
}
auto fp32_values = op_info->GetAttr<std::vector<float>>("fp32_values");
std::vector<float> fp32_values;
std::vector<int> int32_values;
float* assign_data =
reinterpret_cast<float*>(malloc(buffer_size * sizeof(float)));
CHECK(assign_data != nullptr);
CHECK_EQ(buffer_size, fp32_values.size());
bm_data_type_t data_type = static_cast<bm_data_type_t>(DTYPE_FP32);
fp32_values = op_info->GetAttr<std::vector<float>>("fp32_values");
if (0 != fp32_values.size()) {
for (int i = 0; i < fp32_values.size(); i++) {
assign_data[i] = fp32_values[i];
}
} else {
int32_values = op_info->GetAttr<std::vector<int>>("int32_values");
data_type = static_cast<bm_data_type_t>(DTYPE_INT32);
CHECK_EQ(buffer_size, int32_values.size());
for (int i = 0; i < int32_values.size(); i++) {
assign_data[i] = int32_values[i];
}
}
bm_add_const_tensor(graph->GetCompilerHandle(),
static_cast<const char*>(output_var_name.c_str()),
const_cast<const int*>(i_output_shape_data.data()),
output_dims.size(),
static_cast<bm_data_type_t>(DTYPE_FP32),
data_type,
reinterpret_cast<const void*>(assign_data));
graph->AddNode(output_var_name);
return SUCCESS;
......
......@@ -91,7 +91,6 @@ int ConvConverter(void* ctx, OpLite* op, KernelBase* kernel) {
dilations[1],
static_cast<int>(has_bias));
graph->AddNode(output_var_name);
LOG(INFO) << output_var_name << input_dims << " " << output_dims;
return SUCCESS;
}
......
......@@ -108,3 +108,6 @@ int ConvTransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
REGISTER_SUBGRAPH_BRIDGE(conv2d_transpose,
kBM,
paddle::lite::subgraph::bm::ConvTransposeConverter);
REGISTER_SUBGRAPH_BRIDGE(depthwise_conv2d_transpose,
kBM,
paddle::lite::subgraph::bm::ConvTransposeConverter);
......@@ -65,7 +65,6 @@ int ElementwiseConverter(void* ctx, OpLite* op, KernelBase* kernel) {
auto output_dims = output->dims();
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
LOG(INFO) << x_dims << " " << output_dims;
std::vector<int32_t> i_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
......
......@@ -54,6 +54,7 @@ int InterpolateConverter(void* ctx, OpLite* op, KernelBase* kernel) {
} else {
type = 0;
}
is_int = false;
if (type == 2 && is_int) {
add_upsample_layer(graph->GetCompilerHandle(),
const_cast<const int*>(&i_x_shape_data[0]),
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include <bmcompiler_op_code.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int MatMulConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto unique_op_name = lite::subgraph::bm::UniqueName(op_type);
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const int64_t* x_shape_data = const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
auto y_var_name = op_info->Input("Y").front();
auto y = scope->FindVar(y_var_name)->GetMutable<lite::Tensor>();
auto y_dims = y->dims();
const int64_t* y_shape_data = const_cast<const int64_t*>(&y_dims.data()[0]);
std::vector<int32_t> i_y_shape_data(y_dims.size());
for (size_t i = 0; i < y_dims.size(); i++) {
i_y_shape_data[i] = static_cast<int>(y_shape_data[i]);
}
// output
auto output_var_name = op_info->Output("Out").front();
bool transpose_x = op_info->GetAttr<bool>("transpose_X");
bool transpose_y = op_info->GetAttr<bool>("transpose_Y");
float alpha = op_info->GetAttr<float>("alpha");
LOG(INFO) << x_dims << " " << y_dims << " " << alpha << " " << transpose_x
<< " " << transpose_y;
#if 0
add_const_binary_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
scale,
static_cast<const char*>(unique_op_scale_name.c_str()),
BINARY_MUL,
0);
add_const_binary_layer(graph->GetCompilerHandle(),
static_cast<const char*>(unique_op_scale_name.c_str()),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
bias,
static_cast<const char*>(output_var_name.c_str()),
BINARY_ADD,
0);
#endif
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(matmul,
kBM,
paddle::lite::subgraph::bm::MatMulConverter);
......@@ -45,14 +45,6 @@ int MultiClassNMSConverter(void* ctx, OpLite* op, KernelBase* kernel) {
i_score_shape_data[i] = static_cast<int32_t>(score_dims[i]);
}
auto out_var_name = op_info->Output("Out").front();
auto out = scope->FindVar(out_var_name)->GetMutable<lite::Tensor>();
auto out_dims = out->dims();
std::vector<int32_t> i_out_shape_data(out_dims.size());
for (size_t i = 0; i < out_dims.size(); i++) {
i_out_shape_data[i] = static_cast<int32_t>(out_dims[i]);
}
auto background_label = op_info->GetAttr<int>("background_label");
auto keep_top_k = op_info->GetAttr<int>("keep_top_k");
auto nms_top_k = op_info->GetAttr<int>("nms_top_k");
......@@ -64,6 +56,26 @@ int MultiClassNMSConverter(void* ctx, OpLite* op, KernelBase* kernel) {
normalized = op_info->GetAttr<bool>("normalized");
}
auto out_var_name = op_info->Output("Out").front();
auto out = scope->FindVar(out_var_name)->GetMutable<lite::Tensor>();
std::vector<int64_t> vec_out_dim(score_dims.size());
if (3 == score_dims.size()) {
vec_out_dim[0] = score_dims[0]; // batch_size
vec_out_dim[1] = keep_top_k;
vec_out_dim[2] = 6;
} else {
vec_out_dim[0] = keep_top_k;
vec_out_dim[1] = 6;
}
DDimLite out_dims(vec_out_dim);
out->Resize(out_dims);
out->mutable_data<float>();
std::vector<int32_t> i_out_shape_data(out_dims.size());
for (size_t i = 0; i < out_dims.size(); i++) {
i_out_shape_data[i] = static_cast<int32_t>(out_dims[i]);
}
user_cpu_param_t bm_param;
bm_param.op_type = USER_PADDLE_MULTICLASS_NMS;
bm_param.u.multiclass_nms_param.background_label = background_label;
......@@ -88,12 +100,9 @@ int MultiClassNMSConverter(void* ctx, OpLite* op, KernelBase* kernel) {
int32_t* out_shape[1];
int32_t out_dim[1];
const char* out_name[1];
i_out_shape_data[0] = keep_top_k;
i_out_shape_data[1] = 6;
out_shape[0] = &i_out_shape_data[0];
out_dim[0] = 2;
out_dim[0] = out_dims.size();
out_name[0] = static_cast<const char*>(out_var_name.c_str());
add_user_cpu_layer(graph->GetCompilerHandle(),
input_num,
in_shape,
......
......@@ -48,8 +48,13 @@ USE_SUBGRAPH_BRIDGE(slice, kBM);
USE_SUBGRAPH_BRIDGE(conv2d_transpose, kBM);
USE_SUBGRAPH_BRIDGE(reduce_sum, kBM);
USE_SUBGRAPH_BRIDGE(reduce_mean, kBM);
USE_SUBGRAPH_BRIDGE(reduce_max, kBM);
USE_SUBGRAPH_BRIDGE(squeeze, kBM);
USE_SUBGRAPH_BRIDGE(squeeze2, kBM);
USE_SUBGRAPH_BRIDGE(cast, kBM);
USE_SUBGRAPH_BRIDGE(fill_constant, kBM);
USE_SUBGRAPH_BRIDGE(assign_value, kBM);
USE_SUBGRAPH_BRIDGE(depthwise_conv2d_transpose, kBM);
USE_SUBGRAPH_BRIDGE(shape, kBM);
USE_SUBGRAPH_BRIDGE(split, kBM);
USE_SUBGRAPH_BRIDGE(matmul, kBM);
......@@ -49,6 +49,8 @@ int ReduceFullConverter(void* ctx, OpLite* op, KernelBase* kernel) {
op_code = REDUCE_SUM;
} else if (op_type == "reduce_mean") {
op_code = REDUCE_MEAN;
} else if (op_type == "reduce_max") {
op_code = REDUCE_MAX;
}
add_reduce_full_layer(graph->GetCompilerHandle(),
......@@ -75,3 +77,6 @@ REGISTER_SUBGRAPH_BRIDGE(reduce_sum,
REGISTER_SUBGRAPH_BRIDGE(reduce_mean,
kBM,
paddle::lite::subgraph::bm::ReduceFullConverter);
REGISTER_SUBGRAPH_BRIDGE(reduce_max,
kBM,
paddle::lite::subgraph::bm::ReduceFullConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_defs.h>
#include <bmcompiler_if.h>
#include <bmcompiler_if_lite.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int ShapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
// input
auto x_var_name = op_info->Input("Input").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
// output
auto output_var_name = op_info->Output("Out").front();
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int32_t>(x_dims[i]);
}
add_shape_ref_layer(graph->GetCompilerHandle(),
static_cast<const char*>(x_var_name.c_str()),
const_cast<const int*>(i_x_shape_data.data()),
x_dims.size(),
static_cast<const char*>(output_var_name.c_str()));
graph->AddNode(output_var_name);
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(shape,
kBM,
paddle::lite::subgraph::bm::ShapeConverter);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <bmcompiler_if.h>
#include <bmcompiler_op_code.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace bm {
int SplitConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto scope = op->scope();
auto op_info = op->op_info();
auto op_type = op_info->Type();
// input
auto x_var_name = op_info->Input("X").front();
auto x = scope->FindVar(x_var_name)->GetMutable<lite::Tensor>();
auto x_dims = x->dims();
const int64_t* x_shape_data = const_cast<const int64_t*>(&x_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
// output
auto output_names = op_info->Output("Out");
auto axis = op_info->GetAttr<int>("axis");
auto num = op_info->GetAttr<int>("num");
auto sections = op_info->GetAttr<std::vector<int>>("sections");
if (0 == num) {
num = sections.size();
}
if (0 == sections.size()) {
for (size_t i = 0; i < num; i++) {
sections.push_back(x_dims[axis] / num);
}
}
int** shape = new int*[num];
int* dim = new int[num];
const char** name = new const char*[num];
for (size_t i = 0; i < num; i++) {
auto out = scope->FindVar(output_names[i])->GetMutable<lite::Tensor>();
name[i] = static_cast<const char*>(output_names[i].c_str());
auto out_dims = out->dims();
shape[i] = new int[out_dims.size()];
for (size_t j = 0; j < out_dims.size(); j++) {
shape[i][j] = out_dims[j];
}
dim[i] = out_dims.size();
}
add_tf_split_layer(graph->GetCompilerHandle(),
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
static_cast<const char*>(x_var_name.c_str()),
num,
shape,
dim,
name,
x_dims.size(),
axis,
const_cast<const int*>(&sections[0]),
num);
for (size_t i = 0; i < num; i++) {
graph->AddNode(output_names[i]);
delete[] shape[i];
}
delete[] shape;
delete[] name;
delete[] dim;
return SUCCESS;
}
} // namespace bm
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(split,
kBM,
paddle::lite::subgraph::bm::SplitConverter);
......@@ -15,6 +15,7 @@
#include <bmcompiler_defs.h>
#include <bmcompiler_if.h>
#include "lite/kernels/bm/bridges/graph.h"
#include "lite/kernels/bm/bridges/utility.h"
#include "lite/kernels/npu/bridges/registry.h"
namespace paddle {
......@@ -39,11 +40,20 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
const int64_t* output_shape_data =
const_cast<const int64_t*>(&output_dims.data()[0]);
std::vector<int32_t> i_x_shape_data(x_dims.size());
std::vector<int32_t> i_output_shape_data(output_dims.size());
std::vector<int32_t> i_output_shape_data(x_dims.size());
for (size_t i = 0; i < x_dims.size(); i++) {
i_x_shape_data[i] = static_cast<int>(x_shape_data[i]);
}
for (size_t i = 0; i < output_dims.size(); i++) {
auto out_name = output_var_name;
if (x_dims.size() > output_dims.size()) {
for (size_t i = 0; i < (x_dims.size() - output_dims.size()); i++) {
i_output_shape_data[i] = 1;
}
out_name = lite::subgraph::bm::UniqueName(op_type);
}
for (size_t i = (x_dims.size() - output_dims.size()); i < output_dims.size();
i++) {
i_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
auto axis = op_info->GetAttr<std::vector<int>>("axis");
......@@ -53,9 +63,22 @@ int TransposeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
const_cast<const int*>(&i_x_shape_data[0]),
x_dims.size(),
DTYPE_FP32,
static_cast<const char*>(output_var_name.c_str()),
static_cast<const char*>(out_name.c_str()),
NULL,
const_cast<const int*>(&axis[0]));
if (x_dims.size() > output_dims.size()) {
std::vector<int32_t> i_real_output_shape_data(output_dims.size());
for (size_t i = 0; i < output_dims.size(); i++) {
i_real_output_shape_data[i] = static_cast<int>(output_shape_data[i]);
}
add_reshape_layer_v2(graph->GetCompilerHandle(),
static_cast<const char*>(out_name.c_str()),
const_cast<const int*>(&i_output_shape_data[0]),
i_output_shape_data.size(),
static_cast<const char*>(output_var_name.c_str()),
const_cast<const int*>(&i_real_output_shape_data[0]),
output_dims.size());
}
graph->AddNode(output_var_name);
return SUCCESS;
}
......
......@@ -88,18 +88,27 @@ int SubgraphEngine::BuildDeviceProgram() {
// output
origin_odims_.resize(output_names_.size());
origin_otensors_.resize(output_names_.size());
device_outputs_.resize(output_names_.size());
for (size_t i = 0; i < output_names_.size(); i++) {
origin_otensors_[i] = scope_->FindMutableTensor(net_info_->output_names[i]);
CHECK(origin_otensors_[i]);
origin_odims_[i] = origin_otensors_[i]->dims();
origin_otensors_[i]->mutable_data<float>();
device_outputs_.resize(net_info_->output_num);
int out_index = 0;
for (int i = 0; i < output_names_.size(); i++) {
outname_map_.insert(std::pair<std::string, int>(output_names_[i], i));
}
for (int i = 0; i < net_info_->output_num; i++) {
Tensor* t_cur = scope_->FindMutableTensor(net_info_->output_names[i]);
CHECK(t_cur != nullptr);
bm_device_mem_t* p_mem =
static_cast<bm_device_mem_t*>(malloc(sizeof(bm_device_mem_t)));
CHECK(p_mem != nullptr);
CHECK_EQ(bm_malloc_device_byte(
bm_hd_, p_mem, origin_otensors_[i]->memory_size()),
BM_SUCCESS);
if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) {
origin_otensors_[out_index] = t_cur;
origin_odims_[out_index] = origin_otensors_[out_index]->dims();
origin_otensors_[out_index]->mutable_data<float>();
out_index += 1;
}
CHECK_EQ(
bm_malloc_device_byte(bm_hd_, p_mem, net_info_->max_output_bytes[i]),
BM_SUCCESS);
bmrt_tensor_with_device(&device_outputs_[i],
*p_mem,
net_info_->output_dtypes[i],
......@@ -123,10 +132,14 @@ int SubgraphEngine::LaunchDeviceProgram() {
true,
false);
bm_thread_sync(bm_hd_);
int out_index = 0;
for (size_t i = 0; i < device_outputs_.size(); i++) {
bm_memcpy_d2s(bm_hd_,
const_cast<void*>(origin_otensors_[i]->raw_data()),
device_outputs_[i].device_mem);
if (outname_map_.find(net_info_->output_names[i]) != outname_map_.end()) {
bm_memcpy_d2s(bm_hd_,
const_cast<void*>(origin_otensors_[out_index]->raw_data()),
device_outputs_[i].device_mem);
out_index++;
}
}
return 0;
}
......
......@@ -51,6 +51,7 @@ class SubgraphEngine : public subgraph::Engine {
void *bmrt_hd_;
std::vector<bm_tensor_t> device_inputs_;
std::vector<bm_tensor_t> device_outputs_;
std::map<std::string, int> outname_map_;
const char **net_names_;
const bm_net_info_t *net_info_;
bm_handle_t bm_hd_;
......
......@@ -2,7 +2,9 @@ message(STATUS "compile with lite host kernels")
add_kernel(feed_compute_host Host basic SRCS feed_compute.cc DEPS ${lite_kernel_deps})
add_kernel(fetch_compute_host Host basic SRCS fetch_compute.cc DEPS ${lite_kernel_deps})
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps} reshape_op)
add_kernel(reshape_compute_host Host basic SRCS reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(multiclass_nms_compute_host Host basic SRCS multiclass_nms_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_host Host extra SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(crf_decoding_compute_host Host extra SRCS crf_decoding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(compare_compute_host Host extra SRCS compare_compute.cc DEPS ${lite_kernel_deps})
add_kernel(ctc_align_compute_host Host extra SRCS ctc_align_compute.cc DEPS ${lite_kernel_deps})
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/compare_compute.h"
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
#define COMPARE_FUNCTOR(name, op) \
template <typename T> \
struct _##name##Functor { \
using TYPE = T; \
inline bool operator()(const T &a, const T &b) const { return a op b; } \
};
COMPARE_FUNCTOR(Equal, ==);
COMPARE_FUNCTOR(NotEqual, !=);
COMPARE_FUNCTOR(LessThan, <);
COMPARE_FUNCTOR(LessEqual, <=);
COMPARE_FUNCTOR(GreaterThan, >);
COMPARE_FUNCTOR(GreaterEqual, >=);
template <>
struct _EqualFunctor<float> {
using TYPE = float;
inline bool operator()(const float &a, const float &b) const {
// It is safe to cast a and b to double.
return fabs(static_cast<double>(a - b)) < 1e-8;
}
};
template <>
struct _NotEqualFunctor<float> {
using TYPE = float;
inline bool operator()(const float &a, const float &b) const {
return !_EqualFunctor<float>()(a, b);
}
};
inline void get_mid_dims(const lite::DDim &x_dims,
const lite::DDim &y_dims,
const int axis,
int *pre,
int *n,
int *post) {
*pre = 1;
*n = 1;
*post = 1;
for (int i = 0; i < axis; ++i) {
(*pre) *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
(*n) *= y_dims[i];
}
for (int i = axis + y_dims.size(); i < x_dims.size(); ++i) {
(*post) *= x_dims[i];
}
}
template <PrecisionType PType, typename CompareFunctor>
void CompareCompute<PType, CompareFunctor>::Run() {
auto &param = this->template Param<operators::CompareParam>();
using DType = typename CompareFunctor::TYPE;
const size_t x_size = param.X->numel();
const size_t y_size = param.Y->numel();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<DType>();
const auto *y = param.Y->template data<DType>();
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
int outer_num, mid_num, inner_num;
get_mid_dims(x_dims, y_dims, axis, &outer_num, &mid_num, &inner_num);
for (int outer_id = 0; outer_id < outer_num; ++outer_id) {
for (int mid_id = 0; mid_id < mid_num; ++mid_id) {
auto y_data = y[mid_id];
for (int inner_id = 0; inner_id < inner_num; ++inner_id) {
int index = (outer_id * mid_num + mid_id) * inner_num + inner_id;
z[index] = CompareFunctor()(x[index], y_data);
}
}
}
}
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
using equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_EqualFunctor<float>>;
REGISTER_LITE_KERNEL(equal, kHost, kFloat, kAny, equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using equal_int32 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt32),
paddle::lite::kernels::host::_EqualFunctor<int32_t>>;
REGISTER_LITE_KERNEL(equal, kHost, kInt32, kAny, equal_int32, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using not_equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_NotEqualFunctor<float>>;
REGISTER_LITE_KERNEL(not_equal, kHost, kFloat, kAny, not_equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using less_than_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_LessThanFunctor<float>>;
REGISTER_LITE_KERNEL(less_than, kHost, kFloat, kAny, less_than_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using less_than_int32 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt32),
paddle::lite::kernels::host::_LessThanFunctor<int32_t>>;
REGISTER_LITE_KERNEL(less_than, kHost, kInt32, kAny, less_than_int32, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using less_than_int64 = paddle::lite::kernels::host::CompareCompute<
PRECISION(kInt64),
paddle::lite::kernels::host::_LessThanFunctor<int64_t>>;
REGISTER_LITE_KERNEL(less_than, kHost, kInt64, kAny, less_than_int64, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt64), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using less_equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_LessEqualFunctor<float>>;
REGISTER_LITE_KERNEL(less_equal, kHost, kFloat, kAny, less_equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using greater_than_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterThanFunctor<float>>;
REGISTER_LITE_KERNEL(greater_than, kHost, kFloat, kAny, greater_than_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
using greater_equal_float = paddle::lite::kernels::host::CompareCompute<
PRECISION(kFloat),
paddle::lite::kernels::host::_GreaterEqualFunctor<float>>;
REGISTER_LITE_KERNEL(
greater_equal, kHost, kFloat, kAny, greater_equal_float, def)
.BindInput("X",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindInput("Y",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kFloat), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kBool), DATALAYOUT(kAny), -1)})
.Finalize();
......@@ -13,43 +13,24 @@
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/operators/compare_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
template <template <typename T> class Functor>
class CompareCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
template <PrecisionType PType, typename CompareFunctor>
class CompareCompute
: public KernelLite<TARGET(kHost), PType, DATALAYOUT(kAny)> {
public:
void Run() override;
~CompareCompute() {}
virtual ~CompareCompute() = default;
};
template <template <typename T> class Functor>
class CompareCompute_int32
: public KernelLite<TARGET(kARM), PRECISION(kInt32)> {
public:
void Run() override;
~CompareCompute_int32() {}
};
template <template <typename T> class Functor>
class CompareCompute_int64
: public KernelLite<TARGET(kARM), PRECISION(kInt64)> {
public:
void Run() override;
~CompareCompute_int64() {}
};
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/host/reshape_compute.h"
#include <gtest/gtest.h>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace host {
TEST(reshape_host, init) {
ReshapeCompute reshape;
ASSERT_EQ(reshape.precision(), PRECISION(kAny));
ASSERT_EQ(reshape.target(), TARGET(kHost));
}
TEST(reshape_host, compute) {
ReshapeCompute reshape;
operators::ReshapeParam param;
Tensor input;
Tensor output;
input.Resize({1, 2, 4, 6});
auto* input_data = input.mutable_data<float>();
for (int i = 0; i < input.numel(); i++) {
input_data[i] = i;
}
Tensor shape_tensor;
shape_tensor.Resize({2});
auto* shape_tensor_data = shape_tensor.mutable_data<int>();
shape_tensor_data[0] = 6;
shape_tensor_data[1] = 8;
// set param and run
param.x = &input;
param.shape_tensor = &shape_tensor; // use shape_tensor
param.inplace = false;
param.output = &output;
reshape.SetParam(param);
reshape.Run();
// check output dims
CHECK_EQ(shape_tensor.numel(), output.numel());
for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], shape_tensor_data[i]);
}
// check output data
auto* output_data = output.mutable_data<float>();
CHECK_NE(output_data, input_data);
for (int i = 0; i < output.numel(); i++) {
EXPECT_NEAR(output_data[i], input_data[i], 1e-6);
}
// use shape, set param and run
param.shape_tensor = nullptr;
param.shape_vct = {-1, 0, 3, 2, 1};
reshape.SetParam(param);
reshape.Run();
// check output dims
CHECK_EQ(shape_tensor.numel(), output.numel());
for (int i = 0; i < output.dims().size(); i++) {
CHECK_EQ(output.dims()[i], shape_tensor_data[i]);
}
// check output data
output_data = output.mutable_data<float>();
CHECK_NE(output_data, input_data);
for (int i = 0; i < output.numel(); i++) {
EXPECT_NEAR(output_data[i], input_data[i], 1e-6);
}
// check output data if inplace = true;
param.inplace = true;
reshape.SetParam(param);
reshape.Run();
output_data = output.mutable_data<float>();
CHECK_EQ(output_data, input_data);
}
TEST(reshape, retrive_op) {
auto reshape =
KernelRegistry::Global()
.Create<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)>("reshape");
ASSERT_FALSE(reshape.empty());
ASSERT_TRUE(reshape.front());
}
TEST(reshape2, retrive_op) {
auto reshape2 =
KernelRegistry::Global()
.Create<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)>("reshape2");
ASSERT_FALSE(reshape2.empty());
ASSERT_TRUE(reshape2.front());
}
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(reshape, kHost, kAny, kAny, def);
USE_LITE_KERNEL(reshape2, kHost, kAny, kAny, def);
......@@ -12,13 +12,12 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/shape_compute.h"
#include "lite/backends/arm/math/funcs.h"
#include "lite/kernels/host/shape_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
void ShapeCompute::Run() {
auto& param = Param<operators::ShapeParam>();
......@@ -29,13 +28,17 @@ void ShapeCompute::Run() {
}
}
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
shape, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::ShapeCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
shape, kHost, kAny, kAny, paddle::lite::kernels::host::ShapeCompute, def)
.BindInput("Input",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny), -1)})
.BindOutput("Out",
{LiteType::GetTensorTy(
TARGET(kHost), PRECISION(kInt32), DATALAYOUT(kAny), -1)})
.Finalize();
......@@ -19,16 +19,17 @@
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
namespace host {
class ShapeCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
class ShapeCompute
: public KernelLite<TARGET(kHost), PRECISION(kAny), DATALAYOUT(kAny)> {
public:
void Run() override;
virtual ~ShapeCompute() = default;
};
} // namespace arm
} // namespace host
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -38,6 +38,8 @@ lite_cc_library(subgraph_bridge_shuffle_channel_op_npu SRCS shuffle_channel_op.c
lite_cc_library(subgraph_bridge_pad2d_op_npu SRCS pad2d_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_reduce_mean_op_npu SRCS reduce_mean_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_unsqueeze_op_npu SRCS unsqueeze_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_gather_op_npu SRCS gather_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_lookup_table_op_npu SRCS lookup_table_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_argmax_op_npu SRCS argmax_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_instance_norm_op_npu SRCS instance_norm_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_dropout_op_npu SRCS dropout_op.cc DEPS ${npu_subgraph_bridge_deps})
......@@ -47,6 +49,7 @@ lite_cc_library(subgraph_bridge_fill_constant_op_npu SRCS fill_constant_op.cc DE
lite_cc_library(subgraph_bridge_fill_constant_batch_size_like_op_npu SRCS fill_constant_batch_size_like_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_increment_op_npu SRCS increment_op.cc DEPS ${npu_subgraph_bridge_deps})
lite_cc_library(subgraph_bridge_compare_op_npu SRCS compare_op.cc DEPS ${npu_subgraph_bridge_deps})
#lite_cc_library(subgraph_bridge_shape_op_npu SRCS shape_op.cc DEPS ${npu_subgraph_bridge_deps})
set(npu_subgraph_bridges
......@@ -73,6 +76,8 @@ set(npu_subgraph_bridges
subgraph_bridge_pad2d_op_npu
subgraph_bridge_reduce_mean_op_npu
subgraph_bridge_unsqueeze_op_npu
subgraph_bridge_gather_op_npu
subgraph_bridge_lookup_table_op_npu
subgraph_bridge_argmax_op_npu
subgraph_bridge_instance_norm_op_npu
subgraph_bridge_dropout_op_npu
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
int GatherConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_name = op_info->Input("X").front();
auto x = scope->FindTensor(x_name);
auto index_name = op_info->Input("Index").front();
auto index = scope->FindTensor(index_name);
auto index_dims = index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1))
<< "index dims unmatch";
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Index node
std::shared_ptr<Node> index_node = nullptr;
if (graph->Has(index_name)) {
index_node = graph->Get(index_name);
} else {
index_node = graph->Add(index_name, *index);
}
// Gather node
auto gather_node = graph->Add<ge::op::Gather>(out_name);
auto gather_op = gather_node->data<ge::op::Gather>();
gather_op->set_input_params(*x_node->data());
gather_op->set_input_indices(*index_node->data());
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(gather,
kNPU,
paddle::lite::subgraph::npu::GatherConverter);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
int LookupTableConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto w_name = op_info->Input("W").front();
auto w = scope->FindTensor(w_name);
auto index_name = op_info->Input("Ids").front();
auto index = scope->FindTensor(index_name);
auto out_name = op_info->Output("Out").front();
auto out = scope->FindTensor(out_name);
auto out_shape = out->dims().Vectorize();
// W node
std::shared_ptr<Node> w_node = nullptr;
if (graph->Has(w_name)) {
w_node = graph->Get(w_name);
} else {
w_node = graph->Add(w_name, *w);
}
// Index node
std::shared_ptr<Node> index_node = nullptr;
if (graph->Has(index_name)) {
index_node = graph->Get(index_name);
} else {
index_node = graph->Add(index_name, *index);
}
// reshape ids
auto reshaped_index_node =
graph->Add<ge::op::Reshape>(index_name + "/reshape");
auto reshaped_index_op = reshaped_index_node->data<ge::op::Reshape>();
reshaped_index_op->set_input_tensor(*index_node->data());
reshaped_index_op->set_attr_shape(ge::AttrValue::LIST_INT({index->numel()}));
reshaped_index_op->set_attr_axis(0);
index_node = reshaped_index_node;
// Gather node
auto gather_node = graph->Add<ge::op::Gather>(out_name);
auto gather_op = gather_node->data<ge::op::Gather>();
gather_op->set_input_params(*w_node->data());
gather_op->set_input_indices(*index_node->data());
// reshape out
auto reshaped_gather_node = graph->Add<ge::op::Reshape>(out_name);
auto reshaped_gather_op = reshaped_gather_node->data<ge::op::Reshape>();
reshaped_gather_op->set_input_tensor(*gather_node->data());
reshaped_gather_op->set_attr_shape(
ge::AttrValue::LIST_INT(out_shape.begin(), out_shape.end()));
reshaped_gather_op->set_attr_axis(0);
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(lookup_table,
kNPU,
paddle::lite::subgraph::npu::LookupTableConverter);
......@@ -45,6 +45,8 @@ USE_SUBGRAPH_BRIDGE(fusion_elementwise_div_activation, kNPU);
USE_SUBGRAPH_BRIDGE(fill_constant, kNPU)
USE_SUBGRAPH_BRIDGE(fill_constant_batch_size_like, kNPU)
// USE_SUBGRAPH_BRIDGE(gather, kNPU);
// USE_SUBGRAPH_BRIDGE(lookup_table, kNPU);
USE_SUBGRAPH_BRIDGE(increment, kNPU);
USE_SUBGRAPH_BRIDGE(instance_norm, kNPU);
USE_SUBGRAPH_BRIDGE(fc, kNPU);
......@@ -59,6 +61,7 @@ USE_SUBGRAPH_BRIDGE(reduce_mean, kNPU);
USE_SUBGRAPH_BRIDGE(reshape, kNPU);
USE_SUBGRAPH_BRIDGE(reshape2, kNPU);
USE_SUBGRAPH_BRIDGE(scale, kNPU);
// USE_SUBGRAPH_BRIDGE(shape, kNPU);
USE_SUBGRAPH_BRIDGE(shuffle_channel, kNPU);
USE_SUBGRAPH_BRIDGE(softmax, kNPU);
USE_SUBGRAPH_BRIDGE(split, kNPU);
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/npu/bridges/graph.h"
#include "lite/kernels/npu/bridges/registry.h"
#include "lite/kernels/npu/bridges/utility.h"
namespace paddle {
namespace lite {
namespace subgraph {
namespace npu {
int ShapeConverter(void* ctx, OpLite* op, KernelBase* kernel) {
CHECK(ctx != nullptr);
CHECK(op != nullptr);
auto graph = static_cast<Graph*>(ctx);
auto op_info = op->op_info();
auto op_type = op_info->Type();
auto scope = op->scope();
VLOG(3) << "[NPU] Converting " + op_type + "...";
// Get input, output and op attributes
auto x_name = op_info->Input("Input").front();
auto x = scope->FindTensor(x_name);
auto out_name = op_info->Output("Out").front();
// X node
std::shared_ptr<Node> x_node = nullptr;
if (graph->Has(x_name)) {
x_node = graph->Get(x_name);
} else {
x_node = graph->Add(x_name, *x);
}
// Shape node
auto shape_node = graph->Add<ge::op::Shape>(out_name);
auto shape_op = shape_node->data<ge::op::Shape>();
shape_op->set_input_x(*x_node->data());
return REBUILD_WHEN_SHAPE_CHANGED;
}
} // namespace npu
} // namespace subgraph
} // namespace lite
} // namespace paddle
REGISTER_SUBGRAPH_BRIDGE(shape,
kNPU,
paddle::lite::subgraph::npu::ShapeConverter);
......@@ -62,6 +62,7 @@ class ReluCompute
CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{count};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -77,7 +78,7 @@ class ReluCompute
std::string kernel_func_name_{"relu"};
std::string build_options_{"-DCL_DTYPE_float -DRELU"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
class SigmoidCompute
......@@ -120,6 +121,7 @@ class SigmoidCompute
CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{count};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -135,7 +137,7 @@ class SigmoidCompute
std::string kernel_func_name_{"sigmoid"};
std::string build_options_{"-DCL_DTYPE_float -DSIGMOID"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -147,6 +147,7 @@ class ActivationComputeImageDefault
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -174,7 +175,7 @@ class ActivationComputeImageDefault
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
} // namespace kernels
......
......@@ -142,6 +142,7 @@ class BilinearInterpImageCompute
static_cast<cl::size_type>(default_work_size[1]),
static_cast<cl::size_type>(default_work_size[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -162,7 +163,7 @@ class BilinearInterpImageCompute
std::string kernel_func_name_{"bilinear_interp"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -120,6 +120,7 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
cl::NDRange{static_cast<cl::size_type>(default_work_size[0]),
static_cast<cl::size_type>(default_work_size[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -141,7 +142,7 @@ class BoxCoderComputeImage : public KernelLite<TARGET(kOpenCL),
param_t* boxcoder_param_{nullptr};
std::string kernel_func_name_{};
std::string build_options_{" -DCL_DTYPE_half"};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -123,6 +123,7 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total1);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -156,6 +157,7 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, total0);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -180,7 +182,7 @@ class ConcatCompute : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{};
std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -187,6 +187,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
CL_CHECK_FATAL(status);
status = kernel.setArg(++arg_idx, width_);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -230,6 +231,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
status = kernel.setArg(++arg_idx, width_);
CL_CHECK_FATAL(status);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -254,7 +256,7 @@ class ConcatComputeImage : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{};
std::string build_options_{" -DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -205,6 +205,7 @@ void ConvCompute::GemmlikeConv2d() {
CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{static_cast<size_t>(out_stride)};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
img2col_kernel,
cl::NullRange,
......@@ -300,6 +301,7 @@ void ConvCompute::GemmBatched(cl::Kernel& kernel,
status = kernel.setArg(++arg_idx, batch_size);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......
......@@ -57,7 +57,7 @@ class ConvCompute
std::vector<std::string> kernel_func_paths_{};
std::vector<std::string> build_options_{};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -38,6 +38,7 @@ void ConvImageCompute::PrepareForRun() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
filter_gpu_image_ = std::unique_ptr<Tensor>(new Tensor);
int bs = x_dims[0];
int c_in = x_dims[1];
int h_out = output_dims[2];
......@@ -113,7 +114,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d1x1opt;
......@@ -174,7 +175,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
#endif
......@@ -194,7 +195,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::DepthwiseConv2d;
......@@ -209,7 +210,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d3x3opt;
......@@ -241,7 +242,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d5x5;
......@@ -257,7 +258,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d5x5opt;
......@@ -290,7 +291,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
this->filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
this->filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d7x7;
......@@ -306,7 +307,7 @@ void ConvImageCompute::PrepareForRun() {
std::vector<half_t> filter_image_v(filter_image_dims[0] *
filter_image_dims[1] * 4); // 4 : RGBA
converter.NCHWToImage(filter_cpu, filter_image_v.data(), filter_dims);
this->filter_gpu_image_.mutable_data<half_t, cl::Image2D>(
this->filter_gpu_image_->mutable_data<half_t, cl::Image2D>(
filter_image_dims[0], filter_image_dims[1], filter_image_v.data());
impl_ = &ConvImageCompute::Conv2d7x7opt;
......@@ -349,6 +350,7 @@ void ConvImageCompute::PrepareForRun() {
const bool is_element_wise_bias =
has_bias && param.output->dims() == param.bias->dims();
if (has_bias) {
bias_gpu_image_ = std::unique_ptr<Tensor>(new Tensor);
build_options_single +=
is_element_wise_bias ? " -DBIASE_ELE" : " -DBIASE_CH";
......@@ -361,7 +363,7 @@ void ConvImageCompute::PrepareForRun() {
float* bias_cpu_data = param.bias->mutable_data<float>();
bias_converter.NCHWToImage(
bias_cpu_data, bias_image_v.data(), param.bias->dims());
this->bias_gpu_image_.mutable_data<half_t, cl::Image2D>(
this->bias_gpu_image_->mutable_data<half_t, cl::Image2D>(
bias_image_dims[0], bias_image_dims[1], bias_image_v.data());
// convert cpu buffer bias --> gpu image --- end ----
}
......@@ -434,7 +436,7 @@ void ConvImageCompute::Conv2d1x1opt(bool is_turn) {
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -498,7 +500,7 @@ void ConvImageCompute::Conv2d1x1opt(bool is_turn) {
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -542,6 +544,7 @@ void ConvImageCompute::Conv2d1x1opt(bool is_turn) {
status = kernel.setArg(++arg_idx, default_w_blk_);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -565,7 +568,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -647,7 +650,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -707,6 +710,7 @@ void ConvImageCompute::Conv2d3x3(bool is_turn) {
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -732,7 +736,7 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -781,7 +785,7 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -834,6 +838,7 @@ void ConvImageCompute::Conv2d3x3opt(bool is_turn) {
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -856,7 +861,7 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) {
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -914,7 +919,7 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -965,6 +970,7 @@ void ConvImageCompute::Conv2d5x5(bool is_turn) {
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -989,7 +995,7 @@ void ConvImageCompute::Conv2d5x5opt(bool is_turn) {
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -1039,7 +1045,7 @@ void ConvImageCompute::Conv2d5x5opt(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -1084,6 +1090,7 @@ void ConvImageCompute::Conv2d5x5opt(bool is_turn) {
// VLOG(4) << "out_image: " << out_image;
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -1106,7 +1113,7 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) {
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -1164,7 +1171,7 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -1215,6 +1222,7 @@ void ConvImageCompute::Conv2d7x7(bool is_turn) {
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -1239,7 +1247,7 @@ void ConvImageCompute::Conv2d7x7opt(bool is_turn) {
auto dilations = *param.dilations;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -1287,7 +1295,7 @@ void ConvImageCompute::Conv2d7x7opt(bool is_turn) {
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -1331,6 +1339,7 @@ void ConvImageCompute::Conv2d7x7opt(bool is_turn) {
status = kernel.setArg(++arg_idx, output_height);
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -1357,11 +1366,11 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) {
auto dilations = *param.dilations;
auto* input_img = param.x->data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_->data<half_t, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_img = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
......@@ -1389,7 +1398,7 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) {
has_bias && param.output->dims() == param.bias->dims();
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: ";
#endif
......@@ -1415,6 +1424,7 @@ void ConvImageCompute::DepthwiseConv2d3x3s1(bool is_turn) {
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -1444,11 +1454,11 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
int input_c_block = (x_dims[1] + 3) / 4;
auto* input_img = param.x->data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_img = filter_gpu_image_->data<half_t, cl::Image2D>();
const cl::Image2D* bias_img = nullptr;
if (param.bias) {
bias_img = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_img = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto image_shape = InitImageDimInfoWith(output_dims);
......@@ -1487,7 +1497,7 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
has_bias && param.output->dims() == param.bias->dims();
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
#ifndef LITE_SHUTDOWN_LOG
VLOG(4) << "set bias_image: ";
#endif
......@@ -1513,6 +1523,7 @@ void ConvImageCompute::DepthwiseConv2d3x3(bool is_turn) {
status = kernel.setArg(++arg_idx, static_cast<const int>(output_dims[2]));
CL_CHECK_FATAL(status);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -1536,7 +1547,7 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) {
auto paddings = *param.paddings;
auto strides = param.strides;
auto* input_image = param.x->data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_.data<half_t, cl::Image2D>();
auto* filter_image = filter_gpu_image_->data<half_t, cl::Image2D>();
auto filter_dims = param.filter->dims();
auto output_dims = param.output->dims();
......@@ -1595,7 +1606,7 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) {
const cl::Buffer* bias_buf = nullptr;
const cl::Image2D* bias_image = nullptr;
if (has_bias) {
bias_image = bias_gpu_image_.data<half_t, cl::Image2D>();
bias_image = bias_gpu_image_->data<half_t, cl::Image2D>();
}
auto kernel = kernel_;
......@@ -1650,6 +1661,7 @@ void ConvImageCompute::DepthwiseConv2d(bool is_turn) {
<< global_work_size_[1] << "," << global_work_size_[2] << "}";
#endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......
......@@ -58,9 +58,9 @@ class ConvImageCompute : public KernelLite<TARGET(kOpenCL),
std::vector<std::string> kernel_func_paths_{};
std::vector<std::string> build_options_{};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
Tensor filter_gpu_image_;
Tensor bias_gpu_image_;
std::shared_ptr<cl::Event> event_{nullptr};
std::unique_ptr<Tensor> filter_gpu_image_{nullptr};
std::unique_ptr<Tensor> bias_gpu_image_{nullptr};
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
int c_blk_ = 1;
......
......@@ -108,6 +108,7 @@ class DepthwiseConv2dCompute
status = kernel.setArg(++arg_idx, *bias_buf);
CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange(static_cast<size_t>(numel));
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -123,7 +124,7 @@ class DepthwiseConv2dCompute
std::string kernel_func_name_{"depthwise_conv2d"};
std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -89,6 +89,7 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
static_cast<cl::size_type>(default_work_size.data()[1]),
static_cast<cl::size_type>(default_work_size.data()[2])};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -104,7 +105,7 @@ class DropoutComputeImage2D : public KernelLite<TARGET(kOpenCL),
std::string kernel_func_name_{"dropout"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -63,6 +63,7 @@ void ElementwiseAddCompute::Run() {
CL_CHECK_FATAL(status);
auto global_work_size = cl::NDRange{channels_, batch_};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......
......@@ -48,7 +48,7 @@ class ElementwiseAddCompute
std::string kernel_func_name_{"elementwise_add"};
std::string build_options_{"-DCL_DTYPE_float"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -153,6 +153,7 @@ void ElementwiseAddImageCompute::Run() {
auto& context = ctx_->As<OpenCLContext>();
CHECK(context.cl_context() != nullptr);
event_ = std::shared_ptr<cl::Event>(new cl::Event);
status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......
......@@ -63,7 +63,7 @@ class ElementwiseAddImageCompute
cl::Kernel kernel_;
cl::NDRange global_work_size_ = cl::NDRange{
static_cast<size_t>(1), static_cast<size_t>(1), static_cast<size_t>(1)};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -150,7 +150,8 @@ void ElementwiseMulFloatImageCompute::Run() {
auto global_work_size = cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)};
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
event_ = std::shared_ptr<cl::Event>(new cl::Event);
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel
kernel,
cl::NullRange,
global_work_size,
......
......@@ -185,6 +185,7 @@ class ElementwiseMulImageCompute
auto global_work_size =
cl::NDRange{static_cast<cl::size_type>(x_img_width),
static_cast<cl::size_type>(x_img_height)};
event_ = std::shared_ptr<cl::Event>(new cl::Event);
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......@@ -204,7 +205,7 @@ class ElementwiseMulImageCompute
std::string kernel_func_name_{"elementwise_mul"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
......@@ -138,6 +138,7 @@ void ElementwiseSubImageCompute::Run() {
VLOG(4) << "global_work_size:[2D]:" << x_img_width << " " << x_img_height;
#endif
event_ = std::shared_ptr<cl::Event>(new cl::Event);
auto status = context.cl_context()->GetCommandQueue().enqueueNDRangeKernel(
kernel,
cl::NullRange,
......
......@@ -46,7 +46,7 @@ class ElementwiseSubImageCompute
std::string kernel_func_name_{"elementwise_sub"};
std::string build_options_{"-DCL_DTYPE_half"};
std::string time_stamp_{GetTimeStamp()};
std::shared_ptr<cl::Event> event_{new cl::Event};
std::shared_ptr<cl::Event> event_{nullptr};
};
} // namespace opencl
......
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
此差异已折叠。
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册