Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
2ddb1122
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
2ddb1122
编写于
8月 10, 2017
作者:
D
dongzhihong
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
"on hold"
上级
56faf513
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
55 addition
and
5 deletion
+55
-5
paddle/operators/mul_op.cc
paddle/operators/mul_op.cc
+23
-4
paddle/operators/mul_op.cu
paddle/operators/mul_op.cu
+2
-1
paddle/operators/mul_op.h
paddle/operators/mul_op.h
+28
-0
python/paddle/v2/framework/tests/test_mul_op.py
python/paddle/v2/framework/tests/test_mul_op.py
+2
-0
未找到文件。
paddle/operators/mul_op.cc
浏览文件 @
2ddb1122
...
...
@@ -54,10 +54,27 @@ The equation is: Out = X * Y
class
MulOpGrad
:
public
framework
::
OperatorWithKernel
{
protected:
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{}
std
::
string
DebugString
()
const
override
{
LOG
(
INFO
)
<<
"MulGrad"
;
return
""
;
void
InferShape
(
const
framework
::
InferShapeContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
ctx
.
InputSize
(),
3UL
,
"Input of MulOpGrad should be 3, X, Y, Out@GRAD"
);
PADDLE_ENFORCE_EQ
(
ctx
.
OutputSize
(),
2UL
,
"Output of MulOpGrad should be 2, X@GRAD, Y@GRAD"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"X"
),
"Input(X) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE_NOT_NULL
(
ctx
.
InputVar
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
*
x_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
y_grad
=
ctx
.
Output
<
framework
::
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
auto
dim0
=
ctx
.
Input
<
Tensor
>
(
0
)
->
dims
();
auto
dim1
=
ctx
.
Input
<
Tensor
>
(
1
)
->
dims
();
auto
out_dims
=
ctx
.
Input
<
Tensor
>
(
2
)
->
dims
();
PADDLE_ENFORCE
(
dim0
[
0
]
*
dim1
[
0
]
==
out_dims
[
0
],
"Out@GRAD[0] must equal to X[0] * Y[0]"
);
PADDLE_ENFORCE
(
dim0
[
1
]
*
dim1
[
1
]
==
out_dims
[
1
],
"Out@GRAD shape must equal to X[1] * Y[1]"
);
x_grad
->
Resize
(
dim1
);
y_grad
->
Resize
(
dim0
);
}
};
...
...
@@ -69,3 +86,5 @@ REGISTER_OP(mul, ops::MulOp, ops::MulOpMaker);
REGISTER_GRADIENT_OP
(
mul
,
mul_grad
,
ops
::
MulOpGrad
);
REGISTER_OP_CPU_KERNEL
(
mul
,
ops
::
MulKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
REGISTER_OP_CPU_KERNEL
(
mul_grad
,
ops
::
MulGradKernel
<
paddle
::
platform
::
CPUPlace
,
float
>
);
paddle/operators/mul_op.cu
浏览文件 @
2ddb1122
...
...
@@ -16,5 +16,6 @@
#include "paddle/operators/mul_op.h"
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_GPU_KERNEL
(
mul
,
ops
::
MulKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
REGISTER_OP_GPU_KERNEL
(
mul_grad
,
ops
::
MulGradKernel
<
paddle
::
platform
::
GPUPlace
,
float
>
);
paddle/operators/mul_op.h
浏览文件 @
2ddb1122
...
...
@@ -46,5 +46,33 @@ class MulKernel : public framework::OpKernel {
}
};
template
<
typename
Place
,
typename
T
>
class
MulGradKernel
:
public
framework
::
OpKernel
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input0
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
input1
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
input2
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
output0
=
ctx
.
Output
<
Tensor
>
(
0
);
auto
*
output1
=
ctx
.
Output
<
Tensor
>
(
1
);
output0
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
output1
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
auto
X
=
EigenMatrix
<
T
>::
From
(
*
input0
);
auto
Y
=
EigenMatrix
<
T
>::
From
(
*
input1
);
auto
dOut
=
EigenMatrix
<
T
>::
From
(
*
input2
);
auto
dX
=
EigenMatrix
<
T
>::
From
(
*
output0
);
auto
dY
=
EigenMatrix
<
T
>::
From
(
*
output1
);
// dX = Out@G * Y'
// dY = X' * Out@G
auto
place
=
ctx
.
GetEigenDevice
<
Place
>
();
// TODO(dzh,qijun) : need transpose feature of blas library
// Eigen Tensor does not support it very well
// dX.device(place) = dOut.contract(dOut, transpose)
}
};
}
// namespace operators
}
// namespace paddle
python/paddle/v2/framework/tests/test_mul_op.py
浏览文件 @
2ddb1122
...
...
@@ -15,5 +15,7 @@ class TestMulOp(unittest.TestCase):
self
.
outputs
=
{
'Out'
:
np
.
dot
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
# TODO(dzh,qijun) : mulgrad test case need transpose feature of blas library
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录