Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle-Lite
提交
bc66d2be
P
Paddle-Lite
项目概览
PaddlePaddle
/
Paddle-Lite
通知
337
Star
4
Fork
1
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
271
列表
看板
标记
里程碑
合并请求
78
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle-Lite
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
271
Issue
271
列表
看板
标记
里程碑
合并请求
78
合并请求
78
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
bc66d2be
编写于
7月 01, 2020
作者:
W
Wilber
提交者:
GitHub
7月 01, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[CUDA] [Kernel] Add sequence_mask kernel. (#3868)
上级
973cca29
变更
8
隐藏空白更改
内联
并排
Showing
8 changed file
with
371 addition
and
0 deletion
+371
-0
lite/kernels/cuda/CMakeLists.txt
lite/kernels/cuda/CMakeLists.txt
+2
-0
lite/kernels/cuda/sequence_mask_compute.cu
lite/kernels/cuda/sequence_mask_compute.cu
+91
-0
lite/kernels/cuda/sequence_mask_compute.h
lite/kernels/cuda/sequence_mask_compute.h
+41
-0
lite/kernels/cuda/sequence_mask_compute_test.cc
lite/kernels/cuda/sequence_mask_compute_test.cc
+131
-0
lite/operators/CMakeLists.txt
lite/operators/CMakeLists.txt
+1
-0
lite/operators/op_params.h
lite/operators/op_params.h
+8
-0
lite/operators/sequence_mask_op.cc
lite/operators/sequence_mask_op.cc
+52
-0
lite/operators/sequence_mask_op.h
lite/operators/sequence_mask_op.h
+45
-0
未找到文件。
lite/kernels/cuda/CMakeLists.txt
浏览文件 @
bc66d2be
...
@@ -37,6 +37,7 @@ add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_comput
...
@@ -37,6 +37,7 @@ add_kernel(sequence_reverse_compute_cuda CUDA extra SRCS sequence_reverse_comput
add_kernel
(
sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
sequence_pad_compute_cuda CUDA extra SRCS sequence_pad_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
sequence_unpad_compute_cuda CUDA extra SRCS sequence_unpad_compute.cu DEPS
${
lite_kernel_deps
}
${
math_cuda
}
)
add_kernel
(
sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_concat_compute_cuda CUDA extra SRCS sequence_concat_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_mask_compute_cuda CUDA extra SRCS sequence_mask_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
sequence_arithmetic_compute_cuda CUDA extra SRCS sequence_arithmetic_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
lookup_table_compute_cuda CUDA extra SRCS lookup_table_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS
${
lite_kernel_deps
}
)
add_kernel
(
attention_padding_mask_compute_cuda CUDA extra SRCS attention_padding_mask_compute.cu DEPS
${
lite_kernel_deps
}
)
...
@@ -80,6 +81,7 @@ if(LITE_BUILD_EXTRA)
...
@@ -80,6 +81,7 @@ if(LITE_BUILD_EXTRA)
nv_test
(
sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda
)
nv_test
(
sequence_reverse_compute_cuda_test SRCS sequence_reverse_compute_test.cc DEPS sequence_reverse_compute_cuda
)
nv_test
(
sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS sequence_pad_compute_cuda
)
nv_test
(
sequence_pad_compute_cuda_test SRCS sequence_pad_compute_test.cc DEPS sequence_pad_compute_cuda
)
nv_test
(
sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS sequence_unpad_compute_cuda
)
nv_test
(
sequence_unpad_compute_cuda_test SRCS sequence_unpad_compute_test.cc DEPS sequence_unpad_compute_cuda
)
nv_test
(
sequence_mask_compute_cuda_test SRCS sequence_mask_compute_test.cc DEPS sequence_mask_compute_cuda
)
nv_test
(
var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda
)
nv_test
(
var_conv_2d_compute_cuda_test SRCS var_conv_2d_compute_test.cc DEPS var_conv_2d_compute_cuda
)
#nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
#nv_test(sequence_concat_compute_cuda_test SRCS sequence_concat_compute_test.cc DEPS sequence_concat_compute_cuda)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
#nv_test(attention_padding_mask_compute_cuda_test SRCS attention_padding_mask_compute_test.cc DEPS attention_padding_mask_compute_cuda)
...
...
lite/kernels/cuda/sequence_mask_compute.cu
0 → 100644
浏览文件 @
bc66d2be
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/kernels/cuda/sequence_mask_compute.h"
#include <thrust/device_ptr.h>
#include <thrust/reduce.h>
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
>
__global__
void
SequenceMaskKernel
(
T
*
dst
,
const
int64_t
*
src
,
int
count
,
int
maxlen
)
{
CUDA_KERNEL_LOOP
(
index
,
count
)
{
int
src_idx
=
index
/
maxlen
;
int
inner_idx
=
index
%
maxlen
;
dst
[
index
]
=
static_cast
<
T
>
(
inner_idx
<
src
[
src_idx
]
?
1
:
0
);
}
}
template
<
typename
T
,
PrecisionType
Ptype
>
void
SequenceMaskCompute
<
T
,
Ptype
>::
Run
()
{
auto
&
param
=
this
->
template
Param
<
param_t
>();
auto
&
ctx
=
this
->
ctx_
->
template
As
<
CUDAContext
>();
auto
stream
=
ctx
.
exec_stream
();
const
auto
*
x
=
param
.
X
;
auto
*
x_data
=
x
->
template
data
<
int64_t
>();
auto
*
y
=
param
.
Y
;
int
maxlen
=
param
.
maxlen
;
if
(
param
.
MaxLenTensor
)
{
auto
*
len_tensor_data
=
param
.
MaxLenTensor
->
template
data
<
int32_t
>();
int32_t
len_data
{
0
};
TargetWrapperCuda
::
MemcpySync
(
&
len_data
,
len_tensor_data
,
sizeof
(
int32_t
),
IoDirection
::
DtoH
);
maxlen
=
len_data
;
}
if
(
maxlen
<
0
)
{
maxlen
=
thrust
::
reduce
(
x_data
,
x_data
+
x
->
numel
(),
0
,
thrust
::
maximum
<
T
>
());
}
auto
y_dim
=
x
->
dims
().
Vectorize
();
y_dim
.
push_back
(
maxlen
);
y
->
Resize
(
y_dim
);
const
int
count
=
y
->
numel
();
auto
*
dst_data
=
y
->
template
mutable_data
<
float
>(
TARGET
(
kCUDA
));
if
(
param
.
out_dtype
==
5
)
{
SequenceMaskKernel
<
float
><<<
CUDA_GET_BLOCKS
(
count
),
CUDA_NUM_THREADS
,
0
,
stream
>>>
(
dst_data
,
x_data
,
count
,
maxlen
);
}
else
{
LOG
(
FATAL
)
<<
"not supported out_dtype: "
<<
param
.
out_dtype
;
}
CUDA_POST_KERNEL_CHECK
;
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
using
SeqMaskFp32
=
paddle
::
lite
::
kernels
::
cuda
::
SequenceMaskCompute
<
float
,
PRECISION
(
kFloat
)
>
;
REGISTER_LITE_KERNEL
(
sequence_mask
,
kCUDA
,
kFloat
,
kNCHW
,
SeqMaskFp32
,
def
)
.
BindInput
(
"X"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
),
PRECISION
(
kInt64
))})
.
BindInput
(
"MaxLenTensor"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
BindOutput
(
"Y"
,
{
LiteType
::
GetTensorTy
(
TARGET
(
kCUDA
))})
.
Finalize
();
lite/kernels/cuda/sequence_mask_compute.h
0 → 100644
浏览文件 @
bc66d2be
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <vector>
#include "lite/core/kernel.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
template
<
typename
T
,
PrecisionType
Ptype
>
class
SequenceMaskCompute
:
public
KernelLite
<
TARGET
(
kCUDA
),
Ptype
>
{
public:
using
param_t
=
operators
::
SequenceMaskParam
;
void
Run
()
override
;
virtual
~
SequenceMaskCompute
()
=
default
;
// private:
// lite::Tensor seq_offsets_;
// std::vector<int64_t> seq_len_;
// std::vector<size_t> seq_offsets_vec_;
};
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/kernels/cuda/sequence_mask_compute_test.cc
0 → 100644
浏览文件 @
bc66d2be
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <gtest/gtest.h>
#include <memory>
#include <random>
#include <utility>
#include <vector>
#include "lite/api/test_helper.h"
#include "lite/backends/cuda/cuda_utils.h"
#include "lite/kernels/cuda/sequence_mask_compute.h"
// #include "lite/utils/float16.h"
namespace
paddle
{
namespace
lite
{
namespace
kernels
{
namespace
cuda
{
class
SequenceMaskTest
:
public
::
testing
::
Test
{
protected:
SequenceMaskTest
()
:
maxlen
(
4
),
out_dtype
(
5
),
x_data
({
3
,
2
,
1
,
0
}),
out_shape
({
static_cast
<
int64_t
>
(
x_data
.
size
()),
maxlen
})
{
X_ref
.
Resize
(
lite
::
DDim
({
static_cast
<
int64_t
>
(
x_data
.
size
())}));
X_gpu
.
Resize
(
X_ref
.
dims
());
auto
*
x_ref_data
=
X_ref
.
mutable_data
<
int64_t
>
();
// prepare input
for
(
size_t
i
=
0
;
i
<
x_data
.
size
();
i
++
)
{
x_ref_data
[
i
]
=
x_data
[
i
];
}
Out_ref
.
Resize
(
lite
::
DDim
(
out_shape
));
Out_gpu
.
Resize
(
Out_ref
.
dims
());
Out_cpu
.
Resize
(
Out_ref
.
dims
());
cpu_base
(
&
X_ref
,
&
Out_ref
);
device_init
();
}
void
device_init
()
{
ctx
.
reset
(
new
KernelContext
);
cudaStreamCreate
(
&
stream
);
param
.
X
=
&
X_gpu
;
param
.
Y
=
&
Out_gpu
;
param
.
maxlen
=
maxlen
;
param
.
out_dtype
=
out_dtype
;
}
void
float_data_init
()
{
X_gpu
.
Assign
<
int64_t
,
lite
::
DDim
,
TARGET
(
kCUDA
)
>
(
X_ref
.
data
<
int64_t
>
(),
X_gpu
.
dims
());
}
void
half_data_init
()
{}
void
cpu_base
(
const
lite
::
Tensor
*
X
,
lite
::
Tensor
*
Out
)
{
auto
*
out_data
=
Out
->
mutable_data
<
float
>
();
for
(
size_t
i
=
0
;
i
<
x_data
.
size
();
++
i
)
{
for
(
int
j
=
0
;
j
<
maxlen
;
++
j
)
{
out_data
[
i
*
maxlen
+
j
]
=
j
<
x_data
[
i
]
?
1
:
0
;
}
}
}
int
maxlen
,
out_dtype
;
std
::
vector
<
int64_t
>
x_data
,
out_shape
;
lite
::
Tensor
X_ref
,
Out_ref
;
lite
::
Tensor
X_gpu
,
Out_gpu
;
lite
::
Tensor
Out_cpu
;
operators
::
SequenceMaskParam
param
;
std
::
unique_ptr
<
KernelContext
>
ctx
;
cudaStream_t
stream
;
};
TEST_F
(
SequenceMaskTest
,
fp32
)
{
float_data_init
();
auto
&
context
=
ctx
->
As
<
CUDAContext
>
();
context
.
SetExecStream
(
stream
);
SequenceMaskCompute
<
float
,
PRECISION
(
kFloat
)
>
kernel
;
kernel
.
SetParam
(
param
);
kernel
.
SetContext
(
std
::
move
(
ctx
));
for
(
int
i
=
0
;
i
<
FLAGS_warmup
;
++
i
)
{
kernel
.
Launch
();
cudaDeviceSynchronize
();
}
auto
start
=
GetCurrentUS
();
kernel
.
PrepareForRun
();
for
(
int
i
=
0
;
i
<
FLAGS_repeats
;
++
i
)
{
kernel
.
Run
();
}
cudaDeviceSynchronize
();
auto
duration
=
(
GetCurrentUS
()
-
start
)
/
1000.0
;
LOG
(
INFO
)
<<
"fp32, warmup: "
<<
FLAGS_warmup
<<
", repeats: "
<<
FLAGS_repeats
<<
", spend "
<<
duration
/
FLAGS_repeats
<<
" ms in average."
;
CopySync
<
TARGET
(
kCUDA
)
>
(
Out_cpu
.
mutable_data
<
float
>
(),
Out_gpu
.
data
<
float
>
(),
sizeof
(
float
)
*
Out_gpu
.
numel
(),
IoDirection
::
DtoH
);
for
(
int
i
=
0
;
i
<
Out_gpu
.
numel
();
++
i
)
{
EXPECT_NEAR
(
Out_cpu
.
data
<
float
>
()[
i
],
Out_ref
.
data
<
float
>
()[
i
],
1e-5
);
}
}
}
// namespace cuda
}
// namespace kernels
}
// namespace lite
}
// namespace paddle
lite/operators/CMakeLists.txt
浏览文件 @
bc66d2be
...
@@ -78,6 +78,7 @@ add_operator(shape_op_lite extra SRCS shape_op.cc DEPS ${op_DEPS})
...
@@ -78,6 +78,7 @@ add_operator(shape_op_lite extra SRCS shape_op.cc DEPS ${op_DEPS})
add_operator
(
sequence_expand_op_lite extra SRCS sequence_expand_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_expand_op_lite extra SRCS sequence_expand_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_unpad_op_lite extra SRCS sequence_unpad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_unpad_op_lite extra SRCS sequence_unpad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_pad_op_lite extra SRCS sequence_pad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_pad_op_lite extra SRCS sequence_pad_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
sequence_mask_op_lite extra SRCS sequence_mask_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
im2sequence_op extra SRCS im2sequence_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
im2sequence_op extra SRCS im2sequence_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
gather_op extra SRCS gather_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
gather_op extra SRCS gather_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
anchor_generator_op extra SRCS anchor_generator_op.cc DEPS
${
op_DEPS
}
)
add_operator
(
anchor_generator_op extra SRCS anchor_generator_op.cc DEPS
${
op_DEPS
}
)
...
...
lite/operators/op_params.h
浏览文件 @
bc66d2be
...
@@ -1045,6 +1045,14 @@ struct SequenceUnpadParam : ParamBase {
...
@@ -1045,6 +1045,14 @@ struct SequenceUnpadParam : ParamBase {
lite
::
Tensor
*
Out
{};
lite
::
Tensor
*
Out
{};
};
};
struct
SequenceMaskParam
:
ParamBase
{
const
lite
::
Tensor
*
X
{};
const
lite
::
Tensor
*
MaxLenTensor
{
nullptr
};
lite
::
Tensor
*
Y
{};
int
maxlen
{
-
1
};
int
out_dtype
;
};
struct
SequenceExpandAsParam
:
ParamBase
{
struct
SequenceExpandAsParam
:
ParamBase
{
const
lite
::
Tensor
*
x
{
nullptr
};
const
lite
::
Tensor
*
x
{
nullptr
};
const
lite
::
Tensor
*
y
{
nullptr
};
const
lite
::
Tensor
*
y
{
nullptr
};
...
...
lite/operators/sequence_mask_op.cc
0 → 100644
浏览文件 @
bc66d2be
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "lite/operators/sequence_mask_op.h"
#include "lite/core/op_registry.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
bool
SequenceMaskOp
::
CheckShape
()
const
{
CHECK_OR_FALSE
(
param_
.
X
);
CHECK_OR_FALSE
(
param_
.
Y
);
return
true
;
}
bool
SequenceMaskOp
::
InferShapeImpl
()
const
{
return
true
;
}
bool
SequenceMaskOp
::
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
{
param_
.
X
=
const_cast
<
lite
::
Tensor
*>
(
&
scope
->
FindVar
(
opdesc
.
Input
(
"X"
).
front
())
->
Get
<
lite
::
Tensor
>
());
if
(
opdesc
.
HasInput
(
"MaxLenTensor"
)
&&
!
opdesc
.
Input
(
"MaxLenTensor"
).
empty
())
{
auto
var
=
scope
->
FindVar
(
opdesc
.
Input
(
"MaxLenTensor"
).
front
());
if
(
var
!=
nullptr
)
{
param_
.
MaxLenTensor
=
var
->
GetMutable
<
lite
::
Tensor
>
();
}
}
param_
.
Y
=
scope
->
FindVar
(
opdesc
.
Output
(
"Y"
).
front
())
->
GetMutable
<
lite
::
Tensor
>
();
param_
.
maxlen
=
opdesc
.
GetAttr
<
int
>
(
"maxlen"
);
param_
.
out_dtype
=
opdesc
.
GetAttr
<
int
>
(
"out_dtype"
);
return
true
;
}
}
// namespace operators
}
// namespace lite
}
// namespace paddle
REGISTER_LITE_OP
(
sequence_mask
,
paddle
::
lite
::
operators
::
SequenceMaskOp
);
lite/operators/sequence_mask_op.h
0 → 100644
浏览文件 @
bc66d2be
// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include <vector>
#include "lite/core/op_lite.h"
#include "lite/core/scope.h"
namespace
paddle
{
namespace
lite
{
namespace
operators
{
class
SequenceMaskOp
:
public
OpLite
{
public:
SequenceMaskOp
()
{}
explicit
SequenceMaskOp
(
const
std
::
string
&
op_type
)
:
OpLite
(
op_type
)
{}
bool
CheckShape
()
const
override
;
bool
InferShapeImpl
()
const
override
;
bool
AttachImpl
(
const
cpp
::
OpDesc
&
opdesc
,
lite
::
Scope
*
scope
)
override
;
void
AttachKernel
(
KernelBase
*
kernel
)
override
{
kernel
->
SetParam
(
param_
);
}
std
::
string
DebugString
()
const
override
{
return
"sequence_mask"
;
}
private:
mutable
SequenceMaskParam
param_
;
};
}
// namespace operators
}
// namespace lite
}
// namespace paddle
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录