未验证 提交 acf09294 编写于 作者: W Wilber 提交者: GitHub

add sequence_reverse op and kerenl for arm and cuda test=develop (#2397)

- add sequence_reverse op

- add sequence_reverse kernel for x86 and cuda

- add sequence_reverse_test for x86 and cuda
上级 694f7517
...@@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k ...@@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps}) add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps}) add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps}) add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps}) add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
...@@ -39,6 +40,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp ...@@ -39,6 +40,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda) nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda ) nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda) nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda) nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
if(LITE_BUILD_EXTRA) if(LITE_BUILD_EXTRA)
nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda) nv_test(lookup_table_compute_cuda_test SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_cuda)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_reverse_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__host__ __device__ inline size_t UpperBound(const T* x,
size_t num,
const T& val) {
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto* first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
auto step = (count >> 1);
auto* it = first + step;
if (val < *it) {
count = step;
} else {
first = ++it;
count -= (step + 1);
}
}
return static_cast<size_t>(first - x);
}
__global__ void SequenceReverseKernelGridIsOne(const float* x,
float* y,
const int64_t* lod,
size_t lod_count,
int64_t row_numel) {
int64_t idx = static_cast<int64_t>(threadIdx.x);
auto row_idx_x = idx / row_numel;
auto lod_idx = UpperBound(lod, lod_count, row_idx_x);
auto row_idx_y = lod[lod_idx - 1] + (lod[lod_idx] - 1 - row_idx_x);
auto idx_y = row_idx_y * row_numel + idx % row_numel;
y[idx_y] = x[idx];
}
__global__ void SequenceReverseKernel(const float* x,
float* y,
const int64_t* lod,
size_t lod_count,
int64_t row_numel,
size_t limit) {
int64_t idx = static_cast<int64_t>(blockIdx.x * blockDim.x + threadIdx.x);
if (idx < limit) {
auto row_idx_x = idx / row_numel;
auto lod_idx = UpperBound(lod, lod_count, row_idx_x);
auto row_idx_y = lod[lod_idx - 1] + (lod[lod_idx] - 1 - row_idx_x);
auto idx_y = row_idx_y * row_numel + idx % row_numel;
y[idx_y] = x[idx];
}
}
void SequenceReverseCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
size_t limit = static_cast<size_t>(param.X->numel());
int64_t row_numel = static_cast<int64_t>(limit / param.X->dims()[0]);
const auto* x_data = param.X->data<float>();
auto y_data = param.Out->mutable_data<float>(TARGET(kCUDA));
CHECK_NE(x_data, y_data)
<< "SequenceReverse Op does not support in-place operation";
const auto lod = param.X->lod()[param.X->lod().size() - 1];
const size_t lod_count = lod.size();
lod_cuda.Resize({static_cast<int64_t>(lod.size())});
int64_t* lod_data = lod_cuda.mutable_data<int64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(lod_data,
lod.data(),
sizeof(int64_t) * lod.size(),
IoDirection::HtoD,
stream);
constexpr int num_threads = 1024;
int block_size = limit <= num_threads ? limit : num_threads;
int grid_size = (limit + num_threads - 1) / num_threads;
if (grid_size == 1) {
SequenceReverseKernelGridIsOne<<<1, block_size, 0, stream>>>(
x_data, y_data, lod_data, lod_count, row_numel);
} else {
SequenceReverseKernel<<<grid_size, block_size, 0, stream>>>(
x_data, y_data, lod_data, lod_count, row_numel, limit);
}
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_reverse,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequenceReverseCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequenceReverseCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceReverseParam;
void Run() override;
virtual ~SequenceReverseCompute() = default;
private:
lite::Tensor lod_cuda;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_reverse_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
const auto* x_data = x->data<float>();
auto seq_offset = x->lod()[x->lod().size() - 1];
int width = x->numel() / x->dims()[0];
auto* y_data = y->mutable_data<float>();
for (int i = 0; i < static_cast<int>(seq_offset.size()) - 1; ++i) {
auto start_pos = seq_offset[i];
auto end_pos = seq_offset[i + 1];
for (auto pos = start_pos; pos < end_pos; ++pos) {
auto cur_pos = end_pos - pos - 1 + start_pos;
std::memcpy(y_data + pos * width,
x_data + cur_pos * width,
width * sizeof(float));
}
}
}
TEST(sequence_reverse_cuda, normal) {
SequenceReverseCompute seq_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
operators::SequenceReverseParam param;
lite::Tensor x, x_cpu, x_ref;
lite::Tensor y, y_cpu, y_ref;
int32_t lod_len = 10, feature_len = 4;
LoD lod_info{{0, 2, 4}, {0, 3, 5, 6, 10}};
x.Resize({lod_len, feature_len});
x_cpu.Resize({lod_len, feature_len});
x_ref.Resize({lod_len, feature_len});
y.Resize({lod_len, feature_len});
y_cpu.Resize({lod_len, feature_len});
y_ref.Resize({lod_len, feature_len});
x.set_lod(lod_info);
x_cpu.set_lod(lod_info);
x_ref.set_lod(lod_info);
y.set_lod(lod_info);
y_cpu.set_lod(lod_info);
y_ref.set_lod(lod_info);
auto* y_data = y.mutable_data<float>(TARGET(kCUDA));
float* x_cpu_data = x_cpu.mutable_data<float>();
float* x_ref_data = x_ref.mutable_data<float>();
float* y_cpu_data = y_cpu.mutable_data<float>();
float* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < x_cpu.numel(); ++i) {
x_cpu_data[i] = (i - 2.0) * 1.0;
x_ref_data[i] = (i - 2.0) * 1.0;
}
x.Assign<float, lite::DDim, TARGET(kCUDA)>(x_cpu_data, x_cpu.dims());
param.X = &x;
param.Out = &y;
seq_kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
seq_kernel.SetContext(std::move(ctx));
seq_kernel.Run();
cudaDeviceSynchronize();
CopySync<TARGET(kCUDA)>(
y_cpu_data, y_data, sizeof(float) * y.numel(), IoDirection::DtoH);
sequence_reverse_ref(&x_ref, &y_ref);
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_cpu_data[i], y_ref_data[i], 1e-5);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -34,6 +34,7 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps ...@@ -34,6 +34,7 @@ add_kernel(mul_compute_x86 X86 basic SRCS mul_compute.cc DEPS ${lite_kernel_deps
add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps}) add_kernel(concat_compute_x86 X86 basic SRCS concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps}) add_kernel(shape_compute_x86 X86 basic SRCS shape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling) add_kernel(sequence_pool_compute_x86 X86 basic SRCS sequence_pool_compute.cc DEPS ${lite_kernel_deps} sequence_pooling)
add_kernel(sequence_reverse_compute_x86 X86 basic SRCS sequence_reverse_compute.cc DEPS ${lite_kernel_deps})
add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax) add_kernel(softmax_compute_x86 X86 basic SRCS softmax_compute.cc DEPS ${lite_kernel_deps} softmax)
add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps}) add_kernel(elementwise_compute_x86 X86 basic SRCS elementwise_compute.cc DEPS ${lite_kernel_deps})
add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps}) add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${lite_kernel_deps})
...@@ -55,6 +56,7 @@ lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_b ...@@ -55,6 +56,7 @@ lite_cc_test(test_fill_constant_batch_size_like_compute_x86 SRCS fill_constant_b
lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86) lite_cc_test(test_reshape_compute_x86 SRCS reshape_compute_test.cc DEPS reshape_compute_x86)
lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86) lite_cc_test(test_concat_compute_x86 SRCS concat_compute_test.cc DEPS concat_compute_x86)
lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86) lite_cc_test(test_sequence_pool_compute_x86 SRCS sequence_pool_compute_test.cc DEPS sequence_pool_compute_x86)
lite_cc_test(test_sequence_reverse_compute_x86 SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_x86)
lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86) lite_cc_test(test_shape_compute_x86 SRCS shape_compute_test.cc DEPS shape_compute_x86)
lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86) lite_cc_test(test_batch_norm_compute_x86 SRCS batch_norm_compute_test.cc DEPS batch_norm_compute_x86)
lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86) lite_cc_test(test_softmax_compute_x86 SRCS softmax_compute_test.cc DEPS softmax_compute_x86)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_reverse_compute.h"
REGISTER_LITE_KERNEL(sequence_reverse,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SequenceReverseCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SequenceReverseCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SequenceReverseParam;
void Run() override {
auto& param = *param_.get_mutable<operators::SequenceReverseParam>();
auto* output = param.Out;
const auto* din = param.X->data<T>();
T* dout = output->mutable_data<T>();
CHECK_NE(din, dout)
<< "SequenceReverse Op does not support in-place operation";
const auto lod = param.X->lod()[param.X->lod().size() - 1];
const size_t lod_count = lod.size();
size_t limit = static_cast<size_t>(param.X->numel());
size_t row_numel = static_cast<size_t>(limit / param.X->dims()[0]);
for (size_t idx = 0; idx < lod_count - 1; ++idx) {
auto start_pos = lod[idx];
auto end_pos = lod[idx + 1];
for (auto pos = start_pos; pos < end_pos; ++pos) {
auto cur_pos = end_pos - pos - 1 + start_pos;
std::memcpy(dout + pos * row_numel,
din + cur_pos * row_numel,
row_numel * sizeof(T));
}
}
output->set_lod(param.X->lod());
}
virtual ~SequenceReverseCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/sequence_reverse_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
namespace {
static void sequence_reverse_ref(const lite::Tensor* x, lite::Tensor* y) {
const auto* x_data = x->data<float>();
auto seq_offset = x->lod()[x->lod().size() - 1];
int width = x->numel() / x->dims()[0];
auto* y_data = y->mutable_data<float>();
for (int i = 0; i < seq_offset.size() - 1; ++i) {
auto start_pos = seq_offset[i];
auto end_pos = seq_offset[i + 1];
for (auto pos = start_pos; pos < end_pos; ++pos) {
auto cur_pos = end_pos - pos - 1 + start_pos;
std::memcpy(y_data + pos * width,
x_data + cur_pos * width,
width * sizeof(float));
}
}
}
} // namespace
TEST(sequence_reverse_x86, retrive_op) {
auto sequence_reverse =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"sequence_reverse");
ASSERT_FALSE(sequence_reverse.empty());
ASSERT_TRUE(sequence_reverse.front());
}
TEST(sequence_reverse_x86, init) {
SequenceReverseCompute<float> sequence_reverse;
ASSERT_EQ(sequence_reverse.precision(), PRECISION(kFloat));
ASSERT_EQ(sequence_reverse.target(), TARGET(kX86));
}
TEST(sequence_reverse_x86, run_test) {
SequenceReverseCompute<float> seq_kernel;
std::unique_ptr<KernelContext> ctx(new KernelContext);
operators::SequenceReverseParam param;
lite::Tensor x, x_ref;
lite::Tensor y, y_ref;
int32_t lod_len = 10, feature_len = 4;
LoD lod_info{{0, 2, 4}, {0, 3, 5, 6, 10}};
x.Resize({lod_len, feature_len});
x_ref.Resize({lod_len, feature_len});
y.Resize({lod_len, feature_len});
y_ref.Resize({lod_len, feature_len});
x.set_lod(lod_info);
x_ref.set_lod(lod_info);
y.set_lod(lod_info);
y_ref.set_lod(lod_info);
auto* y_data = y.mutable_data<float>();
float* x_data = x.mutable_data<float>();
float* x_ref_data = x_ref.mutable_data<float>();
float* y_ref_data = y_ref.mutable_data<float>();
for (int i = 0; i < x.numel(); ++i) {
x_ref_data[i] = (i - 2.0) * 1.0;
x_data[i] = (i - 2.0) * 1.0;
}
param.X = &x;
param.Out = &y;
seq_kernel.SetParam(param);
seq_kernel.SetContext(std::move(ctx));
seq_kernel.Run();
sequence_reverse_ref(&x_ref, &y_ref);
for (int i = 0; i < y.numel(); i++) {
EXPECT_NEAR(y_data[i], y_ref_data[i], 1e-5);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(sequence_reverse, kX86, kFloat, kNCHW, def);
...@@ -78,6 +78,7 @@ add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS}) ...@@ -78,6 +78,7 @@ add_operator(assign_value_op extra SRCS assign_value_op.cc DEPS ${op_DEPS})
add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_quantize_dequantize_moving_avg_abs_max_op extra SRCS fake_quantize_dequantize_moving_avg_max_abs.cc DEPS ${op_DEPS})
add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS}) add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wise_dequantize_max_abs.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS}) add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS}) add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
......
...@@ -751,6 +751,11 @@ struct SequenceExpandAsParam { ...@@ -751,6 +751,11 @@ struct SequenceExpandAsParam {
lite::Tensor* out{nullptr}; lite::Tensor* out{nullptr};
}; };
struct SequenceReverseParam {
const lite::Tensor* X{};
lite::Tensor* Out{};
};
struct SequenceConcatParam { struct SequenceConcatParam {
std::vector<lite::Tensor*> X{}; std::vector<lite::Tensor*> X{};
lite::Tensor* Out{}; lite::Tensor* Out{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_reverse_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceReverseOp::CheckShape() const {
CHECK_OR_FALSE(param_.X);
CHECK_OR_FALSE(param_.Out);
CHECK_EQ(param_.X->lod().empty(), false)
<< "Input(X) Tensor of SequenceReverseOp does not contain "
"LoD information.";
CHECK_GE(param_.X->dims().size(), 2)
<< "Rank of Input(X) must be not less than 2.";
return true;
}
bool SequenceReverseOp::InferShape() const {
const auto *input = param_.X;
auto out_dims = input->dims();
param_.Out->Resize(out_dims);
return true;
}
bool SequenceReverseOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
param_.X = const_cast<lite::Tensor *>(
&scope->FindVar(opdesc.Input("X").front())->Get<lite::Tensor>());
param_.Out =
scope->FindVar(opdesc.Output("Y").front())->GetMutable<lite::Tensor>();
CHECK(param_.X);
CHECK(param_.Out);
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_reverse, paddle::lite::operators::SequenceReverseOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceReverseOp : public OpLite {
public:
SequenceReverseOp() {}
explicit SequenceReverseOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_reverse"; }
private:
mutable SequenceReverseParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册