未验证 提交 e1aab593 编写于 作者: X xiaogang 提交者: GitHub

Develop nlp patch (#3059)

* fix: fix nlp ops input and output type
* fix: add elementwise x_dims>y_dims case
上级 a9d17eef
...@@ -230,6 +230,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING) ...@@ -230,6 +230,7 @@ if(LITE_WITH_LIGHT_WEIGHT_FRAMEWORK AND WITH_TESTING)
ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl ARGS --cl_path=${CMAKE_SOURCE_DIR}/lite/backends/opencl
--model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL) --model_dir=${LITE_MODEL_DIR}/inception_v4 SERIAL)
add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz) add_dependencies(test_inceptionv4 extern_lite_download_inception_v4_simple_tar_gz)
# brief: we comment ocr_test_ut because we do not supply ocr model to test, it is the reference to infer nlp model
# lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc # lite_cc_test(test_ocr_attention SRCS ocr_attention_test.cc
# DEPS ${lite_model_test_DEPS}) # DEPS ${lite_model_test_DEPS})
...@@ -378,6 +379,16 @@ if(NOT IOS) ...@@ -378,6 +379,16 @@ if(NOT IOS)
FPGA_DEPS ${fpga_kernels} FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels} X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels}) CUDA_DEPS ${cuda_kernels})
lite_cc_binary(test_transformer SRCS transform_test.cc DEPS paddle_api_full paddle_api_light gflags utils
${ops} ${host_kernels}
ARM_DEPS ${arm_kernels}
CV_DEPS paddle_cv_arm
NPU_DEPS ${npu_kernels}
XPU_DEPS ${xpu_kernels}
CL_DEPS ${opencl_kernels}
FPGA_DEPS ${fpga_kernels}
X86_DEPS ${x86_kernels}
CUDA_DEPS ${cuda_kernels})
endif() endif()
#lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc #lite_cc_binary(cxx_api_bin SRCS cxx_api_bin.cc
......
...@@ -32,18 +32,10 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) { ...@@ -32,18 +32,10 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) {
predictor.Build(FLAGS_model_dir, "", "", valid_places); predictor.Build(FLAGS_model_dir, "", "", valid_places);
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 1, 48, 512})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
auto* init_scores = predictor.GetInput(2); auto* init_scores = predictor.GetInput(2);
init_scores->Resize(DDim(std::vector<DDim::value_type>({1, 1}))); init_scores->Resize(DDim(std::vector<DDim::value_type>({1, 1})));
auto* data_scores = init_scores->mutable_data<float>(); auto* data_scores = init_scores->mutable_data<float>();
auto scores_size = input_tensor->dims().production(); auto scores_size = init_scores->dims().production();
for (int i = 0; i < scores_size; i++) { for (int i = 0; i < scores_size; i++) {
data_scores[i] = 0; data_scores[i] = 0;
} }
...@@ -53,7 +45,7 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) { ...@@ -53,7 +45,7 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) {
auto* init_ids = predictor.GetInput(1); auto* init_ids = predictor.GetInput(1);
init_ids->Resize(DDim(std::vector<DDim::value_type>({1, 1}))); init_ids->Resize(DDim(std::vector<DDim::value_type>({1, 1})));
auto* data_ids = init_ids->mutable_data<float>(); auto* data_ids = init_ids->mutable_data<int64_t>();
auto ids_size = init_ids->dims().production(); auto ids_size = init_ids->dims().production();
for (int i = 0; i < ids_size; i++) { for (int i = 0; i < ids_size; i++) {
data_ids[i] = 0; data_ids[i] = 0;
...@@ -62,6 +54,13 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) { ...@@ -62,6 +54,13 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) {
std::vector<std::vector<uint64_t>> lod_i{{0, 1}, {0, 1}}; std::vector<std::vector<uint64_t>> lod_i{{0, 1}, {0, 1}};
*lod_ids = lod_i; *lod_ids = lod_i;
auto* input_tensor = predictor.GetInput(0);
input_tensor->Resize(DDim(std::vector<DDim::value_type>({1, 1, 48, 512})));
auto* data = input_tensor->mutable_data<float>();
auto item_size = input_tensor->dims().production();
for (int i = 0; i < item_size; i++) {
data[i] = 1;
}
for (int i = 0; i < FLAGS_warmup; ++i) { for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run(); predictor.Run();
} }
...@@ -102,6 +101,7 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) { ...@@ -102,6 +101,7 @@ void TestModel(const std::vector<Place>& valid_places, bool use_npu = false) {
TEST(OcrAttention, test_arm) { TEST(OcrAttention, test_arm) {
std::vector<Place> valid_places({ std::vector<Place> valid_places({
Place{TARGET(kARM), PRECISION(kInt64)},
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
}); });
......
...@@ -28,11 +28,10 @@ DEFINE_int32(batch, 1, "batch"); ...@@ -28,11 +28,10 @@ DEFINE_int32(batch, 1, "batch");
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace test_transformer {
namespace test_transformer {
std::vector<std::string> inputed_lines; std::vector<std::string> inputed_lines;
void load_input_lines(const char* filename) {
void LoadInputLines(const char* filename) {
static const int max_line_buf_size = 100 * 1024 * 1024; static const int max_line_buf_size = 100 * 1024 * 1024;
char* line_buffer = (char*)calloc(max_line_buf_size, sizeof(char)); // NOLINT char* line_buffer = (char*)calloc(max_line_buf_size, sizeof(char)); // NOLINT
FILE* input_file = fopen(filename, "r"); FILE* input_file = fopen(filename, "r");
...@@ -49,7 +48,7 @@ void LoadInputLines(const char* filename) { ...@@ -49,7 +48,7 @@ void LoadInputLines(const char* filename) {
line_buffer = NULL; line_buffer = NULL;
fclose(input_file); fclose(input_file);
} }
void Split2(const std::string& main_str, void split2(const std::string& main_str,
std::vector<std::string>& str_list, // NOLINT std::vector<std::string>& str_list, // NOLINT
const std::string& delimiter) { const std::string& delimiter) {
size_t pre_pos = 0; size_t pre_pos = 0;
...@@ -75,19 +74,19 @@ void Split2(const std::string& main_str, ...@@ -75,19 +74,19 @@ void Split2(const std::string& main_str,
} }
} // NOLINT } // NOLINT
void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT void pad_batch_input(std::vector<std::string>& input_lines, // NOLINT
int pad_idx, int pad_idx,
int n_head, int n_head,
Tensor* src_word, Tensor* src_word,
Tensor* src_pos, Tensor* src_pos,
Tensor* src_attn_bias, Tensor* src_attn_bias,
Tensor* trg_word, Tensor* trg_word,
Tensor* init_scores, Tensor* init_scores,
Tensor* init_idx, Tensor* init_idx,
Tensor* trg_bias, Tensor* trg_bias,
int line_start, int line_start,
int batch_size, int batch_size,
int bos_idx) { int bos_idx) {
int max_len = 0; int max_len = 0;
int max_line = input_lines.size(); int max_line = input_lines.size();
...@@ -98,27 +97,27 @@ void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT ...@@ -98,27 +97,27 @@ void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT
std::vector<std::string> split_str; std::vector<std::string> split_str;
test_transformer::Split2(cur_line, split_str, " "); test_transformer::split2(cur_line, split_str, " ");
batch_lines.push_back(split_str); batch_lines.push_back(split_str);
max_len = max_len >= split_str.size() ? max_len : split_str.size(); max_len = max_len >= split_str.size() ? max_len : split_str.size();
} }
src_word->Resize(std::vector<DDim::value_type>({batch_size, max_len, 1})); src_word->Resize(std::vector<DDim::value_type>({batch_size, max_len}));
src_pos->Resize(std::vector<DDim::value_type>({batch_size, max_len, 1})); src_pos->Resize(std::vector<DDim::value_type>({batch_size, max_len}));
src_attn_bias->Resize( src_attn_bias->Resize(
std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len})); std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len}));
trg_bias->Resize( trg_bias->Resize(
std::vector<DDim::value_type>({batch_size, n_head, 1, max_len})); std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len}));
float* src_word_data = src_word->mutable_data<float>(); auto* src_word_data = src_word->mutable_data<int64_t>();
float* src_pos_data = src_pos->mutable_data<float>(); auto* src_pos_data = src_pos->mutable_data<int64_t>();
float* src_bias_data = src_attn_bias->mutable_data<float>(); float* src_bias_data = src_attn_bias->mutable_data<float>();
float* trg_bias_data = trg_bias->mutable_data<float>(); float* trg_bias_data = trg_bias->mutable_data<float>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size; ++i) {
std::vector<std::string> cur_words = batch_lines[i]; std::vector<std::string> cur_words = batch_lines[i];
int fill_len = cur_words.size(); int fill_len = cur_words.size();
int src_bias_start = i * n_head * max_len * max_len; int src_bias_start = i * n_head * max_len * max_len;
int trg_bias_start = i * n_head * max_len; int trg_bias_start = i * n_head * max_len * max_len;
for (int j = 0; j < fill_len; ++j) { for (int j = 0; j < fill_len; ++j) {
src_word_data[i * max_len + j] = (atoi(cur_words[j].c_str())); src_word_data[i * max_len + j] = (atoi(cur_words[j].c_str()));
src_pos_data[i * max_len + j] = j; src_pos_data[i * max_len + j] = j;
...@@ -137,22 +136,24 @@ void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT ...@@ -137,22 +136,24 @@ void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT
int value_ind = j % max_len + src_bias_start; int value_ind = j % max_len + src_bias_start;
src_bias_data[j] = src_bias_data[value_ind]; src_bias_data[j] = src_bias_data[value_ind];
} }
for (int j = trg_bias_start; j < trg_bias_start + n_head * max_len; ++j) { for (int j = trg_bias_start;
j < trg_bias_start + n_head * max_len * max_len;
++j) {
int value_ind = j % max_len + trg_bias_start; int value_ind = j % max_len + trg_bias_start;
trg_bias_data[j] = trg_bias_data[value_ind]; trg_bias_data[j] = trg_bias_data[value_ind];
} }
} }
trg_word->Resize(std::vector<DDim::value_type>({batch_size, 1, 1})); trg_word->Resize(std::vector<DDim::value_type>({batch_size, max_len}));
auto* trg_word_data = trg_word->mutable_data<float>(); auto* trg_word_data = trg_word->mutable_data<int64_t>();
for (int i = 0; i < batch_size; ++i) { for (int i = 0; i < batch_size * max_len; ++i) {
trg_word_data[i] = bos_idx; trg_word_data[i] = bos_idx;
} }
init_scores->Resize(std::vector<DDim::value_type>({batch_size, 1})); init_scores->Resize(std::vector<DDim::value_type>({batch_size, 1}));
init_idx->Resize(std::vector<DDim::value_type>({batch_size})); init_idx->Resize(std::vector<DDim::value_type>({batch_size}));
float* score_data = init_scores->mutable_data<float>(); float* score_data = init_scores->mutable_data<float>();
float* idx_data = init_idx->mutable_data<float>(); auto* idx_data = init_idx->mutable_data<int32_t>();
for (int i = 0; i < init_scores->numel(); ++i) { for (int i = 0; i < init_scores->numel(); ++i) {
score_data[i] = 0; score_data[i] = 0;
} }
...@@ -175,21 +176,25 @@ void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT ...@@ -175,21 +176,25 @@ void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT
void TestModel(const std::vector<Place>& valid_places, void TestModel(const std::vector<Place>& valid_places,
const Place& preferred_place, const Place& preferred_place,
bool use_npu = false) { bool use_npu = false) {
#ifdef LITE_WITH_ARM
DeviceInfo::Init(); DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
#endif
lite::Predictor predictor; lite::Predictor predictor;
std::string test_data_path = FLAGS_input; std::string test_data_path = FLAGS_input;
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places); predictor.Build("",
FLAGS_model_dir + "/__model__",
FLAGS_model_dir + "/weights",
valid_places);
// predictor.Build(FLAGS_model_dir, "", "", valid_places);
int n_head = 8; int n_head = 8;
int batch_size = FLAGS_batch; int batch_size = FLAGS_batch;
int bos_idx = 0; int bos_idx = 0;
int eos_idx = 1; int eos_idx = 1;
LOG(INFO) << "reading";
test_transformer::LoadInputLines(test_data_path.c_str()); test_transformer::load_input_lines(test_data_path.c_str());
LOG(INFO) << "reading finished";
auto* trg_bias = predictor.GetInput(6); auto* trg_bias = predictor.GetInput(6);
auto* src_word = predictor.GetInput(0); auto* src_word = predictor.GetInput(0);
...@@ -205,28 +210,31 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -205,28 +210,31 @@ void TestModel(const std::vector<Place>& valid_places,
auto start = GetCurrentUS(); auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) { for (int i = 0; i < FLAGS_repeats; ++i) {
auto start_i = GetCurrentUS(); pad_batch_input(test_transformer::inputed_lines,
PadBatchInput(test_transformer::inputed_lines, eos_idx,
eos_idx, n_head,
n_head, src_word, // src_word
src_word, // src_word src_pos, // src_pos
src_pos, // src_pos src_bias, // src_bias
src_bias, // src_bias trg_word, // trg_word
trg_word, // trg_word init_score, // init_score
init_score, // init_score init_idx, // init_idx
init_idx, // init_idx trg_bias, // trg_bias
trg_bias, // trg_bias i * batch_size,
i * batch_size, batch_size,
batch_size, bos_idx);
bos_idx);
LOG(INFO) << "src_word:" << src_word->dims();
auto start_ii = GetCurrentUS();
LOG(INFO) << i << "->ii:" << (start_ii - start_i) / 1000.0;
predictor.Run(); predictor.Run();
auto start_iii = GetCurrentUS(); auto* outs = predictor.GetOutput(0);
LOG(INFO) << i << "->iii:" << (start_iii - start_ii) / 1000.0; auto o_data = outs->data<int64_t>();
auto* outs = predictor.GetOutputs(); auto lod = outs->lod();
LOG(INFO) << "out:" << (*outs)[0].dims(); for (int i = 0; i < outs->numel(); ++i) {
LOG(INFO) << o_data[i];
}
for (int i = 0; i < lod.size(); ++i) {
for (int j = 0; j < lod[i].size(); ++j) {
LOG(INFO) << lod[i][j];
}
}
} }
LOG(INFO) << "================== Speed Report ==================="; LOG(INFO) << "================== Speed Report ===================";
...@@ -234,25 +242,18 @@ void TestModel(const std::vector<Place>& valid_places, ...@@ -234,25 +242,18 @@ void TestModel(const std::vector<Place>& valid_places,
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats << ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0 << ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average."; << " ms in average.";
auto* outs = predictor.GetOutputs();
for (auto out : *outs) {
LOG(INFO) << "======"
<< "here";
LOG(INFO) << out;
}
LOG(INFO) << "======"
<< "hereggg";
} }
TEST(OcrAttention, test_arm) { } // namespace lite
} // namespace paddle
using namespace paddle::lite; // NOLINT
int main(int argc, char** argv) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
std::vector<Place> valid_places({ std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kInt64)},
Place{TARGET(kARM), PRECISION(kFloat)}, Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kHost), PRECISION(kFloat)},
}); });
TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)})); TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)}));
} }
} // namespace lite
} // namespace paddle
...@@ -70,7 +70,7 @@ void PruneEndBeams(const Tensor *pre_ids, ...@@ -70,7 +70,7 @@ void PruneEndBeams(const Tensor *pre_ids,
std::vector<std::vector<Item>> *items, std::vector<std::vector<Item>> *items,
size_t lod_level, size_t lod_level,
int end_id) { int end_id) {
auto *pre_ids_data = pre_ids->data<float>(); auto *pre_ids_data = pre_ids->data<int64_t>();
auto &high_level = abs_lod[lod_level]; auto &high_level = abs_lod[lod_level];
for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) { for (size_t src_idx = 0; src_idx < high_level.size() - 1; ++src_idx) {
size_t src_prefix_start = high_level[src_idx]; size_t src_prefix_start = high_level[src_idx];
...@@ -152,10 +152,10 @@ std::vector<std::vector<Item>> SelectTopBeamSizeItems(const Tensor *pre_ids, ...@@ -152,10 +152,10 @@ std::vector<std::vector<Item>> SelectTopBeamSizeItems(const Tensor *pre_ids,
// find the current candidates // find the current candidates
// auto abs_lod = framework::ToAbsOffset(scores->lod()); // auto abs_lod = framework::ToAbsOffset(scores->lod());
auto abs_lod = scores->lod(); auto abs_lod = scores->lod();
auto *pre_ids_data = pre_ids->data<float>(); auto *pre_ids_data = pre_ids->data<int64_t>();
auto *pre_scores_data = pre_scores->data<float>(); auto *pre_scores_data = pre_scores->data<float>();
auto *ids_data = ids ? ids->data<int>() : nullptr; auto *ids_data = ids ? ids->data<int64_t>() : nullptr;
auto *scores_data = scores->data<float>(); auto *scores_data = scores->data<float>();
size_t num_seqs = abs_lod[lod_level].size() - 1; size_t num_seqs = abs_lod[lod_level].size() - 1;
...@@ -236,7 +236,7 @@ void beam_search(const Tensor *pre_ids, ...@@ -236,7 +236,7 @@ void beam_search(const Tensor *pre_ids,
if (parent_idx) { if (parent_idx) {
parent_idx->Resize(dims); parent_idx->Resize(dims);
} }
auto *selected_ids_data = selected_ids->mutable_data<float>(); auto *selected_ids_data = selected_ids->mutable_data<int64_t>();
auto *selected_scores_data = selected_scores->mutable_data<float>(); auto *selected_scores_data = selected_scores->mutable_data<float>();
auto *parent_idx_data = auto *parent_idx_data =
parent_idx ? parent_idx->mutable_data<int>() : nullptr; parent_idx ? parent_idx->mutable_data<int>() : nullptr;
......
...@@ -13,11 +13,161 @@ ...@@ -13,11 +13,161 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include <algorithm>
#include <string>
#include <vector>
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
template <typename T>
void elementwise_broadcast_common(T const* x_data,
T const* y_data,
T* out_data,
std::vector<int64_t> x_real_dim,
std::vector<int64_t> y_real_dim,
std::vector<int64_t> out_real_dim,
std::string type,
bool is_xsize_large = false) {
int out_size = 1;
int max_dim = out_real_dim.size();
std::vector<int> index_array(max_dim, 0);
for (int i = 0; i < max_dim; ++i) {
out_size *= out_real_dim[i];
}
int x_index, y_index;
for (int out_index = 0; out_index < out_size; ++out_index) {
x_index = 0;
for (int i = 0; i < max_dim; i++) {
if (x_real_dim[i] > 1) {
x_index = x_index * x_real_dim[i] + index_array[i];
}
}
y_index = 0;
for (int i = 0; i < max_dim; i++) {
if (y_real_dim[i] > 1) {
y_index = y_index * y_real_dim[i] + index_array[i];
}
}
if (type == "add") {
out_data[out_index] = x_data[x_index] + y_data[y_index];
}
if (type == "mul") {
out_data[out_index] = x_data[x_index] * y_data[y_index];
}
}
for (int i = max_dim - 1; i >= 0; --i) {
++index_array[i];
if (index_array[i] >= out_real_dim[i]) {
index_array[i] -= out_real_dim[i];
} else {
break;
}
}
}
template <typename dtype>
void elementwise_compute_basic(const operators::ElementwiseParam& param,
const std::string elt_type,
const std::string act_type) {
const dtype* x_data = param.X->data<const dtype>();
const dtype* y_data = param.Y->data<const dtype>();
dtype* out_data = param.Out->mutable_data<dtype>();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
int axis = param.axis;
if (axis < 0) {
axis = x_dims.size() - y_dims.size();
}
int batch = 1;
int channels = 1;
int num = 1;
for (int i = 0; i < axis; ++i) {
batch *= x_dims[i];
}
for (int i = 0; i < y_dims.size(); ++i) {
channels *= y_dims[i];
}
for (int i = y_dims.size() + axis; i < x_dims.size(); ++i) {
num *= x_dims[i];
}
// do elementwise add/sub/max...
if (elt_type == "add") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr + diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "sub") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr - diny_data;
dout_ptr++;
}
}
}
} else if (elt_type == "mul") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = *din_ptr * diny_data;
dout_ptr++;
din_ptr++;
}
}
}
} else if (elt_type == "max") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
int offset = (i * channels + j) * num;
const dtype* din_ptr = x_data + offset;
const dtype diny_data = y_data[j];
dtype* dout_ptr = out_data + offset;
for (int k = 0; k < num; ++k) {
*dout_ptr = std::max(*din_ptr, diny_data);
dout_ptr++;
din_ptr++;
}
}
}
} else {
LOG(FATAL) << "unsupported Elementwise type: " << elt_type;
}
// do activation relu/sigmod...
if (act_type.size() > 0) {
if (act_type == "relu") {
for (int i = 0; i < batch; ++i) {
for (int j = 0; j < channels; ++j) {
dtype* dout_ptr = out_data + (i * channels + j) * num;
for (int k = 0; k < num; ++k) {
*dout_ptr = *dout_ptr > 0.0f ? *dout_ptr : 0.0f;
dout_ptr++;
}
}
}
} else {
LOG(FATAL) << "unsupported Activation type: " << elt_type;
}
}
}
template <typename T> template <typename T>
void elementwise_add(const T* dinx, const T* diny, T* dout, int num); void elementwise_add(const T* dinx, const T* diny, T* dout, int num);
......
...@@ -20,18 +20,7 @@ ...@@ -20,18 +20,7 @@
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {} // namespace math
void increment(const float* input,
const int n,
const float step,
float* out,
Context<TARGET(kARM)>* ctx) {
for (int i = 0; i < n; i++) {
out[i] = input[i] + step;
}
}
} // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
...@@ -21,11 +21,16 @@ namespace paddle { ...@@ -21,11 +21,16 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void increment(const float* input, template <typename T>
void increment(const T* input,
const int n, const int n,
const float step, const float step,
float* out, T* out,
Context<TARGET(kARM)>* ctx); Context<TARGET(kARM)>* ctx) {
for (int i = 0; i < n; i++) {
out[i] = input[i] + static_cast<T>(step);
}
}
} // namespace math } // namespace math
} // namespace arm } // namespace arm
......
...@@ -13,12 +13,30 @@ ...@@ -13,12 +13,30 @@
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#include "lite/core/tensor.h"
#include "lite/operators/op_params.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
template <typename dtype>
void scale_compute_basic(const operators::ScaleParam& param) {
const dtype* x_data = param.x->data<dtype>();
dtype* output_data = param.output->mutable_data<dtype>();
DDim x_dims = param.x->dims();
DDim output_dims = param.output->dims();
bool bias_after_scale = param.bias_after_scale;
float scale = param.scale;
float bias = param.bias;
if (!bias_after_scale) {
bias *= scale;
}
for (int i = 0; i < output_dims.production(); i++) {
output_data[i] = static_cast<dtype>(x_data[i] * scale + bias);
}
}
template <typename T> template <typename T>
void scale(const T* din, T* dout, int num, T scale, T bias); void scale(const T* din, T* dout, int num, T scale, T bias);
......
...@@ -26,7 +26,7 @@ bool comp_func(std::pair<float, int> a, std::pair<float, int> b) { ...@@ -26,7 +26,7 @@ bool comp_func(std::pair<float, int> a, std::pair<float, int> b) {
void topk(const float* in_data, void topk(const float* in_data,
float* out_val, float* out_val,
int* out_ind, int64_t* out_ind,
int m, int m,
int n, int n,
int k, int k,
...@@ -34,7 +34,7 @@ void topk(const float* in_data, ...@@ -34,7 +34,7 @@ void topk(const float* in_data,
for (int i = 0; i < m; i++) { for (int i = 0; i < m; i++) {
const float* in_tmp = in_data + i * n; const float* in_tmp = in_data + i * n;
float* out_val_tmp = out_val + i * k; float* out_val_tmp = out_val + i * k;
int* out_ind_tmp = out_ind + i * k; int64_t* out_ind_tmp = out_ind + i * k;
std::vector<std::pair<float, int>> vec; std::vector<std::pair<float, int>> vec;
for (int j = 0; j < n; j++) { for (int j = 0; j < n; j++) {
vec.push_back(std::make_pair(in_tmp[j], j)); vec.push_back(std::make_pair(in_tmp[j], j));
......
...@@ -22,7 +22,7 @@ namespace math { ...@@ -22,7 +22,7 @@ namespace math {
void topk(const float* din, void topk(const float* din,
float* out_val, float* out_val,
int* out_ind, int64_t* out_ind,
int m, int m,
int n, int n,
int k, int k,
......
...@@ -158,6 +158,7 @@ KernelRegistry::KernelRegistry() ...@@ -158,6 +158,7 @@ KernelRegistry::KernelRegistry()
INIT_FOR(kARM, kAny, kNCHW); INIT_FOR(kARM, kAny, kNCHW);
INIT_FOR(kARM, kAny, kAny); INIT_FOR(kARM, kAny, kAny);
INIT_FOR(kARM, kInt32, kNCHW); INIT_FOR(kARM, kInt32, kNCHW);
INIT_FOR(kARM, kInt64, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNCHW); INIT_FOR(kOpenCL, kFloat, kNCHW);
INIT_FOR(kOpenCL, kFloat, kNHWC); INIT_FOR(kOpenCL, kFloat, kNHWC);
......
...@@ -147,6 +147,9 @@ class KernelRegistry final { ...@@ -147,6 +147,9 @@ class KernelRegistry final {
KernelRegistryForTarget<TARGET(kARM), KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt8), PRECISION(kInt8),
DATALAYOUT(kNCHW)> *, // DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt64),
DATALAYOUT(kNCHW)> *, //
KernelRegistryForTarget<TARGET(kARM), KernelRegistryForTarget<TARGET(kARM),
PRECISION(kInt32), PRECISION(kInt32),
DATALAYOUT(kNCHW)> *, // DATALAYOUT(kNCHW)> *, //
......
...@@ -82,7 +82,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) { ...@@ -82,7 +82,7 @@ void TensorLite::CopyDataFrom(const TensorLite &other) {
target_ = other.target_; target_ = other.target_;
lod_ = other.lod_; lod_ = other.lod_;
memory_size_ = other.memory_size_; memory_size_ = other.memory_size_;
precision_ = other.precision_; precision_ = other.precision();
buffer_->CopyDataFrom(*other.buffer_, memory_size_); buffer_->CopyDataFrom(*other.buffer_, memory_size_);
} }
......
...@@ -38,7 +38,7 @@ const size_t kSentenceLevel = 1; ...@@ -38,7 +38,7 @@ const size_t kSentenceLevel = 1;
template <typename T> template <typename T>
struct Sentence { struct Sentence {
std::vector<float> word_ids; std::vector<int64_t> word_ids;
std::vector<T> scores; std::vector<T> scores;
}; };
...@@ -73,7 +73,7 @@ struct BeamSearchDecoder { ...@@ -73,7 +73,7 @@ struct BeamSearchDecoder {
std::vector<uint64_t> source_level_lod = {0}; std::vector<uint64_t> source_level_lod = {0};
std::vector<uint64_t> sentence_level_lod = {0}; std::vector<uint64_t> sentence_level_lod = {0};
std::vector<float> id_data; std::vector<int64_t> id_data;
std::vector<T> score_data; std::vector<T> score_data;
for (size_t src_idx = 0; src_idx < src_num; ++src_idx) { for (size_t src_idx = 0; src_idx < src_num; ++src_idx) {
...@@ -117,9 +117,9 @@ struct BeamSearchDecoder { ...@@ -117,9 +117,9 @@ struct BeamSearchDecoder {
*(id_tensor->mutable_lod()) = lod; *(id_tensor->mutable_lod()) = lod;
id_tensor->Resize({static_cast<int64_t>(id_data.size())}); id_tensor->Resize({static_cast<int64_t>(id_data.size())});
auto id_ptr = id_tensor->mutable_data<float>(); auto id_ptr = id_tensor->mutable_data<int64_t>();
TargetCopy( TargetCopy(
TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(float)); TARGET(kARM), id_ptr, id_data.data(), id_data.size() * sizeof(int64_t));
*(score_tensor->mutable_lod()) = lod; *(score_tensor->mutable_lod()) = lod;
score_tensor->Resize({static_cast<int64_t>(score_data.size())}); score_tensor->Resize({static_cast<int64_t>(score_data.size())});
...@@ -169,7 +169,7 @@ struct BeamSearchDecoder { ...@@ -169,7 +169,7 @@ struct BeamSearchDecoder {
++candidate_idx) { ++candidate_idx) {
prefix_idx_vector.push_back(prefix_idx); prefix_idx_vector.push_back(prefix_idx);
size_t idx = prefix_idx_vector.size() - 1; size_t idx = prefix_idx_vector.size() - 1;
auto cur_id = cur_ids.data<float>()[candidate_idx]; auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
auto cur_score = cur_scores.data<T>()[candidate_idx]; auto cur_score = cur_scores.data<T>()[candidate_idx];
sentence_vector.at(idx).word_ids.push_back(cur_id); sentence_vector.at(idx).word_ids.push_back(cur_id);
sentence_vector.at(idx).scores.push_back(cur_score); sentence_vector.at(idx).scores.push_back(cur_score);
...@@ -184,7 +184,7 @@ struct BeamSearchDecoder { ...@@ -184,7 +184,7 @@ struct BeamSearchDecoder {
cur_ids.lod().at(kSentenceLevel)[prefix_idx]; cur_ids.lod().at(kSentenceLevel)[prefix_idx];
for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) { for (size_t idx = 0; idx < prefix_idx_vector.size(); ++idx) {
auto candidate_idx = prefix_idx_vector.at(idx); auto candidate_idx = prefix_idx_vector.at(idx);
auto cur_id = cur_ids.data<float>()[candidate_idx]; auto cur_id = cur_ids.data<int64_t>()[candidate_idx];
auto cur_score = cur_scores.data<T>()[candidate_idx]; auto cur_score = cur_scores.data<T>()[candidate_idx];
if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) { if (cur_id != end_id_ || sentence_vector.at(idx).word_ids.empty()) {
// to skip redundant end tokens // to skip redundant end tokens
......
...@@ -148,6 +148,42 @@ void CompareCompute_int32<Functor>::Run() { ...@@ -148,6 +148,42 @@ void CompareCompute_int32<Functor>::Run() {
} }
} }
template <template <typename T> class Functor>
void CompareCompute_int64<Functor>::Run() {
auto &param = this->Param<operators::CompareParam>();
using CompareFunctor = Functor<int64_t>;
const size_t x_size = param.X->numel();
const size_t y_size = param.Y->numel();
auto x_dims = param.X->dims();
auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<int64_t>();
const auto *y = param.Y->template data<int64_t>();
auto axis = param.axis;
bool force_cpu = param.force_cpu;
if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]);
}
} else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
int outer_num, mid_num, inner_num;
get_mid_dims(x_dims, y_dims, axis, &outer_num, &mid_num, &inner_num);
for (int outer_id = 0; outer_id < outer_num; ++outer_id) {
for (int mid_id = 0; mid_id < mid_num; ++mid_id) {
auto y_data = y[mid_id];
for (int inner_id = 0; inner_id < inner_num; ++inner_id) {
int index = (outer_id * mid_num + mid_id) * inner_num + inner_id;
z[index] = CompareFunctor()(x[index], y_data);
// z[index] = x[index] < y_data;
}
}
}
}
}
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -164,6 +200,17 @@ REGISTER_LITE_KERNEL(less_than, ...@@ -164,6 +200,17 @@ REGISTER_LITE_KERNEL(less_than,
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(less_than,
kARM,
kInt64,
kNCHW,
paddle::lite::kernels::arm::CompareCompute_int64<
paddle::lite::kernels::arm::_LessThanFunctor>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.Finalize();
REGISTER_LITE_KERNEL(equal, REGISTER_LITE_KERNEL(equal,
kARM, kARM,
kFloat, kFloat,
......
...@@ -46,6 +46,17 @@ class CompareCompute_int32 ...@@ -46,6 +46,17 @@ class CompareCompute_int32
~CompareCompute_int32() {} ~CompareCompute_int32() {}
}; };
template <template <typename T> class Functor>
class CompareCompute_int64
: public KernelLite<TARGET(kARM), PRECISION(kInt64)> {
public:
using param_t = operators::LogicalParam;
void Run() override;
~CompareCompute_int64() {}
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
......
...@@ -80,7 +80,11 @@ void ElementwiseAddCompute::Run() { ...@@ -80,7 +80,11 @@ void ElementwiseAddCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_add_broadcast( lite::arm::math::elementwise_add_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
...@@ -99,7 +103,15 @@ void ElementwiseAddActivationCompute::Run() { ...@@ -99,7 +103,15 @@ void ElementwiseAddActivationCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_add_relu_broadcast(
y_data, x_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_add_relu_broadcast( lite::arm::math::elementwise_add_relu_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
...@@ -125,6 +137,9 @@ void ElementwiseSubCompute::Run() { ...@@ -125,6 +137,9 @@ void ElementwiseSubCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_sub_broadcast( lite::arm::math::elementwise_sub_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
...@@ -143,6 +158,9 @@ void ElementwiseSubActivationCompute::Run() { ...@@ -143,6 +158,9 @@ void ElementwiseSubActivationCompute::Run() {
std::string act_type = param.act_type; std::string act_type = param.act_type;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
...@@ -164,19 +182,29 @@ void ElementwiseSubActivationCompute::Run() { ...@@ -164,19 +182,29 @@ void ElementwiseSubActivationCompute::Run() {
template <typename T, PrecisionType PType> template <typename T, PrecisionType PType>
void ElementwiseMulCompute<T, PType>::Run() { void ElementwiseMulCompute<T, PType>::Run() {
auto& param = this->template Param<operators::ElementwiseParam>(); auto& param = this->template Param<operators::ElementwiseParam>();
auto* x_data = param.X->template data<T>(); if (param.X->precision() == PRECISION(kFloat)) {
auto* y_data = param.Y->template data<T>(); auto* x_data = param.X->template data<float>();
auto* out_data = param.Out->template mutable_data<T>(); auto* y_data = param.Y->template data<float>();
int axis = param.axis; auto* out_data = param.Out->template mutable_data<float>();
auto x_dims = param.X->dims(); int axis = param.axis;
auto y_dims = param.Y->dims(); auto x_dims = param.X->dims();
int pre, n, post; auto y_dims = param.Y->dims();
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { int pre, n, post;
lite::arm::math::elementwise_mul_broadcast<T>( if (x_dims.size() < y_dims.size() &&
x_data, y_data, out_data, pre, n, post); is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast<float>(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_mul_broadcast<float>(
x_data, y_data, out_data, pre, n, post);
} else {
lite::arm::math::elementwise_mul<float>(
x_data, y_data, out_data, x_dims.production());
}
} else if (param.X->precision() == PRECISION(kInt64)) {
lite::arm::math::elementwise_compute_basic<int64_t>(param, "mul", "");
} else { } else {
lite::arm::math::elementwise_mul<T>( LOG(FATAL) << "unsupport input type";
x_data, y_data, out_data, x_dims.production());
} }
} }
...@@ -190,7 +218,15 @@ void ElementwiseMulActivationCompute::Run() { ...@@ -190,7 +218,15 @@ void ElementwiseMulActivationCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu_broadcast<float>(
y_data, x_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_mul_relu_broadcast( lite::arm::math::elementwise_mul_relu_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
...@@ -216,7 +252,11 @@ void ElementwiseMaxCompute::Run() { ...@@ -216,7 +252,11 @@ void ElementwiseMaxCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_max_broadcast(
y_data, x_data, out_data, pre, n, post);
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_max_broadcast( lite::arm::math::elementwise_max_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
} else { } else {
...@@ -235,7 +275,15 @@ void ElementwiseMaxActivationCompute::Run() { ...@@ -235,7 +275,15 @@ void ElementwiseMaxActivationCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (x_dims.size() < y_dims.size() &&
is_broadcast(y_dims, x_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") {
lite::arm::math::elementwise_max_relu_broadcast<float>(
y_data, x_data, out_data, pre, n, post);
} else {
LOG(FATAL) << "unsupported Activation type: " << act_type;
}
} else if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
lite::arm::math::elementwise_max_relu_broadcast( lite::arm::math::elementwise_max_relu_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
...@@ -261,6 +309,9 @@ void ElementwiseDivCompute::Run() { ...@@ -261,6 +309,9 @@ void ElementwiseDivCompute::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
int pre, n, post; int pre, n, post;
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
lite::arm::math::elementwise_div_broadcast( lite::arm::math::elementwise_div_broadcast(
x_data, y_data, out_data, pre, n, post); x_data, y_data, out_data, pre, n, post);
...@@ -279,6 +330,9 @@ void ElementwiseDivActivationCompute::Run() { ...@@ -279,6 +330,9 @@ void ElementwiseDivActivationCompute::Run() {
std::string act_type = param.act_type; std::string act_type = param.act_type;
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
if (x_dims.size() < y_dims.size()) {
LOG(FATAL) << "elewise div don't support x_dims size < y_dims size";
}
int pre, n, post; int pre, n, post;
if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) { if (is_broadcast(x_dims, y_dims, axis, &pre, &n, &post)) {
if (act_type == "relu") { if (act_type == "relu") {
......
...@@ -39,6 +39,12 @@ void FillConstantBatchSizeLikeCompute::Run() { ...@@ -39,6 +39,12 @@ void FillConstantBatchSizeLikeCompute::Run() {
for (int i = 0; i < param.out->numel(); i++) { for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value; data[i] = param.value;
} }
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT64)) {
auto data = param.out->template mutable_data<int64_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else { } else {
LOG(FATAL) << "not supported dtype " << param.dtype; LOG(FATAL) << "not supported dtype " << param.dtype;
} }
......
...@@ -39,6 +39,12 @@ void FillConstantCompute::Run() { ...@@ -39,6 +39,12 @@ void FillConstantCompute::Run() {
for (int i = 0; i < param.out->numel(); i++) { for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value; data[i] = param.value;
} }
} else if (param.dtype ==
static_cast<int32_t>(lite::core::FluidType::INT64)) {
auto data = param.out->template mutable_data<int64_t>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
} else { } else {
LOG(FATAL) << "not supported dtype " << param.dtype; LOG(FATAL) << "not supported dtype " << param.dtype;
} }
......
...@@ -27,10 +27,22 @@ void IncrementCompute::Run() { ...@@ -27,10 +27,22 @@ void IncrementCompute::Run() {
auto& param = this->Param<operators::IncrementParam>(); auto& param = this->Param<operators::IncrementParam>();
int total_num = param.X->dims().production(); int total_num = param.X->dims().production();
if (param.X->precision() == PRECISION(kFloat)) {
const auto* x_data = param.X->data<float>(); const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<float>(); auto* o_data = param.Out->mutable_data<float>();
lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx); lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx);
} else if (param.X->precision() == PRECISION(kInt64)) {
const auto* x_data = param.X->data<int64_t>();
auto* o_data = param.Out->mutable_data<int64_t>();
lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx);
} else if (param.X->precision() == PRECISION(kInt32)) {
const auto* x_data = param.X->data<int32_t>();
auto* o_data = param.Out->mutable_data<int32_t>();
lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx);
} else {
LOG(FATAL) << "unsupport input type "
<< PrecisionToStr(param.X->precision());
}
} }
} // namespace arm } // namespace arm
......
...@@ -72,7 +72,7 @@ REGISTER_LITE_KERNEL(lookup_table, ...@@ -72,7 +72,7 @@ REGISTER_LITE_KERNEL(lookup_table,
paddle::lite::kernels::arm::LookupTableCompute, paddle::lite::kernels::arm::LookupTableCompute,
def) def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -83,6 +83,6 @@ REGISTER_LITE_KERNEL(lookup_table_v2, ...@@ -83,6 +83,6 @@ REGISTER_LITE_KERNEL(lookup_table_v2,
paddle::lite::kernels::arm::LookupTableCompute, paddle::lite::kernels::arm::LookupTableCompute,
def) def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("W", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kAny))}) .BindInput("Ids", {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt64))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -25,7 +25,7 @@ void TopkCompute::Run() { ...@@ -25,7 +25,7 @@ void TopkCompute::Run() {
auto& param = Param<operators::TopkParam>(); auto& param = Param<operators::TopkParam>();
const float* x_data = param.X->data<float>(); const float* x_data = param.X->data<float>();
float* out_val = param.Out->mutable_data<float>(); float* out_val = param.Out->mutable_data<float>();
int* out_ind = param.Indices->mutable_data<int>(); auto out_ind = param.Indices->mutable_data<int64_t>();
DDim x_dims = param.X->dims(); DDim x_dims = param.X->dims();
int K = param.K; int K = param.K;
int dim_size = x_dims.size(); int dim_size = x_dims.size();
......
...@@ -13,8 +13,9 @@ ...@@ -13,8 +13,9 @@
// limitations under the License. // limitations under the License.
#include "lite/operators/elementwise_ops.h" #include "lite/operators/elementwise_ops.h"
#include <algorithm>
#include <cmath>
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
...@@ -27,10 +28,59 @@ bool ElementwiseOp::CheckShape() const { ...@@ -27,10 +28,59 @@ bool ElementwiseOp::CheckShape() const {
} }
bool ElementwiseOp::InferShape() const { bool ElementwiseOp::InferShape() const {
CHECK_OR_FALSE(param_.X->dims().size() >= param_.Y->dims().size()); auto x_dim = param_.X->dims();
param_.Out->Resize(param_.X->dims()); auto y_dim = param_.Y->dims();
auto out_lod = param_.Out->mutable_lod(); if (x_dim == y_dim) {
*out_lod = param_.X->lod(); param_.Out->Resize(x_dim);
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
} else {
int max_dim = (x_dim.size() > y_dim.size() ? x_dim.size() : y_dim.size());
int axis = param_.axis;
axis = (axis == -1 ? std::abs(static_cast<int>(x_dim.size() - y_dim.size()))
: axis);
std::vector<int64_t> x_dims_array(max_dim);
std::vector<int64_t> y_dims_array(max_dim);
std::vector<int64_t> out_dims_array(max_dim);
if (x_dim.size() > y_dim.size()) {
for (int i = 0; i < axis; ++i) {
y_dims_array[i] = 1;
}
if (axis + y_dim.size() < max_dim) {
for (int i = axis + y_dim.size(); i < max_dim; ++i) {
y_dims_array[i] = 1;
}
}
x_dims_array = x_dim.Vectorize();
for (int i = 0; i < y_dim.size(); ++i) {
y_dims_array[i + axis] = y_dim[i];
}
} else {
for (int i = 0; i < axis; ++i) {
x_dims_array[i] = 1;
}
if (axis + x_dim.size() < max_dim) {
for (int i = axis + x_dim.size(); i < max_dim; ++i) {
x_dims_array[i] = 1;
}
}
y_dims_array = y_dim.Vectorize();
for (int i = 0; i < x_dim.size(); ++i) {
x_dims_array[i + axis] = x_dim[i];
}
}
for (int i = 0; i < max_dim; i++) {
if (x_dims_array[i] == -1 || y_dims_array[i] == -1) {
out_dims_array[i] = -1;
} else {
out_dims_array[i] = std::max(x_dims_array[i], y_dims_array[i]);
}
}
param_.Out->Resize(DDim(out_dims_array));
auto out_lod = param_.Out->mutable_lod();
*out_lod = param_.X->lod();
}
return true; return true;
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册