提交 73296cb6 编写于 作者: Z zhupengyang 提交者: GitHub

[X86][CUDA] add sequence_arithmetic op , x86 kernel, cuda kernel and unit test (#2436)

* [X86][CUDA] add sequence_arithmetic op , x86 kernel, cuda kernel and unit test

test=develop

* add sequence_arithmetic cuda kernel unit test

test=develop
上级 188162f4
......@@ -25,6 +25,7 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.
add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_arithmetic_compute_cuda CUDA basic SRCS sequence_arithmetic_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_cuda CUDA basic SRCS match_matrix_tensor_compute.cu DEPS ${lite_kernel_deps} cuda_gemm)
......@@ -45,7 +46,9 @@ nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc D
nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
nv_test(sequence_arithmetic_compute_cuda_test SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_cuda)
nv_test(match_matrix_tensor_compute_cuda_test SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_cuda)
if(LITE_BUILD_EXTRA)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_arithmetic_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
const int CUDA_NUM_THREADS = 512;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
inline int CUDA_GET_BLOCKS(const int N) {
return (N + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
}
template <typename Dtype>
__global__ void ker_arithmetic_sum(Dtype* out_data,
const Dtype* in_data_0,
const Dtype* in_data_1,
const int* offset_0,
const int* offset_1,
const int* word_id_to_seq_id,
const int seq_num,
const int inner_size,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
int emb_id = tid % inner_size;
int word_id = tid / inner_size;
int seq_id = word_id_to_seq_id[word_id];
int word_id_in_cur_seq = word_id - offset_0[seq_id];
int seq_len_1 = offset_1[seq_id + 1] - offset_1[seq_id];
if (word_id_in_cur_seq < seq_len_1) {
out_data[tid] =
in_data_0[tid] +
in_data_1[(offset_1[seq_id] + word_id_in_cur_seq) * inner_size +
emb_id];
} else {
out_data[tid] = in_data_0[tid];
}
}
}
template <typename Dtype>
__global__ void ker_arithmetic_sub(Dtype* out_data,
const Dtype* in_data_0,
const Dtype* in_data_1,
const int* offset_0,
const int* offset_1,
const int* word_id_to_seq_id,
const int seq_num,
const int inner_size,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
int emb_id = tid % inner_size;
int word_id = tid / inner_size;
int seq_id = word_id_to_seq_id[word_id];
int word_id_in_cur_seq = word_id - offset_0[seq_id];
int seq_len_1 = offset_1[seq_id + 1] - offset_1[seq_id];
if (word_id_in_cur_seq < seq_len_1) {
out_data[tid] =
in_data_0[tid] -
in_data_1[(offset_1[seq_id] + word_id_in_cur_seq) * inner_size +
emb_id];
} else {
out_data[tid] = in_data_0[tid];
}
}
}
template <typename Dtype>
__global__ void ker_arithmetic_mul(Dtype* out_data,
const Dtype* in_data_0,
const Dtype* in_data_1,
const int* offset_0,
const int* offset_1,
const int* word_id_to_seq_id,
const int seq_num,
const int inner_size,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
int emb_id = tid % inner_size;
int word_id = tid / inner_size;
int seq_id = word_id_to_seq_id[word_id];
int word_id_in_cur_seq = word_id - offset_0[seq_id];
int seq_len_1 = offset_1[seq_id + 1] - offset_1[seq_id];
if (word_id_in_cur_seq < seq_len_1) {
out_data[tid] =
in_data_0[tid] *
in_data_1[(offset_1[seq_id] + word_id_in_cur_seq) * inner_size +
emb_id];
} else {
out_data[tid] = in_data_0[tid];
}
}
}
void SequenceArithmeticCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto x_data = param.X->data<float>();
auto x_lod = param.X->lod()[0];
auto y_data = param.X->data<float>();
auto y_lod = param.Y->lod()[0];
auto out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
offset_x.Resize({static_cast<int64_t>(x_lod.size())});
auto offset_x_data = offset_x.mutable_data<int>(TARGET(kCUDA));
offset_y.Resize({static_cast<int64_t>(y_lod.size())});
auto offset_y_data = offset_y.mutable_data<int>(TARGET(kCUDA));
word_id_to_seq_id.Resize({param.X->numel()});
auto word_id_to_seq_id_data =
word_id_to_seq_id.mutable_data<int>(TARGET(kCUDA));
std::vector<int> word_seq_map;
for (int i = 0; i < x_lod.size() - 1; i++) {
for (int j = x_lod[i]; j < x_lod[i + 1]; j++) {
word_seq_map.push_back(i);
}
}
std::vector<int> offset_x_data_cpu(x_lod.size(), 0);
auto x_lod_data = x_lod.data();
for (int i = 0; i < offset_x_data_cpu.size(); i++) {
offset_x_data_cpu[i] = x_lod_data[i];
}
std::vector<int> offset_y_data_cpu(y_lod.size(), 0);
auto y_lod_data = y_lod.data();
for (int i = 0; i < offset_y_data_cpu.size(); i++) {
offset_y_data_cpu[i] = y_lod_data[i];
}
TargetWrapperCuda::MemcpyAsync(offset_x_data,
offset_x_data_cpu.data(),
sizeof(int) * x_lod.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(offset_y_data,
offset_y_data_cpu.data(),
sizeof(int) * y_lod.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(word_id_to_seq_id_data,
word_seq_map.data(),
sizeof(int) * word_seq_map.size(),
IoDirection::HtoD,
stream);
int seq_num = x_lod.size() - 1;
int count = param.X->numel();
int inner_size = param.X->dims()[1];
switch (param.op_type) {
case 1: // sum
ker_arithmetic_sum<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
out_data,
x_data,
y_data,
offset_x_data,
offset_y_data,
word_id_to_seq_id_data,
seq_num,
inner_size,
count);
break;
case 2: // sub
ker_arithmetic_sub<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
out_data,
x_data,
y_data,
offset_x_data,
offset_y_data,
word_id_to_seq_id_data,
seq_num,
inner_size,
count);
break;
case 3: // mul
ker_arithmetic_mul<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
out_data,
x_data,
y_data,
offset_x_data,
offset_y_data,
word_id_to_seq_id_data,
seq_num,
inner_size,
count);
break;
default:
break;
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_arithmetic,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequenceArithmeticCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
REGISTER_LITE_KERNEL(search_seq_arithmetic,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequenceArithmeticCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequenceArithmeticCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceArithmeticParam;
void Run() override;
virtual ~SequenceArithmeticCompute() = default;
private:
lite::Tensor offset_x;
lite::Tensor offset_y;
lite::Tensor word_id_to_seq_id;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_arithmetic_compute.h"
#include <gtest/gtest.h>
#include <algorithm>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
void sequence_arithmetic_compute_ref(const Tensor& x,
const Tensor& y,
Tensor* out,
int op_type) {
auto x_data = x.data<float>();
auto y_data = y.data<float>();
out->Resize(x.dims());
out->set_lod(x.lod());
auto out_data = out->mutable_data<float>();
auto x_seq_offset = x.lod()[0];
auto y_seq_offset = y.lod()[0];
int seq_num = x_seq_offset.size() - 1;
int inner_size = x.numel() / x.dims()[0];
for (int i = 0; i < seq_num; i++) {
int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
auto input_x = x_data + x_seq_offset[i] * inner_size;
auto input_y = y_data + y_seq_offset[i] * inner_size;
auto t_out = out_data + x_seq_offset[i] * inner_size;
int len = std::min(len_x, len_y);
for (int j = 0; j < len; j++) {
switch (op_type) {
case 1:
t_out[j] = input_x[j] + input_y[j];
break;
case 2:
t_out[j] = input_x[j] - input_y[j];
break;
case 3:
t_out[j] = input_x[j] * input_y[j];
break;
default:
break;
}
}
if (len_x > len) {
memcpy(t_out + len, input_x + len, sizeof(float) * (len_x - len));
}
}
}
void prepare_input(Tensor* x, const LoD& x_lod) {
x->Resize({static_cast<int64_t>(x_lod[0].back()), 3});
x->set_lod(x_lod);
auto x_data = x->mutable_data<float>();
for (int i = 0; i < x->numel(); i++) {
x_data[i] = (i - x->numel() / 2) * 1.1;
}
}
TEST(sequence_arithmetic_cuda, run_test) {
lite::Tensor x, y, x_cpu, y_cpu;
lite::Tensor out, out_cpu, out_ref;
lite::LoD x_lod{{0, 2, 5, 9}}, y_lod{{0, 2, 5, 9}};
prepare_input(&x_cpu, x_lod);
prepare_input(&y_cpu, y_lod);
x.Resize(x_cpu.dims());
x.set_lod(x_cpu.lod());
auto x_cpu_data = x_cpu.mutable_data<float>();
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
y.Resize(y_cpu.dims());
y.set_lod(y_cpu.lod());
auto y_cpu_data = y_cpu.mutable_data<float>();
y.Assign<float, lite::DDim, TARGET(kCUDA)>(y_cpu_data, y_cpu.dims());
operators::SequenceArithmeticParam param;
param.X = &x;
param.Y = &y;
param.Out = &out;
param.op_type = 1;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
SequenceArithmeticCompute sequence_arithmetic;
sequence_arithmetic.SetContext(std::move(ctx));
sequence_arithmetic.SetParam(param);
sequence_arithmetic.Run();
cudaDeviceSynchronize();
auto out_data = out.mutable_data<float>(TARGET(kCUDA));
out_cpu.Resize(out.dims());
auto out_cpu_data = out_cpu.mutable_data<float>();
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
sequence_arithmetic_compute_ref(x_cpu, y_cpu, &out_ref, param.op_type);
auto out_ref_data = out_ref.data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], out_ref_data[i], 1e-3);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -46,6 +46,7 @@ add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_
add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type)
add_kernel(sequence_arithmetic_compute_x86 X86 basic SRCS sequence_arithmetic_compute.cc DEPS ${lite_kernel_deps})
# for content-dnn specific
add_kernel(search_aligned_mat_mul_compute_x86 X86 extra SRCS search_aligned_mat_mul_compute.cc DEPS ${lite_kernel_deps} blas)
......@@ -87,3 +88,4 @@ endif()
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86)
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
lite_cc_test(test_sequence_arithmetic_compute_x86 SRCS sequence_arithmetic_compute_test.cc DEPS sequence_arithmetic_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
REGISTER_LITE_KERNEL(
sequence_arithmetic,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SequenceArithmeticCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
REGISTER_LITE_KERNEL(
search_seq_arithmetic,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SequenceArithmeticCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <cstring>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SequenceArithmeticCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceArithmeticParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
auto x = param.X;
auto y = param.Y;
auto out = param.Out;
int op_type = param.op_type;
out->Resize(x->dims());
out->set_lod(x->lod());
auto x_data = x->data<T>();
auto y_data = y->data<T>();
auto out_data = out->mutable_data<T>();
auto x_seq_offset = x->lod()[0];
auto y_seq_offset = y->lod()[0];
int seq_num = x_seq_offset.size() - 1;
int inner_size = (x->numel()) / (x->dims()[0]);
// sum
if (op_type == 1) {
for (int i = 0; i < seq_num; i++) {
int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
auto input_x = x_data + x_seq_offset[i] * inner_size;
auto input_y = y_data + y_seq_offset[i] * inner_size;
auto t_out = out_data + x_seq_offset[i] * inner_size;
int len = std::min(len_x, len_y);
for (int j = 0; j < len; j++) {
t_out[j] = input_x[j] + input_y[j];
}
if (len_x > len) {
memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len));
}
}
}
// sub
if (op_type == 2) {
for (int i = 0; i < seq_num; i++) {
int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
auto input_x = x_data + x_seq_offset[i] * inner_size;
auto input_y = y_data + y_seq_offset[i] * inner_size;
auto t_out = out_data + x_seq_offset[i] * inner_size;
int len = std::min(len_x, len_y);
for (int j = 0; j < len; j++) {
t_out[j] = input_x[j] - input_y[j];
}
if (len_x > len) {
memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len));
}
}
}
// mul
if (op_type == 3) {
for (int i = 0; i < seq_num; i++) {
int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
auto input_x = x_data + x_seq_offset[i] * inner_size;
auto input_y = y_data + y_seq_offset[i] * inner_size;
auto t_out = out_data + x_seq_offset[i] * inner_size;
int len = std::min(len_x, len_y);
for (int j = 0; j < len; j++) {
t_out[j] = input_x[j] * input_y[j];
}
if (len_x > len) {
memcpy(t_out + len, input_x + len, sizeof(T) * (len_x - len));
}
}
}
}
virtual ~SequenceArithmeticCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_arithmetic_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
void sequence_arithmetic_compute_ref(const Tensor& x,
const Tensor& y,
Tensor* out,
int op_type) {
auto x_data = x.data<float>();
auto y_data = y.data<float>();
out->Resize(x.dims());
out->set_lod(x.lod());
auto out_data = out->mutable_data<float>();
auto x_seq_offset = x.lod()[0];
auto y_seq_offset = y.lod()[0];
int seq_num = x_seq_offset.size() - 1;
int inner_size = x.numel() / x.dims()[0];
for (int i = 0; i < seq_num; i++) {
int len_x = (x_seq_offset[i + 1] - x_seq_offset[i]) * inner_size;
int len_y = (y_seq_offset[i + 1] - y_seq_offset[i]) * inner_size;
auto input_x = x_data + x_seq_offset[i] * inner_size;
auto input_y = y_data + y_seq_offset[i] * inner_size;
auto t_out = out_data + x_seq_offset[i] * inner_size;
int len = std::min(len_x, len_y);
for (int j = 0; j < len; j++) {
switch (op_type) {
case 1:
t_out[j] = input_x[j] + input_y[j];
break;
case 2:
t_out[j] = input_x[j] - input_y[j];
break;
case 3:
t_out[j] = input_x[j] * input_y[j];
break;
default:
break;
}
}
if (len_x > len) {
memcpy(t_out + len, input_x + len, sizeof(float) * (len_x - len));
}
}
}
void prepare_input(Tensor* x, const LoD& x_lod) {
x->Resize({static_cast<int64_t>(x_lod[0].back()), 3});
x->set_lod(x_lod);
auto x_data = x->mutable_data<float>();
for (int i = 0; i < x->numel(); i++) {
x_data[i] = (i - x->numel() / 2) * 1.1;
}
}
TEST(sequence_arithmetic_x86, retrive_op) {
auto sequence_arithmetic =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_arithmetic");
ASSERT_FALSE(sequence_arithmetic.empty());
ASSERT_TRUE(sequence_arithmetic.front());
}
TEST(sequence_arithmetic_x86, init) {
SequenceArithmeticCompute<float> sequence_arithmetic;
ASSERT_EQ(sequence_arithmetic.precision(), PRECISION(kFloat));
ASSERT_EQ(sequence_arithmetic.target(), TARGET(kX86));
}
TEST(sequence_arithmetic_x86, run_test) {
SequenceArithmeticCompute<float> sequence_arithmetic;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
lite::Tensor x, y, out, out_ref;
lite::LoD x_lod{{0, 2, 5, 9}}, y_lod{{0, 2, 5, 9}};
prepare_input(&x, x_lod);
prepare_input(&y, y_lod);
operators::SequenceArithmeticParam param;
param.X = &x;
param.Y = &y;
param.Out = &out;
param.op_type = 1;
sequence_arithmetic.SetContext(std::move(ctx));
sequence_arithmetic.SetParam(param);
sequence_arithmetic.Run();
sequence_arithmetic_compute_ref(x, y, &out_ref, param.op_type);
auto out_data = out.data<float>();
auto out_ref_data = out_ref.data<float>();
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_data[i], out_ref_data[i], 1e-3);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(sequence_arithmetic, kX86, kFloat, kNCHW, def);
......@@ -85,6 +85,7 @@ add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc
add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS})
add_operator(sequence_arithmetic_op_lite extra SRCS sequence_arithmetic_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -769,6 +769,13 @@ struct SequenceConcatParam {
lite::Tensor* Out{};
};
struct SequenceArithmeticParam {
const lite::Tensor* X{};
const lite::Tensor* Y{};
int op_type{1};
lite::Tensor* Out{};
};
struct ReduceMaxParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_arithmetic_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceArithmeticOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Y);
CHECK_EQ(param_.X->dims().size(), 2) << "Input X should a 2-D Tensor";
CHECK_EQ(param_.Y->dims().size(), 2) << "Input Y should a 2-D Tensor";
CHECK_OR_FALSE(param_.Out);
return true;
}
bool SequenceArithmeticOp::InferShape() const {
param_.Out->Resize(param_.X->dims());
param_.Out->set_lod(param_.X->lod());
return true;
}
bool SequenceArithmeticOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
param_.X = scope->FindTensor(opdesc.Input("X").front());
param_.Y = scope->FindTensor(opdesc.Input("Y").front());
param_.Out = scope->FindMutableTensor(opdesc.Input("Out").front());
param_.op_type = opdesc.GetAttr<int>("op_type");
CHECK(param_.X);
CHECK(param_.Y);
CHECK(param_.Out);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_arithmetic,
paddle::lite::operators::SequenceArithmeticOp);
REGISTER_LITE_OP(search_seq_arithmetic,
paddle::lite::operators::SequenceArithmeticOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceArithmeticOp : public OpLite {
public:
SequenceArithmeticOp() {}
explicit SequenceArithmeticOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_arithmetic"; }
private:
mutable SequenceArithmeticParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册