Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
79918a84
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
79918a84
编写于
8月 22, 2018
作者:
Q
qingqing01
提交者:
sneaxiy
8月 23, 2018
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
add sequence_mask_op for DAM model
上级
77489634
变更
9
隐藏空白更改
内联
并排
Showing
9 changed file
with
278 addition
and
5 deletion
+278
-5
paddle/fluid/API.spec
paddle/fluid/API.spec
+1
-0
paddle/fluid/operators/batch_norm_op.cc
paddle/fluid/operators/batch_norm_op.cc
+1
-1
paddle/fluid/operators/sequence_mask_op.cc
paddle/fluid/operators/sequence_mask_op.cc
+26
-0
paddle/fluid/operators/sequence_mask_op.cu
paddle/fluid/operators/sequence_mask_op.cu
+22
-0
paddle/fluid/operators/sequence_mask_op.h
paddle/fluid/operators/sequence_mask_op.h
+117
-0
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+20
-2
python/paddle/fluid/nets.py
python/paddle/fluid/nets.py
+1
-1
python/paddle/fluid/tests/book/test_image_classification.py
python/paddle/fluid/tests/book/test_image_classification.py
+4
-1
python/paddle/fluid/tests/unittests/test_sequence_mask.py
python/paddle/fluid/tests/unittests/test_sequence_mask.py
+86
-0
未找到文件。
paddle/fluid/API.spec
浏览文件 @
79918a84
...
...
@@ -162,6 +162,7 @@ paddle.fluid.layers.crop ArgSpec(args=['x', 'shape', 'offsets', 'name'], varargs
paddle.fluid.layers.rank_loss ArgSpec(args=['label', 'left', 'right', 'name'], varargs=None, keywords=None, defaults=(None,))
paddle.fluid.layers.prelu ArgSpec(args=['x', 'mode', 'param_attr', 'name'], varargs=None, keywords=None, defaults=(None, None))
paddle.fluid.layers.flatten ArgSpec(args=['x', 'axis', 'name'], varargs=None, keywords=None, defaults=(1, None))
paddle.fluid.layers.sequence_mask ArgSpec(args=['x', 'max_len', 'mask_dtype'], varargs=None, keywords=None, defaults=('int64',))
paddle.fluid.layers.data ArgSpec(args=['name', 'shape', 'append_batch_size', 'dtype', 'lod_level', 'type', 'stop_gradient'], varargs=None, keywords=None, defaults=(True, 'float32', 0, VarType.LOD_TENSOR, True))
paddle.fluid.layers.open_recordio_file ArgSpec(args=['filename', 'shapes', 'lod_levels', 'dtypes', 'pass_num', 'for_parallel'], varargs=None, keywords=None, defaults=(1, True))
paddle.fluid.layers.open_files ArgSpec(args=['filenames', 'shapes', 'lod_levels', 'dtypes', 'thread_num', 'buffer_size', 'pass_num', 'is_test'], varargs=None, keywords=None, defaults=(None, None, 1, None))
...
...
paddle/fluid/operators/batch_norm_op.cc
浏览文件 @
79918a84
...
...
@@ -135,7 +135,7 @@ class BatchNormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput
(
"Variance"
,
"The global variance (for training) "
"or estimated Variance (for testing)"
);
AddOutput
(
"Y"
,
"result after normalization"
)
.
Reuse
(
"X"
)
;
AddOutput
(
"Y"
,
"result after normalization"
);
AddOutput
(
"MeanOut"
,
"Share memory with Mean. "
"Store the global mean when training"
)
...
...
paddle/fluid/operators/sequence_mask_op.cc
0 → 100644
浏览文件 @
79918a84
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_mask_op.h"
REGISTER_OPERATOR
(
sequence_mask
,
paddle
::
operators
::
SequenceMaskOp
,
paddle
::
operators
::
SequenceMaskOpMaker
,
paddle
::
framework
::
EmptyGradOpMaker
);
REGISTER_OP_CPU_KERNEL
(
sequence_mask
,
paddle
::
operators
::
SequenceMaskKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int
>
,
paddle
::
operators
::
SequenceMaskKernel
<
paddle
::
platform
::
CPUDeviceContext
,
int64_t
>
);
paddle/fluid/operators/sequence_mask_op.cu
0 → 100644
浏览文件 @
79918a84
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/sequence_mask_op.h"
REGISTER_OP_CUDA_KERNEL
(
sequence_mask
,
paddle
::
operators
::
SequenceMaskKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int
>
,
paddle
::
operators
::
SequenceMaskKernel
<
paddle
::
platform
::
CUDADeviceContext
,
int64_t
>
);
paddle/fluid/operators/sequence_mask_op.h
0 → 100644
浏览文件 @
79918a84
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/for_range.h"
namespace
paddle
{
namespace
operators
{
class
SequenceMaskOp
:
public
framework
::
OperatorWithKernel
{
public:
using
framework
::
OperatorWithKernel
::
OperatorWithKernel
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
),
"Input(X) must exist"
);
auto
max_len
=
ctx
->
Attrs
().
Get
<
int
>
(
"max_len"
);
PADDLE_ENFORCE_GT
(
max_len
,
1
,
"Attr(max_len) must be larger than 1"
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Y"
),
"Output(Y) must exist"
);
auto
dim
=
framework
::
vectorize2int
(
ctx
->
GetInputDim
(
"X"
));
dim
.
push_back
(
max_len
);
ctx
->
SetOutputDim
(
"Y"
,
framework
::
make_ddim
(
dim
));
}
};
class
SequenceMaskOpMaker
:
public
framework
::
OpProtoAndCheckerMaker
{
public:
void
Make
()
override
{
AddInput
(
"X"
,
"The input of sequence_mask op."
);
AddOutput
(
"Y"
,
"The output mask of sequence_mask op."
);
AddAttr
<
int
>
(
"max_len"
,
"The maximum length of the sequence."
)
.
GreaterThan
(
1
);
AddAttr
<
int
>
(
"out_dtype"
,
"Output data type"
);
AddComment
(
R"DOC(
SequenceMask Operator
This operator outputs a Mask according to Input(X) and Attr(max_len).
Supposing Input(X) is a Tensor with shape [d_1, d_2, ..., d_n], the
Output(Y) is a mask with shape [d_1, d_2, ..., d_n, max_len], where:
Y(i_1, i_2, ..., i_n, j) = (j < X(i_1, i_2, ..., i_n))
)DOC"
);
}
};
template
<
typename
Tx
,
typename
Ty
>
struct
SequenceMaskForRangeFunctor
{
HOSTDEVICE
SequenceMaskForRangeFunctor
(
const
Tx
*
x
,
Ty
*
y
,
int
max_len
)
:
x_
(
x
),
y_
(
y
),
max_len_
(
max_len
)
{}
HOSTDEVICE
void
operator
()(
int
y_idx
)
const
{
int
x_idx
=
y_idx
/
max_len_
;
int
j
=
y_idx
%
max_len_
;
y_
[
y_idx
]
=
static_cast
<
Ty
>
(
j
<
x_
[
x_idx
]
?
1
:
0
);
}
private:
const
Tx
*
x_
;
Ty
*
y_
;
int
max_len_
;
};
template
<
typename
DeviceContext
,
typename
Tx
>
struct
SequenceMaskFunctor
{
using
Tensor
=
framework
::
LoDTensor
;
SequenceMaskFunctor
(
const
DeviceContext
&
ctx
,
const
Tx
*
x
,
Tensor
*
y
,
int
limits
,
int
max_len
)
:
ctx_
(
ctx
),
x_
(
x
),
y_
(
y
),
limits_
(
limits
),
max_len_
(
max_len
)
{}
template
<
typename
Ty
>
void
operator
()()
const
{
auto
*
y_data
=
y_
->
mutable_data
<
Ty
>
(
ctx_
.
GetPlace
());
platform
::
ForRange
<
DeviceContext
>
for_range
(
ctx_
,
limits_
);
for_range
(
SequenceMaskForRangeFunctor
<
Tx
,
Ty
>
(
x_
,
y_data
,
max_len_
));
}
private:
const
DeviceContext
&
ctx_
;
const
Tx
*
x_
;
Tensor
*
y_
;
int
limits_
;
int
max_len_
;
};
template
<
typename
DeviceContext
,
typename
Tx
>
class
SequenceMaskKernel
:
public
framework
::
OpKernel
<
Tx
>
{
using
Tensor
=
framework
::
LoDTensor
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Output
<
Tensor
>
(
"Y"
);
auto
max_len
=
ctx
.
Attr
<
int
>
(
"max_len"
);
auto
out_dtype
=
static_cast
<
framework
::
proto
::
VarType
::
Type
>
(
ctx
.
Attr
<
int
>
(
"out_dtype"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
framework
::
VisitDataType
(
out_dtype
,
SequenceMaskFunctor
<
DeviceContext
,
Tx
>
(
dev_ctx
,
x
->
data
<
Tx
>
(),
y
,
x
->
numel
()
*
max_len
,
max_len
));
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/fluid/layers/nn.py
浏览文件 @
79918a84
...
...
@@ -27,6 +27,7 @@ from . import utils
import
random
from
..
import
unique_name
from
functools
import
reduce
import
warnings
__all__
=
[
'fc'
,
...
...
@@ -103,6 +104,7 @@ __all__ = [
'rank_loss'
,
'prelu'
,
'flatten'
,
'sequence_mask'
,
]
...
...
@@ -2046,7 +2048,7 @@ def batch_norm(input,
param_attr(ParamAttr): The parameter attribute for Parameter `scale`.
bias_attr(ParamAttr): The parameter attribute for Parameter `bias`.
data_layout(string, default NCHW): NCHW|NHWC
in_place(bool, Default False):
Make the input and output of batch norm reuse memory
.
in_place(bool, Default False):
This argument is deprecated since 0.15.0
.
use_mkldnn(bool, Default false): ${use_mkldnn_comment}
name(string, Default None): A name for this layer(optional). If set None, the layer
will be named automatically.
...
...
@@ -2068,6 +2070,10 @@ def batch_norm(input,
helper
=
LayerHelper
(
'batch_norm'
,
**
locals
())
dtype
=
helper
.
input_dtype
()
if
in_place
:
raise
warnings
.
warn
(
"The argument in_place is deprecated since 0.15.0, "
"please do not set it True."
)
input_shape
=
input
.
shape
if
data_layout
==
'NCHW'
:
channel_num
=
input_shape
[
1
]
...
...
@@ -2117,7 +2123,7 @@ def batch_norm(input,
saved_mean
=
helper
.
create_tmp_variable
(
dtype
=
dtype
,
stop_gradient
=
True
)
saved_variance
=
helper
.
create_tmp_variable
(
dtype
=
dtype
,
stop_gradient
=
True
)
batch_norm_out
=
input
if
in_place
else
helper
.
create_tmp_variable
(
dtype
)
batch_norm_out
=
helper
.
create_tmp_variable
(
dtype
)
helper
.
append_op
(
type
=
"batch_norm"
,
...
...
@@ -5517,3 +5523,15 @@ def flatten(x, axis=1, name=None):
outputs
=
{
'Out'
:
out
},
attrs
=
{
"axis"
:
axis
})
return
out
def
sequence_mask
(
x
,
max_len
,
mask_dtype
=
'int64'
):
helper
=
LayerHelper
(
'sequence_mask'
,
**
locals
())
y
=
helper
.
create_tmp_variable
(
dtype
=
mask_dtype
)
helper
.
append_op
(
type
=
'sequence_mask'
,
inputs
=
{
'X'
:
[
x
]},
outputs
=
{
'Y'
:
y
},
attrs
=
{
'max_len'
:
max_len
,
'out_dtype'
:
y
.
dtype
})
return
y
python/paddle/fluid/nets.py
浏览文件 @
79918a84
...
...
@@ -229,7 +229,7 @@ def img_conv_group(input,
use_mkldnn
=
use_mkldnn
)
if
conv_with_batchnorm
[
i
]:
tmp
=
layers
.
batch_norm
(
input
=
tmp
,
act
=
conv_act
,
in_place
=
True
)
tmp
=
layers
.
batch_norm
(
input
=
tmp
,
act
=
conv_act
)
drop_rate
=
conv_batchnorm_drop_rate
[
i
]
if
abs
(
drop_rate
)
>
1e-5
:
tmp
=
layers
.
dropout
(
x
=
tmp
,
dropout_prob
=
drop_rate
)
...
...
python/paddle/fluid/tests/book/test_image_classification.py
浏览文件 @
79918a84
...
...
@@ -256,7 +256,10 @@ def main(net_type, use_cuda, is_local=True):
save_dirname
=
"image_classification_"
+
net_type
+
".inference.model"
train
(
net_type
,
use_cuda
,
save_dirname
,
is_local
)
infer
(
use_cuda
,
save_dirname
)
# There is bug in fluid.InferenceTranspiler for VGG.
if
net_type
==
"resnet"
:
infer
(
use_cuda
,
save_dirname
)
class
TestImageClassification
(
unittest
.
TestCase
):
...
...
python/paddle/fluid/tests/unittests/test_sequence_mask.py
0 → 100644
浏览文件 @
79918a84
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
op_test
import
OpTest
from
paddle.fluid.framework
import
convert_np_dtype_to_dtype_
import
numpy
as
np
import
copy
import
unittest
class
SequenceMaskTestBase
(
OpTest
):
def
initDefaultParameters
(
self
):
self
.
op_type
=
'sequence_mask'
self
.
max_len
=
10
self
.
mask_dtype
=
'int64'
self
.
x
=
[[
0
,
3
,
4
],
[
5
,
7
,
9
]]
def
initParameters
(
self
):
pass
def
setUp
(
self
):
self
.
initDefaultParameters
()
self
.
initParameters
()
if
not
isinstance
(
self
.
x
,
np
.
ndarray
):
self
.
x
=
np
.
array
(
self
.
x
)
self
.
inputs
=
{
'X'
:
self
.
x
}
self
.
outputs
=
{
'Y'
:
self
.
calc_ground_truth_mask
()}
self
.
attrs
=
{
'max_len'
:
self
.
max_len
,
'out_dtype'
:
convert_np_dtype_to_dtype_
(
self
.
mask_dtype
)
}
def
calc_ground_truth_mask
(
self
):
shape
=
self
.
x
.
shape
+
(
self
.
max_len
,
)
index_broadcast
=
np
.
broadcast_to
(
np
.
reshape
(
range
(
self
.
max_len
),
newshape
=
[
1
]
*
self
.
x
.
ndim
+
[
-
1
]),
shape
=
shape
)
x_broadcast
=
np
.
broadcast_to
(
np
.
reshape
(
self
.
x
,
newshape
=
self
.
x
.
shape
+
(
-
1
,
)),
shape
=
shape
)
return
(
index_broadcast
<
x_broadcast
).
astype
(
self
.
mask_dtype
)
def
test_check_output
(
self
):
self
.
check_output
()
class
SequenceMaskTest1
(
SequenceMaskTestBase
):
def
initParameters
(
self
):
self
.
mask_dtype
=
'bool'
class
SequenceMaskTest2
(
SequenceMaskTestBase
):
def
initParameters
(
self
):
self
.
mask_dtype
=
'uint8'
class
SequenceMaskTest3
(
SequenceMaskTestBase
):
def
initParameters
(
self
):
self
.
mask_dtype
=
'int32'
class
SequenceMaskTest4
(
SequenceMaskTestBase
):
def
initParameters
(
self
):
self
.
mask_dtype
=
'float32'
class
SequenceMaskTest5
(
SequenceMaskTestBase
):
def
initParameters
(
self
):
self
.
mask_dtype
=
'float64'
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录