未验证 提交 5dcd5637 编写于 作者: W Wilber 提交者: GitHub

add reverse embedding. test=develop (#4106)

上级 bdff2240
...@@ -41,6 +41,7 @@ USE_MIR_PASS(lite_conv_activation_fuse_pass); ...@@ -41,6 +41,7 @@ USE_MIR_PASS(lite_conv_activation_fuse_pass);
USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass); USE_MIR_PASS(lite_var_conv_2d_activation_fuse_pass);
USE_MIR_PASS(lite_match_matrix_activation_fuse_pass); USE_MIR_PASS(lite_match_matrix_activation_fuse_pass);
USE_MIR_PASS(lite_scales_fuse_pass); USE_MIR_PASS(lite_scales_fuse_pass);
USE_MIR_PASS(lite_sequence_reverse_embedding_fuse_pass);
USE_MIR_PASS(lite_elementwise_activation_fuse_pass); USE_MIR_PASS(lite_elementwise_activation_fuse_pass);
USE_MIR_PASS(lite_quant_dequant_fuse_pass); USE_MIR_PASS(lite_quant_dequant_fuse_pass);
USE_MIR_PASS(type_precision_cast_pass); USE_MIR_PASS(type_precision_cast_pass);
......
...@@ -31,6 +31,7 @@ lite_cc_library(mir_passes ...@@ -31,6 +31,7 @@ lite_cc_library(mir_passes
fusion/__xpu__mmdnn_fuse_pass.cc fusion/__xpu__mmdnn_fuse_pass.cc
fusion/match_matrix_activation_fuse_pass.cc fusion/match_matrix_activation_fuse_pass.cc
fusion/scales_fuse_pass.cc fusion/scales_fuse_pass.cc
fusion/sequence_reverse_embedding_fuse_pass.cc
elimination/identity_scale_eliminate_pass.cc elimination/identity_scale_eliminate_pass.cc
elimination/identity_dropout_eliminate_pass.cc elimination/identity_dropout_eliminate_pass.cc
elimination/elementwise_mul_constant_eliminate_pass.cc elimination/elementwise_mul_constant_eliminate_pass.cc
......
...@@ -43,6 +43,9 @@ lite_cc_library(fuse_match_matrix_activation ...@@ -43,6 +43,9 @@ lite_cc_library(fuse_match_matrix_activation
lite_cc_library(fuse_scales lite_cc_library(fuse_scales
SRCS scales_fuser.cc SRCS scales_fuser.cc
DEPS pattern_matcher_high_api) DEPS pattern_matcher_high_api)
lite_cc_library(fuse_sequence_reverse_embedding
SRCS sequence_reverse_embedding_fuser.cc
DEPS pattern_matcher_high_api)
set(mir_fusers set(mir_fusers
fuse_fc fuse_fc
...@@ -60,6 +63,7 @@ set(mir_fusers ...@@ -60,6 +63,7 @@ set(mir_fusers
fuse_scale_activation fuse_scale_activation
fuse_match_matrix_activation fuse_match_matrix_activation
fuse_scales fuse_scales
fuse_sequence_reverse_embedding
CACHE INTERNAL "fusers") CACHE INTERNAL "fusers")
if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK) if (LITE_WITH_LIGHT_WEIGHT_FRAMEWORK)
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_reverse_embedding_fuse_pass.h"
#include <memory>
#include <vector>
#include "lite/core/mir/fusion/sequence_reverse_embedding_fuser.h"
#include "lite/core/mir/pass_registry.h"
namespace paddle {
namespace lite {
namespace mir {
void SequenceReverseEmbeddingFusePass::Apply(
const std::unique_ptr<SSAGraph>& graph) {
fusion::SequenceReverseEmbeddingFuser fuser;
fuser(graph.get());
}
} // namespace mir
} // namespace lite
} // namespace paddle
REGISTER_MIR_PASS(lite_sequence_reverse_embedding_fuse_pass,
paddle::lite::mir::SequenceReverseEmbeddingFusePass)
.BindTargets({TARGET(kCUDA)});
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pass.h"
namespace paddle {
namespace lite {
namespace mir {
class SequenceReverseEmbeddingFusePass : public ProgramPass {
public:
void Apply(const std::unique_ptr<SSAGraph>& graph) override;
};
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/core/mir/fusion/sequence_reverse_embedding_fuser.h"
#include <memory>
#include <vector>
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
void SequenceReverseEmbeddingFuser::BuildPattern() {
// create input nodes.
auto* x =
VarNode("x")->assert_is_op_input("sequence_reverse", "X")->AsInput();
auto* w = VarNode("w")->assert_is_op_input("lookup_table", "W")->AsInput();
// create op nodes
auto* sequence_reverse = OpNode("sequence_reverse", "sequence_reverse")
->assert_is_op("sequence_reverse")
->AsIntermediate();
auto* lookup_table = OpNode("lookup_table", "lookup_table")
->assert_is_op("lookup_table")
->AsIntermediate();
// create intermediate nodes
auto* sequence_reverse_out =
VarNode("sequence_reverse_out")
->assert_is_op_output("sequence_reverse", "Y")
->assert_is_op_input("lookup_table", "Ids")
->AsIntermediate();
// create output node
auto* out =
VarNode("out")->assert_is_op_output("lookup_table", "Out")->AsOutput();
// create topology.
*x >> *sequence_reverse >> *sequence_reverse_out >> *lookup_table >> *out;
*w >> *lookup_table;
}
void SequenceReverseEmbeddingFuser::InsertNewNode(SSAGraph* graph,
const key2nodes_t& matched) {
auto op_desc = GenOpDesc(matched);
auto fuse_op = LiteOpRegistry::Global().Create("sequence_reverse_embedding");
auto lookup_table = matched.at("lookup_table")->stmt()->op();
auto* scope = lookup_table->scope();
auto& valid_places = lookup_table->valid_places();
fuse_op->Attach(op_desc, scope);
auto* new_op_node = graph->GraphCreateInstructNode(fuse_op, valid_places);
IR_NODE_LINK_TO(matched.at("x"), new_op_node);
IR_NODE_LINK_TO(matched.at("w"), new_op_node);
IR_NODE_LINK_TO(new_op_node, matched.at("out"));
}
cpp::OpDesc SequenceReverseEmbeddingFuser::GenOpDesc(
const key2nodes_t& matched) {
auto op_desc = *matched.at("lookup_table")->stmt()->op_info();
op_desc.SetType("sequence_reverse_embedding");
auto& in_name = matched.at("x")->arg()->name;
auto& w_name = matched.at("w")->arg()->name;
auto& out_name = matched.at("out")->arg()->name;
op_desc.SetInput("Ids", {in_name});
op_desc.SetInput("W", {w_name});
op_desc.SetOutput("Out", {out_name});
return op_desc;
}
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include "lite/core/mir/pattern_matcher_high_api.h"
namespace paddle {
namespace lite {
namespace mir {
namespace fusion {
class SequenceReverseEmbeddingFuser : public FuseBase {
public:
void BuildPattern() override;
void InsertNewNode(SSAGraph* graph, const key2nodes_t& matched) override;
private:
cpp::OpDesc GenOpDesc(const key2nodes_t& matched) override;
};
} // namespace fusion
} // namespace mir
} // namespace lite
} // namespace paddle
...@@ -98,6 +98,7 @@ class Optimizer { ...@@ -98,6 +98,7 @@ class Optimizer {
"lite_interpolate_fuse_pass", // "lite_interpolate_fuse_pass", //
"identity_scale_eliminate_pass", // "identity_scale_eliminate_pass", //
"lite_scales_fuse_pass", // "lite_scales_fuse_pass", //
"lite_sequence_reverse_embedding_fuse_pass", //
"elementwise_mul_constant_eliminate_pass", // "elementwise_mul_constant_eliminate_pass", //
"lite_sequence_pool_concat_fuse_pass", // "lite_sequence_pool_concat_fuse_pass", //
"lite_scale_activation_fuse_pass", // "lite_scale_activation_fuse_pass", //
......
...@@ -38,6 +38,7 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute. ...@@ -38,6 +38,7 @@ add_kernel(bilinear_interp_compute_cuda CUDA basic SRCS bilinear_interp_compute.
add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps}) add_kernel(search_seq_depadding_compute_cuda CUDA extra SRCS search_seq_depadding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda}) add_kernel(search_grnn_compute_cuda CUDA extra SRCS search_grnn_compute.cu DEPS ${lite_kernel_deps} cuda_gemm ${math_cuda})
add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_reverse_embedding_compute_cuda CUDA extra SRCS sequence_reverse_embedding_compute.cu DEPS ${lite_kernel_deps})
add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda}) add_kernel(sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS ${lite_kernel_deps} ${math_cuda})
add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps}) add_kernel(sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS ${lite_kernel_deps})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
#include "lite/core/target_wrapper.h"
#include "lite/kernels/cuda/sequence_reverse_embedding_compute.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T>
__host__ __device__ inline size_t UpperBound(const T* x,
const int num,
const T& val) {
// The following code is from
// https://en.cppreference.com/w/cpp/algorithm/upper_bound
auto* first = x;
int64_t count = static_cast<int64_t>(num);
while (count > 0) {
auto step = (count >> 1);
auto* it = first + step;
if (val < *it) {
count = step;
} else {
first = ++it;
count -= (step + 1);
}
}
return static_cast<size_t>(first - x);
}
template <typename T>
__global__ void SequenceReverseEmbeddingKernel(const int64_t* ids,
const T* table,
T* out,
const int64_t* lod,
const int lod_count,
const int width,
const int count,
const bool padding_flag,
const int64_t padding_idx) {
CUDA_KERNEL_LOOP(tid, count) {
int64_t row = tid / width;
int col = tid % width;
auto lod_idx = UpperBound(lod, lod_count, row);
auto reverse_row = lod[lod_idx - 1] + lod[lod_idx] - 1 - row;
if (padding_flag) {
if (ids[reverse_row] == padding_idx)
out[tid] = 0;
else
out[tid] = table[ids[reverse_row] * width + col];
} else {
out[tid] = table[ids[reverse_row] * width + col];
}
}
}
template <typename T, PrecisionType Ptype>
void SequenceReverseEmbeddingCompute<T, Ptype>::Run() {
auto& param = this->template Param<param_t>();
auto& ctx = this->ctx_->template As<CUDAContext>();
auto stream = ctx.exec_stream();
auto io_stream = ctx.io_stream();
auto* table_data = param.W->template data<T>();
auto* out_data = param.Out->template mutable_data<T>(TARGET(kCUDA));
auto* ids_data = param.Ids->template data<int64_t>();
const auto lod = param.Ids->lod()[param.Ids->lod().size() - 1];
const int lod_count = lod.size();
const int width = param.W->dims()[1];
const int count = param.Out->numel();
lod_info_.Resize({static_cast<int64_t>(lod.size())});
int64_t* lod_data = lod_info_.mutable_data<int64_t>(TARGET(kCUDA));
TargetWrapperCuda::MemcpyAsync(lod_data,
lod.data(),
sizeof(int64_t) * lod.size(),
IoDirection::HtoD,
stream);
int64_t padding_idx = param.padding_idx;
bool padding_flag = padding_idx != -1;
SequenceReverseEmbeddingKernel<
T><<<CUDA_GET_BLOCKS(count), CUDA_NUM_THREADS, 0, stream>>>(ids_data,
table_data,
out_data,
lod_data,
lod_count,
width,
count,
padding_flag,
padding_idx);
CUDA_POST_KERNEL_CHECK;
}
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
using SeqReverseEmbFp32 = paddle::lite::kernels::cuda::
SequenceReverseEmbeddingCompute<float, PRECISION(kFloat)>;
REGISTER_LITE_KERNEL(
sequence_reverse_embedding, kCUDA, kFloat, kNCHW, SeqReverseEmbFp32, def)
.BindInput("W", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.BindInput("Ids", {LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kInt64))})
.BindOutput("Out",
{LiteType::GetTensorTy(TARGET(kCUDA), PRECISION(kFloat))})
.Finalize();
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "lite/core/kernel.h"
namespace paddle {
namespace lite {
namespace kernels {
namespace cuda {
template <typename T, PrecisionType Ptype>
class SequenceReverseEmbeddingCompute
: public KernelLite<TARGET(kCUDA), Ptype> {
public:
using param_t = operators::LookupTableParam;
void Run() override;
virtual ~SequenceReverseEmbeddingCompute() = default;
private:
lite::Tensor lod_info_;
};
} // namespace cuda
} // namespace kernels
} // namespace lite
} // namespace paddle
...@@ -100,6 +100,7 @@ add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${o ...@@ -100,6 +100,7 @@ add_operator(sequence_reverse_op_lite extra SRCS sequence_reverse_op.cc DEPS ${o
add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS}) add_operator(sequence_pool extra SRCS sequence_pool_op.cc DEPS ${op_DEPS})
add_operator(sequence_conv extra SRCS sequence_conv_op.cc DEPS ${op_DEPS}) add_operator(sequence_conv extra SRCS sequence_conv_op.cc DEPS ${op_DEPS})
add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS}) add_operator(sequence_pool_concat extra SRCS sequence_pool_concat_op.cc DEPS ${op_DEPS})
add_operator(sequence_reverse_embedding_op_lite extra SRCS sequence_reverse_embedding_op.cc DEPS ${op_DEPS})
add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS}) add_operator(reduce_sum_op_lite extra SRCS reduce_ops.cc DEPS ${op_DEPS})
add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS}) add_operator(match_matrix_tensor_op_lite extra SRCS match_matrix_tensor_op.cc DEPS ${op_DEPS})
add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS}) add_operator(search_seq_depadding_op_lite extra SRCS search_seq_depadding_op.cc DEPS ${op_DEPS})
......
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_reverse_embedding_op.h"
#include "lite/core/op_registry.h"
namespace paddle {
namespace lite {
namespace operators {
bool SequenceReverseEmbeddingOp::CheckShape() const {
CHECK_OR_FALSE(param_.W)
CHECK_OR_FALSE(param_.Ids)
CHECK_OR_FALSE(param_.Out)
CHECK_EQ(param_.Ids->lod().empty(), false)
<< "Input(Ids) Tensor of SequenceReverseEmbeddingOp does not contain "
"LoD information.";
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
int ids_rank = ids_dims.size();
CHECK_EQ_OR_FALSE(table_dims.size(), 2)
CHECK_EQ_OR_FALSE(ids_dims[ids_rank - 1], 1)
return true;
}
bool SequenceReverseEmbeddingOp::InferShapeImpl() const {
const auto& table_dims = param_.W->dims();
const auto& ids_dims = param_.Ids->dims();
auto out_dims = ids_dims;
int ids_rank = ids_dims.size();
out_dims[ids_rank - 1] = table_dims[1];
param_.Out->Resize(out_dims);
param_.Out->set_lod(param_.Ids->lod());
return true;
}
bool SequenceReverseEmbeddingOp::AttachImpl(const cpp::OpDesc& op_desc,
lite::Scope* scope) {
auto input = op_desc.Input("W").front();
auto ids = op_desc.Input("Ids").front();
auto out = op_desc.Output("Out").front();
param_.W = scope->FindTensor(input);
param_.Ids = scope->FindTensor(ids);
param_.Out = scope->FindMutableTensor(out);
param_.padding_idx = op_desc.GetAttr<int64_t>("padding_idx");
return true;
}
} // namespace operators
} // namespace lite
} // namespace paddle
REGISTER_LITE_OP(sequence_reverse_embedding,
paddle::lite::operators::SequenceReverseEmbeddingOp);
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace paddle {
namespace lite {
namespace operators {
class SequenceReverseEmbeddingOp : public OpLite {
public:
SequenceReverseEmbeddingOp() {}
explicit SequenceReverseEmbeddingOp(const std::string &op_type)
: OpLite(op_type) {}
bool CheckShape() const override;
bool InferShapeImpl() const override;
bool AttachImpl(const cpp::OpDesc &opdesc, lite::Scope *scope) override;
void AttachKernel(KernelBase *kernel) override { kernel->SetParam(param_); }
std::string DebugString() const override {
return "sequence_reverse_embedding";
}
private:
mutable LookupTableParam param_;
};
} // namespace operators
} // namespace lite
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册