提交 7f408ee8 编写于 作者: J juncaipeng 提交者: GitHub

Add content-dnn ops (#2429)

* add search_seq_depadding x86 and cuda
* add match_matrix_tensor x86
* add search_grnn x86, no test
上级 cd5d97e3
......@@ -22,6 +22,7 @@ add_kernel(dropout_compute_cuda CUDA basic SRCS dropout_compute.cc DEPS ${lite_k
add_kernel(softmax_compute_cuda CUDA basic SRCS softmax_compute.cu DEPS ${lite_kernel_deps})
add_kernel(pool_compute_cuda CUDA basic SRCS pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_seq_depadding_compute_cuda CUDA basic SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_compute_cuda CUDA basic SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_cuda CUDA basic SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS ${lite_kernel_deps})
......@@ -40,6 +41,7 @@ nv_test(softmax_compute_cuda_test SRCS softmax_compute_test.cc DEPS softmax_comp
nv_test(mul_compute_cuda_test SRCS mul_compute_test.cc DEPS mul_compute_cuda)
nv_test(dropout_compute_cuda_test SRCS dropout_compute_test.cc DEPS dropout_compute_cuda )
nv_test(bilinear_interp_compute_cuda_test SRCS bilinear_interp_compute_test.cc DEPS bilinear_interp_compute_cuda)
nv_test(search_seq_depadding_compute_cuda_test SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_cuda)
nv_test(sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda)
nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
if(LITE_BUILD_EXTRA)
......
/* Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <vector>
#include "lite/core/op_registry.h"
#include "lite/kernels/cuda/search_seq_depadding_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
#define CUDA_KERNEL_LOOP(i, n) \
for (int i = blockIdx.x * blockDim.x + threadIdx.x; i < (n); \
i += blockDim.x * gridDim.x)
template <typename Dtype>
__global__ void ker_sequence_depadding_fwd(Dtype* out_data,
const Dtype* in_data,
const int* seq_id_map,
const int seq_num,
const int max_len,
const int emb_size,
const int count) {
CUDA_KERNEL_LOOP(tid, count) {
int emb_id = tid % emb_size;
int word_id = tid / emb_size;
int seq_id = seq_id_map[word_id];
out_data[tid] = in_data[seq_id * emb_size + emb_id];
}
}
void SearchSeqDepaddingCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto cuda_stream = ctx.exec_stream();
auto* pad = param.pad;
auto* src = param.src;
auto* out = param.out;
auto* in_data = pad->data<float>();
auto* out_data = out->mutable_data<float>(TARGET(kCUDA));
const int count = out->numel();
const auto& pad_seq_offset = pad->lod()[0];
const auto& src_seq_offset = src->lod()[0];
int max_len = pad_seq_offset[1];
int seq_num = pad_seq_offset.size() - 1;
int emb_size = pad->dims()[1];
std::vector<int> seq_id_map;
for (int i = 0; i < seq_num; i++) {
int cur_len = src_seq_offset[i + 1] - src_seq_offset[i];
for (int j = 0; j < cur_len; j++) {
seq_id_map.push_back(i * max_len + j);
}
}
int map_size = seq_id_map.size();
seq_id_map_tensor.Resize({map_size, 1, 1, 1});
int* seq_id_map_data = seq_id_map_tensor.mutable_data<int>(TARGET(kCUDA));
TargetW::MemcpyAsync(seq_id_map_data,
&seq_id_map[0],
seq_id_map.size() * sizeof(int),
IoDirection::HtoD,
cuda_stream);
int threads = 512;
ker_sequence_depadding_fwd<<<count, threads, 0, cuda_stream>>>(
out_data, in_data, seq_id_map_data, seq_num, max_len, emb_size, count);
cudaError_t error = cudaGetLastError();
if (error != cudaSuccess) LOG(INFO) << cudaGetErrorString(error);
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_seq_depadding,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SearchSeqDepaddingCompute,
def)
.BindInput("Src",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindInput("Pad",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA),
PRECISION(kFloat),
DATALAYOUT(kNCHW))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SearchSeqDepaddingCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat), DATALAYOUT(kNCHW)> {
public:
using param_t = operators::SearchSeqDepaddingParam;
using TargetW = TargetWrapper<TARGET(kCUDA)>;
void Run() override;
virtual ~SearchSeqDepaddingCompute() = default;
private:
Tensor seq_id_map_tensor;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/search_seq_depadding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
using Tensor = lite::Tensor;
TEST(search_seq_depadding, normal) {
std::unique_ptr<KernelContext> ctx(new KernelContext);
auto& context = ctx->As<CUDAContext>();
SearchSeqDepaddingCompute kernel;
operators::SearchSeqDepaddingParam param;
Tensor pad, src, out;
pad.Resize({2 * 3, 4});
src.Resize({3, 1});
out.Resize({3, 4});
LoD pad_lod{};
pad_lod.push_back({0, 4, 6});
pad.set_lod(pad_lod);
LoD src_lod{};
src_lod.push_back({0, 2, 3});
src.set_lod(src_lod);
Tensor pad_cpu, src_cpu, out_cpu;
pad_cpu.Resize({2 * 3, 4});
src_cpu.Resize({3, 1});
out_cpu.Resize({3, 4});
auto* pad_cpu_data = pad_cpu.mutable_data<float>();
auto* src_cpu_data = src_cpu.mutable_data<float>();
for (int i = 0; i < pad_cpu.numel(); ++i) {
pad_cpu_data[i] = static_cast<float>(i);
}
pad.Assign<float, lite::DDim, TARGET(kCUDA)>(pad_cpu_data, pad_cpu.dims());
src.Assign<float, lite::DDim, TARGET(kCUDA)>(src_cpu_data, src_cpu.dims());
param.pad = &pad;
param.src = &src;
param.out = &out;
kernel.SetParam(param);
cudaStream_t stream;
cudaStreamCreate(&stream);
context.SetExecStream(stream);
kernel.SetContext(std::move(ctx));
kernel.Launch();
cudaDeviceSynchronize();
auto* out_cpu_data = out_cpu.mutable_data<float>();
auto* out_data = out.mutable_data<float>(TARGET(kCUDA));
CopySync<TARGET(kCUDA)>(
out_cpu_data, out_data, sizeof(float) * out.numel(), IoDirection::DtoH);
std::vector<float> ref_results = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19};
for (int i = 0; i < out.numel(); i++) {
EXPECT_NEAR(out_cpu_data[i], ref_results[i], 1e-5);
// LOG(INFO) << out_cpu_data[i];
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -41,6 +41,9 @@ add_kernel(batch_norm_compute_x86 X86 basic SRCS batch_norm_compute.cc DEPS ${li
add_kernel(reduce_sum_compute_x86 X86 basic SRCS reduce_compute.cc DEPS ${lite_kernel_deps})
add_kernel(lookup_table_compute_x86 X86 extra SRCS lookup_table_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_reshape_compute_x86 X86 basic SRCS sequence_reshape_compute.cc DEPS ${lite_kernel_deps})
add_kernel(match_matrix_tensor_compute_x86 X86 basic SRCS match_matrix_tensor_compute.cc DEPS ${lite_kernel_deps} blas math_function)
add_kernel(search_seq_depadding_compute_x86 X86 basic SRCS search_seq_depadding_compute.cc DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_x86 X86 basic SRCS search_grnn_compute.cc DEPS ${lite_kernel_deps})
add_kernel(sequence_concat_compute_x86 X86 basic SRCS sequence_concat_compute.cc DEPS ${lite_kernel_deps})
add_kernel(var_conv_2d_compute_x86 X86 basic SRCS var_conv_2d_compute.cc DEPS ${lite_kernel_deps} blas fluid_data_type)
......@@ -72,9 +75,11 @@ lite_cc_test(test_cast_compute_x86 SRCS cast_compute_test.cc DEPS cast_compute_x
lite_cc_test(test_pool2d_compute_x86 SRCS pool_compute_test.cc DEPS pool_compute_x86)
lite_cc_test(test_dropout_compute_x86 SRCS dropout_compute_test.cc DEPS dropout_compute_x86)
lite_cc_test(test_transpose_compute_x86 SRCS transpose_compute_test.cc DEPS transpose_compute_x86)
lite_cc_test(test_search_seq_depadding_compute_x86 SRCS search_seq_depadding_compute_test.cc DEPS search_seq_depadding_compute_x86)
if(LITE_BUILD_EXTRA)
lite_cc_test(test_lookup_table_compute_x86 SRCS lookup_table_compute_test.cc DEPS lookup_table_compute_x86)
endif()
lite_cc_test(test_sequence_concat_compute_x86 SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_x86)
lite_cc_test(test_match_matrix_compute_x86 SRCS match_matrix_tensor_compute_test.cc DEPS match_matrix_tensor_compute_x86)
lite_cc_test(test_var_conv_2d_compute_x86 SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_x86)
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void MatchMatrixTensorCompute<T>::Run() {
auto& context = ctx_->As<X86Context>();
auto& param = this->Param<param_t>();
auto* x = param.x;
auto* w = param.w;
auto* y = param.y;
auto* out = param.out;
auto* tmp = param.tmp;
int dim_t = param.dim_t;
int dim_in = x->dims()[1];
const auto& offset_l = x->lod()[0];
const auto& offset_r = y->lod()[0];
std::vector<size_t> top_offset;
int top_size = 0;
top_offset.push_back(top_size);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
top_size += dim_t * len_l * len_r;
top_offset.push_back(top_size);
}
auto* bottom_l_data = x->template data<T>();
auto* bottom_r_data = y->template data<T>();
auto* t_data = w->template data<T>();
auto* out_data = out->template mutable_data<T>();
auto* bottom_l_trans_data = tmp->template mutable_data<T>();
memset(out_data, 0.0, out->dims()[0] * out->dims()[1] * sizeof(T));
memset(bottom_l_trans_data, 0.0, tmp->dims()[0] * tmp->dims()[1] * sizeof(T));
auto blas = lite::x86::math::GetBlas<TARGET(kX86), T>(context);
blas.GEMM(CblasNoTrans,
CblasNoTrans,
x->dims()[0],
dim_t * dim_in,
dim_in,
1.0f,
bottom_l_data,
dim_in,
t_data,
dim_t * dim_in,
0.0f,
bottom_l_trans_data,
dim_t * dim_in);
for (size_t b = 0; b < x->lod()[0].size() - 1; b++) {
for (int t = 0; t < dim_t; t++) {
int len_l = offset_l[b + 1] - offset_l[b];
int len_r = offset_r[b + 1] - offset_r[b];
auto* top_data = out_data + top_offset[b] + t * len_l * len_r;
const auto* l_t_data =
bottom_l_trans_data + offset_l[b] * dim_t * dim_in + t * dim_in;
const auto* r_data = bottom_r_data + offset_r[b] * dim_in;
auto blas = lite::x86::math::GetBlas<TARGET(kX86), T>(context);
blas.GEMM(CblasNoTrans,
CblasTrans,
len_l,
len_r,
dim_in,
1.0f,
l_t_data,
dim_t * dim_in,
r_data,
dim_in,
0.0f,
top_data,
len_r);
}
}
LoD out_lod;
out_lod.push_back(top_offset);
out->set_lod(out_lod);
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
match_matrix_tensor,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::MatchMatrixTensorCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("W", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Y", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Tmp", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include "lite/backends/x86/math/blas.h"
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class MatchMatrixTensorCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::MatchMatrixTensorParam;
void Run() override;
virtual ~MatchMatrixTensorCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/match_matrix_tensor_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(match_matrix_tensor_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"match_matrix_tensor");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
TEST(match_matrix_tensor_x86, init) {
MatchMatrixTensorCompute<float> mmtc;
ASSERT_EQ(mmtc.precision(), PRECISION(kFloat));
ASSERT_EQ(mmtc.target(), TARGET(kX86));
}
TEST(match_matrix_tensor_x86, run_test) {
int ix = 5, iy = 4, h = 2, dim_t = 2;
lite::Tensor x, w, y, out, tmp;
x.Resize({ix, h});
w.Resize({h, dim_t, h});
y.Resize({iy, h});
out.Resize({18, 1});
tmp.Resize({20, 1});
LoD x_lod{};
x_lod.push_back({0, 2, 5});
x.set_lod(x_lod);
LoD y_lod{};
y_lod.push_back({0, 3, 4});
y.set_lod(y_lod);
auto* x_data = x.mutable_data<float>();
for (int64_t i = 0; i < x.numel(); i++) {
x_data[i] = static_cast<float>(i);
}
auto* y_data = y.mutable_data<float>();
for (int64_t i = 0; i < y.numel(); i++) {
y_data[i] = static_cast<float>(i);
}
auto* w_data = w.mutable_data<float>();
for (int64_t i = 0; i < w.numel(); i++) {
w_data[i] = static_cast<float>(i);
}
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
MatchMatrixTensorCompute<float> mmtc;
mmtc.SetContext(std::move(ctx));
operators::MatchMatrixTensorParam param;
param.x = &x;
param.w = &w;
param.y = &y;
param.dim_t = dim_t;
param.out = &out;
param.tmp = &tmp;
mmtc.SetParam(param);
mmtc.Run();
std::vector<float> ref_results = {5,
23,
41,
17,
75,
133,
7,
33,
59,
27,
125,
223,
323,
455,
587,
557,
793,
1029};
auto* out_data = out.mutable_data<float>();
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_results[i], 1e-3);
// LOG(INFO) << out_data[i];
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(match_matrix_tensor, kX86, kFloat, kNCHW, def);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_grnn_compute.h"
#include <algorithm>
#include <vector>
#include "lite/backends/x86/math/blas.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
T sigmoid(T z) {
return 1 / (1 + std::exp(-z));
}
template <typename T>
void CallGemm(const lite::x86::math::BlasT<TARGET(kX86), T>& blas,
const CBLAS_TRANSPOSE TransA,
const CBLAS_TRANSPOSE TransB,
const int M,
const int N,
const int K,
const T alpha,
const T* A,
const T* B,
const T beta,
T* C) {
int lda = (TransA == CblasNoTrans) ? K : M;
int ldb = (TransB == CblasNoTrans) ? N : K;
blas.GEMM(TransA, TransB, M, N, K, alpha, A, lda, B, ldb, beta, C, N);
}
template <typename T>
void SearchGrnnCompute<T>::PrepareLayout(const Tensor* input_blob) {
auto& param = this->Param<param_t>();
auto* _idx_sorted_by_width = param.idx_sorted_by_width;
auto* _layout_input = param.layout_input;
auto* _input = input_blob;
// usually total length
int dim0 = _input->dims()[0];
// if it is id only sequence
int dim1 = 1;
// if its a embedding like sequence (dim1 would be embedding_size)
if (_input->dims().size() > 1) {
dim1 = _input->dims()[1];
}
int batch = _input->lod()[0].size() - 1;
auto& offset = _input->lod()[0];
Tensor _width;
_width.Resize({batch});
_idx_sorted_by_width->Resize({batch});
int* width_data = _width.template mutable_data<int>();
int* idx_sorted_by_width_data =
_idx_sorted_by_width->template mutable_data<int>();
// sort sequence by width (descending) and find the largest width in the
// batch
for (int i = 0; i < batch; i++) {
width_data[i] = offset[i + 1] - offset[i];
idx_sorted_by_width_data[i] = i;
}
std::sort(idx_sorted_by_width_data,
idx_sorted_by_width_data + batch,
[&_width](int a, int b) {
return _width.template data<int>()[a] >
_width.template data<int>()[b];
});
int max_width = width_data[idx_sorted_by_width_data[0]];
// start of reorganizing the input
std::vector<size_t> new_offset;
new_offset.resize(max_width + 1);
new_offset[0] = 0;
int j = batch - 1;
int last_width = 0;
int sub_row = 0;
int sub_col = 0;
for (int i = 1; i <= max_width;) {
for (int k = j; k >= 0; --k) {
if (width_data[idx_sorted_by_width_data[k]] > last_width) {
sub_row = width_data[idx_sorted_by_width_data[k]] - last_width;
sub_col = k + 1;
for (int s = 0; s < sub_row; s++) {
new_offset[i] = new_offset[i - 1] + sub_col;
i++;
}
// move on
last_width = width_data[idx_sorted_by_width_data[k]];
j = k - 1;
break;
}
}
}
// copying to the reorganized buffer
if (_input->dims().size() == 1) {
// _layout_input.reshape_batch_sequence({dim0}, new_offset);
LOG(FATAL) << "_input->dims().size() = 1, error.";
} else {
// _layout_input.reshape_batch_sequence({dim0, dim1}, new_offset);
LoD new_lod;
new_lod.push_back(new_offset);
_layout_input->set_lod(new_lod);
_layout_input->Resize({dim0, dim1});
}
auto* new_emb = _layout_input->template mutable_data<T>();
for (int i = 0; i < max_width; i++) {
int w = new_offset[i + 1] - new_offset[i];
auto* emb_start = new_emb + dim1 * new_offset[i];
for (int j = 0; j < w; ++j) {
memcpy(emb_start + dim1 * j,
_input->template data<T>() +
dim1 * offset[idx_sorted_by_width_data[j]] + dim1 * i,
dim1 * sizeof(T));
}
}
}
template <typename T>
void SearchGrnnCompute<T>::CopyBack(T* from, T* to, int step) {
auto& param = this->Param<param_t>();
auto* _input = param.x;
auto* _layout_input = param.layout_input;
auto* _idx_sorted_by_width = param.idx_sorted_by_width;
const auto& offset = _input->lod()[0];
const auto& new_offset = _layout_input->lod()[0];
const auto* idx_sorted_by_width_data =
_idx_sorted_by_width->template data<int>();
for (size_t i = 0; i < _layout_input->lod()[0].size() - 1; ++i) {
int w = new_offset[i + 1] - new_offset[i];
for (int j = 0; j < w; j++) {
memcpy(to + step * (offset[idx_sorted_by_width_data[j]] + i),
from + (new_offset[i] + j) * step,
step * sizeof(T));
}
}
}
template <typename T>
void SearchGrnnCompute<T>::Run() {
auto& context = ctx_->As<X86Context>();
auto& param = this->Param<param_t>();
auto* bottom = param.x;
auto* wi = param.wi;
auto* wh = param.wh;
auto* top = param.out;
auto* _buffer = param.tmp_buffer;
int _cap_h = param.num_hidden;
int _cap_e = param.num_input;
int _cap_l = bottom->dims()[0];
int batch = bottom->lod()[0].size() - 1;
const auto& offset = bottom->lod()[0];
LoD top_lod;
top_lod.push_back(offset);
top->set_lod(top_lod);
std::vector<int64_t> top_dims_vec{_cap_l, _cap_h};
top->Resize(top_dims_vec);
auto* top_hidden = top->template mutable_data<T>();
const auto* dense_e2h = wi->template data<T>();
const auto* dense_h2h = wh->template data<T>();
const auto* e2h = dense_e2h;
const auto* e2hr = dense_e2h + 1 * _cap_e * _cap_h;
const auto* e2hz = dense_e2h + 2 * _cap_e * _cap_h;
const auto* h2h = dense_h2h;
const auto* h2hr = dense_h2h + 1 * _cap_h * _cap_h;
const auto* h2hz = dense_h2h + 2 * _cap_h * _cap_h;
PrepareLayout(bottom);
auto* _layout_input = param.layout_input;
auto* new_emb = _layout_input->template mutable_data<T>();
const auto& new_offset = _layout_input->lod()[0];
int max_width = _layout_input->lod()[0].size() - 1;
// this buffer is used for book keeping info which will be used in bp
// buffer also needed in bp, so make it larger
_buffer->Resize({20, _cap_l, _cap_h});
auto* buffer_data = _buffer->template mutable_data<T>();
auto* w_x_e = buffer_data + 0 * _cap_l * _cap_h;
auto* wr_x_e = buffer_data + 1 * _cap_l * _cap_h;
auto* wz_x_e = buffer_data + 2 * _cap_l * _cap_h;
auto* u_x_h = buffer_data + 3 * _cap_l * _cap_h;
auto* ur_x_h = buffer_data + 4 * _cap_l * _cap_h;
auto* uz_x_h = buffer_data + 5 * _cap_l * _cap_h;
auto* r = buffer_data + 6 * _cap_l * _cap_h;
auto* z = buffer_data + 7 * _cap_l * _cap_h;
auto* tilde = buffer_data + 8 * _cap_l * _cap_h;
// the internal hidden
auto* hidden = buffer_data + 19 * _cap_l * _cap_h;
auto blas = lite::x86::math::GetBlas<TARGET(kX86), T>(context);
CallGemm(blas,
CblasNoTrans,
CblasTrans,
_cap_l,
_cap_h,
_cap_e,
1.0f,
new_emb,
e2h,
0.0f,
w_x_e);
CallGemm(blas,
CblasNoTrans,
CblasTrans,
_cap_l,
_cap_h,
_cap_e,
1.0f,
new_emb,
e2hr,
0.0f,
wr_x_e);
CallGemm(blas,
CblasNoTrans,
CblasTrans,
_cap_l,
_cap_h,
_cap_e,
1.0f,
new_emb,
e2hz,
0.0f,
wz_x_e);
// precompute hidden0
for (int i = 0; i < batch * _cap_h; i++) {
tilde[i] = std::tanh(w_x_e[i]);
z[i] = sigmoid<T>(wz_x_e[i]);
hidden[i] = (1. - z[i]) * tilde[i];
}
// recurrence
for (int i = 1; i < max_width; i++) {
int w_tm1 = new_offset[i] - new_offset[i - 1];
int w = new_offset[i + 1] - new_offset[i];
// precompute hidden i-1 to hidden i
auto* htm1 = hidden + new_offset[i - 1] * _cap_h;
CallGemm(blas,
CblasNoTrans,
CblasTrans,
w,
_cap_h,
_cap_h,
1.0f,
htm1,
h2h,
0.0f,
u_x_h + new_offset[i] * _cap_h);
CallGemm(blas,
CblasNoTrans,
CblasTrans,
w,
_cap_h,
_cap_h,
1.0f,
htm1,
h2hr,
0.0f,
ur_x_h + new_offset[i] * _cap_h);
CallGemm(blas,
CblasNoTrans,
CblasTrans,
w,
_cap_h,
_cap_h,
1.0f,
htm1,
h2hz,
0.0f,
uz_x_h + new_offset[i] * _cap_h);
// compute the gate and hidden
for (size_t j = new_offset[i] * _cap_h; j < (new_offset[i] + w) * _cap_h;
j++) {
r[j] = sigmoid(wr_x_e[j] + ur_x_h[j]);
z[j] = sigmoid(wz_x_e[j] + uz_x_h[j]);
tilde[j] = std::tanh(w_x_e[j] + r[j] * u_x_h[j]);
hidden[j] = z[j] * hidden[j - _cap_h * w_tm1] + (1.0 - z[j]) * tilde[j];
}
}
CopyBack(hidden, top_hidden, _cap_h);
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(search_grnn,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SearchGrnnCompute<float>,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Wi", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Wh", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("tmp_buffer", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("idx_sorted_by_width",
{LiteType::GetTensorTy(TARGET(kX86), PRECISION(kInt32))})
.BindOutput("layout_input", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SearchGrnnCompute : public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SearchGrnnParam;
void Run() override;
virtual ~SearchGrnnCompute() = default;
private:
void PrepareLayout(const Tensor* input);
void CopyBack(T* from, T* to, int step);
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_seq_depadding_compute.h"
#include <vector>
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
void SearchSeqDepaddingCompute<T>::Run() {
auto& param = this->Param<param_t>();
auto* pad = param.pad;
auto* src = param.src;
auto* out = param.out;
const int pad_batch = pad->lod()[0].size() - 1;
const int src_batch = src->lod()[0].size() - 1;
if (pad_batch % src_batch != 0) {
LOG(FATAL) << "Mismatch batch size.";
}
const auto& pad_offset = pad->lod()[0];
const int pad_cap_e = pad->dims()[1];
const auto& src_offset = src->lod()[0];
const int src_cap_l = src->dims()[0];
LoD out_lod;
out_lod.push_back(src_offset);
out->set_lod(out_lod);
out->Resize({src_cap_l, pad_cap_e});
const auto* pad_data = pad->template data<T>();
auto* out_data = out->template mutable_data<T>();
for (int i = 0; i < src_batch; ++i) {
const int src_i_l = src_offset[i + 1] - src_offset[i];
const int pad_i_l = pad_offset[i + 1] - pad_offset[i];
if (pad_i_l < src_i_l) {
LOG(FATAL)
<< "the length of padding seq input is less than source seq input.";
}
memcpy(out_data + src_offset[i] * pad_cap_e,
pad_data + pad_offset[i] * pad_cap_e,
src_i_l * pad_cap_e * sizeof(T));
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(
search_seq_depadding,
kX86,
kFloat,
kNCHW,
paddle::lite::kernels::x86::SearchSeqDepaddingCompute<float>,
def)
.BindInput("Pad", {LiteType::GetTensorTy(TARGET(kX86))})
.BindInput("Src", {LiteType::GetTensorTy(TARGET(kX86))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kX86))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
#include "lite/operators/op_params.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
template <typename T>
class SearchSeqDepaddingCompute
: public KernelLite<TARGET(kX86), PRECISION(kFloat)> {
public:
using param_t = operators::SearchSeqDepaddingParam;
void Run() override;
virtual ~SearchSeqDepaddingCompute() = default;
};
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/x86/search_seq_depadding_compute.h"
#include <gtest/gtest.h>
#include <memory>
#include <utility>
#include <vector>
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace x86 {
TEST(search_seq_depadding_x86, retrive_op) {
auto kernel =
KernelRegistry::Global().Create<TARGET(kX86), PRECISION(kFloat)>(
"search_seq_depadding");
ASSERT_FALSE(kernel.empty());
ASSERT_TRUE(kernel.front());
}
TEST(search_seq_depadding_x86, init) {
SearchSeqDepaddingCompute<float> ssdc;
ASSERT_EQ(ssdc.precision(), PRECISION(kFloat));
ASSERT_EQ(ssdc.target(), TARGET(kX86));
}
TEST(search_seq_depadding_x86, run_test) {
lite::Tensor pad, src, out;
pad.Resize({2 * 3, 4});
src.Resize({3, 1});
out.Resize({3, 4});
LoD pad_lod{};
pad_lod.push_back({0, 4, 6});
pad.set_lod(pad_lod);
LoD src_lod{};
src_lod.push_back({0, 2, 3});
src.set_lod(src_lod);
auto* pad_data = pad.mutable_data<float>();
for (int64_t i = 0; i < pad.dims().production(); i++) {
pad_data[i] = static_cast<float>(i);
}
SearchSeqDepaddingCompute<float> ssdc;
operators::SearchSeqDepaddingParam param;
std::unique_ptr<KernelContext> ctx(new KernelContext);
ctx->As<X86Context>();
ssdc.SetContext(std::move(ctx));
param.pad = &pad;
param.src = &src;
param.out = &out;
ssdc.SetParam(param);
ssdc.Run();
std::vector<float> ref_results = {0, 1, 2, 3, 4, 5, 6, 7, 16, 17, 18, 19};
auto* out_data = out.mutable_data<float>();
for (int i = 0; i < out.dims().production(); i++) {
EXPECT_NEAR(out_data[i], ref_results[i], 1e-3);
}
}
} // namespace x86
} // namespace kernels
} // namespace lite
} // namespace paddle
USE_LITE_KERNEL(search_seq_depadding, kX86, kFloat, kNCHW, def);
......@@ -80,6 +80,9 @@ add_operator(fake_channel_wise_dequantize_max_abs_op extra SRCS fake_channel_wis
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS})
add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS})
add_operator(search_grnn_op_lite extra SRCS search_grnn_op.cc DEPS ${op_DEPS})
add_operator(sequence_concat_op_lite extra SRCS sequence_concat_op.cc DEPS ${op_DEPS})
add_operator(var_conv_2d_op_lite extra SRCS var_conv_2d_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/match_matrix_tensor_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool MatchMatrixTensorOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.y);
CHECK_OR_FALSE(param_.w);
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.tmp);
DDim x_dims = param_.x->dims();
DDim y_dims = param_.y->dims();
DDim w_dims = param_.w->dims();
int dim_t = param_.dim_t;
CHECK_OR_FALSE(x_dims.size() == 2);
CHECK_OR_FALSE(y_dims.size() == 2);
CHECK_OR_FALSE(w_dims.size() == 3);
CHECK_OR_FALSE(x_dims[1] == w_dims[0] && y_dims[1] == w_dims[2] &&
w_dims[1] == dim_t);
return true;
}
bool MatchMatrixTensorOpLite::InferShape() const {
const Tensor* x = param_.x;
const Tensor* y = param_.y;
DDim x_dims = param_.x->dims();
DDim y_dims = param_.y->dims();
DDim w_dims = param_.w->dims();
int dim_t = param_.dim_t;
const auto& x_lod = x->lod();
CHECK_OR_FALSE(!x_lod.empty());
const auto& x_lod_0 = x_lod[0];
CHECK_OR_FALSE(x_lod_0.size() >= 2);
CHECK_OR_FALSE(x_dims[0] == x_lod_0.back());
const auto& y_lod = y->lod();
CHECK_OR_FALSE(!y_lod.empty());
const auto& y_lod_0 = y_lod[0];
CHECK_OR_FALSE(y_lod_0.size() >= 2);
CHECK_OR_FALSE(y_dims[0] == y_lod_0.back());
CHECK_OR_FALSE(x_lod_0.size() == y_lod_0.size());
int out_dim_0 = 0;
for (size_t i = 1; i < x_lod_0.size(); i++) {
int x_len = x_lod_0[i] - x_lod_0[i - 1];
int y_len = y_lod_0[i] - y_lod_0[i - 1];
out_dim_0 += (x_len * y_len);
}
out_dim_0 *= dim_t;
int tmp_dim_0 = x_dims[0] * dim_t * x_dims[1];
param_.out->Resize({out_dim_0, 1});
param_.tmp->Resize({tmp_dim_0, 1});
return true;
}
bool MatchMatrixTensorOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto x = op_desc.Input("X").front();
auto w = op_desc.Input("W").front();
auto y = op_desc.Input("Y").front();
auto out = op_desc.Output("Out").front();
auto tmp = op_desc.Output("Tmp").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.w = scope->FindVar(w)->GetMutable<lite::Tensor>();
param_.y = scope->FindVar(y)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.tmp = scope->FindVar(tmp)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(match_matrix_tensor,
paddle::lite::operators::MatchMatrixTensorOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class MatchMatrixTensorOpLite : public OpLite {
public:
MatchMatrixTensorOpLite() {}
explicit MatchMatrixTensorOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "match_matrix_tensor"; }
private:
mutable MatchMatrixTensorParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
......@@ -951,6 +951,38 @@ struct AssignValueParam {
lite::Tensor* Out{};
};
/// --------------------- match_matrix_tensor operators --------------------
struct MatchMatrixTensorParam {
const lite::Tensor* x{};
const lite::Tensor* y{};
const lite::Tensor* w{};
lite::Tensor* out{};
lite::Tensor* tmp{};
int dim_t;
};
/// --------------------- search_seq_depadding operators --------------------
struct SearchSeqDepaddingParam {
const lite::Tensor* pad{};
const lite::Tensor* src{};
lite::Tensor* out{};
};
/// --------------------- search_grnn operators --------------------
struct SearchGrnnParam {
const lite::Tensor* x{};
const lite::Tensor* wi{};
const lite::Tensor* wh{};
int num_input;
int num_hidden;
lite::Tensor* out{};
lite::Tensor* tmp_buffer{};
lite::Tensor* idx_sorted_by_width{};
lite::Tensor* layout_input{};
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_grnn_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchGrnnOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.x);
CHECK_OR_FALSE(param_.wi);
CHECK_OR_FALSE(param_.wh);
CHECK_OR_FALSE(param_.out);
CHECK_OR_FALSE(param_.tmp_buffer);
CHECK_OR_FALSE(param_.idx_sorted_by_width);
CHECK_OR_FALSE(param_.layout_input);
int _cap_h = param_.num_hidden;
int _cap_e = param_.num_input;
const auto& x_dims = param_.x->dims();
CHECK_OR_FALSE(x_dims.size() == 2);
CHECK_OR_FALSE(x_dims[1] == _cap_e);
const auto& wi_dims = param_.wi->dims();
CHECK_OR_FALSE(wi_dims.size() == 3);
CHECK_OR_FALSE(wi_dims[0] == 3);
CHECK_OR_FALSE(wi_dims[1] == _cap_h);
CHECK_OR_FALSE(wi_dims[2] == _cap_e);
const auto& wh_dims = param_.wh->dims();
CHECK_OR_FALSE(wh_dims.size() == 3);
CHECK_OR_FALSE(wh_dims[0] == 3);
CHECK_OR_FALSE(wh_dims[1] == _cap_h);
CHECK_OR_FALSE(wh_dims[2] == _cap_h);
return true;
}
bool SearchGrnnOpLite::InferShape() const {
const auto& x_dims = param_.x->dims();
const auto& x_lod = param_.x->lod();
CHECK_OR_FALSE(!x_lod.empty());
CHECK_OR_FALSE(x_dims[0] == x_lod[0].back());
param_.out->set_lod(x_lod);
return true;
}
bool SearchGrnnOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto x = op_desc.Input("X").front();
auto wi = op_desc.Input("Wi").front();
auto wh = op_desc.Input("Wh").front();
param_.x = scope->FindVar(x)->GetMutable<lite::Tensor>();
param_.wi = scope->FindVar(wi)->GetMutable<lite::Tensor>();
param_.wh = scope->FindVar(wh)->GetMutable<lite::Tensor>();
param_.num_input = op_desc.GetAttr<int>("num_input");
param_.num_hidden = op_desc.GetAttr<int>("num_hidden");
auto out = op_desc.Output("Out").front();
auto tmp_buffer = op_desc.Output("tmp_buffer").front();
auto idx_sorted_by_width = op_desc.Output("idx_sorted_by_width").front();
auto layout_input = op_desc.Output("layout_input").front();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
param_.tmp_buffer = scope->FindVar(tmp_buffer)->GetMutable<lite::Tensor>();
param_.idx_sorted_by_width =
scope->FindVar(idx_sorted_by_width)->GetMutable<lite::Tensor>();
param_.layout_input =
scope->FindVar(layout_input)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(search_grnn, paddle::lite::operators::SearchGrnnOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchGrnnOpLite : public OpLite {
public:
SearchGrnnOpLite() {}
explicit SearchGrnnOpLite(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_grnn"; }
private:
mutable SearchGrnnParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/search_seq_depadding_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SearchSeqDepaddingOpLite::CheckShape() const {
CHECK_OR_FALSE(param_.pad);
CHECK_OR_FALSE(param_.src);
CHECK_OR_FALSE(param_.out);
DDim pad_dims = param_.pad->dims();
DDim src_dims = param_.src->dims();
CHECK_OR_FALSE(pad_dims.size() == 2);
CHECK_OR_FALSE(src_dims.size() == 2);
const auto& pad_lod = param_.pad->lod();
CHECK_OR_FALSE(!pad_lod.empty());
const auto& pad_lod_0 = pad_lod[0];
CHECK_OR_FALSE(pad_lod_0.size() >= 2);
CHECK_OR_FALSE(pad_dims[0] == pad_lod_0.back());
const auto& src_lod = param_.src->lod();
CHECK_OR_FALSE(!src_lod.empty());
const auto& src_lod_0 = src_lod[0];
CHECK_OR_FALSE(src_lod_0.size() >= 2);
CHECK_OR_FALSE(src_dims[0] == src_lod_0.back());
return true;
}
bool SearchSeqDepaddingOpLite::InferShape() const {
DDim pad_dims = param_.pad->dims();
param_.out->Resize({-1, pad_dims[1]});
return true;
}
bool SearchSeqDepaddingOpLite::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto pad = op_desc.Input("Pad").front();
auto src = op_desc.Input("Src").front();
auto out = op_desc.Output("Out").front();
param_.pad = scope->FindVar(pad)->GetMutable<lite::Tensor>();
param_.src = scope->FindVar(src)->GetMutable<lite::Tensor>();
param_.out = scope->FindVar(out)->GetMutable<lite::Tensor>();
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(search_seq_depadding,
paddle::lite::operators::SearchSeqDepaddingOpLite);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/operators/op_params.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SearchSeqDepaddingOpLite : public OpLite {
public:
SearchSeqDepaddingOpLite() {}
explicit SearchSeqDepaddingOpLite(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "search_seq_depadding"; }
private:
mutable SearchSeqDepaddingParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册