未验证 提交 1b74fded 编写于 作者: W Wilber 提交者: GitHub

add sequence_pool_concat fuse and kernel test=develop (#2645)

add sequence_pool_concat fuse pass

add fuse kernel
上级 3723451b
......@@ -31,6 +31,7 @@ USE_MIR_PASS(lite_fc_fuse_pass);
USE_MIR_PASS(lite_shuffle_channel_fuse_pass);
USE_MIR_PASS(lite_transpose_softmax_transpose_fuse_pass);
USE_MIR_PASS(lite_interpolate_fuse_pass);
USE_MIR_PASS(lite_sequence_pool_concat_fuse_pass);
USE_MIR_PASS(identity_scale_eliminate_pass);
USE_MIR_PASS(lite_conv_elementwise_fuse_pass);
USE_MIR_PASS(lite_conv_activation_fuse_pass);
......
......@@ -20,6 +20,7 @@ lite_cc_library(mir_passes
fusion/conv_bn_fuse_pass.cc
fusion/elementwise_add_activation_fuse_pass.cc
fusion/quant_dequant_fuse_pass.cc
fusion/sequence_pool_concat_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc
static_kernel_pick_pass.cc
......
......@@ -28,6 +28,9 @@ lite_cc_library(fuse_transpose_softmax_transpose
lite_cc_library(fuse_interpolate
SRCS interpolate_fuser.cc
DEPS pattern_matcher_high_api)
lite_cc_library(fuse_sequence_pool_concat
SRCS sequence_pool_concat_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers
fuse_fc
......@@ -40,6 +43,7 @@ set(mir_fusers
fuse_elementwise_add_activation
fuse_transpose_softmax_transpose
fuse_interpolate
fuse_sequence_pool_concat
CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_pool_concat_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void SequencePoolConcatFusePass::Apply(const std::unique_ptr<SSAGraph>& graph) {
fusion::SequencePoolConcatFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_sequence_pool_concat_fuse_pass,
paddle::lite::mir::SequencePoolConcatFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class SequencePoolConcatFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_pool_concat_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
// """
// merge {sequence_pool x 7, concat} => merge_sequence_pool_and_concat
// src1 src2 src7 src1 src2 src7
// | | | | |
// v v | | ... |
// sequence_pool sequence_pool ...(sequence_pool) | | |
// | | | => -------------------
// --------------------------------- |
// | |
// v v
// concat sequence_pool_concat
// """
void SequencePoolConcatFuser::BuildPattern() {
// create nodes.
auto* concat = OpNode("concat", "concat")->AsIntermediate();
#define STR1(R) #R
#define STR2(R) STR1(R)
#define POOL_CONCAT_PATTERN(num) \
auto* x_##num = VarNode(STR2(sequence_pool_x_##num)) \
->assert_is_op_input("sequence_pool", "X") \
->AsInput(); \
auto* sequence_pool_##num = \
OpNode(STR2(sequence_pool_##num), "sequence_pool")->AsIntermediate(); \
auto* sequence_pool_##num##_out = \
VarNode(STR2(sequence_pool_##num##_out)) \
->assert_is_op_output("sequence_pool", "Out") \
->assert_is_op_nth_input("concat", "X", num - 1) \
->AsIntermediate(); \
auto* sequence_pool_##num##_idx = \
VarNode(STR2(sequence_pool_##num##_idx)) \
->assert_is_op_output("sequence_pool", "MaxIndex") \
->AsIntermediate(); \
*sequence_pool_##num >> *sequence_pool_##num##_idx; \
*x_##num >> *sequence_pool_##num >> *sequence_pool_##num##_out >> *concat;
auto* concat_out =
VarNode("concat_out")->assert_is_op_output("concat", "Out");
*concat >> *concat_out;
POOL_CONCAT_PATTERN(1);
POOL_CONCAT_PATTERN(2);
POOL_CONCAT_PATTERN(3);
POOL_CONCAT_PATTERN(4);
POOL_CONCAT_PATTERN(5);
POOL_CONCAT_PATTERN(6);
POOL_CONCAT_PATTERN(7);
#undef POOL_CONCAT_PATTERN
#undef STR1
#undef STR2
}
void SequencePoolConcatFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto sequence_pool_concat_op =
LiteOpRegistry::Global().Create("sequence_pool_concat");
auto concat = matched.at("concat")->stmt()->op();
auto* scope = concat->scope();
auto& valid_places = concat->valid_places();
sequence_pool_concat_op->Attach(op_desc, scope);
auto* new_op_node =
graph->GraphCreateInstructNode(sequence_pool_concat_op, valid_places);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_1"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_2"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_3"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_4"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_5"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_6"), new_op_node);
IR_NODE_LINK_TO(matched.at("sequence_pool_x_7"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("concat_out"));
}
cpp::OpDesc SequencePoolConcatFuser::GenOpDesc(const key2nodes_t& matched) {
cpp::OpDesc op_desc = *matched.at("concat")->stmt()->op_info();
op_desc.SetType("sequence_pool_concat");
op_desc.SetInput("X",
{matched.at("sequence_pool_x_1")->arg()->name,
matched.at("sequence_pool_x_2")->arg()->name,
matched.at("sequence_pool_x_3")->arg()->name,
matched.at("sequence_pool_x_4")->arg()->name,
matched.at("sequence_pool_x_5")->arg()->name,
matched.at("sequence_pool_x_6")->arg()->name,
matched.at("sequence_pool_x_7")->arg()->name});
std::vector<std::string> pooltypes;
pooltypes.push_back(matched.at("sequence_pool_1")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_2")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_3")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_4")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_5")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_6")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
pooltypes.push_back(matched.at("sequence_pool_7")
->stmt()
->op_info()
->GetAttr<std::string>("pooltype"));
op_desc.SetAttr("pooltype", pooltypes);
op_desc.SetOutput("Out", {matched.at("concat_out")->arg()->name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class SequencePoolConcatFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
......@@ -69,6 +69,7 @@ class Optimizer {
"lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", //
"elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", //
#if (defined LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) || (defined LITE_WITH_CUDA) || \
(defined LITE_WITH_ARM)
"lite_elementwise_add_activation_fuse_pass", //
......
......@@ -11,6 +11,7 @@ add_kernel(leaky_relu_compute_cuda CUDA basic SRCS leaky_relu_compute.cu DEPS ${
add_kernel(relu_compute_cuda CUDA basic SRCS relu_compute.cu DEPS ${lite_kernel_deps})
add_kernel(yolo_box_compute_cuda CUDA basic SRCS yolo_box_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_compute_cuda CUDA extra SRCS sequence_pool_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pool_concat_compute_cuda CUDA extra SRCS sequence_pool_concat_compute.cu DEPS ${lite_kernel_deps})
add_kernel(transpose_compute_cuda CUDA basic SRCS transpose_compute.cu DEPS ${lite_kernel_deps} ${math_cuda} cuda_transpose)
add_kernel(nearest_interp_compute_cuda CUDA basic SRCS nearest_interp_compute.cu DEPS ${lite_kernel_deps})
add_kernel(conv2d_cuda CUDA basic SRCS conv_compute.cc DEPS ${lite_kernel_deps} ${math_cuda})
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <vector>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_pool_concat_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename Dtype>
__global__ void sequence_pool_concat(const uint64_t* input_locate_data,
const int* pool_type_list,
Dtype* output_data,
const int* offset,
int batch,
int in_num,
int in_dim) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int em_id = tid % in_dim;
int in_id = (tid / in_dim) % in_num;
int seq_id = tid / (in_dim * in_num);
if (seq_id >= batch) {
return;
}
Dtype* out_data = output_data + tid;
int offset_id = in_id * (batch + 1) + seq_id;
if (pool_type_list[in_id] == 4) { // last
const Dtype* in_data =
reinterpret_cast<const Dtype*>(
reinterpret_cast<uintptr_t>(input_locate_data[in_id])) +
em_id;
output_data[tid] = in_data[(offset[offset_id + 1] - 1) * in_dim];
} else if (pool_type_list[in_id] == 6) { // max
const Dtype* in_data =
reinterpret_cast<const Dtype*>(
reinterpret_cast<uintptr_t>(input_locate_data[in_id])) +
em_id + offset[offset_id] * in_dim;
Dtype max = in_data[0];
for (int i = 1; i < offset[offset_id + 1] - offset[offset_id]; i++) {
Dtype cur_data = in_data[i * in_dim];
max = cur_data > max ? cur_data : max;
}
output_data[tid] = max;
} else {
return;
}
}
template <typename Dtype>
__global__ void sequence_pool_concat(const uint64_t* input_locate_data,
const int* pool_type_list,
Dtype* output_data,
const int* offset,
int batch,
int in_num,
const int* out_offset,
const int* out_id_seq_map_data,
int out_dim) {
int tid = threadIdx.x + blockDim.x * blockIdx.x;
int em_id = tid % out_dim;
int seq_id = tid / out_dim;
int in_id = out_id_seq_map_data[em_id];
em_id = em_id - out_offset[in_id];
int in_dim = out_offset[in_id + 1] - out_offset[in_id];
if (seq_id >= batch) {
return;
}
Dtype* out_data = output_data + tid;
int offset_id = in_id * (batch + 1) + seq_id;
if (pool_type_list[in_id] == 4) { // last
const Dtype* in_data =
reinterpret_cast<const Dtype*>(
reinterpret_cast<uintptr_t>(input_locate_data[in_id])) +
em_id;
output_data[tid] = in_data[(offset[offset_id + 1] - 1) * in_dim];
} else if (pool_type_list[in_id] == 6) { // max
const Dtype* in_data =
reinterpret_cast<const Dtype*>(
reinterpret_cast<uintptr_t>(input_locate_data[in_id])) +
em_id + offset[offset_id] * in_dim;
Dtype max = in_data[0];
for (int i = 1; i < offset[offset_id + 1] - offset[offset_id]; i++) {
Dtype cur_data = in_data[i * in_dim];
max = cur_data > max ? cur_data : max;
}
output_data[tid] = max;
} else {
return;
}
}
void SequencePoolConcatCompute::PrepareForRun() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
int in_num = param.X.size();
std::vector<int64_t> shape({in_num, 1, 1, 1});
_in_offset_tensor.Resize(shape);
_in_ptr_tensor.Resize(shape);
_in_pool_type_tensor.Resize(shape);
int* in_pool_type_data =
_in_pool_type_tensor.mutable_data<int>(TARGET(kCUDA));
std::vector<int> pool_type_list;
for (auto type : param.pool_type) {
if (type == "AVERAGE") {
pool_type_list.push_back(1);
} else if (type == "SUM") {
pool_type_list.push_back(2);
} else if (type == "SQRT") {
pool_type_list.push_back(3);
} else if (type == "LAST") {
pool_type_list.push_back(4);
} else if (type == "FIRST") {
pool_type_list.push_back(5);
} else if (type == "MAX") {
pool_type_list.push_back(6);
} else {
LOG(ERROR) << "pool type " << type << " is not supoorted.";
}
}
_is_in_same_len = true;
int in_len = param.X[0]->dims().count(1, param.X[0]->dims().size());
std::vector<int> out_id_seq_map_list;
std::vector<int> out_offset_list;
int total_len = 0;
out_offset_list.push_back(total_len);
for (int i = 0; i < in_num; ++i) {
int cur_len = param.X[i]->dims().count(1, param.X[i]->dims().size());
_is_in_same_len = _is_in_same_len && in_len == cur_len;
for (int k = 0; k < cur_len; ++k) {
out_id_seq_map_list.push_back(i);
}
total_len += cur_len;
out_offset_list.push_back(total_len);
}
std::vector<int64_t> out_id_seq_map_shape({total_len, 1, 1, 1});
std::vector<int64_t> out_offset_shape({in_num + 1, 1, 1, 1});
_out_offset_tensor.Resize(out_offset_shape);
_out_id_seq_map_tensor.Resize(out_id_seq_map_shape);
int* out_offset_data = _out_offset_tensor.mutable_data<int>(TARGET(kCUDA));
int* out_id_seq_map_data =
_out_id_seq_map_tensor.mutable_data<int>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(in_pool_type_data,
&pool_type_list[0],
sizeof(int) * param.X.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(out_offset_data,
&out_offset_list[0],
sizeof(int) * out_offset_list.size(),
IoDirection::HtoD,
stream);
TargetWrapperCuda::MemcpyAsync(out_id_seq_map_data,
&out_id_seq_map_list[0],
sizeof(int) * out_id_seq_map_list.size(),
IoDirection::HtoD,
stream);
cudaStreamSynchronize(stream);
}
void SequencePoolConcatCompute::Run() {
auto& param = this->Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto& inputs = param.X;
auto offset = inputs[0]->lod()[0];
int batch = offset.size() - 1;
CHECK_GE(offset.size(), 1);
std::vector<int> all_offset;
for (int i = 0; i < inputs.size(); ++i) {
auto it = all_offset.end();
auto cur_offset = inputs[i]->lod()[0];
all_offset.insert(it, cur_offset.begin(), cur_offset.end());
}
int total_size = all_offset.size();
std::vector<int64_t> offset_shape({total_size, 1, 1, 1});
_in_offset_tensor.Resize(offset_shape);
int* offset_data = _in_offset_tensor.mutable_data<int>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(offset_data,
&all_offset[0],
sizeof(int) * all_offset.size(),
IoDirection::HtoD,
stream);
std::vector<uint64_t> in_locate_vec;
for (int i = 0; i < inputs.size(); ++i) {
in_locate_vec.push_back(
reinterpret_cast<uintptr_t>(inputs[i]->data<float>()));
}
uint64_t* in_locate_data =
_in_ptr_tensor.mutable_data<uint64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(in_locate_data,
&in_locate_vec[0],
sizeof(uint64_t) * inputs.size(),
IoDirection::HtoD,
stream);
const int* in_pool_type_data = _in_pool_type_tensor.data<int>();
const int* out_id_seq_map_data = _out_id_seq_map_tensor.data<int>();
const int* out_offset_data = _out_offset_tensor.data<int>();
int count = param.Out->numel();
int in_dim = inputs[0]->numel() / inputs[0]->dims()[0];
float* out_data = param.Out->mutable_data<float>(TARGET(kCUDA));
int in_num = inputs.size();
if (_is_in_same_len) {
sequence_pool_concat<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
in_locate_data,
in_pool_type_data,
out_data,
offset_data,
batch,
in_num,
in_dim);
} else {
int out_dim = param.Out->numel() / param.Out->dims()[0];
sequence_pool_concat<
float><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(
in_locate_data,
in_pool_type_data,
out_data,
offset_data,
batch,
in_num,
out_offset_data,
out_id_seq_map_data,
out_dim);
}
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
REGISTER_LITE_KERNEL(sequence_pool_concat,
kCUDA,
kFloat,
kNCHW,
paddle::lite::kernels::cuda::SequencePoolConcatCompute,
def)
.BindInput("X", {LiteType::GetTensorTy(TARGET(kCUDA))})
.BindOutput("Out", {LiteType::GetTensorTy(TARGET(kCUDA))})
.Finalize();
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
class SequencePoolConcatCompute
: public KernelLite<TARGET(kCUDA), PRECISION(kFloat)> {
public:
using param_t = operators::SequencePoolConcatParam;
void Run() override;
void PrepareForRun() override;
virtual ~SequencePoolConcatCompute() = default;
private:
lite::Tensor _in_offset_tensor;
lite::Tensor _in_ptr_tensor;
lite::Tensor _in_pool_type_tensor;
lite::Tensor _out_offset_tensor;
lite::Tensor _out_id_seq_map_tensor;
bool _is_in_same_len;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
......@@ -90,6 +90,8 @@ add_operator(merge_lod_tensor_op_lite extra SRCS merge_lod_tensor_op.cc DEPS ${o
add_operator(reduce_prod_op_lite extra SRCS reduce_prod_op.cc DEPS ${op_DEPS})
add_operator(sequence_reshape_op_lite extra SRCS sequence_reshape_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS})
add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS})
......@@ -120,7 +122,6 @@ add_operator(greater_than extra SRCS compare_op.cc DEPS ${op_DEPS})
add_operator(greater_equal extra SRCS compare_op.cc DEPS ${op_DEPS})
add_operator(read_from_array_op extra SRCS read_from_array_op.cc DEPS ${op_DEPS})
add_operator(beam_search_op extra SRCS beam_search_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS})
add_operator(lod_reset_op extra SRCS lod_reset_op.cc DEPS ${op_DEPS})
add_operator(is_empty extra SRCS is_empty_op.cc DEPS ${op_DEPS})
add_operator(slice_op_lite basic SRCS slice_op.cc DEPS ${op_DEPS})
......
......@@ -769,6 +769,12 @@ struct SequencePoolParam {
#endif
};
struct SequencePoolConcatParam {
std::vector<lite::Tensor*> X{};
lite::Tensor* Out{};
std::vector<std::string> pool_type{};
};
struct SearchGroupPaddingParam {
lite::Tensor* x{};
lite::Tensor* out_emb_padding{};
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_pool_concat_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequencePoolConcatOp::CheckShape() const {
CHECK_GE(param_.X.size(), 1)
<< "The number of input sequences is at least two.";
CHECK_OR_FALSE(param_.Out);
return true;
}
bool SequencePoolConcatOp::InferShape() const {
int out_dim = 0;
for (int i = 0; i < param_.X.size(); ++i) {
out_dim += param_.X[i]->dims().count(1, param_.X[i]->dims().size());
}
int seq_num = param_.X[0]->lod()[0].size() - 1;
std::vector<std::vector<uint64_t>> lod(1);
for (int i = 0; i < seq_num + 1; ++i) {
lod[0].push_back(i);
}
param_.Out->set_lod(lod);
param_.Out->Resize({seq_num, out_dim});
return true;
}
bool SequencePoolConcatOp::AttachImpl(const cpp::OpDesc &opdesc,
lite::Scope *scope) {
auto input_list = opdesc.Input("X");
param_.X.clear();
for (auto var : input_list) {
param_.X.push_back(scope->FindVar(var)->GetMutable<lite::Tensor>());
}
param_.Out =
scope->FindVar(opdesc.Output("Out").front())->GetMutable<lite::Tensor>();
CHECK(param_.Out) << "Output(Out) of Sequence Concat Op should not be null.";
param_.pool_type = opdesc.GetAttr<std::vector<std::string>>("pooltype");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_pool_concat,
paddle::lite::operators::SequencePoolConcatOp);
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
#include "lite/utils/all.h"
namespace paddle {
namespace lite {
namespace operators {
class SequencePoolConcatOp : public OpLite {
public:
SequencePoolConcatOp() {}
explicit SequencePoolConcatOp(const std::string &op_type) : OpLite(op_type) {}
bool CheckShape() const override;
bool InferShape() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override { return "sequence_pool_concat"; }
private:
mutable SequencePoolConcatParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册