未验证 提交 f0a6c1eb 编写于 作者: T TianXiaogang 提交者: GitHub

Transformer pr (#2214)

* feat: add beam_search_special function for support nlp model

* fix: add beam_search_compute kernel input and output

* feat: add assign op & copy_compute kernel

* feat: add fill_const_batch_size_like op & kernel

* feat: add layer_norm op and kernel and ut

* fix: fix some bugs
    fix mul_op infer_shape bug when x_dim_idx = 2, x_dims.size()=3 & y_dim_idx = 1, y_dims.size()=2
    fix elementwise_compute bug when y axis is all 1
    fix beam_search choose math_func wrong bug
    fix layer_norm get attr bug
    fix fill_constant_batch_size_like shape_set bug

* feat: add gather op and kernel & and transform ut

* feats: add ops and fix bugs to support transformer op
       fix type_cast passes to skip `while`
       fix elementwise infer_shape bug when x.dims=3 and y.dims={1} & axis=0
       fix lookup_table compute bug
       fix read_from_array/beam_search/increment/compate/gather ops data_type problems

* fix:
    transfomer ut add word read inferface
    fix copy/gather/norm/layer_norm include path problem

* fix:debug info

* fix: fix input reshape bug

* fix: fix norm bug

* style: style fix & test=develop

* style: fix operators cmakelist

* style: fix operators cmakelist; test=develop

* fix and test=develop

* fix and test=develop

* style: style fix; test=develop
上级 7c69b6b4
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
USE_LITE_OP(mul); USE_LITE_OP(mul);
USE_LITE_OP(matmul); USE_LITE_OP(matmul);
USE_LITE_OP(fc); USE_LITE_OP(fc);
USE_LITE_OP(assign);
USE_LITE_OP(relu); USE_LITE_OP(relu);
USE_LITE_OP(relu6); USE_LITE_OP(relu6);
USE_LITE_OP(scale); USE_LITE_OP(scale);
......
...@@ -28,7 +28,7 @@ namespace lite { ...@@ -28,7 +28,7 @@ namespace lite {
TEST(model, test) { TEST(model, test) {
#ifdef LITE_WITH_ARM #ifdef LITE_WITH_ARM
DeviceInfo::Init(); DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads); DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_NO_BIND, FLAGS_threads);
lite::Predictor predictor; lite::Predictor predictor;
std::vector<Place> valid_places({Place{TARGET(kARM), PRECISION(kFloat)}, std::vector<Place> valid_places({Place{TARGET(kARM), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kInt8)}}); Place{TARGET(kARM), PRECISION(kInt8)}});
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gflags/gflags.h>
#include <gtest/gtest.h>
#include <string>
#include <vector>
#include "lite/api/cxx_api.h"
#include "lite/api/paddle_use_kernels.h"
#include "lite/api/paddle_use_ops.h"
#include "lite/api/paddle_use_passes.h"
#include "lite/api/test_helper.h"
#include "lite/core/op_registry.h"
DEFINE_string(input, "", "input_data");
DEFINE_int32(batch, 1, "batch");
namespace paddle {
namespace lite {
namespace test_transformer {
std::vector<std::string> inputed_lines;
void LoadInputLines(const char* filename) {
static const int max_line_buf_size = 100 * 1024 * 1024;
char* line_buffer = (char*)calloc(max_line_buf_size, sizeof(char)); // NOLINT
FILE* input_file = fopen(filename, "r");
while (fgets(line_buffer, max_line_buf_size, input_file)) {
// trim newline at end
char* pos = NULL;
if ((pos = strchr(line_buffer, '\n')) != NULL) {
*pos = 0;
}
inputed_lines.push_back(line_buffer);
}
free(line_buffer);
line_buffer = NULL;
fclose(input_file);
}
void Split2(const std::string& main_str,
std::vector<std::string>& str_list, // NOLINT
const std::string& delimiter) {
size_t pre_pos = 0;
size_t position = 0;
std::string tmp_str;
str_list.clear();
if (main_str.empty()) {
return;
}
while ((position = main_str.find(delimiter, pre_pos)) != std::string::npos) {
tmp_str.assign(main_str, pre_pos, position - pre_pos);
str_list.push_back(tmp_str);
pre_pos = position + 1;
}
tmp_str.assign(main_str, pre_pos, main_str.length() - pre_pos);
if (!tmp_str.empty()) {
str_list.push_back(tmp_str);
}
}
} // NOLINT
void PadBatchInput(std::vector<std::string>& input_lines, // NOLINT
int pad_idx,
int n_head,
Tensor* src_word,
Tensor* src_pos,
Tensor* src_attn_bias,
Tensor* trg_word,
Tensor* init_scores,
Tensor* init_idx,
Tensor* trg_bias,
int line_start,
int batch_size,
int bos_idx) {
int max_len = 0;
int max_line = input_lines.size();
std::vector<std::vector<std::string>> batch_lines;
for (int i = line_start; i < line_start + batch_size; ++i) {
int i_index = i % max_line;
std::string cur_line = input_lines[i_index];
std::vector<std::string> split_str;
test_transformer::Split2(cur_line, split_str, " ");
batch_lines.push_back(split_str);
max_len = max_len >= split_str.size() ? max_len : split_str.size();
}
src_word->Resize(std::vector<DDim::value_type>({batch_size, max_len, 1}));
src_pos->Resize(std::vector<DDim::value_type>({batch_size, max_len, 1}));
src_attn_bias->Resize(
std::vector<DDim::value_type>({batch_size, n_head, max_len, max_len}));
trg_bias->Resize(
std::vector<DDim::value_type>({batch_size, n_head, 1, max_len}));
float* src_word_data = src_word->mutable_data<float>();
float* src_pos_data = src_pos->mutable_data<float>();
float* src_bias_data = src_attn_bias->mutable_data<float>();
float* trg_bias_data = trg_bias->mutable_data<float>();
for (int i = 0; i < batch_size; ++i) {
std::vector<std::string> cur_words = batch_lines[i];
int fill_len = cur_words.size();
int src_bias_start = i * n_head * max_len * max_len;
int trg_bias_start = i * n_head * max_len;
for (int j = 0; j < fill_len; ++j) {
src_word_data[i * max_len + j] = (atoi(cur_words[j].c_str()));
src_pos_data[i * max_len + j] = j;
src_bias_data[src_bias_start + j] = 0;
trg_bias_data[trg_bias_start + j] = 0;
}
for (int j = fill_len; j < max_len; ++j) {
src_word_data[i * max_len + j] = pad_idx;
src_pos_data[i * max_len + j] = 0;
src_bias_data[src_bias_start + j] = -1000000000;
trg_bias_data[trg_bias_start + j] = -1000000000;
}
for (int j = src_bias_start;
j < src_bias_start + n_head * max_len * max_len;
++j) {
int value_ind = j % max_len + src_bias_start;
src_bias_data[j] = src_bias_data[value_ind];
}
for (int j = trg_bias_start; j < trg_bias_start + n_head * max_len; ++j) {
int value_ind = j % max_len + trg_bias_start;
trg_bias_data[j] = trg_bias_data[value_ind];
}
}
trg_word->Resize(std::vector<DDim::value_type>({batch_size, 1, 1}));
auto* trg_word_data = trg_word->mutable_data<float>();
for (int i = 0; i < batch_size; ++i) {
trg_word_data[i] = bos_idx;
}
init_scores->Resize(std::vector<DDim::value_type>({batch_size, 1}));
init_idx->Resize(std::vector<DDim::value_type>({batch_size}));
float* score_data = init_scores->mutable_data<float>();
float* idx_data = init_idx->mutable_data<float>();
for (int i = 0; i < init_scores->numel(); ++i) {
score_data[i] = 0;
}
std::vector<std::vector<uint64_t>> lod_s;
lod_s.resize(2);
for (int i = 0; i < batch_size; ++i) {
lod_s[0].push_back(i);
lod_s[1].push_back(i);
idx_data[i] = i;
}
lod_s[0].push_back(batch_size);
lod_s[1].push_back(batch_size);
auto score_lod = init_scores->mutable_lod();
*score_lod = lod_s;
auto trg_word_lod = trg_word->mutable_lod();
*trg_word_lod = lod_s;
}
void TestModel(const std::vector<Place>& valid_places,
const Place& preferred_place,
bool use_npu = false) {
DeviceInfo::Init();
DeviceInfo::Global().SetRunMode(lite_api::LITE_POWER_HIGH, FLAGS_threads);
lite::Predictor predictor;
std::string test_data_path = FLAGS_input;
predictor.Build(FLAGS_model_dir, "", "", preferred_place, valid_places);
int n_head = 8;
int batch_size = FLAGS_batch;
int bos_idx = 0;
int eos_idx = 1;
LOG(INFO) << "reading";
test_transformer::LoadInputLines(test_data_path.c_str());
LOG(INFO) << "reading finished";
auto* trg_bias = predictor.GetInput(6);
auto* src_word = predictor.GetInput(0);
auto* src_pos = predictor.GetInput(1);
auto* src_bias = predictor.GetInput(2);
auto* trg_word = predictor.GetInput(3);
auto* init_score = predictor.GetInput(4);
auto* init_idx = predictor.GetInput(5);
for (int i = 0; i < FLAGS_warmup; ++i) {
predictor.Run();
}
auto start = GetCurrentUS();
for (int i = 0; i < FLAGS_repeats; ++i) {
auto start_i = GetCurrentUS();
PadBatchInput(test_transformer::inputed_lines,
eos_idx,
n_head,
src_word, // src_word
src_pos, // src_pos
src_bias, // src_bias
trg_word, // trg_word
init_score, // init_score
init_idx, // init_idx
trg_bias, // trg_bias
i * batch_size,
batch_size,
bos_idx);
LOG(INFO) << "src_word:" << src_word->dims();
auto start_ii = GetCurrentUS();
LOG(INFO) << i << "->ii:" << (start_ii - start_i) / 1000.0;
predictor.Run();
auto start_iii = GetCurrentUS();
LOG(INFO) << i << "->iii:" << (start_iii - start_ii) / 1000.0;
auto* outs = predictor.GetOutputs();
LOG(INFO) << "out:" << (*outs)[0].dims();
}
LOG(INFO) << "================== Speed Report ===================";
LOG(INFO) << "Model: " << FLAGS_model_dir << ", threads num " << FLAGS_threads
<< ", warmup: " << FLAGS_warmup << ", repeats: " << FLAGS_repeats
<< ", spend " << (GetCurrentUS() - start) / FLAGS_repeats / 1000.0
<< " ms in average.";
auto* outs = predictor.GetOutputs();
for (auto out : *outs) {
LOG(INFO) << "======"
<< "here";
LOG(INFO) << out;
}
LOG(INFO) << "======"
<< "hereggg";
}
TEST(OcrAttention, test_arm) {
std::vector<Place> valid_places({
Place{TARGET(kHost), PRECISION(kFloat)},
Place{TARGET(kARM), PRECISION(kFloat)},
});
TestModel(valid_places, Place({TARGET(kARM), PRECISION(kFloat)}));
}
} // namespace lite
} // namespace paddle
...@@ -21,10 +21,10 @@ namespace paddle { ...@@ -21,10 +21,10 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void increment(const int* input, void increment(const float* input,
const int n, const int n,
const float step, const float step,
int* out, float* out,
Context<TARGET(kARM)>* ctx) { Context<TARGET(kARM)>* ctx) {
for (int i = 0; i < n; i++) { for (int i = 0; i < n; i++) {
out[i] = input[i] + step; out[i] = input[i] + step;
......
...@@ -21,10 +21,10 @@ namespace paddle { ...@@ -21,10 +21,10 @@ namespace paddle {
namespace lite { namespace lite {
namespace arm { namespace arm {
namespace math { namespace math {
void increment(const int* input, void increment(const float* input,
const int n, const int n,
const float step, const float step,
int* out, float* out,
Context<TARGET(kARM)>* ctx); Context<TARGET(kARM)>* ctx);
} // namespace math } // namespace math
......
...@@ -15,6 +15,7 @@ ...@@ -15,6 +15,7 @@
#include "lite/backends/arm/math/norm.h" #include "lite/backends/arm/math/norm.h"
#include <arm_neon.h> #include <arm_neon.h>
#include <cmath> #include <cmath>
#include "lite/backends/arm/math/funcs.h"
#include "lite/utils/cp_logging.h" #include "lite/utils/cp_logging.h"
namespace paddle { namespace paddle {
...@@ -43,7 +44,143 @@ void norm(const float* input, ...@@ -43,7 +44,143 @@ void norm(const float* input,
} }
} }
} }
LOG(INFO) << "norm math finished"; }
void matrix_norm_row(const float* x_data,
const float* scale_data,
const float* bias_data,
float* out_data,
float* mean_out,
float* var_out,
float epsilon,
int batch_size,
int feature_size) {
int cnt = feature_size >> 4;
int remain = feature_size & 0xf;
#pragma omp parallel for
for (int bi = 0; bi < batch_size; ++bi) {
int offset = bi * feature_size;
const float* x_ptr = x_data + offset;
float mean = 0.f;
float variance = 0.f;
// get mean and variance
float32x4_t mean_v = vdupq_n_f32(0);
float32x4_t var_v = vdupq_n_f32(0);
for (int oi = 0; oi < cnt; ++oi) {
float32x4_t odim1 = vld1q_f32(x_ptr);
float32x4_t odim2 = vld1q_f32(x_ptr + 4);
float32x4_t odim3 = vld1q_f32(x_ptr + 8);
float32x4_t odim4 = vld1q_f32(x_ptr + 12);
mean_v = vaddq_f32(mean_v, odim1);
mean_v = vaddq_f32(mean_v, odim2);
mean_v = vaddq_f32(mean_v, odim3);
mean_v = vaddq_f32(mean_v, odim4);
var_v = vmlaq_f32(var_v, odim1, odim1);
var_v = vmlaq_f32(var_v, odim2, odim2);
var_v = vmlaq_f32(var_v, odim3, odim3);
var_v = vmlaq_f32(var_v, odim4, odim4);
x_ptr += 16;
}
mean = vgetq_lane_f32(mean_v, 0) + vgetq_lane_f32(mean_v, 1) +
vgetq_lane_f32(mean_v, 2) + vgetq_lane_f32(mean_v, 3);
variance = vgetq_lane_f32(var_v, 0) + vgetq_lane_f32(var_v, 1) +
vgetq_lane_f32(var_v, 2) + vgetq_lane_f32(var_v, 3);
for (int i = 0; i < remain; ++i) {
mean += *x_ptr;
variance += (*x_ptr) * (*x_ptr);
++x_ptr;
}
mean /= feature_size;
variance = variance / feature_size - mean * mean;
mean_out[bi] = mean;
var_out[bi] = variance;
variance = sqrtf(variance + epsilon);
float rvar = 1 / variance;
// compute norm_out
float* out_ptr = out_data + offset;
x_ptr = x_data + offset;
auto* scale_ptr = scale_data;
auto* bias_ptr = bias_data;
float32x4_t vneg = vdupq_n_f32(-1);
float32x4_t scale1 = vdupq_n_f32(1);
float32x4_t scale2 = vdupq_n_f32(1);
float32x4_t scale3 = vdupq_n_f32(1);
float32x4_t scale4 = vdupq_n_f32(1);
float32x4_t bias1 = vdupq_n_f32(0);
float32x4_t bias2 = vdupq_n_f32(0);
float32x4_t bias3 = vdupq_n_f32(0);
float32x4_t bias4 = vdupq_n_f32(0);
for (int oi = 0; oi < cnt; ++oi) {
float32x4_t odim1 = vld1q_f32(x_ptr);
float32x4_t odim2 = vld1q_f32(x_ptr + 4);
float32x4_t odim3 = vld1q_f32(x_ptr + 8);
float32x4_t odim4 = vld1q_f32(x_ptr + 12);
odim1 = vmlaq_n_f32(odim1, vneg, mean);
odim2 = vmlaq_n_f32(odim2, vneg, mean);
odim3 = vmlaq_n_f32(odim3, vneg, mean);
odim4 = vmlaq_n_f32(odim4, vneg, mean);
if (scale_data) {
scale1 = vld1q_f32(scale_ptr);
scale2 = vld1q_f32(scale_ptr + 4);
scale3 = vld1q_f32(scale_ptr + 8);
scale4 = vld1q_f32(scale_ptr + 12);
scale_ptr += 16;
}
if (bias_data) {
bias1 = vld1q_f32(bias_ptr);
bias2 = vld1q_f32(bias_ptr + 4);
bias3 = vld1q_f32(bias_ptr + 8);
bias4 = vld1q_f32(bias_ptr + 12);
bias_ptr += 16;
}
float32x4_t os1 = vmulq_n_f32(scale1, rvar);
float32x4_t os2 = vmulq_n_f32(scale2, rvar);
float32x4_t os3 = vmulq_n_f32(scale3, rvar);
float32x4_t os4 = vmulq_n_f32(scale4, rvar);
odim1 = vmlaq_f32(bias1, odim1, os1);
odim2 = vmlaq_f32(bias2, odim2, os2);
odim3 = vmlaq_f32(bias3, odim3, os3);
odim4 = vmlaq_f32(bias4, odim4, os4);
vst1q_f32(out_ptr, odim1);
vst1q_f32(out_ptr + 4, odim2);
vst1q_f32(out_ptr + 8, odim3);
vst1q_f32(out_ptr + 12, odim4);
x_ptr += 16;
out_ptr += 16;
}
for (int i = 0; i < remain; ++i) {
auto out_value = (*x_ptr - mean) / variance;
if (scale_data) {
out_value = out_value * (*scale_ptr);
++scale_ptr;
}
if (bias_data) {
out_value = out_value + *bias_ptr;
++bias_ptr;
}
*out_ptr = out_value;
++out_ptr;
++x_ptr;
}
} // for bi
} }
} // namespace math } // namespace math
......
...@@ -29,6 +29,15 @@ void norm(const float* input, ...@@ -29,6 +29,15 @@ void norm(const float* input,
float* out, float* out,
Context<TARGET(kARM)>* ctx); Context<TARGET(kARM)>* ctx);
void matrix_norm_row(const float* x_data,
const float* scale_data,
const float* bias_data,
float* out_data,
float* mean_out,
float* var_out,
float epsilon,
int batch_size,
int feature_size);
} // namespace math } // namespace math
} // namespace arm } // namespace arm
} // namespace lite } // namespace lite
......
...@@ -25,7 +25,7 @@ void FcFuser::BuildPattern() { ...@@ -25,7 +25,7 @@ void FcFuser::BuildPattern() {
// create nodes. // create nodes.
auto* x = VarNode("x")->assert_is_op_input("mul", "X"); auto* x = VarNode("x")->assert_is_op_input("mul", "X");
auto* W = VarNode("W")->assert_is_op_input("mul", "Y"); auto* W = VarNode("W")->assert_is_op_input("mul", "Y");
auto* b = VarNode("b"); auto* b = VarNode("b")->assert_is_persistable_var();
auto* mul = OpNode("mul", "mul"); auto* mul = OpNode("mul", "mul");
auto* mul_out = VarNode("mul_out"); auto* mul_out = VarNode("mul_out");
auto* add = OpNode("add", "elementwise_add"); auto* add = OpNode("add", "elementwise_add");
......
...@@ -37,7 +37,7 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -37,7 +37,7 @@ void TypeLayoutTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
VLOG(4) << "nodes.size():" << nodes.size(); VLOG(4) << "nodes.size():" << nodes.size();
for (auto& node : nodes) { for (auto& node : nodes) {
VLOG(4) << "!node->IsStmt():" << !node->IsStmt(); VLOG(4) << "!node->IsStmt():" << !node->IsStmt();
if (!node->IsStmt()) continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
VLOG(4) << "node->AsStmt().desc:" << node->AsStmt().desc VLOG(4) << "node->AsStmt().desc:" << node->AsStmt().desc
<< " inlinks.size():" << inlinks.size(); << " inlinks.size():" << inlinks.size();
......
...@@ -33,7 +33,7 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -33,7 +33,7 @@ void PrecisionCastPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
} }
for (auto& node : nodes) { for (auto& node : nodes) {
if (!node->IsStmt()) continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in);
......
...@@ -36,7 +36,7 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) { ...@@ -36,7 +36,7 @@ void TypeTargetTransformPass::Apply(const std::unique_ptr<SSAGraph>& graph) {
CHECK(!valid_places_.empty()); CHECK(!valid_places_.empty());
for (auto& node : nodes) { for (auto& node : nodes) {
if (!node->IsStmt()) continue; if (!node->IsStmt() || node->AsStmt().op_type() == "while") continue;
auto inlinks = node->inlinks; auto inlinks = node->inlinks;
for (auto* in : inlinks) { for (auto* in : inlinks) {
ComplementInputs(graph.get(), node, in); ComplementInputs(graph.get(), node, in);
......
...@@ -46,6 +46,8 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li ...@@ -46,6 +46,8 @@ add_kernel(reduce_max_compute_arm ARM basic SRCS reduce_max_compute.cc DEPS ${li
add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_expand_compute_arm ARM basic SRCS sequence_expand_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(im2sequence_compute_arm ARM basic SRCS im2sequence_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(sequence_pool_compute_arm ARM basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(layer_norm_compute_arm ARM basic SRCS layer_norm_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(gather_compute_arm ARM basic SRCS gather_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(reduce_mean_compute_arm ARM basic SRCS reduce_mean_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(stack_compute_arm ARM basic SRCS stack_compute.cc DEPS ${lite_kernel_deps} math_arm)
add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm) add_kernel(assign_compute_arm ARM basic SRCS assign_compute.cc DEPS ${lite_kernel_deps} math_arm)
...@@ -98,4 +100,5 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_ ...@@ -98,4 +100,5 @@ lite_cc_test(test_dropout_compute_arm SRCS dropout_compute_test.cc DEPS dropout_
lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra) lite_cc_test(test_transpose_compute_arm SRCS transpose_compute_test.cc DEPS transpose_compute_arm COMPILE_LEVEL extra)
lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm) lite_cc_test(test_argmax_compute_arm SRCS argmax_compute_test.cc DEPS argmax_compute_arm)
lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm) lite_cc_test(test_axpy_compute_arm SRCS axpy_compute_test.cc DEPS axpy_compute_arm)
lite_cc_test(test_layer_norm_compute_arm SRCS layer_norm_compute_test.cc DEPS layer_norm_compute_arm)
lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm) lite_cc_test(test_conv_transpose_compute_arm SRCS conv_transpose_compute_test.cc DEPS conv_transpose_compute_arm)
...@@ -276,6 +276,10 @@ void BeamSearchDecodeCompute::Run() { ...@@ -276,6 +276,10 @@ void BeamSearchDecodeCompute::Run() {
param.end_id); param.end_id);
func.apply<float>(); func.apply<float>();
// when decode finish, we clear ids and scores
param.ids->clear();
param.scores->clear();
} }
} // namespace arm } // namespace arm
......
...@@ -87,14 +87,13 @@ void CompareCompute<Functor>::Run() { ...@@ -87,14 +87,13 @@ void CompareCompute<Functor>::Run() {
auto x_dims = param.X->dims(); auto x_dims = param.X->dims();
auto y_dims = param.Y->dims(); auto y_dims = param.Y->dims();
bool *z = param.Out->template mutable_data<bool>(); bool *z = param.Out->template mutable_data<bool>();
const auto *x = param.X->template data<int>(); const auto *x = param.X->template data<float>();
const auto *y = param.Y->template data<float>(); const auto *y = param.Y->template data<float>();
auto axis = param.axis; auto axis = param.axis;
bool force_cpu = param.force_cpu; bool force_cpu = param.force_cpu;
if (x_size == y_size) { if (x_size == y_size) {
for (int i = 0; i < x_size; ++i) { for (int i = 0; i < x_size; ++i) {
z[i] = CompareFunctor()(x[i], y[i]); z[i] = CompareFunctor()(x[i], y[i]);
// z[i] = x[i] < y[i];
} }
} else { } else {
int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis); int axis = (param.axis == -1 ? x_dims.size() - y_dims.size() : param.axis);
......
...@@ -38,6 +38,31 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> { ...@@ -38,6 +38,31 @@ class FillConstantCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
virtual ~FillConstantCompute() = default; virtual ~FillConstantCompute() = default;
}; };
template <typename T>
class FillConstantBatchLikeCompute
: public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::FillConstantBatchLikeParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto& context = ctx_->As<ARMContext>();
if (param.input->lod().size() && param.input_dim_idx == 0) {
auto odims = param.out->dims();
odims[param.output_dim_idx] = param.input->lod().back().size() - 1;
param.out->Resize(odims);
}
auto data = param.out->template mutable_data<T>();
for (int i = 0; i < param.out->numel(); i++) {
data[i] = param.value;
}
}
virtual ~FillConstantBatchLikeCompute() = default;
};
} // namespace arm } // namespace arm
} // namespace kernels } // namespace kernels
} // namespace lite } // namespace lite
...@@ -52,3 +77,13 @@ REGISTER_LITE_KERNEL(fill_constant, ...@@ -52,3 +77,13 @@ REGISTER_LITE_KERNEL(fill_constant,
def) def)
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize(); .Finalize();
REGISTER_LITE_KERNEL(
fill_constant_batch_size_like,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::FillConstantBatchLikeCompute<float>,
def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/gather_compute.h"
#include <vector>
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void GatherCompute::PrepareForRun() {}
void GatherCompute::Run() {
auto& param = this->Param<operators::GatherParam>();
auto* p_output = param.Out->mutable_data<float>();
auto index_size = param.Index->dims()[0];
auto src_dims = param.X->dims();
const float* p_src = param.X->data<float>();
const float* p_index = param.Index->data<float>();
int slice_size = 1;
for (int i = 1; i < src_dims.size(); ++i) {
slice_size *= src_dims[i];
}
for (int i = 0; i < index_size; ++i) {
int index_ = p_index[i];
memcpy(p_output + i * slice_size,
p_src + index_ * slice_size,
slice_size * sizeof(float));
}
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
gather, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::GatherCompute, def)
.BindInput("Input", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class GatherCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::GatherParam;
void PrepareForRun() override;
void Run() override;
~GatherCompute() {}
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -28,8 +28,8 @@ void IncrementCompute::Run() { ...@@ -28,8 +28,8 @@ void IncrementCompute::Run() {
int total_num = param.X->dims().production(); int total_num = param.X->dims().production();
const auto* x_data = param.X->data<int>(); const auto* x_data = param.X->data<float>();
auto* o_data = param.Out->mutable_data<int>(); auto* o_data = param.Out->mutable_data<float>();
lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx); lite::arm::math::increment(x_data, total_num, param.step, o_data, &ctx);
} }
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include "lite/backends/arm/math/funcs.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void LayerNormCompute::PrepareForRun() {}
void LayerNormCompute::Run() {
auto& param = this->Param<operators::LayerNormParam>();
auto input_dims = param.X->dims();
const auto* x_data = param.X->data<float>();
const auto* scale = param.Scale ? param.Scale->data<float>() : nullptr;
const auto* bias = param.Bias ? param.Bias->data<float>() : nullptr;
auto* o_data = param.Y->mutable_data<float>();
auto* mean = param.Mean->mutable_data<float>();
auto* var = param.Variance->mutable_data<float>();
int axis = param.begin_norm_axis;
auto matrix_dim = param.X->dims().Flatten2D(axis);
int left = matrix_dim[0];
int right = matrix_dim[1];
lite::arm::math::matrix_norm_row(
x_data, scale, bias, o_data, mean, var, param.epsilon, left, right);
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(layer_norm,
kARM,
kFloat,
kNCHW,
paddle::lite::kernels::arm::LayerNormCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Scale", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("Bias", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Mean", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Variance", {LiteType::GetTensorTy(TARGET(kARM))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <stdint.h>
#include "lite/backends/arm/math/type_trans.h"
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
class LayerNormCompute : public KernelLite<TARGET(kARM), PRECISION(kFloat)> {
public:
using param_t = operators::LayerNormParam;
void PrepareForRun() override;
void Run() override;
~LayerNormCompute() {}
};
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/arm/layer_norm_compute.h"
#include <gtest/gtest.h>
#include <limits>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace arm {
void LayerNormComputeRef(const operators::LayerNormParam& param) {
auto* x = param.X;
auto* y = param.Y;
auto* scale_tensor = param.Scale;
auto* bias_tensor = param.Bias;
auto* mean_tensor = param.Mean;
auto* var_tensor = param.Variance;
int begin_norm_axis = param.begin_norm_axis;
float epsilon = param.epsilon;
auto* x_data = x->data<float>();
auto* scale_data =
(scale_tensor == nullptr ? nullptr : scale_tensor->data<float>());
auto* bias_data =
(bias_tensor == nullptr ? nullptr : bias_tensor->data<float>());
auto* out_data = y->mutable_data<float>();
auto* mean_data = mean_tensor->mutable_data<float>();
auto* var_data = var_tensor->mutable_data<float>();
auto matrix_dim = x->dims().Flatten2D(begin_norm_axis);
int batch_size = matrix_dim[0];
int feature_size = matrix_dim[1];
for (int i = 0; i < batch_size; ++i) {
int start = i * feature_size;
int end = start + feature_size;
float mean = 0;
float var = 0;
for (int j = start; j < end; ++j) {
mean += x_data[j];
var += x_data[j] * x_data[j];
}
mean /= feature_size;
var = var / feature_size - mean * mean;
mean_data[i] = mean;
var_data[i] = var;
var = sqrt(var + epsilon);
for (int j = start; j < end; ++j) {
out_data[j] = (x_data[j] - mean) / var;
if (scale_data) {
out_data[j] *= scale_data[j - start];
}
if (bias_data) {
out_data[j] += bias_data[j - start];
}
}
}
}
TEST(layer_norm_arm, init) {
LayerNormCompute layer_norm;
ASSERT_EQ(layer_norm.precision(), PRECISION(kFloat));
ASSERT_EQ(layer_norm.target(), TARGET(kARM));
}
TEST(layer_norm_arm, compute) {
LayerNormCompute layer_norm;
operators::LayerNormParam param;
lite::Tensor x;
lite::Tensor output;
lite::Tensor output_mean;
lite::Tensor output_var;
lite::Tensor output_ref;
lite::Tensor output_mean_ref;
lite::Tensor output_var_ref;
lite::Tensor bias;
lite::Tensor scale;
lite::Tensor* bias_ptr;
lite::Tensor* scale_ptr;
for (auto n : {1, 3}) {
for (auto c : {1, 3, 5}) {
for (auto h : {3, 16, 20, 32}) {
for (auto w : {3, 16, 20, 32}) {
for (auto axis : {0, 1, 2}) {
for (auto has_bias : {true, false}) {
for (auto has_scale : {true, false}) {
auto dims = DDim(std::vector<int64_t>({n, c, h, w}));
auto out_size = dims.Flatten2D(axis)[0];
auto inner_size = dims.Flatten2D(axis)[1];
bias_ptr = nullptr;
scale_ptr = nullptr;
if (has_bias) {
bias.Resize(std::vector<int64_t>({inner_size, 1, 1, 1}));
float* bias_data = bias.mutable_data<float>();
for (int i = 0; i < inner_size; ++i) {
bias_data[i] = 0.01;
}
bias_ptr = &bias;
}
if (has_scale) {
scale.Resize(std::vector<int64_t>({inner_size, 1, 1, 1}));
float* scale_data = scale.mutable_data<float>();
for (int i = 0; i < inner_size; ++i) {
scale_data[i] = 0.2;
}
scale_ptr = &scale;
}
x.Resize(dims);
output.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_ref.Resize(DDim(std::vector<int64_t>({n, c, h, w})));
output_mean.Resize(std::vector<int64_t>({out_size, 1, 1, 1}));
output_mean_ref.Resize(
std::vector<int64_t>({out_size, 1, 1, 1}));
output_var.Resize(std::vector<int64_t>({out_size, 1, 1, 1}));
output_var_ref.Resize(
std::vector<int64_t>({out_size, 1, 1, 1}));
auto* x_data = x.mutable_data<float>();
auto* output_data = output.mutable_data<float>();
auto* output_mean_data = output_mean.mutable_data<float>();
auto* output_var_data = output_var.mutable_data<float>();
auto* output_data_ref = output_ref.mutable_data<float>();
auto* output_mean_data_ref =
output_mean_ref.mutable_data<float>();
auto* output_var_data_ref =
output_var_ref.mutable_data<float>();
for (int i = 0; i < x.dims().production(); i++) {
x_data[i] = i % 255 * 0.001;
}
param.X = &x;
param.Y = &output;
param.begin_norm_axis = axis;
param.Bias = bias_ptr;
param.Scale = scale_ptr;
param.Mean = &output_mean;
param.Variance = &output_var;
param.epsilon = 0.00001;
layer_norm.SetParam(param);
layer_norm.Run();
param.Y = &output_ref;
param.Mean = &output_mean_ref;
param.Variance = &output_var_ref;
LayerNormComputeRef(param);
for (int i = 0; i < output.dims().production(); i++) {
EXPECT_NEAR(output_data[i], output_data_ref[i], 1e-4);
}
for (int i = 0; i < output_mean.dims().production(); ++i) {
EXPECT_NEAR(
output_mean_data[i], output_mean_data_ref[i], 1e-5);
EXPECT_NEAR(output_var_data[i], output_var_data_ref[i], 1e-5);
}
}
}
}
}
}
}
}
}
TEST(layer_norm, retrive_op) {
auto layer_norm =
KernelRegistry::Global().Create<TARGET(kARM), PRECISION(kFloat)>(
"layer_norm");
ASSERT_FALSE(layer_norm.empty());
ASSERT_TRUE(layer_norm.front());
}
} // namespace arm
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(layer_norm, kARM, kFloat, kNCHW, def);
...@@ -38,13 +38,14 @@ void LookupTableCompute::Run() { ...@@ -38,13 +38,14 @@ void LookupTableCompute::Run() {
auto table_dim = w->dims(); auto table_dim = w->dims();
int64_t ids_numel = ids->numel(); int64_t ids_numel = ids->numel();
auto ids_data = ids->data<float>(); auto ids_data = ids->data<float>();
int ids_int = ids_data[0];
int64_t row_number = table_dim[0]; int64_t row_number = table_dim[0];
int64_t row_width = table_dim[1]; int64_t row_width = table_dim[1];
auto table_data = w->data<float>(); auto table_data = w->data<float>();
auto dout = out->mutable_data<float>(); auto dout = out->mutable_data<float>();
for (int64_t i = 0; i < ids_numel; ++i) { for (int64_t i = 0; i < ids_numel; ++i) {
int ids_int = ids_data[i];
if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) { if (param.padding_idx != -1 && ids_data[i] == param.padding_idx) {
memset(dout + i * row_width, 0, row_width * sizeof(float)); memset(dout + i * row_width, 0, row_width * sizeof(float));
} else { } else {
......
...@@ -28,14 +28,13 @@ void ReadFromArrayCompute::Run() { ...@@ -28,14 +28,13 @@ void ReadFromArrayCompute::Run() {
int in_num = param.X->size(); int in_num = param.X->size();
CHECK_EQ(param.I->numel(), 1) << "I should have only one element"; CHECK_EQ(param.I->numel(), 1) << "I should have only one element";
int id = param.I->data<int>()[0]; int id = param.I->data<float>()[0];
CHECK_LE(id, in_num) << "id is not valid"; CHECK_LE(id, in_num) << "id is not valid";
int input_size = (*param.X)[id].numel(); int input_size = (*param.X)[id].numel();
param.Out->Resize((*param.X)[id].dims()); param.Out->Resize((*param.X)[id].dims());
auto* o_data = param.Out->mutable_data<float>(); param.Out->CopyDataFrom((*param.X)[id]);
const auto* x_data = (*param.X)[id].data<float>();
memcpy(o_data, x_data, sizeof(float) * input_size);
auto out_lod = param.Out->mutable_lod(); auto out_lod = param.Out->mutable_lod();
*out_lod = (*param.X)[id].lod(); *out_lod = (*param.X)[id].lod();
} }
......
...@@ -43,5 +43,6 @@ REGISTER_LITE_KERNEL( ...@@ -43,5 +43,6 @@ REGISTER_LITE_KERNEL(
top_k, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::TopkCompute, def) top_k, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::TopkCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Indices", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Indices",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kInt32))})
.Finalize(); .Finalize();
...@@ -46,7 +46,7 @@ void WhileCompute::Run() { ...@@ -46,7 +46,7 @@ void WhileCompute::Run() {
REGISTER_LITE_KERNEL( REGISTER_LITE_KERNEL(
while, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::WhileCompute, def) while, kARM, kFloat, kNCHW, paddle::lite::kernels::arm::WhileCompute, def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorListTy(TARGET(kARM))})
.BindInput("Condition", .BindInput("Condition",
{LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))}) {LiteType::GetTensorTy(TARGET(kARM), PRECISION(kBool))})
.BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
......
...@@ -28,7 +28,7 @@ void WriteToArrayCompute::Run() { ...@@ -28,7 +28,7 @@ void WriteToArrayCompute::Run() {
CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element"; CHECK_EQ(param.I->numel(), 1) << "input2 should have only one element";
const auto* x_data = param.X->data<float>(); const auto* x_data = param.X->data<float>();
int id = param.I->data<int>()[0]; int id = param.I->data<float>()[0];
int id_test = param.I->data<int64_t>()[0]; int id_test = param.I->data<int64_t>()[0];
if (id >= param.Out->size()) { if (id >= param.Out->size()) {
for (int i = param.Out->size(); i < id + 1; i++) { for (int i = param.Out->size(); i < id + 1; i++) {
...@@ -57,5 +57,5 @@ REGISTER_LITE_KERNEL(write_to_array, ...@@ -57,5 +57,5 @@ REGISTER_LITE_KERNEL(write_to_array,
def) def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("X", {LiteType::GetTensorTy(TARGET(kARM))})
.BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))}) .BindInput("I", {LiteType::GetTensorTy(TARGET(kARM))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kARM))}) .BindOutput("Out", {LiteType::GetTensorListTy(TARGET(kARM))})
.Finalize(); .Finalize();
...@@ -5,6 +5,7 @@ lite_cc_library(op_params SRCS op_params.cc DEPS tensor any) ...@@ -5,6 +5,7 @@ lite_cc_library(op_params SRCS op_params.cc DEPS tensor any)
add_operator(conv_op basic SRCS conv_op.cc DEPS ${op_DEPS}) add_operator(conv_op basic SRCS conv_op.cc DEPS ${op_DEPS})
add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS}) add_operator(pool_op basic SRCS pool_op.cc DEPS ${op_DEPS})
add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS}) add_operator(fc_op basic SRCS fc_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS}) add_operator(relu_op basic SRCS relu_op.cc DEPS ${op_DEPS})
add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS}) add_operator(mul_op basic SRCS mul_op.cc DEPS ${op_DEPS})
add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS}) add_operator(matmul_op basic SRCS matmul_op.cc DEPS ${op_DEPS})
...@@ -25,7 +26,6 @@ add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DE ...@@ -25,7 +26,6 @@ add_operator(multiclass_nms_op_lite basic SRCS multiclass_nms_op.cc DEPS ${op_DE
add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS}) add_operator(fusion_elementwise_activation_ops basic SRCS fusion_elementwise_activation_ops.cc DEPS elementwise_ops ${op_DEPS})
add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS}) add_operator(mean_op basic SRCS mean_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS}) add_operator(fill_constant_op basic SRCS fill_constant_op.cc DEPS ${op_DEPS})
add_operator(fill_constant_batch_size_like_op basic SRCS fill_constant_batch_size_like_op.cc DEPS ${op_DEPS})
#add_operator(sgd_op basic SRCS sgd_op.cc DEPS ${op_DEPS}) #add_operator(sgd_op basic SRCS sgd_op.cc DEPS ${op_DEPS})
add_operator(uniform_random_op basic SRCS uniform_random_op.cc DEPS ${op_DEPS}) add_operator(uniform_random_op basic SRCS uniform_random_op.cc DEPS ${op_DEPS})
add_operator(power_op basic SRCS power_op.cc DEPS ${op_DEPS}) add_operator(power_op basic SRCS power_op.cc DEPS ${op_DEPS})
...@@ -61,10 +61,10 @@ add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_ ...@@ -61,10 +61,10 @@ add_operator(sequence_expand_op_lite basic SRCS sequence_expand_op.cc DEPS ${op_
add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS}) add_operator(squeeze_op_lite basic SRCS squeeze_op.cc DEPS ${op_DEPS})
add_operator(unsqueeze_op_lite basic SRCS unsqueeze_op.cc DEPS ${op_DEPS}) add_operator(unsqueeze_op_lite basic SRCS unsqueeze_op.cc DEPS ${op_DEPS})
add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS}) add_operator(im2sequence_op basic SRCS im2sequence_op.cc DEPS ${op_DEPS})
add_operator(gather_op basic SRCS gather_op.cc DEPS ${op_DEPS})
add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS}) add_operator(reduce_mean_op basic SRCS reduce_mean_op.cc DEPS ${op_DEPS})
add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS}) add_operator(stack_op basic SRCS stack_op.cc DEPS ${op_DEPS})
add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS}) add_operator(cast_op_lite basic SRCS cast_op.cc DEPS ${op_DEPS})
add_operator(assign_op basic SRCS assign_op.cc DEPS ${op_DEPS})
add_operator(affine_channel_op basic SRCS affine_channel_op.cc DEPS ${op_DEPS}) add_operator(affine_channel_op basic SRCS affine_channel_op.cc DEPS ${op_DEPS})
add_operator(anchor_generator_op basic SRCS anchor_generator_op.cc DEPS ${op_DEPS}) add_operator(anchor_generator_op basic SRCS anchor_generator_op.cc DEPS ${op_DEPS})
add_operator(generate_proposals_op basic SRCS generate_proposals_op.cc DEPS ${op_DEPS}) add_operator(generate_proposals_op basic SRCS generate_proposals_op.cc DEPS ${op_DEPS})
...@@ -100,6 +100,7 @@ add_operator(slice_op_lite basic SRCS slice_op.cc DEPS ${op_DEPS}) ...@@ -100,6 +100,7 @@ add_operator(slice_op_lite basic SRCS slice_op.cc DEPS ${op_DEPS})
add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS}) add_operator(write_to_array_op extra SRCS write_to_array_op.cc DEPS ${op_DEPS})
add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS}) add_operator(topk_op extra SRCS topk_op.cc DEPS ${op_DEPS})
add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS}) add_operator(increment_op extra SRCS increment_op.cc DEPS ${op_DEPS})
add_operator(layer_norm_op extra SRCS layer_norm_op.cc DEPS ${op_DEPS})
add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS}) add_operator(sequence_softmax_op extra SRCS sequence_softmax_op.cc DEPS ${op_DEPS})
......
...@@ -52,8 +52,67 @@ class FillConstantOp : public OpLite { ...@@ -52,8 +52,67 @@ class FillConstantOp : public OpLite {
mutable operators::FillConstantParam param_; mutable operators::FillConstantParam param_;
}; };
class FillConstantBatchLikeOp : public OpLite {
public:
explicit FillConstantBatchLikeOp(const std::string& type) : OpLite(type) {}
bool CheckShape() const override {
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.input);
CHECK_GT_OR_FALSE(param_.shape.size(), 0);
CHECK_GE_OR_FALSE(param_.input_dim_idx, 0);
CHECK_GE_OR_FALSE(param_.output_dim_idx, 0);
return true;
}
bool InferShape() const override {
auto output_dim = param_.shape;
output_dim[param_.output_dim_idx] =
param_.input->dims()[param_.input_dim_idx];
param_.out->Resize(output_dim);
return true;
}
bool AttachImpl(const cpp::OpDesc& opdesc, lite::Scope* scope) override {
auto Out_name = opdesc.Output("Out").front();
auto In_name = opdesc.Input("Input").front();
param_.out = GetMutableVar<lite::Tensor>(scope, Out_name);
param_.input = GetMutableVar<lite::Tensor>(scope, In_name);
param_.dtype = opdesc.GetAttr<int>("dtype");
auto shape = opdesc.GetAttr<std::vector<int>>("shape");
std::vector<int64_t> outshape;
for (auto i : shape) {
outshape.push_back(i);
}
param_.shape = outshape;
if (opdesc.HasAttr("value")) {
param_.value = opdesc.GetAttr<float>("value");
}
if (opdesc.HasAttr("input_dim_idx")) {
param_.input_dim_idx = opdesc.GetAttr<int>("input_dim_idx");
}
if (opdesc.HasAttr("output_dim_idx")) {
param_.output_dim_idx = opdesc.GetAttr<int>("output_dim_idx");
}
return true;
}
void AttachKernel(KernelBase* kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "fill_constant_batch_size_like";
}
private:
mutable operators::FillConstantBatchLikeParam param_;
};
} // namespace operators } // namespace operators
} // namespace lite } // namespace lite
} // namespace paddle } // namespace paddle
REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp); REGISTER_LITE_OP(fill_constant, paddle::lite::operators::FillConstantOp);
REGISTER_LITE_OP(fill_constant_batch_size_like,
paddle::lite::operators::FillConstantBatchLikeOp);
...@@ -11,59 +11,48 @@ ...@@ -11,59 +11,48 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "lite/operators/gather_op.h"
#include "lite/operators/fill_constant_batch_size_like_op.h"
#include <algorithm> #include <algorithm>
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h" #include "lite/core/op_registry.h"
#include "lite/core/tensor.h"
namespace paddle { namespace paddle {
namespace lite { namespace lite {
namespace operators { namespace operators {
bool FillConstantBatchSizeLikeOp::CheckShape() const { bool GatherOp::CheckShape() const {
CHECK_OR_FALSE(param_.Input); CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Index);
CHECK_OR_FALSE(param_.Out); CHECK_OR_FALSE(param_.Out);
return true; return true;
} }
bool FillConstantBatchSizeLikeOp::InferShape() const { bool GatherOp::InferShape() const {
auto shape = param_.shape; auto index_dims = param_.Index->dims();
std::vector<int64_t> shape_int64(shape.size(), 0); CHECK(index_dims.size() == 1 ||
std::transform(shape.begin(), shape.end(), shape_int64.begin(), [](int a) { (index_dims.size() == 2 && index_dims[1] == 1))
return static_cast<int64_t>(a); << "index dims unmatch";
}); int batch_size = index_dims[0];
lite::DDim output_dim(shape_int64); auto out_dims = param_.X->dims();
out_dims[0] = batch_size;
int input_dim_idx = param_.input_dim_idx; param_.Out->Resize(out_dims);
int output_dim_idx = param_.output_dim_idx;
output_dim[output_dim_idx] = param_.Input->dims()[input_dim_idx];
param_.Out->Resize(output_dim);
return true; return true;
} }
bool FillConstantBatchSizeLikeOp::AttachImpl(const cpp::OpDesc &op_desc, bool GatherOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
lite::Scope *scope) { param_.X =
auto Input = op_desc.Input("Input").front(); scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
auto Out = op_desc.Output("Out").front(); param_.Out =
param_.Input = scope->FindVar(Input)->GetMutable<lite::Tensor>(); scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
param_.Out = scope->FindVar(Out)->GetMutable<lite::Tensor>(); param_.Index =
param_.shape = op_desc.GetAttr<std::vector<int>>("shape"); scope->FindVar(opdesc.Input("Index").front())->GetMutable<lite::Tensor>();
param_.input_dim_idx = op_desc.GetAttr<int>("input_dim_idx"); CHECK(param_.X) << "X is null";
param_.output_dim_idx = op_desc.GetAttr<int>("output_dim_idx"); CHECK(param_.Out) << "out is null";
param_.dtype = op_desc.GetAttr<int>("dtype"); CHECK(param_.Index) << "index is null";
param_.value = op_desc.GetAttr<float>("value");
CHECK(param_.Input);
CHECK(param_.Out);
return true; return true;
} }
} /* namespace operators */ } // namespace operators
} /* namespace lite */ } // namespace lite
} /* namespace paddle */ } // namespace paddle
REGISTER_LITE_OP(fill_constant_batch_size_like, REGISTER_LITE_OP(gather, paddle::lite::operators::GatherOp);
paddle::lite::operators::FillConstantBatchSizeLikeOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class GatherOp : public OpLite {
public:
GatherOp() {}
explicit GatherOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "gather"; }
private:
mutable GatherParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/layer_norm_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool LayerNormOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_OR_FALSE(param_.Mean);
CHECK_OR_FALSE(param_.Variance);
return true;
}
bool LayerNormOp::InferShape() const {
auto out_dims = param_.X->dims();
param_.Y->Resize(out_dims);
auto inner_size = out_dims.Flatten2D(param_.begin_norm_axis)[1];
param_.Mean->Resize(std::vector<int64_t>({inner_size}));
param_.Variance->Resize(std::vector<int64_t>({inner_size}));
auto out_lod = param_.Y->mutable_lod();
*out_lod = param_.X->lod();
return true;
}
bool LayerNormOp::AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) {
param_.X =
scope->FindVar(opdesc.Input("X").front())->GetMutable<lite::Tensor>();
param_.Y =
scope->FindVar(opdesc.Output("Y").front())->GetMutable<lite::Tensor>();
param_.Mean =
scope->FindVar(opdesc.Output("Mean").front())->GetMutable<lite::Tensor>();
param_.Variance = scope->FindVar(opdesc.Output("Variance").front())
->GetMutable<lite::Tensor>();
CHECK(param_.X);
CHECK(param_.Y);
CHECK(param_.Mean);
CHECK(param_.Variance);
if (opdesc.HasInput("Scale")) {
param_.Scale = scope->FindVar(opdesc.Input("Scale").front())
->GetMutable<lite::Tensor>();
}
if (opdesc.HasInput("Bias")) {
param_.Bias = scope->FindVar(opdesc.Input("Bias").front())
->GetMutable<lite::Tensor>();
}
param_.begin_norm_axis = opdesc.GetAttr<int>("begin_norm_axis");
param_.epsilon = opdesc.GetAttr<float>("epsilon");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(layer_norm, paddle::lite::operators::LayerNormOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class LayerNormOp : public OpLite {
public:
LayerNormOp() {}
explicit LayerNormOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "layer_norm"; }
private:
mutable LayerNormParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
...@@ -23,6 +23,7 @@ bool MulOpLite::CheckShape() const { ...@@ -23,6 +23,7 @@ bool MulOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x); CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.y); CHECK_OR_FALSE(param_.y);
CHECK_OR_FALSE(param_.output); CHECK_OR_FALSE(param_.output);
// bias is optional. // bias is optional.
const auto x_dims = param_.x->dims(); const auto x_dims = param_.x->dims();
...@@ -54,17 +55,15 @@ bool MulOpLite::InferShape() const { ...@@ -54,17 +55,15 @@ bool MulOpLite::InferShape() const {
const auto y_dims = param_.y->dims(); const auto y_dims = param_.y->dims();
// Set output dims // Set output dims
std::vector<int64_t> out_dims( std::vector<int64_t> out_dims;
param_.x_num_col_dims + y_dims.size() - param_.y_num_col_dims, 0);
for (int i = 0; i < param_.x_num_col_dims; ++i) { for (int i = 0; i < param_.x_num_col_dims; ++i) {
out_dims[i] = x_dims[i]; out_dims.push_back(x_dims[i]);
} }
for (auto i = static_cast<size_t>(param_.y_num_col_dims); i < y_dims.size(); for (auto i = static_cast<size_t>(param_.y_num_col_dims); i < y_dims.size();
++i) { ++i) {
out_dims[i] = y_dims[i]; out_dims.push_back(y_dims[i]);
} }
param_.output->Resize(lite::DDim(out_dims)); param_.output->Resize(lite::DDim(out_dims));
auto out_lod = param_.output->mutable_lod(); auto out_lod = param_.output->mutable_lod();
*out_lod = param_.x->lod(); *out_lod = param_.x->lod();
......
...@@ -373,6 +373,17 @@ struct FillConstantParam { ...@@ -373,6 +373,17 @@ struct FillConstantParam {
bool force_cpu{false}; bool force_cpu{false};
lite::Tensor* Out{}; lite::Tensor* Out{};
}; };
struct FillConstantBatchLikeParam {
int dtype{static_cast<int>(VarDescAPI::VarDataType::FP32)};
std::vector<int64_t> shape{};
float value{0.0f};
// useless for x86, keep it for compatibility
bool force_cpu{false};
lite::Tensor* out{};
const lite::Tensor* input{};
int input_dim_idx{0};
int output_dim_idx{0};
};
struct FillConstantBatchSizeLikeParam { struct FillConstantBatchSizeLikeParam {
lite::Tensor* Input; lite::Tensor* Input;
...@@ -619,6 +630,16 @@ struct NormParam { ...@@ -619,6 +630,16 @@ struct NormParam {
int axis{1}; int axis{1};
float epsilon{1e-10}; float epsilon{1e-10};
}; };
struct LayerNormParam {
const lite::Tensor* X{};
const lite::Tensor* Scale{};
const lite::Tensor* Bias{};
lite::Tensor* Y{};
lite::Tensor* Mean{};
lite::Tensor* Variance{};
int begin_norm_axis{1};
float epsilon{1e-5};
};
struct LogicalParam { struct LogicalParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
...@@ -816,6 +837,12 @@ struct MatMulParam { ...@@ -816,6 +837,12 @@ struct MatMulParam {
float alpha{1.0f}; float alpha{1.0f};
}; };
struct GatherParam {
const lite::Tensor* X{};
const lite::Tensor* Index{};
lite::Tensor* Out{};
};
/// ----------------------- assign operators ----------------------- /// ----------------------- assign operators -----------------------
struct AssignParam { struct AssignParam {
const lite::Tensor* X{}; const lite::Tensor* X{};
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册