Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3d8571e8
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3d8571e8
编写于
3月 10, 2020
作者:
G
guofei
提交者:
GitHub
3月 10, 2020
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
modify assign op and add unittest of assign op (#22769)
As the title.
上级
e081c7a0
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
44 addition
and
1 deletion
+44
-1
paddle/fluid/operators/assign_op.cc
paddle/fluid/operators/assign_op.cc
+13
-1
python/paddle/fluid/tests/unittests/test_assign_op.py
python/paddle/fluid/tests/unittests/test_assign_op.py
+31
-0
未找到文件。
paddle/fluid/operators/assign_op.cc
浏览文件 @
3d8571e8
...
...
@@ -57,6 +57,17 @@ class AssignOp : public framework::OperatorWithKernel {
}
};
class
AssignInferVarType
:
public
framework
::
VarTypeInference
{
public:
void
operator
()(
framework
::
InferVarTypeContext
*
ctx
)
const
override
{
auto
out_var_name
=
ctx
->
Output
(
"Out"
)[
0
];
auto
input_type
=
ctx
->
GetType
(
ctx
->
Input
(
"X"
)[
0
]);
auto
input_data_type
=
ctx
->
GetDataType
(
ctx
->
Input
(
"X"
)[
0
]);
ctx
->
SetType
(
out_var_name
,
input_type
);
ctx
->
SetDataType
(
out_var_name
,
input_data_type
);
}
};
class
AssignKernel
{
public:
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
)
const
{
...
...
@@ -116,7 +127,8 @@ namespace plat = paddle::platform;
REGISTER_OPERATOR
(
assign
,
ops
::
AssignOp
,
ops
::
AssignGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
AssignGradMaker
<
paddle
::
imperative
::
OpBase
>
,
ops
::
AssignOpProtoMaker
,
ops
::
AssignOpInplaceInferer
);
ops
::
AssignOpProtoMaker
,
ops
::
AssignOpInplaceInferer
,
ops
::
AssignInferVarType
);
REGISTER_OP_CPU_KERNEL_FUNCTOR
(
assign
,
float
,
ops
::
AssignKernel
,
double
,
ops
::
AssignKernel
,
int
,
ops
::
AssignKernel
,
...
...
python/paddle/fluid/tests/unittests/test_assign_op.py
浏览文件 @
3d8571e8
...
...
@@ -21,6 +21,7 @@ import paddle.fluid.core as core
from
paddle.fluid.op
import
Operator
import
paddle.fluid
as
fluid
from
paddle.fluid
import
compiler
,
Program
,
program_guard
from
paddle.fluid.backward
import
append_backward
class
TestAssignOp
(
op_test
.
OpTest
):
...
...
@@ -51,6 +52,36 @@ class TestAssignFP16Op(op_test.OpTest):
self
.
check_grad
([
'X'
],
'Out'
)
class
TestAssignOpWithLoDTensorArray
(
unittest
.
TestCase
):
def
test_assign_LoDTensorArray
(
self
):
main_program
=
Program
()
startup_program
=
Program
()
with
program_guard
(
main_program
):
x
=
fluid
.
data
(
name
=
'x'
,
shape
=
[
100
,
10
],
dtype
=
'float32'
)
x
.
stop_gradient
=
False
y
=
fluid
.
layers
.
fill_constant
(
shape
=
[
100
,
10
],
dtype
=
'float32'
,
value
=
1
)
z
=
fluid
.
layers
.
elementwise_add
(
x
=
x
,
y
=
y
)
i
=
fluid
.
layers
.
fill_constant
(
shape
=
[
1
],
dtype
=
'int64'
,
value
=
0
)
init_array
=
fluid
.
layers
.
array_write
(
x
=
z
,
i
=
i
)
array
=
fluid
.
layers
.
assign
(
init_array
)
sums
=
fluid
.
layers
.
array_read
(
array
=
init_array
,
i
=
i
)
mean
=
fluid
.
layers
.
mean
(
sums
)
append_backward
(
mean
)
place
=
fluid
.
CUDAPlace
(
0
)
if
core
.
is_compiled_with_cuda
(
)
else
fluid
.
CPUPlace
()
exe
=
fluid
.
Executor
(
place
)
feed_x
=
np
.
random
.
random
(
size
=
(
100
,
10
)).
astype
(
'float32'
)
ones
=
np
.
ones
((
100
,
10
)).
astype
(
'float32'
)
feed_add
=
feed_x
+
ones
res
=
exe
.
run
(
main_program
,
feed
=
{
'x'
:
feed_x
},
fetch_list
=
[
sums
.
name
,
x
.
grad_name
])
self
.
assertTrue
(
np
.
allclose
(
res
[
0
],
feed_add
))
self
.
assertTrue
(
np
.
allclose
(
res
[
1
],
ones
/
1000.0
))
class
TestAssignOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录