提交 46baf05f 编写于 作者: J juncaipeng

Merge branch 'gpu' of https://github.com/yongqiangma/Paddle-Lite into gpu

......@@ -174,6 +174,16 @@ if(NOT WITH_DSO)
endif(WIN32)
endif(NOT WITH_DSO)
get_filename_component(CUDA_LIB_PATH ${CUDA_curand_LIBRARY} DIRECTORY)
function(import_static_library alias path)
add_library(${alias} STATIC IMPORTED GLOBAL)
set_property(TARGET ${alias} PROPERTY IMPORTED_LOCATION ${path})
endfunction()
import_static_library(cudart_static ${CUDA_LIB_PATH}/libcudart_static.a)
import_static_library(cublas_static ${CUDA_LIB_PATH}/libcublas_static.a)
import_static_library(curand_static ${CUDA_LIB_PATH}/libcurand_static.a)
import_static_library(culibos_static ${CUDA_LIB_PATH}/libculibos.a)
# setting nvcc arch flags
select_nvcc_arch_flags(NVCC_FLAGS_EXTRA)
list(APPEND CUDA_NVCC_FLAGS ${NVCC_FLAGS_EXTRA})
......
......@@ -53,11 +53,10 @@ if(APPLE)
set(CUDNN_LIB_NAME "libcudnn.dylib" "libcudnn.so")
endif(APPLE)
find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME} # libcudnn_static.a
find_library(CUDNN_LIBRARY NAMES ${CUDNN_LIB_NAME}
PATHS ${CUDNN_CHECK_LIBRARY_DIRS} ${CUDNN_INCLUDE_DIR} ${__libpath_hist}
NO_DEFAULT_PATH
DOC "Path to cuDNN library.")
DOC "Path to cuDNN dynamic library.")
if(CUDNN_INCLUDE_DIR AND CUDNN_LIBRARY)
set(CUDNN_FOUND ON)
......@@ -69,6 +68,9 @@ if(CUDNN_FOUND)
file(READ ${CUDNN_INCLUDE_DIR}/cudnn.h CUDNN_VERSION_FILE_CONTENTS)
get_filename_component(CUDNN_LIB_PATH ${CUDNN_LIBRARY} DIRECTORY)
add_library(cudnn_static STATIC IMPORTED GLOBAL)
set_property(TARGET cudnn_static PROPERTY IMPORTED_LOCATION
"${CUDNN_LIB_PATH}/libcudnn_static.a")
string(REGEX MATCH "define CUDNN_VERSION +([0-9]+)"
CUDNN_VERSION "${CUDNN_VERSION_FILE_CONTENTS}")
......
......@@ -511,6 +511,7 @@ function(nv_test TARGET_NAME)
get_property(os_dependency_modules GLOBAL PROPERTY OS_DEPENDENCY_MODULES)
target_link_libraries(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest
gflags glog ${os_dependency_modules} ${CUDNN_LIBRARY} ${CUBLAS_LIBRARIES} )
add_dependencies(${TARGET_NAME} ${nv_test_DEPS} lite_gtest_main gtest gflags glog)
common_link(${TARGET_NAME})
add_test(${TARGET_NAME} ${TARGET_NAME})
......
......@@ -14,8 +14,12 @@ if (NOT LITE_ON_TINY_PUBLISH)
#full api dynamic library
add_library(paddle_full_api_shared SHARED "")
target_sources(paddle_full_api_shared PUBLIC ${__lite_cc_files} paddle_api.cc light_api.cc cxx_api.cc cxx_api_impl.cc light_api_impl.cc)
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto xxhash )
target_link_libraries(paddle_full_api_shared framework_proto xxhash)# ${cuda_kernels})
add_dependencies(paddle_full_api_shared op_list_h kernel_list_h framework_proto)
target_link_libraries(paddle_full_api_shared framework_proto)
if(LITE_WITH_X86)
add_dependencies(paddle_full_api_shared xxhash)
target_link_libraries(paddle_full_api_shared xxhash)
endif(LITE_WITH_X86)
if(LITE_WITH_CUDA)
target_link_libraries(paddle_full_api_shared ${math_cuda} "-Wl,--whole-archive" ${cuda_kernels} "-Wl,--no-whole-archive")
endif(LITE_WITH_CUDA)
......@@ -73,6 +77,8 @@ set(light_api_deps
scope target_wrapper_host model_parser program)
if(LITE_WITH_CUDA)
set(light_api_deps ${light_api_deps} target_wrapper_cuda)
set(cuda_static_deps cudart_static cublas_static curand_static
cudnn_static culibos_static)
endif()
lite_cc_library(light_api SRCS light_api.cc
DEPS scope target_wrapper_host model_parser
......@@ -200,7 +206,7 @@ if (NOT LITE_ON_TINY_PUBLISH)
# The final inference library for just MobileConfig.
bundle_static_library(paddle_api_full paddle_api_full_bundled bundle_full_api)
get_property(fluid_modules GLOBAL PROPERTY FLUID_MODULES)
cc_library(api_full_static SRCS DEPS paddle_api_full cxx_api paddle_api light_api ${cxx_api_deps} ${ops} ${host_kernels} ${cuda_kernels} program tensor memory naive_buffer types ${fluid_modules} protobuf)
cc_library(api_full_static SRCS DEPS paddle_api_full cxx_api paddle_api light_api ${cxx_api_deps} ${ops} ${host_kernels} ${cuda_kernels} program tensor memory naive_buffer types ${fluid_modules} protobuf ${cuda_static_deps})
endif()
bundle_static_library(paddle_api_light paddle_api_light_bundled bundle_light_api)
#-----------------------------------------------------------------------------------------------------
......
......@@ -21,6 +21,7 @@
USE_LITE_OP(mul);
USE_LITE_OP(matmul);
USE_LITE_OP(fc);
USE_LITE_OP(assign);
USE_LITE_OP(relu);
USE_LITE_OP(relu6);
USE_LITE_OP(scale);
......
......@@ -14,6 +14,7 @@
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <fstream>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
......@@ -22,6 +23,10 @@
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input_img_txt_path,
"",
"if set input_img_txt_path, read the img filename as input.");
namespace paddle {
namespace lite {
......@@ -36,8 +41,18 @@ void TestModel(const std::vector<Place>& valid_places) {
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 3, 224, 224})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
if (FLAGS_input_img_txt_path.empty()) {
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
} else {
std::fstream fs(FLAGS_input_img_txt_path, std::ios::in);
if (!fs.is_open()) {
LOG(FATAL) << "open input_img_txt error.";
}
for (int i = 0; i < item_size; i++) {
fs >> data[i];
}
}
for (int i = 0; i < FLAGS_warmup; ++i) {
......
......@@ -28,7 +28,7 @@ namespace lite {
TEST(model, test) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads);
lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)}});
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input, "", "input_data");
DEFINE_int32(batch, 1, "batch");
namespace paddle {
namespace lite {
namespace test_transformer {
std::vector<std::string> inputed_lines;
void LoadInputLines(const char* filename) {
static const int max_line_buf_size = 100 * 1024 * 1024;
char* line_buffer = (char*)calloc(max_line_buf_size, sizeof(char)); // NOLINT
FILE* input_file = fopen(filename, "r");
while (fgets(line_buffer, max_line_buf_size, input_file)) {
// trim newline at end
char* pos = NULL;
if ((pos = strchr(line_buffer, '\n')) != NULL) {
*pos = 0;
}
inputed_lines.push_back(line_buffer);
}
free(line_buffer);
line_buffer = NULL;
fclose(input_file);
}
void Split2(const std::string& main_str,
std::vector<std::string>& str_list, // NOLINT
const std::string& delimiter) {
size_t pre_pos = 0;
size_t position = 0;
std::string tmp_str;
str_list.clear();
if (main_str.empty()) {
return;
}
while ((position = main_str.find(delimiter, pre_pos)) != std::string::npos) {
tmp_str.assign(main_str, pre_pos, position - pre_pos);
str_list.push_back(tmp_str);
pre_pos = position + 1;
}
tmp_str.assign(main_str, pre_pos, main_str.length() - pre_pos);
if (!tmp_str.empty()) {
str_list.push_back(tmp_str);
}
}
} // NOLINT
void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT
int pad_idx,
int n_head,
Tensor* src_word,
Tensor* src_pos,
Tensor* src_attn_bias,
Tensor* trg_word,
Tensor* init_scores,
Tensor* init_idx,
Tensor* trg_bias,
int line_start,
int batch_size,
int bos_idx) {
int max_len = 0;
int max_line = input_lines.size();
std::vector<std::vector<std::string>> batch_lines;
for (int i = line_start; i < line_start + batch_size; ++i) {
int i_index = i % max_line;
std::string cur_line = input_lines[i_index];
std::vector<std::string> split_str;
test_transformer::Split2(cur_line, split_str, " ");
batch_lines.push_back(split_str);
max_len = max_len >= split_str.size() ? max_len : split_str.size();
}
src_word->Resize(std::vector<DDim::value_type>({batch_size, max_len, 1}));
src_pos->Resize(std::vector<DDim::value_type>({batch_size, max_len, 1}));
src_attn_bias->Resize(
std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len}));
trg_bias->Resize(
std::vector<DDim::value_type>({batch_size, n_head, 1, max_len}));
float* src_word_data = src_word->mutable_data<float>();
float* src_pos_data = src_pos->mutable_data<float>();
float* src_bias_data = src_attn_bias->mutable_data<float>();
float* trg_bias_data = trg_bias->mutable_data<float>();
for (int i = 0; i < batch_size; ++i) {
std::vector<std::string> cur_words = batch_lines[i];
int fill_len = cur_words.size();
int src_bias_start = i * n_head * max_len * max_len;
int trg_bias_start = i * n_head * max_len;
for (int j = 0; j < fill_len; ++j) {
src_word_data[i * max_len + j] = (atoi(cur_words[j].c_str()));
src_pos_data[i * max_len + j] = j;
src_bias_data[src_bias_start + j] = 0;
trg_bias_data[trg_bias_start + j] = 0;
}
for (int j = fill_len; j < max_len; ++j) {
src_word_data[i * max_len + j] = pad_idx;
src_pos_data[i * max_len + j] = 0;
src_bias_data[src_bias_start + j] = -1000000000;
trg_bias_data[trg_bias_start + j] = -1000000000;
}
for (int j = src_bias_start;
j < src_bias_start + n_head * max_len * max_len;
++j) {
int value_ind = j % max_len + src_bias_start;
src_bias_data[j] = src_bias_data[value_ind];
}
for (int j = trg_bias_start; j < trg_bias_start + n_head * max_len; ++j) {
int value_ind = j % max_len + trg_bias_start;
trg_bias_data[j] = trg_bias_data[value_ind];
}
}
trg_word->Resize(std::vector<DDim::value_type>({batch_size, 1, 1}));
auto* trg_word_data = trg_word->mutable_data<float>();
for (int i = 0; i < batch_size; ++i) {
trg_word_data[i] = bos_idx;
}
init_scores->Resize(std::vector<DDim::value_type>({batch_size, 1}));
init_idx->Resize(std::vector<DDim::value_type>({batch_size}));
float* score_data = init_scores->mutable_data<float>();
float* idx_data = init_idx->mutable_data<float>();
for (int i = 0; i < init_scores->numel(); ++i) {
score_data[i] = 0;
}
std::vector<std::vector<uint64_t>> lod_s;
lod_s.resize(2);
for (int i = 0; i < batch_size; ++i) {
lod_s[0].push_back(i);
lod_s[1].push_back(i);
idx_data[i] = i;
}
lod_s[0].push_back(batch_size);
lod_s[1].push_back(batch_size);
auto score_lod = init_scores->mutable_lod();
*score_lod = lod_s;
auto trg_word_lod = trg_word->mutable_lod();
*trg_word_lod = lod_s;
}
void TestModel(const std::vector<Place>& valid_places,
const Place& preferred_place,
bool use_npu = false) {
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
std::string test_data_path = FLAGS_input;
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
int n_head = 8;
int batch_size = FLAGS_batch;
int bos_idx = 0;
int eos_idx = 1;
LOG(INFO) << "reading";
test_transformer::LoadInputLines(test_data_path.c_str());
LOG(INFO) << "reading finished";
auto* trg_bias = predictor.GetInput(6);
auto* src_word = predictor.GetInput(0);
auto* src_pos = predictor.GetInput(1);
auto* src_bias = predictor.GetInput(2);
auto* trg_word = predictor.GetInput(3);
auto* init_score = predictor.GetInput(4);
auto* init_idx = predictor.GetInput(5);
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
auto start_i = GetCurrentUS();
PadBatchInput(test_transformer::inputed_lines,
eos_idx,
n_head,
src_word, // src_word
src_pos, // src_pos
src_bias, // src_bias
trg_word, // trg_word
init_score, // init_score
init_idx, // init_idx
trg_bias, // trg_bias
i * batch_size,
batch_size,
bos_idx);
LOG(INFO) << "src_word:" << src_word->dims();
auto start_ii = GetCurrentUS();
LOG(INFO) << i << "->ii:" << (start_ii - start_i) / 1000.0;
predictor.Run();
auto start_iii = GetCurrentUS();
LOG(INFO) << i << "->iii:" << (start_iii - start_ii) / 1000.0;
auto* outs = predictor.GetOutputs();
LOG(INFO) << "out:" << (*outs)[0].dims();
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
auto* outs = predictor.GetOutputs();
for (auto out : *outs) {
LOG(INFO) << "======"
<< "here";
LOG(INFO) << out;
}
LOG(INFO) << "======"
<< "hereggg";
}
TEST(OcrAttention, test_arm) {
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)}));
}
} // namespace lite
} // namespace paddle
......@@ -21,10 +21,10 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
void increment(const int* input,
void increment(const float* input,
const int n,
const float step,
int* out,
float* out,
Context<TARGET(kARM)>* ctx) {
for (int i = 0; i < n; i++) {
out[i] = input[i] + step;
......
......@@ -21,10 +21,10 @@ namespace paddle {
namespace lite {
namespace arm {
namespace math {
void increment(const int* input,
void increment(const float* input,
const int n,
const float step,
int* out,
float* out,
Context<TARGET(kARM)>* ctx);
} // namespace math
......
......@@ -15,6 +15,7 @@
#include "lite/backends/arm/math/norm.h"
#include <arm_neon.h>
#include <cmath>
#include "lite/backends/arm/math/funcs.h"
#include "lite/utils/cp_logging.h"
namespace paddle {
......@@ -43,7 +44,143 @@ void norm(const float* input,
}
}
}
LOG(INFO) << "norm math finished";
}
void matrix_norm_row(const float* x_data,
const float* scale_data,
const float* bias_data,
float* out_data,
float* mean_out,
float* var_out,
float epsilon,
int batch_size,
int feature_size) {
int cnt = feature_size >> 4;
int remain = feature_size & 0xf;
#pragma omp parallel for
for (int bi = 0; bi < batch_size; ++bi) {
int offset = bi * feature_size;
const float* x_ptr = x_data + offset;
float mean = 0.f;
float variance = 0.f;
// get mean and variance
float32x4_t mean_v = vdupq_n_f32(0);
float32x4_t var_v = vdupq_n_f32(0);
for (int oi = 0; oi < cnt; ++oi) {
float32x4_t odim1 = vld1q_f32(x_ptr);
float32x4_t odim2 = vld1q_f32(x_ptr + 4);
float32x4_t odim3 = vld1q_f32(x_ptr + 8);
float32x4_t odim4 = vld1q_f32(x_ptr + 12);
mean_v = vaddq_f32(mean_v, odim1);
mean_v = vaddq_f32(mean_v, odim2);
mean_v = vaddq_f32(mean_v, odim3);
mean_v = vaddq_f32(mean_v, odim4);
var_v = vmlaq_f32(var_v, odim1, odim1);
var_v = vmlaq_f32(var_v, odim2, odim2);
var_v = vmlaq_f32(var_v, odim3, odim3);
var_v = vmlaq_f32(var_v, odim4, odim4);
x_ptr += 16;
}
mean = vgetq_lane_f32(mean_v, 0) + vgetq_lane_f32(mean_v, 1) +
vgetq_lane_f32(mean_v, 2) + vgetq_lane_f32(mean_v, 3);
variance = vgetq_lane_f32(var_v, 0) + vgetq_lane_f32(var_v, 1) +
vgetq_lane_f32(var_v, 2) + vgetq_lane_f32(var_v, 3);
for (int i = 0; i < remain; ++i) {
mean += *x_ptr;
variance += (*x_ptr) * (*x_ptr);
++x_ptr;
}
mean /= feature_size;
variance = variance / feature_size - mean * mean;
mean_out[bi] = mean;
var_out[bi] = variance;
variance = sqrtf(variance + epsilon);
float rvar = 1 / variance;
// compute norm_out
float* out_ptr = out_data + offset;
x_ptr = x_data + offset;
auto* scale_ptr = scale_data;
auto* bias_ptr = bias_data;
float32x4_t vneg = vdupq_n_f32(-1);
float32x4_t scale1 = vdupq_n_f32(1);
float32x4_t scale2 = vdupq_n_f32(1);
float32x4_t scale3 = vdupq_n_f32(1);
float32x4_t scale4 = vdupq_n_f32(1);
float32x4_t bias1 = vdupq_n_f32(0);
float32x4_t bias2 = vdupq_n_f32(0);
float32x4_t bias3 = vdupq_n_f32(0);
float32x4_t bias4 = vdupq_n_f32(0);
for (int oi = 0; oi < cnt; ++oi) {
float32x4_t odim1 = vld1q_f32(x_ptr);
float32x4_t odim2 = vld1q_f32(x_ptr + 4);
float32x4_t odim3 = vld1q_f32(x_ptr + 8);
float32x4_t odim4 = vld1q_f32(x_ptr + 12);
odim1 = vmlaq_n_f32(odim1, vneg, mean);
odim2 = vmlaq_n_f32(odim2, vneg, mean);
odim3 = vmlaq_n_f32(odim3, vneg, mean);
odim4 = vmlaq_n_f32(odim4, vneg, mean);
if (scale_data) {
scale1 = vld1q_f32(scale_ptr);
scale2 = vld1q_f32(scale_ptr + 4);
scale3 = vld1q_f32(scale_ptr + 8);
scale4 = vld1q_f32(scale_ptr + 12);
scale_ptr += 16;
}
if (bias_data) {
bias1 = vld1q_f32(bias_ptr);
bias2 = vld1q_f32(bias_ptr + 4);
bias3 = vld1q_f32(bias_ptr + 8);
bias4 = vld1q_f32(bias_ptr + 12);
bias_ptr += 16;
}
float32x4_t os1 = vmulq_n_f32(scale1, rvar);
float32x4_t os2 = vmulq_n_f32(scale2, rvar);
float32x4_t os3 = vmulq_n_f32(scale3, rvar);
float32x4_t os4 = vmulq_n_f32(scale4, rvar);
odim1 = vmlaq_f32(bias1, odim1, os1);
odim2 = vmlaq_f32(bias2, odim2, os2);
odim3 = vmlaq_f32(bias3, odim3, os3);
odim4 = vmlaq_f32(bias4, odim4, os4);
vst1q_f32(out_ptr, odim1);
vst1q_f32(out_ptr + 4, odim2);
vst1q_f32(out_ptr + 8, odim3);
vst1q_f32(out_ptr + 12, odim4);
x_ptr += 16;
out_ptr += 16;
}
for (int i = 0; i < remain; ++i) {
auto out_value = (*x_ptr - mean) / variance;
if (scale_data) {
out_value = out_value * (*scale_ptr);
++scale_ptr;
}
if (bias_data) {
out_value = out_value + *bias_ptr;
++bias_ptr;
}
*out_ptr = out_value;
++out_ptr;
++x_ptr;
}
} // for bi
}
} // namespace math
......
......@@ -29,6 +29,15 @@ void norm(const float* input,
float* out,
Context<TARGET(kARM)>* ctx);
void matrix_norm_row(const float* x_data,
const float* scale_data,
const float* bias_data,
float* out_data,
float* mean_out,
float* var_out,
float epsilon,
int batch_size,
int feature_size);
} // namespace math
} // namespace arm
} // namespace lite
......
......@@ -2,13 +2,16 @@ if(NOT LITE_WITH_CUDA)
return()
endif()
nv_library(cuda_activation SRCS activation.cu)
nv_library(cuda_scale SRCS scale.cu)
nv_library(cuda_type_trans SRCS type_trans.cu)
nv_library(cuda_transpose SRCS transpose.cu )
set(cuda_static_deps cudnn_static cublas_static curand_static
culibos_static cudart_static)
nv_library(cuda_activation SRCS activation.cu DEPS ${cuda_static_deps})
nv_library(cuda_scale SRCS scale.cu DEPS ${cuda_static_deps})
nv_library(cuda_type_trans SRCS type_trans.cu DEPS ${cuda_static_deps})
nv_library(cuda_transpose SRCS transpose.cu DEPS ${cuda_static_deps})
nv_library(cudnn_conv SRCS cudnn_conv.cc DEPS cuda_activation cuda_scale
cuda_type_trans)
nv_library(cuda_elementwise SRCS elementwise.cu )
cuda_type_trans ${cuda_static_deps})
nv_library(cuda_elementwise SRCS elementwise.cu DEPS ${cuda_static_deps})
set (
math_cuda
......
......@@ -25,7 +25,7 @@ void FcFuser::BuildPattern() {
// create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b");
auto* b = VarNode("b")->assert_is_persistable_var();
auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add");
......
......@@ -15,7 +15,6 @@
#include "lite/core/mir/fusion/quant_dequant_fuse_pass.h"
#include <list>
#include <memory>
#include <unordered_set>
#include <vector>
#include "lite/api/paddle_place.h"
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
......@@ -26,63 +25,25 @@ namespace lite {
namespace mir {
void QuantDequantFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
// obtain useful values and save to quantized_node, remove quant_nodes and
// releated nodes
std::unordered_set<std::string> quant_types = {
// delete quant node
std::vector<std::string> quant_op_types = {
"fake_quantize_range_abs_max", "fake_quantize_moving_average_abs_max"};
std::vector<Node*> quant_nodes;
for (auto& cur_node : graph->mutable_nodes()) {
if (cur_node.IsStmt() && quant_types.count(cur_node.stmt()->op_type())) {
quant_nodes.push_back(&cur_node);
}
}
for (auto quant_node : quant_nodes) {
// find input nodes and output nodes
std::list<Node*> input_nodes = quant_node->inlinks;
std::list<Node*> output_nodes = quant_node->outlinks;
CHECK_EQ(input_nodes.size(), 2);
CHECK_EQ(output_nodes.size(), 2);
bool front_is_scale = input_nodes.front()->arg()->is_weight;
Node* input_scale_node =
front_is_scale ? input_nodes.front() : input_nodes.back();
Node* input_act_node =
front_is_scale ? input_nodes.back() : input_nodes.front();
front_is_scale = output_nodes.front()->arg()->is_weight;
Node* output_scale_node =
front_is_scale ? output_nodes.front() : output_nodes.back();
Node* output_act_node =
front_is_scale ? output_nodes.back() : output_nodes.front();
// relink nodes and save value to quantized_node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node_ptr : outlinks) {
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<int>("bit_length",
bit_length);
quantized_node_ptr->stmt()->mutable_op_info()->SetAttr<float>(
"input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node_ptr)
RemoveDirectedLink(output_act_node, quantized_node_ptr);
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph.get(), nodes2rm);
for (auto& op_type : quant_op_types) {
fusion::DeleteQuantOpFuser fuser(op_type);
fuser(graph.get());
}
// fuse quantized node and dequant node
std::unordered_set<std::string> quantized_op_types = {
std::vector<std::string> quantized_op_types = {
"conv2d", "mul", "depthwise_conv2d"};
for (auto& op_type : quantized_op_types) {
fusion::QuantDequantOpFuser fuser(op_type);
fusion::DequantOpFuser fuser(op_type);
fuser(graph.get());
}
// delete quant_dequant_node
for (auto op_type : {"pool2d", "elementwise_add"}) {
fusion::DeleteQuantDequantOpFuser fuser(op_type);
fuser(graph.get());
}
}
......
......@@ -14,6 +14,7 @@
#include "lite/core/mir/fusion/quant_dequant_op_fuser.h"
#include <memory>
#include <unordered_set>
#include <vector>
#include "lite/utils/string.h"
......@@ -22,7 +23,61 @@ namespace lite {
namespace mir {
namespace fusion {
void QuantDequantOpFuser::BuildPattern() {
void DeleteQuantOpFuser::BuildPattern() {
auto* input_scale_node = VarNode("input_scale_node")
->assert_is_op_input(quant_op_type_, "InScale");
auto* input_act_node =
VarNode("input_act_node")->assert_is_op_input(quant_op_type_, "X");
auto* quant_node =
OpNode("quant_node", quant_op_type_)->assert_is_op(quant_op_type_);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_op_type_, "OutScale");
auto* output_act_node =
VarNode("output_act_node")->assert_is_op_output(quant_op_type_, "Out");
quant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_node});
output_act_node->LinksFrom({quant_node});
VLOG(4) << "DeleteQuantOpFuser BuildPattern quant_op_type:" << quant_op_type_;
}
void DeleteQuantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_node = matched.at("quant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
// obtain values, save values and relink node
int bit_length = quant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto outlinks = output_act_node->outlinks;
for (auto* quantized_node : outlinks) {
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
IR_NODE_LINK_TO(input_act_node, quantized_node)
}
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {
input_scale_node, quant_node, output_scale_node, output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
}
cpp::OpDesc DeleteQuantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DequantOpFuser::BuildPattern() {
std::string weight_name = "";
if (op_type_ == "conv2d" || op_type_ == "depthwise_conv2d") {
weight_name = "Filter";
......@@ -55,10 +110,11 @@ void QuantDequantOpFuser::BuildPattern() {
quantized_op_out->LinksFrom({quantized_op});
dequant_op->LinksFrom({quantized_op_out});
dequant_op_out->LinksFrom({dequant_op});
VLOG(4) << "DeQuantOpFuser BuildPattern op_type:" << op_type_;
}
void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
void DequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto* quant_op_input = matched.at("quantized_op_input");
auto* quantized_op_weight = matched.at("quantized_op_weight");
auto* quantized_op = matched.at("quantized_op");
......@@ -127,7 +183,174 @@ void QuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
IR_NODE_LINK_TO(new_quantized_op_node, dequant_op_out);
}
cpp::OpDesc QuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc DequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
void DeleteQuantDequantOpFuser::BuildPattern() {
std::string quant_dequant_op_type =
"fake_quantize_dequantize_moving_average_abs_max";
if (quantized_op_type_ == "pool2d") {
auto* input_scale_node =
VarNode("input_scale_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_node = VarNode("input_act_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_node =
OpNode("quant_dequant_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_node =
VarNode("output_scale_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_node =
VarNode("output_act_node")
->assert_is_op_output(quant_dequant_op_type, "Out");
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quant_dequant_node->LinksFrom({input_scale_node, input_act_node});
output_scale_node->LinksFrom({quant_dequant_node});
output_act_node->LinksFrom({quant_dequant_node});
quantized_node->LinksFrom({output_act_node});
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node =
VarNode("input_scale_left_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_left_node =
VarNode("input_act_left_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_left_node =
OpNode("quant_dequant_left_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_left_node =
VarNode("output_scale_left_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_left_node =
VarNode("output_act_left_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "X");
quant_dequant_left_node->LinksFrom(
{input_scale_left_node, input_act_left_node});
output_scale_left_node->LinksFrom({quant_dequant_left_node});
output_act_left_node->LinksFrom({quant_dequant_left_node});
auto* input_scale_right_node =
VarNode("input_scale_right_node")
->assert_is_op_input(quant_dequant_op_type, "InScale");
auto* input_act_right_node =
VarNode("input_act_right_node")
->assert_is_op_input(quant_dequant_op_type, "X");
auto* quant_dequant_right_node =
OpNode("quant_dequant_right_node", quant_dequant_op_type)
->assert_is_op(quant_dequant_op_type);
auto* output_scale_right_node =
VarNode("output_scale_right_node")
->assert_is_op_output(quant_dequant_op_type, "OutScale");
auto* output_act_right_node =
VarNode("output_act_right_node")
->assert_is_op_output(quant_dequant_op_type, "Out")
->assert_is_op_input(quantized_op_type_, "Y");
quant_dequant_right_node->LinksFrom(
{input_scale_right_node, input_act_right_node});
output_scale_right_node->LinksFrom({quant_dequant_right_node});
output_act_right_node->LinksFrom({quant_dequant_right_node});
auto* quantized_node = OpNode("quantized_node", quantized_op_type_)
->assert_is_op(quantized_op_type_);
quantized_node->LinksFrom({output_act_left_node, output_act_right_node});
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
VLOG(4) << "DeleteQuantDequantOpFuser BuildPattern op_type:"
<< quantized_op_type_;
}
void DeleteQuantDequantOpFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
if (quantized_op_type_ == "pool2d") {
auto* input_scale_node = matched.at("input_scale_node");
auto* input_act_node = matched.at("input_act_node");
auto* quant_dequant_node = matched.at("quant_dequant_node");
auto* output_scale_node = matched.at("output_scale_node");
auto* output_act_node = matched.at("output_act_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_node->stmt()->op()->scope();
auto* scale_tensor = scope->FindVar(output_scale_node->arg()->name)
->GetMutable<lite::Tensor>();
float scale_value = scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("input_scale", scale_value);
op_desc->SetInput("X", {input_act_node->arg()->name});
IR_NODE_LINK_TO(input_act_node, quantized_node)
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_node,
quant_dequant_node,
output_scale_node,
output_act_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else if (quantized_op_type_ == "elementwise_add") {
auto* input_scale_left_node = matched.at("input_scale_left_node");
auto* input_act_left_node = matched.at("input_act_left_node");
auto* quant_dequant_left_node = matched.at("quant_dequant_left_node");
auto* output_scale_left_node = matched.at("output_scale_left_node");
auto* output_act_left_node = matched.at("output_act_left_node");
auto* input_scale_right_node = matched.at("input_scale_right_node");
auto* input_act_right_node = matched.at("input_act_right_node");
auto* quant_dequant_right_node = matched.at("quant_dequant_right_node");
auto* output_scale_right_node = matched.at("output_scale_right_node");
auto* output_act_right_node = matched.at("output_act_right_node");
auto* quantized_node = matched.at("quantized_node");
// obtain values, save values and relink node
int bit_length =
quant_dequant_left_node->stmt()->op_info()->GetAttr<int>("bit_length");
int range = ((1 << (bit_length - 1)) - 1);
auto* scope = quant_dequant_left_node->stmt()->op()->scope();
auto* left_scale_tensor =
scope->FindVar(output_scale_left_node->arg()->name)
->GetMutable<lite::Tensor>();
float left_scale_value = left_scale_tensor->data<float>()[0] / range;
auto* right_scale_tensor =
scope->FindVar(output_scale_right_node->arg()->name)
->GetMutable<lite::Tensor>();
float right_scale_value = right_scale_tensor->data<float>()[0] / range;
auto* op_desc = quantized_node->stmt()->mutable_op_info();
op_desc->SetAttr<int>("bit_length", bit_length);
op_desc->SetAttr<float>("x_input_scale", left_scale_value);
op_desc->SetAttr<float>("y_input_scale", right_scale_value);
op_desc->SetInput("X", {input_act_left_node->arg()->name});
op_desc->SetInput("Y", {input_act_right_node->arg()->name});
IR_NODE_LINK_TO(input_act_left_node, quantized_node)
IR_NODE_LINK_TO(input_act_right_node, quantized_node)
// delete nodes and edges
std::unordered_set<const Node*> nodes2rm = {input_scale_left_node,
quant_dequant_left_node,
output_scale_left_node,
output_act_left_node,
input_scale_right_node,
quant_dequant_right_node,
output_scale_right_node,
output_act_right_node};
GraphSafeRemoveNodes(graph, nodes2rm);
} else {
LOG(FATAL) << "No support quantized_op_type:" << quantized_op_type_;
}
}
cpp::OpDesc DeleteQuantDequantOpFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc;
return op_desc;
}
......
......@@ -34,11 +34,25 @@ namespace fusion {
* the quantized_op.
* In addition, the fuser delete fake_quant and fake_dequant op in the graph at
* the last.
*/
class QuantDequantOpFuser : public FuseBase {
*/
class DeleteQuantOpFuser : public FuseBase {
public:
explicit DeleteQuantOpFuser(const std::string& quant_op_type)
: quant_op_type_(quant_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quant_op_type_{};
};
class DequantOpFuser : public FuseBase {
public:
explicit QuantDequantOpFuser(const std::string& op_type)
: op_type_(op_type) {}
explicit DequantOpFuser(const std::string& op_type) : op_type_(op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
......@@ -49,6 +63,27 @@ class QuantDequantOpFuser : public FuseBase {
std::string op_type_{};
};
/* The pattern like "fake_quantize_dequantize_moving_average_abs_max +
* pooled/elementwise_add" can be deteted by this fuser. The fuser
* extract the input_scale form fake_quant_dequant_op and save into
* the quantized_op. Besides, the fuser delete fake_quant_dequant_op in
* the graph.
*/
class DeleteQuantDequantOpFuser : public FuseBase {
public:
explicit DeleteQuantDequantOpFuser(const std::string& quantized_op_type)
: quantized_op_type_(quantized_op_type) {}
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
private:
std::string quantized_op_type_{};
};
} // namespace fusion
} // namespace mir
} // namespace lite
......
......@@ -207,6 +207,8 @@ void SubgraphProgramPass::InferOnce(const std::unique_ptr<SSAGraph>& graph) {
if (!item->IsStmt()) continue;
auto& stmt = item->AsStmt();
auto& op = stmt.op();
std::string op_type = op->op_info()->Type();
if (op_type == "feed" || op_type == "fetch") continue;
op->CheckShape();
op->InferShape();
// TOOD(xxx): remove Launch() at last
......
......@@ -37,7 +37,7 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
VLOG(4) << "nodes.size():" << nodes.size();
for (auto& node : nodes) {
VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt()) continue;
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
VLOG(4) << "node->AsStmt().desc:" << node->AsStmt().desc
<< " inlinks.size():" << inlinks.size();
......
......@@ -33,7 +33,7 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
}
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
......
......@@ -36,7 +36,7 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK(!valid_places_.empty());
for (auto& node : nodes) {
if (!node->IsStmt()) continue;
if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks;
for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in);
......
......@@ -126,7 +126,6 @@ class Optimizer {
valid_places_.end(),
Place{TARGET(kNPU), PRECISION(kFloat)}) !=
valid_places_.end()) {
CheckInputDimsNotEmpty(exec_scope_);
auto pass = mir::PassManager::Global()
.LookUp<mir::subgraph::GenerateNPUProgramPass>(
"generate_npu_program_pass");
......@@ -150,19 +149,6 @@ class Optimizer {
return program;
}
// check the input dims in the scope, must not be empty
void CheckInputDimsNotEmpty(const lite::Scope* scope) {
CHECK(scope);
auto* feed_var = scope->FindVar("feed");
CHECK(feed_var) << "no feed variable in exec_scope: " << scope;
auto* feed_tensor_list = feed_var->GetMutable<std::vector<lite::Tensor>>();
CHECK_GE(feed_tensor_list->size(), 1);
for (size_t i = 0; i < feed_tensor_list->size(); ++i) {
CHECK(!feed_tensor_list->at(i).dims().empty())
<< "Input " << i << " dims can not be empty.";
}
}
void InitTargetTypeTransformPass() {
auto* pass =
mir::PassManager::Global().LookUp<mir::TypeTargetTransformPass>(
......
......@@ -46,6 +46,8 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li
add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layer_norm_compute_arm ARM basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gather_compute_arm ARM basic SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
......@@ -98,4 +100,5 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra)
lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm)
......@@ -276,6 +276,10 @@ void BeamSearchDecodeCompute::Run() {
param.end_id);
func.apply<float>();
// when decode finish, we clear ids and scores
param.ids->clear();
param.scores->clear();
}
} // namespace arm
......
......@@ -87,14 +87,13 @@ void CompareCompute<Functor>::Run() {
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<int>();
const auto *x = param.X->template data<float>();
const auto *y = param.Y->template data<float>();
auto axis = param.axis;
bool force_cpu = param.force_cpu;
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
// z[i] = x[i] < y[i];
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
......
......@@ -38,6 +38,31 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~FillConstantCompute() = default;
};
template <typename T>
class FillConstantBatchLikeCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantBatchLikeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.input->lod().size() && param.input_dim_idx == 0) {
auto odims = param.out->dims();
odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
param.out->Resize(odims);
}
auto data = param.out->template mutable_data<T>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
}
virtual ~FillConstantBatchLikeCompute() = default;
};
} // namespace arm
} // namespace kernels
} // namespace lite
......@@ -52,3 +77,13 @@ REGISTER_LITE_KERNEL(fill_constant,
def)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
REGISTER_LITE_KERNEL(
fill_constant_batch_size_like,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::FillConstantBatchLikeCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/gather_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void GatherCompute::PrepareForRun() {}
void GatherCompute::Run() {
auto& param = this->Param<operators::GatherParam>();
auto* p_output = param.Out->mutable_data<float>();
auto index_size = param.Index->dims()[0];
auto src_dims = param.X->dims();
const float* p_src = param.X->data<float>();
const float* p_index = param.Index->data<float>();
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size,
p_src + index_ * slice_size,
slice_size * sizeof(float));
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GatherParam;
void PrepareForRun() override;
void Run() override;
~GatherCompute() {}
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -28,8 +28,8 @@ void IncrementCompute::Run() {
int total_num = param.X->dims().production();
const auto* x_data = param.X->data<int>();
auto* o_data = param.Out->mutable_data<int>();
const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>();
lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx);
}
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void LayerNormCompute::PrepareForRun() {}
void LayerNormCompute::Run() {
auto& param = this->Param<operators::LayerNormParam>();
auto input_dims = param.X->dims();
const auto* x_data = param.X->data<float>();
const auto* scale = param.Scale ? param.Scale->data<float>() : nullptr;
const auto* bias = param.Bias ? param.Bias->data<float>() : nullptr;
auto* o_data = param.Y->mutable_data<float>();
auto* mean = param.Mean->mutable_data<float>();
auto* var = param.Variance->mutable_data<float>();
int axis = param.begin_norm_axis;
auto matrix_dim = param.X->dims().Flatten2D(axis);
int left = matrix_dim[0];
int right = matrix_dim[1];
lite::arm::math::matrix_norm_row(
x_data, scale, bias, o_data, mean, var, param.epsilon, left, right);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(layer_norm,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::LayerNormCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class LayerNormCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::LayerNormParam;
void PrepareForRun() override;
void Run() override;
~LayerNormCompute() {}
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void LayerNormComputeRef(const operators::LayerNormParam& param) {
auto* x = param.X;
auto* y = param.Y;
auto* scale_tensor = param.Scale;
auto* bias_tensor = param.Bias;
auto* mean_tensor = param.Mean;
auto* var_tensor = param.Variance;
int begin_norm_axis = param.begin_norm_axis;
float epsilon = param.epsilon;
auto* x_data = x->data<float>();
auto* scale_data =
(scale_tensor == nullptr ? nullptr : scale_tensor->data<float>());
auto* bias_data =
(bias_tensor == nullptr ? nullptr : bias_tensor->data<float>());
auto* out_data = y->mutable_data<float>();
auto* mean_data = mean_tensor->mutable_data<float>();
auto* var_data = var_tensor->mutable_data<float>();
auto matrix_dim = x->dims().Flatten2D(begin_norm_axis);
int batch_size = matrix_dim[0];
int feature_size = matrix_dim[1];
for (int i = 0; i < batch_size; ++i) {
int start = i * feature_size;
int end = start + feature_size;
float mean = 0;
float var = 0;
for (int j = start; j < end; ++j) {
mean += x_data[j];
var += x_data[j] * x_data[j];
}
mean /= feature_size;
var = var / feature_size - mean * mean;
mean_data[i] = mean;
var_data[i] = var;
var = sqrt(var + epsilon);
for (int j = start; j < end; ++j) {
out_data[j] = (x_data[j] - mean) / var;
if (scale_data) {
out_data[j] *= scale_data[j - start];
}
if (bias_data) {
out_data[j] += bias_data[j - start];
}
}
}
}
TEST(layer_norm_arm, init) {
LayerNormCompute layer_norm;
ASSERT_EQ(layer_norm.precision(), PRECISION(kFloat));
ASSERT_EQ(layer_norm.target(), TARGET(kARM));
}
TEST(layer_norm_arm, compute) {
LayerNormCompute layer_norm;
operators::LayerNormParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_mean;
lite::Tensor output_var;
lite::Tensor output_ref;
lite::Tensor output_mean_ref;
lite::Tensor output_var_ref;
lite::Tensor bias;
lite::Tensor scale;
lite::Tensor* bias_ptr;
lite::Tensor* scale_ptr;
for (auto n : {1, 3}) {
for (auto c : {1, 3, 5}) {
for (auto h : {3, 16, 20, 32}) {
for (auto w : {3, 16, 20, 32}) {
for (auto axis : {0, 1, 2}) {
for (auto has_bias : {true, false}) {
for (auto has_scale : {true, false}) {
auto dims = DDim(std::vector<int64_t>({n, c, h, w}));
auto out_size = dims.Flatten2D(axis)[0];
auto inner_size = dims.Flatten2D(axis)[1];
bias_ptr = nullptr;
scale_ptr = nullptr;
if (has_bias) {
bias.Resize(std::vector<int64_t>({inner_size, 1, 1, 1}));
float* bias_data = bias.mutable_data<float>();
for (int i = 0; i < inner_size; ++i) {
bias_data[i] = 0.01;
}
bias_ptr = &bias;
}
if (has_scale) {
scale.Resize(std::vector<int64_t>({inner_size, 1, 1, 1}));
float* scale_data = scale.mutable_data<float>();
for (int i = 0; i < inner_size; ++i) {
scale_data[i] = 0.2;
}
scale_ptr = &scale;
}
x.Resize(dims);
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_mean.Resize(std::vector<int64_t>({out_size, 1, 1, 1}));
output_mean_ref.Resize(
std::vector<int64_t>({out_size, 1, 1, 1}));
output_var.Resize(std::vector<int64_t>({out_size, 1, 1, 1}));
output_var_ref.Resize(
std::vector<int64_t>({out_size, 1, 1, 1}));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_mean_data = output_mean.mutable_data<float>();
auto* output_var_data = output_var.mutable_data<float>();
auto* output_data_ref = output_ref.mutable_data<float>();
auto* output_mean_data_ref =
output_mean_ref.mutable_data<float>();
auto* output_var_data_ref =
output_var_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i % 255 * 0.001;
}
param.X = &x;
param.Y = &output;
param.begin_norm_axis = axis;
param.Bias = bias_ptr;
param.Scale = scale_ptr;
param.Mean = &output_mean;
param.Variance = &output_var;
param.epsilon = 0.00001;
layer_norm.SetParam(param);
layer_norm.Run();
param.Y = &output_ref;
param.Mean = &output_mean_ref;
param.Variance = &output_var_ref;
LayerNormComputeRef(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_data_ref[i], 1e-4);
}
for (int i = 0; i < output_mean.dims().production(); ++i) {
EXPECT_NEAR(
output_mean_data[i], output_mean_data_ref[i], 1e-5);
EXPECT_NEAR(output_var_data[i], output_var_data_ref[i], 1e-5);
}
}
}
}
}
}
}
}
}
TEST(layer_norm, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(layer_norm, kARM, kFloat, kNCHW, def);
......@@ -38,13 +38,14 @@ void LookupTableCompute::Run() {
auto table_dim = w->dims();
int64_t ids_numel = ids->numel();
auto ids_data = ids->data<float>();
int ids_int = ids_data[0];
int64_t row_number = table_dim[0];
int64_t row_width = table_dim[1];
auto table_data = w->data<float>();
auto dout = out->mutable_data<float>();
for (int64_t i = 0; i < ids_numel; ++i) {
int ids_int = ids_data[i];
if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) {
memset(dout + i * row_width, 0, row_width * sizeof(float));
} else {
......
......@@ -28,14 +28,13 @@ void ReadFromArrayCompute::Run() {
int in_num = param.X->size();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<int>()[0];
int id = param.I->data<float>()[0];
CHECK_LE(id, in_num) << "id is not valid";
int input_size = (*param.X)[id].numel();
param.Out->Resize((*param.X)[id].dims());
auto* o_data = param.Out->mutable_data<float>();
const auto* x_data = (*param.X)[id].data<float>();
memcpy(o_data, x_data, sizeof(float) * input_size);
param.Out->CopyDataFrom((*param.X)[id]);
auto out_lod = param.Out->mutable_lod();
*out_lod = (*param.X)[id].lod();
}
......
......@@ -43,5 +43,6 @@ REGISTER_LITE_KERNEL(
top_k, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::TopkCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Indices", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Indices",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize();
......@@ -46,7 +46,7 @@ void WhileCompute::Run() {
REGISTER_LITE_KERNEL(
while, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::WhileCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindInput("Condition",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
......
......@@ -28,7 +28,7 @@ void WriteToArrayCompute::Run() {
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
const auto* x_data = param.X->data<float>();
int id = param.I->data<int>()[0];
int id = param.I->data<float>()[0];
int id_test = param.I->data<int64_t>()[0];
if (id >= param.Out->size()) {
for (int i = param.Out->size(); i < id + 1; i++) {
......@@ -57,5 +57,5 @@ REGISTER_LITE_KERNEL(write_to_array,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
.Finalize();
......@@ -5,6 +5,7 @@ lite_cc_library(op_params SRCS op_params.cc DEPS tensor any)
add_operator(conv_op basic SRCS conv_op.cc DEPS ${op_DEPS})
add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS})
add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS})
add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS})
add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS})
......@@ -25,7 +26,6 @@ add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DE
add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS})
add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS})
#add_operator(sgd_op basic SRCS sgd_op.cc DEPS ${op_DEPS})
add_operator(uniform_random_op basic SRCS uniform_random_op.cc DEPS ${op_DEPS})
add_operator(power_op basic SRCS power_op.cc DEPS ${op_DEPS})
......@@ -61,10 +61,10 @@ add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_
add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS})
add_operator(unsqueeze_op_lite basic SRCS unsqueeze_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(gather_op basic SRCS gather_op.cc DEPS ${op_DEPS})
add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS})
add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS})
add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(affine_channel_op basic SRCS affine_channel_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op basic SRCS anchor_generator_op.cc DEPS ${op_DEPS})
add_operator(generate_proposals_op basic SRCS generate_proposals_op.cc DEPS ${op_DEPS})
......@@ -75,6 +75,7 @@ add_operator(fake_quantize_range_abs_max_op basic SRCS fake_quantize_range_abs_m
add_operator(sequence_expand_as_op_lite basic SRCS sequence_expand_as_op.cc DEPS ${op_DEPS})
add_operator(range_op basic SRCS range_op.cc DEPS ${op_DEPS})
add_operator(assign_value_op basic SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op basic SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......@@ -100,6 +101,7 @@ add_operator(slice_op_lite basic SRCS slice_op.cc DEPS ${op_DEPS})
add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS})
add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
......
......@@ -34,7 +34,7 @@ bool ConvOpLite::CheckShape() const {
CHECK_EQ_OR_FALSE(in_dims.size(), filter_dims.size());
CHECK_OR_FALSE(in_dims.size() - param_.strides.size() == 2U);
CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size());
// CHECK_EQ_OR_FALSE(param_.paddings.size(), param_.strides.size());
// CHECK_EQ_OR_FALSE(in_dims[1], filter_dims[1] * param_.groups);
// CHECK_EQ_OR_FALSE(filter_dims[0] % param_.groups, 0);
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fake_quantize_dequantize_moving_avg_max_abs.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(
fake_quantize_dequantize_moving_average_abs_max,
paddle::lite::operators::FakeQuantizeDequantizeMovingAvgMaxAbsOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class FakeQuantizeDequantizeMovingAvgMaxAbsOpLite : public OpLite {
public:
FakeQuantizeDequantizeMovingAvgMaxAbsOpLite() {}
explicit FakeQuantizeDequantizeMovingAvgMaxAbsOpLite(const std::string &type)
: OpLite(type) {}
bool CheckShape() const override { return true; }
bool InferShape() const override { return true; }
bool AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) override {
auto x = op_desc.Input("X").front();
auto in_scale = op_desc.Input("InScale").front();
auto out = op_desc.Output("Out").front();
auto out_scale = op_desc.Output("OutScale").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.in_scale = scope->FindVar(in_scale)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.out_scale = scope->FindVar(out_scale)->GetMutable<lite::Tensor>();
param_.bit_length = op_desc.GetAttr<int>("bit_length");
return true;
}
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fake_quantize_dequantize_moving_avg_max_abs";
}
private:
mutable FakeQuantizeMovingAvgMaxAbsParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -52,8 +52,67 @@ class FillConstantOp : public OpLite {
mutable operators::FillConstantParam param_;
};
class FillConstantBatchLikeOp : public OpLite {
public:
explicit FillConstantBatchLikeOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.input);
CHECK_GT_OR_FALSE(param_.shape.size(), 0);
CHECK_GE_OR_FALSE(param_.input_dim_idx, 0);
CHECK_GE_OR_FALSE(param_.output_dim_idx, 0);
return true;
}
bool InferShape() const override {
auto output_dim = param_.shape;
output_dim[param_.output_dim_idx] =
param_.input->dims()[param_.input_dim_idx];
param_.out->Resize(output_dim);
return true;
}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto Out_name = opdesc.Output("Out").front();
auto In_name = opdesc.Input("Input").front();
param_.out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.input = GetMutableVar<lite::Tensor>(scope, In_name);
param_.dtype = opdesc.GetAttr<int>("dtype");
auto shape = opdesc.GetAttr<std::vector<int>>("shape");
std::vector<int64_t> outshape;
for (auto i : shape) {
outshape.push_back(i);
}
param_.shape = outshape;
if (opdesc.HasAttr("value")) {
param_.value = opdesc.GetAttr<float>("value");
}
if (opdesc.HasAttr("input_dim_idx")) {
param_.input_dim_idx = opdesc.GetAttr<int>("input_dim_idx");
}
if (opdesc.HasAttr("output_dim_idx")) {
param_.output_dim_idx = opdesc.GetAttr<int>("output_dim_idx");
}
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fill_constant_batch_size_like";
}
private:
mutable operators::FillConstantBatchLikeParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp);
REGISTER_LITE_OP(fill_constant_batch_size_like,
paddle::lite::operators::FillConstantBatchLikeOp);
......@@ -11,59 +11,48 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/fill_constant_batch_size_like_op.h"
#include "lite/operators/gather_op.h"
#include <algorithm>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle {
namespace lite {
namespace operators {
bool FillConstantBatchSizeLikeOp::CheckShape() const {
CHECK_OR_FALSE(param_.Input);
bool GatherOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Index);
CHECK_OR_FALSE(param_.Out);
return true;
}
bool FillConstantBatchSizeLikeOp::InferShape() const {
auto shape = param_.shape;
std::vector<int64_t> shape_int64(shape.size(), 0);
std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) {
return static_cast<int64_t>(a);
});
lite::DDim output_dim(shape_int64);
int input_dim_idx = param_.input_dim_idx;
int output_dim_idx = param_.output_dim_idx;
output_dim[output_dim_idx] = param_.Input->dims()[input_dim_idx];
param_.Out->Resize(output_dim);
bool GatherOp::InferShape() const {
auto index_dims = param_.Index->dims();
CHECK(index_dims.size() == 1 ||
(index_dims.size() == 2 && index_dims[1] == 1))
<< "index dims unmatch";
int batch_size = index_dims[0];
auto out_dims = param_.X->dims();
out_dims[0] = batch_size;
param_.Out->Resize(out_dims);
return true;
}
bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc &op_desc,
lite::Scope *scope) {
auto Input = op_desc.Input("Input").front();
auto Out = op_desc.Output("Out").front();
param_.Input = scope->FindVar(Input)->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>();
param_.shape = op_desc.GetAttr<std::vector<int>>("shape");
param_.input_dim_idx = op_desc.GetAttr<int>("input_dim_idx");
param_.output_dim_idx = op_desc.GetAttr<int>("output_dim_idx");
param_.dtype = op_desc.GetAttr<int>("dtype");
param_.value = op_desc.GetAttr<float>("value");
CHECK(param_.Input);
CHECK(param_.Out);
bool GatherOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.Index =
scope->FindVar(opdesc.Input("Index").front())->GetMutable<lite::Tensor>();
CHECK(param_.X) << "X is null";
CHECK(param_.Out) << "out is null";
CHECK(param_.Index) << "index is null";
return true;
}
} /* namespace operators */
} /* namespace lite */
} /* namespace paddle */
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(fill_constant_batch_size_like,
paddle::lite::operators::FillConstantBatchSizeLikeOp);
REGISTER_LITE_OP(gather, paddle::lite::operators::GatherOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class GatherOp : public OpLite {
public:
GatherOp() {}
explicit GatherOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "gather"; }
private:
mutable GatherParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/layer_norm_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool LayerNormOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Mean);
CHECK_OR_FALSE(param_.Variance);
return true;
}
bool LayerNormOp::InferShape() const {
auto out_dims = param_.X->dims();
param_.Y->Resize(out_dims);
auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[1];
param_.Mean->Resize(std::vector<int64_t>({inner_size}));
param_.Variance->Resize(std::vector<int64_t>({inner_size}));
auto out_lod = param_.Y->mutable_lod();
*out_lod = param_.X->lod();
return true;
}
bool LayerNormOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Y =
scope->FindVar(opdesc.Output("Y").front())->GetMutable<lite::Tensor>();
param_.Mean =
scope->FindVar(opdesc.Output("Mean").front())->GetMutable<lite::Tensor>();
param_.Variance = scope->FindVar(opdesc.Output("Variance").front())
->GetMutable<lite::Tensor>();
CHECK(param_.X);
CHECK(param_.Y);
CHECK(param_.Mean);
CHECK(param_.Variance);
if (opdesc.HasInput("Scale")) {
param_.Scale = scope->FindVar(opdesc.Input("Scale").front())
->GetMutable<lite::Tensor>();
}
if (opdesc.HasInput("Bias")) {
param_.Bias = scope->FindVar(opdesc.Input("Bias").front())
->GetMutable<lite::Tensor>();
}
param_.begin_norm_axis = opdesc.GetAttr<int>("begin_norm_axis");
param_.epsilon = opdesc.GetAttr<float>("epsilon");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(layer_norm, paddle::lite::operators::LayerNormOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class LayerNormOp : public OpLite {
public:
LayerNormOp() {}
explicit LayerNormOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "layer_norm"; }
private:
mutable LayerNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -23,6 +23,7 @@ bool MulOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.y);
CHECK_OR_FALSE(param_.output);
// bias is optional.
const auto x_dims = param_.x->dims();
......@@ -54,17 +55,15 @@ bool MulOpLite::InferShape() const {
const auto y_dims = param_.y->dims();
// Set output dims
std::vector<int64_t> out_dims(
param_.x_num_col_dims + y_dims.size() - param_.y_num_col_dims, 0);
std::vector<int64_t> out_dims;
for (int i = 0; i < param_.x_num_col_dims; ++i) {
out_dims[i] = x_dims[i];
out_dims.push_back(x_dims[i]);
}
for (auto i = static_cast<size_t>(param_.y_num_col_dims); i < y_dims.size();
++i) {
out_dims[i] = y_dims[i];
out_dims.push_back(y_dims[i]);
}
param_.output->Resize(lite::DDim(out_dims));
auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod();
......
......@@ -294,6 +294,8 @@ struct PoolParam {
bool ceil_mode{false};
bool use_quantizer{false};
std::string data_format{"AnyLayout"};
// for int8
WITH_INT8_CONFIG
};
// For Dropout op
......@@ -332,7 +334,10 @@ struct ElementwiseParam {
const lite::Tensor* Y{};
lite::Tensor* Out{};
int axis{-1}; // for broadcasting.
// for int8
WITH_INT8_CONFIG
float x_input_scale{1.0};
float y_input_scale{1.0};
};
struct ElementwiseGradParam {
......@@ -373,6 +378,17 @@ struct FillConstantParam {
bool force_cpu{false};
lite::Tensor* Out{};
};
struct FillConstantBatchLikeParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{};
float value{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu{false};
lite::Tensor* out{};
const lite::Tensor* input{};
int input_dim_idx{0};
int output_dim_idx{0};
};
struct FillConstantBatchSizeLikeParam {
lite::Tensor* Input;
......@@ -619,6 +635,16 @@ struct NormParam {
int axis{1};
float epsilon{1e-10};
};
struct LayerNormParam {
const lite::Tensor* X{};
const lite::Tensor* Scale{};
const lite::Tensor* Bias{};
lite::Tensor* Y{};
lite::Tensor* Mean{};
lite::Tensor* Variance{};
int begin_norm_axis{1};
float epsilon{1e-5};
};
struct LogicalParam {
const lite::Tensor* X{};
......@@ -816,6 +842,12 @@ struct MatMulParam {
float alpha{1.0f};
};
struct GatherParam {
const lite::Tensor* X{};
const lite::Tensor* Index{};
lite::Tensor* Out{};
};
/// ----------------------- assign operators -----------------------
struct AssignParam {
const lite::Tensor* X{};
......
......@@ -310,6 +310,43 @@ class RegisterLiteKernelParser(SyntaxParser):
break
class RegisterLiteOpParser(SyntaxParser):
KEYWORD = 'REGISTER_LITE_OP'
def __init__(self, str):
super(RegisterLiteOpParser, self).__init__(str)
self.ops = []
def parse(self):
while self.cur_pos < len(self.str):
start = self.str.find(self.KEYWORD, self.cur_pos)
if start != -1:
#print 'str ', start, self.str[start-2: start]
if start != 0 and '/' in self.str[start-2: start]:
'''
skip commented code
'''
self.cur_pos = start + 1
continue
self.cur_pos = start
self.ops.append(self.__parse_register())
else:
break
return self.ops
def __parse_register(self):
self.eat_word()
assert self.token == self.KEYWORD
self.eat_spaces()
self.eat_left_parentheses()
self.eat_spaces()
self.eat_word()
return self.token
if __name__ == '__main__':
with open('/home/chunwei/project2/Paddle-Lite/lite/kernels/arm/activation_compute.cc') as f:
c = f.read()
......
......@@ -15,6 +15,7 @@
import sys
import logging
from ast import RegisterLiteOpParser
ops_list_path = sys.argv[1]
dest_path = sys.argv[2]
......@@ -25,24 +26,19 @@ out_lines = [
'',
]
lines = set()
with open(ops_list_path) as f:
for line in f:
lines.add(line.strip())
paths = set()
for line in open(ops_list_path):
paths.add(line.strip())
for line in lines:
path = line.strip()
with open(path) as g:
for line in g:
key = 'REGISTER_LITE_OP'
if line.startswith(key):
end = line.find(',')
op = line[len(key) + 1:end]
if not op: continue
if "_grad" in op: continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
for path in paths:
str_info = open(path.strip()).read()
op_parser = RegisterLiteOpParser(str_info)
ops = op_parser.parse()
for op in ops:
if "_grad" in op:
continue
out = "USE_LITE_OP(%s);" % op
out_lines.append(out)
with open(dest_path, 'w') as f:
logging.info("write op list to %s" % dest_path)
......
......@@ -111,10 +111,14 @@ bool PaddleMobilePredictor<Device, T>::Run(
if (input.dtype == UINT8) {
framework::Tensor input_tensor(static_cast<uint8_t *>(input.data.data()),
ddim);
paddle_mobile_->Predict(input_tensor);
if (paddle_mobile_->Predict(input_tensor) != PMStatus::PMSuccess) {
return false;
}
} else {
framework::Tensor input_tensor(static_cast<T *>(input.data.data()), ddim);
paddle_mobile_->Predict(input_tensor);
if (paddle_mobile_->Predict(input_tensor) != PMStatus::PMSuccess) {
return false;
}
}
}
......@@ -153,6 +157,11 @@ bool PaddleMobilePredictor<Device, T>::Run(
return true;
}
template <typename Device, typename T>
std::string PaddleMobilePredictor<Device, T>::GetExceptionMsg() {
return paddle_mobile_->GetExceptionMsg();
}
#ifdef PADDLE_MOBILE_FPGA
void ConvertPaddleTensors(const PaddleTensor &src, framework::Tensor *des) {
des->Resize(framework::make_ddim(src.shape));
......
......@@ -32,6 +32,7 @@ class PaddleMobilePredictor : public PaddlePredictor {
bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) override;
std::string GetExceptionMsg();
#ifdef PADDLE_MOBILE_FPGA
void Predict_From_To(int start, int end) override;
void FeedPaddleTensors(const std::vector<PaddleTensor>& inputs) override;
......
......@@ -174,6 +174,7 @@ class PaddlePredictor {
virtual bool Run(const std::vector<PaddleTensor>& inputs,
std::vector<PaddleTensor>* output_data,
int batch_size = -1) = 0;
virtual std::string GetExceptionMsg() { return ""; }
// Destroy the Predictor.
virtual ~PaddlePredictor() = default;
......
......@@ -540,6 +540,12 @@ else()
ADD_EXECUTABLE(test-net net/test_net.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-net paddle-mobile)
ADD_EXECUTABLE(test-super net/test_super.cpp test_helper.h test_include.h executor_for_test.h)
target_link_libraries(test-super paddle-mobile)
ADD_EXECUTABLE(test-inference-pre-post net/test_inference_pre_post.cpp)
target_link_libraries(test-inference-pre-post paddle-mobile)
ADD_EXECUTABLE(test-inference-super net/test_inference_super.cpp)
target_link_libraries(test-inference-super paddle-mobile)
endif()
......@@ -559,7 +559,7 @@ def check_mobile_results(args, fuse, mem_opt):
for i in range(len(values1)):
v1 = values1[i]
v2 = values2[len(shape) + i]
if abs(v1 - v2) > diff_threshold:
if ((not math.isnan(v1)) and math.isnan(v2)) or abs(v1 - v2) > diff_threshold:
error_index = index
break
checked_names.append(op_output_var_name)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册