未验证 提交 8a1d942a 编写于 作者: W Wilber 提交者: GitHub

add sequence_concat op kernel and test test=develop (#2414)

- add sequence_concat op

- add sequence_concat kernel for x86 and cuda

- add sequence_concat_test for x86 and cuda
上级 518a87ef
......@@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
lite_cc_test(calib_compute_cuda_test SRCS calib_compute_cuda_test.cc DEPS calib_compute_cuda)
......@@ -38,6 +39,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
if(LITE_BUILD_EXTRA)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
endif()
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_concat_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
const int CUDA_NUM_THREADS = 512;
template <typename Dtype>
__global__ void ker_sequence_concat(Dtype* out_data,
const uint64_t* in_locate_data,
const int* o2i_map,
const int* o2i_w_map,
const int seq_num,
const int emb_size,
const int count) {
int idx = blockIdx.x * blockDim.x + threadIdx.x;
for (int tid = idx; tid < count; tid += blockDim.x * gridDim.x) {
int emb_id = tid % emb_size;
int word_id = tid / emb_size;
int input_id = o2i_map[word_id];
int cur_work_id = o2i_w_map[word_id];
const Dtype* in_data = reinterpret_cast<const Dtype*>(
reinterpret_cast<uintptr_t>(in_locate_data[input_id]));
out_data[tid] = in_data[cur_work_id * emb_size + emb_id];
}
}
void SequenceConcatCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
int seq_num = param.X[0]->lod()[0].size() - 1;
const int emb_size = param.X[0]->numel() / param.X[0]->dims()[0];
std::vector<uint64_t> in_locate_vec;
for (size_t i = 0; i < param.X.size(); ++i) {
in_locate_vec.push_back(
reinterpret_cast<uintptr_t>(param.X[i]->data<float>()));
}
in_locate_tensor.Resize({static_cast<int64_t>(in_locate_vec.size())});
std::vector<int> out2in_map;
std::vector<int> out2in_word_map;
for (int i = 0; i < seq_num; ++i) {
for (int j = 0; j < param.X.size(); ++j) {
auto offset = param.X[j]->lod()[0];
int cur_len = offset[i + 1] - offset[i];
for (int k = 0; k < cur_len; ++k) {
out2in_map.push_back(j);
out2in_word_map.push_back(offset[i] + k);
}
}
}
int word_num = out2in_map.size();
out2in_map_tensor.Resize({word_num});
out2in_word_map_tensor.Resize({word_num});
int* gpu_o2i_map_data = out2in_map_tensor.mutable_data<int>(TARGET(kCUDA));
int* gpu_o2i_w_map_data =
out2in_word_map_tensor.mutable_data<int>(TARGET(kCUDA));
uint64_t* gpu_in_locate_data =
in_locate_tensor.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(gpu_o2i_map_data,
out2in_map.data(),
sizeof(int) * out2in_map.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(gpu_o2i_w_map_data,
out2in_word_map.data(),
sizeof(int) * out2in_word_map.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(gpu_in_locate_data,
in_locate_vec.data(),
sizeof(uint64_t) * in_locate_vec.size(),
IoDirection::HtoD,
stream);
int count = param.X[0]->numel();
for (int i = 1; i < param.X.size(); ++i) {
count += param.X[i]->numel();
}
int blocks = (count + CUDA_NUM_THREADS - 1) / CUDA_NUM_THREADS;
ker_sequence_concat<float><<<blocks, CUDA_NUM_THREADS, 0, stream>>>(
out_data,
gpu_in_locate_data,
gpu_o2i_map_data,
gpu_o2i_w_map_data,
seq_num,
emb_size,
count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_concat,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequenceConcatCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequenceConcatCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceConcatParam;
void Run() override;
virtual ~SequenceConcatCompute() = default;
private:
lite::Tensor out2in_map_tensor;
lite::Tensor out2in_word_map_tensor;
lite::Tensor in_locate_tensor;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_concat_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
namespace {
inline LoD ConcatLoD(const std::vector<lite::Tensor*>& xs,
std::vector<lite::Tensor>* xs_in_order) {
std::vector<size_t> result;
result.resize(xs[0]->lod()[0].size());
for (size_t i = 1; i < result.size(); ++i) {
size_t sum = 0;
for (size_t j = 0; j < xs.size(); ++j) {
auto& x_lod = xs[j]->lod()[0];
if (x_lod[i - 1] < x_lod[i]) {
xs_in_order->emplace_back(xs[j]->Slice<float>(x_lod[i - 1], x_lod[i]));
}
sum += x_lod[i];
}
result[i] = sum;
}
LoD lod;
lod.emplace_back(result);
return lod;
}
static void sequence_concat_ref(const std::vector<lite::Tensor*>& xs,
lite::Tensor* out) {
std::vector<int64_t> out_dims;
int64_t batch_size = 0;
int64_t feature_size = 0;
for (const auto& tensor : xs) {
const auto x_dims = tensor->dims();
if (out_dims.empty()) {
out_dims = x_dims.Vectorize();
}
batch_size += x_dims[0];
if (feature_size == 0) {
feature_size = x_dims.production() / x_dims[0];
} else {
CHECK_EQ(feature_size, x_dims.production() / x_dims[0])
<< "Inputs of sequence concat must have same feature size";
}
}
out_dims[0] = batch_size;
out->Resize(out_dims);
std::vector<lite::Tensor> x_in_order;
out->set_lod(ConcatLoD(xs, &x_in_order));
int num = x_in_order.size();
std::vector<int64_t> input_cols(num);
for (int i = 0; i < num; ++i) {
input_cols[i] = x_in_order[i].numel();
}
float* out_data = out->mutable_data<float>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = x_in_order[j].data<float>();
memcpy(out_data + col_idx, input_data, sizeof(float) * col_len);
col_idx += col_len;
}
}
#define PREPARE_INPUT_DATA(name) \
name.Resize({name##_lod_len, feature_len}); \
name##_cpu.Resize({name##_lod_len, feature_len}); \
name##_ref.Resize({name##_lod_len, feature_len}); \
name.set_lod(lod_info_##name); \
name##_cpu.set_lod(lod_info_##name); \
name##_ref.set_lod(lod_info_##name); \
float* name##_cpu_data = name##_cpu.mutable_data<float>(); \
float* name##_ref_data = name##_ref.mutable_data<float>(); \
for (int i = 0; i < name##_cpu.numel(); ++i) { \
name##_cpu_data[i] = (i - 2.0) * 1.0; \
name##_ref_data[i] = (i - 2.0) * 1.0; \
} \
name.Assign<float, lite::DDim, TARGET(kCUDA)>(name##_cpu_data, \
name##_cpu.dims());
#define PREPARE_OUTPUT_INFO(name) \
name##_cpu.Resize({y_lod_len, feature_len}); \
name##_ref.Resize({y_lod_len, feature_len}); \
name.Resize({y_lod_len, feature_len}); \
float* name##_cpu_data = name##_cpu.mutable_data<float>();
} // namespace
TEST(sequence_concat_cuda, normal) {
SequenceConcatCompute seq_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::SequenceConcatParam param;
lite::Tensor x1, x2, x3, x1_cpu, x2_cpu, x3_cpu, x1_ref, x2_ref, x3_ref;
lite::Tensor y, y_cpu, y_ref;
int32_t x1_lod_len = 10, feature_len = 4;
int32_t x2_lod_len = 4, x3_lod_len = 8;
int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len;
LoD lod_info_x1{{0, 3, 5, 6, 10}};
LoD lod_info_x2{{0, 1, 2, 3, 4}};
LoD lod_info_x3{{0, 2, 4, 6, 8}};
LoD lod_info_y{{0, 0, 0, 0, 0}};
for (size_t i = 0; i < lod_info_x1[0].size(); ++i) {
lod_info_y[0][i] =
lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i];
}
PREPARE_INPUT_DATA(x1);
PREPARE_INPUT_DATA(x2);
PREPARE_INPUT_DATA(x3);
PREPARE_OUTPUT_INFO(y);
param.X = std::vector<lite::Tensor*>({&x1, &x2, &x3});
param.Out = &y;
seq_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
seq_kernel.SetContext(std::move(ctx));
seq_kernel.Run();
cudaDeviceSynchronize();
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
std::vector<lite::Tensor*> input_ref({&x1_ref, &x2_ref, &x3_ref});
sequence_concat_ref(input_ref, &y_ref);
float* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -39,6 +39,7 @@ add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${li
add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 basic SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
if(NOT LITE_WITH_X86)
return()
......@@ -67,3 +68,4 @@ lite_cc_test(test_matmul_compute_x86 SRCS matmul_compute_test.cc DEPS matmul_com
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_concat_compute.h"
REGISTER_LITE_KERNEL(sequence_concat,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SequenceConcatCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
inline LoD ConcatLoD(const std::vector<lite::Tensor*>& xs,
std::vector<lite::Tensor>* xs_in_order) {
std::vector<size_t> result;
result.resize(xs[0]->lod()[0].size());
for (size_t i = 1; i < result.size(); ++i) {
size_t sum = 0;
for (size_t j = 0; j < xs.size(); ++j) {
auto& x_lod = xs[j]->lod()[0];
if (x_lod[i - 1] < x_lod[i]) {
xs_in_order->emplace_back(xs[j]->Slice<T>(x_lod[i - 1], x_lod[i]));
}
sum += x_lod[i];
}
result[i] = sum;
}
LoD lod;
lod.emplace_back(result);
return lod;
}
template <typename T>
class SequenceConcatCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceConcatParam;
void Run() override {
auto& param = *param_.get_mutable<param_t>();
// auto& param = Param<param_t>();
T* dout = param.Out->mutable_data<T>();
std::vector<lite::Tensor> x_in_order;
param.Out->set_lod(ConcatLoD<T>(param.X, &x_in_order));
int num = x_in_order.size();
int out_rows = 1;
std::vector<int64_t> input_cols(num);
for (int i = 0; i < num; ++i) {
input_cols[i] = x_in_order[i].numel() / out_rows;
}
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = x_in_order[j].data<T>();
memcpy(dout + col_idx, input_data, sizeof(T) * col_len);
col_idx += col_len;
}
}
virtual ~SequenceConcatCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_concat_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
namespace {
inline LoD ConcatLoD(const std::vector<lite::Tensor*>& xs,
std::vector<lite::Tensor>* xs_in_order) {
std::vector<size_t> result;
result.resize(xs[0]->lod()[0].size());
for (size_t i = 1; i < result.size(); ++i) {
size_t sum = 0;
for (size_t j = 0; j < xs.size(); ++j) {
auto& x_lod = xs[j]->lod()[0];
if (x_lod[i - 1] < x_lod[i]) {
xs_in_order->emplace_back(xs[j]->Slice<float>(x_lod[i - 1], x_lod[i]));
}
sum += x_lod[i];
}
result[i] = sum;
}
LoD lod;
lod.emplace_back(result);
return lod;
}
static void sequence_concat_ref(const std::vector<lite::Tensor*>& xs,
lite::Tensor* out) {
std::vector<int64_t> out_dims;
int64_t batch_size = 0;
int64_t feature_size = 0;
for (const auto& tensor : xs) {
const auto x_dims = tensor->dims();
if (out_dims.empty()) {
out_dims = x_dims.Vectorize();
}
batch_size += x_dims[0];
if (feature_size == 0) {
feature_size = x_dims.production() / x_dims[0];
} else {
CHECK_EQ(feature_size, x_dims.production() / x_dims[0])
<< "Inputs of sequence concat must have same feature size";
}
}
out_dims[0] = batch_size;
out->Resize(out_dims);
std::vector<lite::Tensor> x_in_order;
out->set_lod(ConcatLoD(xs, &x_in_order));
int num = x_in_order.size();
std::vector<int64_t> input_cols(num);
for (int i = 0; i < num; ++i) {
input_cols[i] = x_in_order[i].numel();
}
float* out_data = out->mutable_data<float>();
int col_idx = 0;
for (int j = 0; j < num; ++j) {
int col_len = input_cols[j];
auto input_data = x_in_order[j].data<float>();
memcpy(out_data + col_idx, input_data, sizeof(float) * col_len);
col_idx += col_len;
}
}
#define PREPARE_INPUT(name) \
name.Resize({name##_lod_len, feature_len}); \
name.set_lod(lod_info_##name); \
float* name##_data = name.mutable_data<float>(); \
for (int i = 0; i < name.numel(); ++i) { \
name##_data[i] = (i - 2.0) * 1.0; \
}
} // namespace
TEST(sequence_concat_x86, retrive_op) {
auto sequence_concat =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_concat");
ASSERT_FALSE(sequence_concat.empty());
ASSERT_TRUE(sequence_concat.front());
}
TEST(sequence_concat_x86, init) {
SequenceConcatCompute<float> sequence_concat;
ASSERT_EQ(sequence_concat.precision(), PRECISION(kFloat));
ASSERT_EQ(sequence_concat.target(), TARGET(kX86));
}
TEST(sequence_concat_x86, run_test) {
SequenceConcatCompute<float> seq_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
operators::SequenceConcatParam param;
lite::Tensor x1, x2, x3;
lite::Tensor y, y_ref;
int32_t x1_lod_len = 10, feature_len = 4;
int32_t x2_lod_len = 4, x3_lod_len = 8;
int32_t y_lod_len = x1_lod_len + x2_lod_len + x3_lod_len;
LoD lod_info_x1{{0, 3, 5, 6, 10}};
LoD lod_info_x2{{0, 1, 2, 3, 4}};
LoD lod_info_x3{{0, 2, 4, 6, 8}};
LoD lod_info_y{{0, 0, 0, 0, 0}};
for (size_t i = 0; i < lod_info_x1[0].size(); ++i) {
lod_info_y[0][i] =
lod_info_x1[0][i] + lod_info_x2[0][i] + lod_info_x3[0][i];
}
PREPARE_INPUT(x1);
PREPARE_INPUT(x2);
PREPARE_INPUT(x3);
y_ref.Resize({y_lod_len, feature_len});
y.Resize({y_lod_len, feature_len});
y_ref.set_lod(lod_info_y);
y.set_lod(lod_info_y);
std::vector<lite::Tensor*> xs{&x1, &x2, &x3};
param.X = xs;
param.Out = &y;
seq_kernel.SetParam(param);
seq_kernel.SetContext(std::move(ctx));
seq_kernel.Run();
auto* y_data = y.mutable_data<float>();
sequence_concat_ref(xs, &y_ref);
float* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(sequence_concat, kX86, kFloat, kNCHW, def);
......@@ -79,6 +79,7 @@ add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quan
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
# for OCR specific
add_operator(while_op extra SRCS while_op.cc DEPS ${op_DEPS})
......
......@@ -751,6 +751,11 @@ struct SequenceExpandAsParam {
lite::Tensor* out{nullptr};
};
struct SequenceConcatParam {
std::vector<lite::Tensor*> X{};
lite::Tensor* Out{};
};
struct ReduceMaxParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_concat_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceConcatOp::CheckShape() const {
CHECK_GT(param_.X.size(), 1)
<< "The number of input sequences is at least two.";
CHECK_OR_FALSE(param_.Out);
size_t lod_size = 0;
for (const auto &t : param_.X) {
CHECK_EQ(t->lod().empty(), false)
<< "Input Tensor of X does not contain LoD information.";
CHECK_EQ(t->lod().size(), 1) << "Only support one level sequence now.";
if (lod_size == 0) {
lod_size = t->lod()[0].size();
} else {
CHECK_EQ(t->lod()[0].size(), lod_size)
<< "The number of sequence must be same between each input";
}
}
CHECK_NE(lod_size, 0) << "Each input must have sequence information";
return true;
}
bool SequenceConcatOp::InferShape() const {
int64_t batch_size = 0;
int64_t feature_size = 0;
std::vector<int64_t> out_dims;
for (const auto &tensor : param_.X) {
const auto x_dims = tensor->dims();
if (out_dims.empty()) {
out_dims = x_dims.Vectorize();
}
batch_size += x_dims[0];
if (feature_size == 0) {
feature_size = x_dims.production() / x_dims[0];
} else {
CHECK_EQ(feature_size, x_dims.production() / x_dims[0])
<< "Inputs of sequence concat must have same feature size";
}
}
if (batch_size < 0) {
batch_size = -1; // Normalize batch size for compile time.
}
out_dims[0] = batch_size;
param_.Out->Resize(out_dims);
// LoD info will be computed in Kernel.
return true;
}
bool SequenceConcatOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
auto input_list = opdesc.Input("X");
param_.X.clear();
for (auto var : input_list) {
param_.X.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.Out) << "Output(Out) of Sequence Concat Op should not be null.";
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_concat, paddle::lite::operators::SequenceConcatOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceConcatOp : public OpLite {
public:
SequenceConcatOp() {}
explicit SequenceConcatOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_concat"; }
private:
mutable SequenceConcatParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册