Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
deb510d4
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
deb510d4
编写于
4月 29, 2019
作者:
T
tangwei12
提交者:
GitHub
4月 29, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
cvm op feature (#17081)
cvm without LoD.
上级
554d3a71
变更
2
隐藏空白更改
内联
并排
Showing
2 changed file
with
157 addition
and
51 deletion
+157
-51
paddle/fluid/operators/cvm_op.h
paddle/fluid/operators/cvm_op.h
+70
-49
python/paddle/fluid/tests/unittests/test_cvm_op.py
python/paddle/fluid/tests/unittests/test_cvm_op.py
+87
-2
未找到文件。
paddle/fluid/operators/cvm_op.h
浏览文件 @
deb510d4
...
...
@@ -22,36 +22,60 @@ namespace operators {
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
template
<
typename
T
>
void
CvmComputeKernel
(
const
bool
use_cvm
,
const
int64_t
item_width
,
const
T
**
X
,
T
**
Y
)
{
const
auto
cvm_offset
=
use_cvm
?
0
:
2
;
std
::
memcpy
(
*
Y
,
*
X
+
cvm_offset
,
(
item_width
-
cvm_offset
)
*
sizeof
(
T
));
if
(
use_cvm
)
{
(
*
Y
)[
0
]
=
log
((
*
Y
)[
0
]
+
1
);
(
*
Y
)[
1
]
=
log
((
*
Y
)[
1
]
+
1
)
-
(
*
Y
)[
0
];
}
(
*
X
)
+=
item_width
;
(
*
Y
)
+=
item_width
-
cvm_offset
;
}
template
<
typename
T
>
void
CvmGradComputeKernel
(
const
bool
use_cvm
,
const
int64_t
item_width
,
const
T
&
CVM
,
const
T
**
DY
,
T
**
DX
)
{
const
auto
cvm_offset
=
use_cvm
?
0
:
2
;
std
::
memcpy
(
*
DX
+
cvm_offset
,
*
DY
,
(
item_width
-
cvm_offset
)
*
sizeof
(
T
));
(
*
DX
)[
0
]
=
(
&
CVM
)[
0
];
(
*
DX
)[
1
]
=
(
&
CVM
)[
1
];
(
*
DX
)
+=
item_width
;
(
*
DY
)
+=
item_width
-
cvm_offset
;
}
template
<
typename
T
>
class
CVMOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
const
LoDTensor
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
const
auto
*
x
=
context
.
Input
<
LoDTensor
>
(
"X"
);
const
T
*
x_data
=
x
->
data
<
T
>
();
auto
lod
=
x
->
lod
()[
0
];
int64_t
item_size
=
x
->
numel
()
/
x
->
dims
()[
0
];
int
offset
=
2
;
if
(
!
context
.
Attr
<
bool
>
(
"use_cvm"
))
{
item_size
-=
offset
;
}
LoDTensor
*
y
=
context
.
Output
<
LoDTensor
>
(
"Y"
);
auto
batch_size
=
x
->
dims
()[
0
];
auto
item_size
=
x
->
numel
()
/
batch_size
;
auto
use_cvm
=
context
.
Attr
<
bool
>
(
"use_cvm"
);
auto
*
y
=
context
.
Output
<
LoDTensor
>
(
"Y"
);
T
*
y_data
=
y
->
mutable_data
<
T
>
(
context
.
GetPlace
());
int
seq_num
=
static_cast
<
int
>
(
lod
.
size
())
-
1
;
for
(
int
i
=
0
;
i
<
seq_num
;
++
i
)
{
int64_t
seq_len
=
static_cast
<
int64_t
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
for
(
int
j
=
0
;
j
<
seq_len
;
++
j
)
{
if
(
context
.
Attr
<
bool
>
(
"use_cvm"
))
{
std
::
memcpy
(
y_data
,
x_data
,
item_size
*
sizeof
(
T
));
y_data
[
0
]
=
log
(
y_data
[
0
]
+
1
);
y_data
[
1
]
=
log
(
y_data
[
1
]
+
1
)
-
y_data
[
0
];
x_data
+=
item_size
;
y_data
+=
item_size
;
}
else
{
std
::
memcpy
(
y_data
,
x_data
+
offset
,
item_size
*
sizeof
(
T
));
x_data
+=
item_size
+
offset
;
y_data
+=
item_size
;
// for Input X do not have Lod Information.
if
(
x
->
NumLevels
()
==
0
)
{
for
(
int
i
=
0
;
i
<
batch_size
;
i
++
)
{
CvmComputeKernel
(
use_cvm
,
item_size
,
&
x_data
,
&
y_data
);
}
}
else
{
auto
lod
=
x
->
lod
()[
0
];
for
(
int
i
=
0
;
i
<
lod
.
size
()
-
1
;
++
i
)
{
for
(
int
j
=
0
;
j
<
lod
[
i
+
1
]
-
lod
[
i
];
++
j
)
{
CvmComputeKernel
(
use_cvm
,
item_size
,
&
x_data
,
&
y_data
);
}
}
}
...
...
@@ -62,42 +86,39 @@ template <typename T>
class
CVMGradOpKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
LoDTensor
*
dx
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dx
=
context
.
Output
<
LoDTensor
>
(
framework
::
GradVarName
(
"X"
));
T
*
dx_data
=
dx
->
mutable_data
<
T
>
(
context
.
GetPlace
());
const
Tensor
*
cvm
=
context
.
Input
<
Tensor
>
(
"CVM"
);
const
T
*
cvm_data
=
cvm
->
data
<
T
>
();
int
offset
=
2
;
const
framework
::
LoDTensor
*
dOut
=
const
auto
*
dOut
=
context
.
Input
<
framework
::
LoDTensor
>
(
framework
::
GradVarName
(
"Y"
));
const
T
*
dout_data
=
dOut
->
data
<
T
>
();
auto
lod
=
dx
->
lod
()[
0
];
int64_t
item_size
=
dx
->
numel
()
/
dx
->
dims
()[
0
];
if
(
!
context
.
Attr
<
bool
>
(
"use_cvm"
))
{
item_size
-=
offset
;
}
auto
use_cvm
=
context
.
Attr
<
bool
>
(
"use_cvm"
);
int
seq_num
=
static_cast
<
int
>
(
lod
.
size
())
-
1
;
for
(
int
i
=
0
;
i
<
seq_num
;
++
i
)
{
int64_t
seq_len
=
static_cast
<
int64_t
>
(
lod
[
i
+
1
]
-
lod
[
i
]);
for
(
int
j
=
0
;
j
<
seq_len
;
++
j
)
{
if
(
context
.
Attr
<
bool
>
(
"use_cvm"
))
{
std
::
memcpy
(
dx_data
,
dout_data
,
item_size
*
sizeof
(
T
));
dx_data
[
0
]
=
cvm_data
[
0
];
dx_data
[
1
]
=
cvm_data
[
1
];
dx_data
+=
item_size
;
dout_data
+=
item_size
;
}
else
{
std
::
memcpy
(
dx_data
+
offset
,
dout_data
,
item_size
*
sizeof
(
T
));
dx_data
[
0
]
=
cvm_data
[
0
];
dx_data
[
1
]
=
cvm_data
[
1
];
dx_data
+=
item_size
+
offset
;
dout_data
+=
item_size
;
auto
offset
=
2
;
auto
batch_size
=
dx
->
dims
()[
0
];
auto
item_size
=
dx
->
numel
()
/
batch_size
;
// for Input X do not have Lod Information.
if
(
dx
->
NumLevels
()
==
0
)
{
for
(
int
x
=
0
;
x
<
batch_size
;
++
x
)
{
CvmGradComputeKernel
(
use_cvm
,
item_size
,
*
cvm_data
,
&
dout_data
,
&
dx_data
);
cvm_data
+=
offset
;
}
}
else
{
auto
lod
=
dx
->
lod
()[
0
];
int
seq_num
=
static_cast
<
int
>
(
lod
.
size
())
-
1
;
for
(
int
i
=
0
;
i
<
seq_num
;
++
i
)
{
for
(
int
j
=
0
;
j
<
lod
[
i
+
1
]
-
lod
[
i
];
++
j
)
{
CvmGradComputeKernel
(
use_cvm
,
item_size
,
*
cvm_data
,
&
dout_data
,
&
dx_data
);
}
cvm_data
+=
offset
;
}
cvm_data
+=
offset
;
}
}
};
...
...
python/paddle/fluid/tests/unittests/test_cvm_op.py
浏览文件 @
deb510d4
...
...
@@ -19,15 +19,50 @@ from op_test import OpTest
import
unittest
class
TestCVMOp
(
OpTest
):
def
cvm_compute
(
X
,
item_width
,
use_cvm
):
cvm_offset
=
0
if
use_cvm
else
2
batch_size
=
X
.
shape
[
0
]
Y
=
np
.
ones
([
batch_size
,
item_width
-
cvm_offset
],
np
.
float32
)
for
idx
in
range
(
batch_size
):
if
use_cvm
:
Y
[
idx
]
=
X
[
idx
]
Y
[
idx
][
0
]
=
log
(
Y
[
idx
][
0
]
+
1
)
Y
[
idx
][
1
]
=
log
(
Y
[
idx
][
1
]
+
1
)
-
Y
[
idx
][
0
]
else
:
Y
[
idx
]
=
X
[
idx
][
2
:]
return
Y
def
cvm_grad_compute
(
DY
,
CVM
,
item_width
,
use_cvm
):
batch_size
=
DY
.
shape
[
0
]
DX
=
np
.
ones
([
batch_size
,
item_width
],
np
.
float32
)
for
idx
in
range
(
batch_size
):
DX
[
idx
][
0
]
=
CVM
[
idx
][
0
]
DX
[
idx
][
1
]
=
CVM
[
idx
][
1
]
if
use_cvm
:
DX
[
idx
][
2
:]
=
DY
[
idx
][
2
:]
else
:
DX
[
idx
][
2
:]
=
DY
[
idx
]
return
DX
class
TestCVMOpWithLodTensor
(
OpTest
):
"""
Test cvm op with discrete one-hot labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cvm"
batch_size
=
4
self
.
use_cvm
=
True
batch_size
=
8
dims
=
11
lod
=
[[
1
]]
self
.
inputs
=
{
'X'
:
(
np
.
random
.
uniform
(
0
,
1
,
[
1
,
dims
]).
astype
(
"float32"
),
lod
),
...
...
@@ -43,5 +78,55 @@ class TestCVMOp(OpTest):
self
.
check_output
()
class
TestCVMOpWithOutLodTensor1
(
OpTest
):
"""
Test cvm op with discrete one-hot labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cvm"
self
.
use_cvm
=
True
batch_size
=
2
item_width
=
11
input
=
np
.
random
.
uniform
(
0
,
1
,
(
batch_size
,
item_width
)).
astype
(
'float32'
)
output
=
cvm_compute
(
input
,
item_width
,
self
.
use_cvm
)
cvm
=
np
.
array
([[
0.6
,
0.4
]]).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
input
,
'CVM'
:
cvm
}
self
.
attrs
=
{
'use_cvm'
:
self
.
use_cvm
}
self
.
outputs
=
{
'Y'
:
output
}
def
test_check_output
(
self
):
self
.
check_output
()
class
TestCVMOpWithOutLodTensor2
(
OpTest
):
"""
Test cvm op with discrete one-hot labels.
"""
def
setUp
(
self
):
self
.
op_type
=
"cvm"
self
.
use_cvm
=
False
batch_size
=
2
item_width
=
11
input
=
np
.
random
.
uniform
(
0
,
1
,
(
batch_size
,
item_width
)).
astype
(
'float32'
)
output
=
cvm_compute
(
input
,
item_width
,
self
.
use_cvm
)
cvm
=
np
.
array
([[
0.6
,
0.4
]]).
astype
(
"float32"
)
self
.
inputs
=
{
'X'
:
input
,
'CVM'
:
cvm
}
self
.
attrs
=
{
'use_cvm'
:
self
.
use_cvm
}
self
.
outputs
=
{
'Y'
:
output
}
def
test_check_output
(
self
):
self
.
check_output
()
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录