Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
0e7baabe
P
Paddle
项目概览
PaddlePaddle
/
Paddle
11 个月 前同步成功
通知
2291
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
提交
0e7baabe
编写于
11月 19, 2019
作者:
D
danleifeng
提交者:
gongweibao
11月 19, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
extend elementwise broadcast function (#20957)
上级
d623e863
变更
21
隐藏空白更改
内联
并排
Showing
21 changed file
with
1074 addition
and
503 deletion
+1074
-503
paddle/fluid/operators/elementwise/elementwise_add_op.cc
paddle/fluid/operators/elementwise/elementwise_add_op.cc
+2
-2
paddle/fluid/operators/elementwise/elementwise_add_op.h
paddle/fluid/operators/elementwise/elementwise_add_op.h
+10
-4
paddle/fluid/operators/elementwise/elementwise_div_op.cc
paddle/fluid/operators/elementwise/elementwise_div_op.cc
+1
-0
paddle/fluid/operators/elementwise/elementwise_div_op.h
paddle/fluid/operators/elementwise/elementwise_div_op.h
+12
-10
paddle/fluid/operators/elementwise/elementwise_mul_op.h
paddle/fluid/operators/elementwise/elementwise_mul_op.h
+8
-2
paddle/fluid/operators/elementwise/elementwise_op.h
paddle/fluid/operators/elementwise/elementwise_op.h
+39
-101
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
.../fluid/operators/elementwise/elementwise_op_function.cu.h
+6
-0
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+828
-352
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
+4
-3
paddle/fluid/operators/elementwise/elementwise_sub_op.h
paddle/fluid/operators/elementwise/elementwise_sub_op.h
+10
-4
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
+5
-2
paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
...operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
+6
-4
paddle/fluid/operators/layer_norm_op.h
paddle/fluid/operators/layer_norm_op.h
+1
-15
paddle/fluid/operators/op_debug_string_test.cc
paddle/fluid/operators/op_debug_string_test.cc
+5
-0
paddle/fluid/train/demo/run.sh
paddle/fluid/train/demo/run.sh
+2
-2
python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_add_op.py
+30
-0
python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_div_op.py
+33
-0
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
+35
-0
python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
+35
-0
python/paddle/fluid/tests/unittests/test_executor_return_tensor_not_overwriting.py
.../unittests/test_executor_return_tensor_not_overwriting.py
+1
-1
python/paddle/fluid/tests/unittests/test_fl_listen_and_serv_op.py
...addle/fluid/tests/unittests/test_fl_listen_and_serv_op.py
+1
-1
未找到文件。
paddle/fluid/operators/elementwise/elementwise_add_op.cc
浏览文件 @
0e7baabe
...
...
@@ -99,8 +99,8 @@ REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(elementwise_add, Add);
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
elementwise_add_grad
,
ops
::
ElementwiseOp
ExplicitGrad
,
ops
::
ElementwiseGrad
OpInplace
,
ops
::
ElementwiseGrad
NoBufVarsInference
,
elementwise_add_grad
,
ops
::
ElementwiseOp
Grad
,
ops
::
ElementwiseGradOpInplace
,
ops
::
ElementwiseGradNoBufVarsInference
,
ops
::
ElementwiseAddDoubleGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ElementwiseAddDoubleGradMaker
<
paddle
::
imperative
::
OpBase
>
);
...
...
paddle/fluid/operators/elementwise/elementwise_add_op.h
浏览文件 @
0e7baabe
...
...
@@ -25,8 +25,13 @@ void default_elementwise_add(const framework::ExecutionContext &ctx,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
AddFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
AddFunctor
<
T
>
(),
z
);
if
(
x
->
numel
()
>=
y
->
numel
())
{
ElementwiseComputeEx
<
AddFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
AddFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseAddFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseAddFunctor
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
...
...
@@ -128,12 +133,13 @@ class ElementwiseAddGradKernel : public ElemwiseGradKernel<T> {
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
// skip out
, x, y
// skip out
auto
*
out
=
dout
;
auto
*
x
=
dout
,
*
y
=
dout
;
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_add_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.cc
浏览文件 @
0e7baabe
...
...
@@ -76,6 +76,7 @@ class ElementwiseDivGradOpMaker : public framework::SingleGradOpMaker<T> {
std
::
unique_ptr
<
T
>
Apply
()
const
override
{
std
::
unique_ptr
<
T
>
op
(
new
T
());
op
->
SetType
(
"elementwise_div_grad"
);
op
->
SetInput
(
"X"
,
this
->
Input
(
"X"
));
op
->
SetInput
(
"Y"
,
this
->
Input
(
"Y"
));
op
->
SetInput
(
"Out"
,
this
->
Output
(
"Out"
));
op
->
SetInput
(
framework
::
GradVarName
(
"Out"
),
this
->
OutputGrad
(
"Out"
));
...
...
paddle/fluid/operators/elementwise/elementwise_div_op.h
浏览文件 @
0e7baabe
...
...
@@ -31,8 +31,13 @@ void default_elementwise_div(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
DivFunctor
<
T
>
(),
z
);
if
(
x
->
numel
()
>=
y
->
numel
())
{
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
DivFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseDivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseDivFunctor
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
...
...
@@ -112,13 +117,13 @@ class ElementwiseDivGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel
<
T
>::
Compute
(
ctx
);
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
out
=
ctx
.
Input
<
Tensor
>
(
"Out"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
*
x
=
dout
;
// Fake x, not used
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_div_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
...
...
@@ -191,7 +196,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
// ddX_safe == null ? 0 : ddX
// ddY_safe == null ? 0 : ddY
Tensor
ddX_safe
,
ddY_safe
;
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
Out
,
ddX
,
&
ddX_safe
);
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
dX
,
ddX
,
&
ddX_safe
);
GetDoubleGradSafeTensor
<
DeviceContext
,
T
>
(
ctx
,
Y
,
ddY
,
&
ddY_safe
);
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
...
...
@@ -209,8 +214,7 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if
(
dY
)
{
// dX_div_Y = dX / Y;
Tensor
dX_div_Y
=
tmp
;
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
dX
,
Y
,
axis
,
DivFunctor
<
T
>
(),
&
dX_div_Y
);
default_elementwise_div
<
DeviceContext
,
T
>
(
ctx
,
dX
,
Y
,
&
dX_div_Y
);
// NOTE(dengkaipeng): in the following ElemwiseGradCompute, for the
// first output tensor is nullptr, the branch to calculate first
...
...
@@ -227,10 +231,8 @@ class ElementwiseDivDoubleGradKernel : public framework::OpKernel<T> {
if
(
ddOut
)
{
// ddOut = ddX / Y - Out * ddY / Y = (ddX - Out * ddY) / Y
default_elementwise_mul
<
DeviceContext
,
T
>
(
ctx
,
Out
,
&
ddY_safe
,
&
tmp
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
ddX_safe
,
&
tmp
,
0
,
SubFunctor
<
T
>
(),
&
tmp
);
ElementwiseComputeEx
<
DivFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
&
tmp
,
Y
,
axis
,
DivFunctor
<
T
>
(),
ddOut
);
default_elementwise_sub
<
DeviceContext
,
T
>
(
ctx
,
&
ddX_safe
,
&
tmp
,
&
tmp
);
default_elementwise_div
<
DeviceContext
,
T
>
(
ctx
,
&
tmp
,
Y
,
ddOut
);
}
if
(
dOut
)
{
...
...
paddle/fluid/operators/elementwise/elementwise_mul_op.h
浏览文件 @
0e7baabe
...
...
@@ -26,9 +26,15 @@ void default_elementwise_mul(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
MulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
MulFunctor
<
T
>
(),
z
);
if
(
x
->
numel
()
>=
y
->
numel
())
{
ElementwiseComputeEx
<
MulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
MulFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseMulFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseMulFunctor
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
struct
SameDimsElemwiseMul
{
void
operator
()(
const
framework
::
ExecutionContext
&
ctx
,
...
...
paddle/fluid/operators/elementwise/elementwise_op.h
浏览文件 @
0e7baabe
...
...
@@ -14,12 +14,15 @@ limitations under the License. */
#pragma once
#include <algorithm> // for max
#include <memory>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h"
...
...
@@ -35,12 +38,12 @@ class ElementwiseOp : public framework::OperatorWithKernel {
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"X"
)
,
"Input(X) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
)
,
"Input(Y) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
HasOutput
(
"Out"
)
,
"Output(Out) of elementwise op should not be null."
);
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInput
(
"X"
),
true
,
"Input(X) of elementwise op should not be null."
);
PADDLE_ENFORCE
_EQ
(
ctx
->
HasInput
(
"Y"
),
true
,
"Input(Y) of elementwise op should not be null."
);
PADDLE_ENFORCE
_EQ
(
ctx
->
HasOutput
(
"Out"
),
true
,
"Output(Out) of elementwise op should not be null."
);
PADDLE_ENFORCE
(
ctx
->
GetInputsVarType
(
"Y"
).
front
()
==
...
...
@@ -49,18 +52,7 @@ class ElementwiseOp : public framework::OperatorWithKernel {
ctx
->
GetInputsVarType
(
"Y"
).
front
(),
ctx
->
Inputs
(
"Y"
).
front
());
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
==
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
auto
x_dim
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dim
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_GE
(
x_dim
.
size
(),
y_dim
.
size
(),
"ShapeError: the dimension of input X must greater than or equal to "
"the one of input Y. But received: the shape of input X = [%s], the "
"dimension of input X = %d, the shape of input Y = [%s], the "
"dimension of input Y = %d"
,
x_dim
,
x_dim
.
size
(),
y_dim
,
y_dim
.
size
());
}
else
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
==
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
framework
::
proto
::
VarType
::
SELECTED_ROWS
)
{
PADDLE_ENFORCE_EQ
(
ctx
->
GetInputDim
(
"Y"
).
size
(),
1u
,
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
...
...
@@ -71,13 +63,31 @@ class ElementwiseOp : public framework::OperatorWithKernel {
"ShapeError: For elementwise_op, if X is Sparse(VarType.SELECTED_ROWS"
"), Y must be scalar. But reveived the first dimension of Y = %s"
,
ctx
->
GetInputDim
(
"Y"
)[
0
]);
}
else
{
}
else
if
(
ctx
->
GetInputsVarType
(
"X"
).
front
()
!=
framework
::
proto
::
VarType
::
LOD_TENSOR
)
{
PADDLE_THROW
(
"X's type[%s] is not supported by elementwise_op."
,
ctx
->
GetInputsVarType
(
"X"
).
front
());
}
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
if
(
ctx
->
GetInputDim
(
"X"
)
==
ctx
->
GetInputDim
(
"Y"
))
{
ctx
->
ShareDim
(
"X"
,
/*->*/
"Out"
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
else
{
auto
x_dims
=
ctx
->
GetInputDim
(
"X"
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
int
max_dim
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
int
axis
=
ctx
->
Attrs
().
Get
<
int
>
(
"axis"
);
axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
std
::
vector
<
int
>
x_dims_array
(
max_dim
);
std
::
vector
<
int
>
y_dims_array
(
max_dim
);
std
::
vector
<
int
>
out_dims_array
(
max_dim
);
GetBroadcastDimsArrays
(
x_dims
,
y_dims
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
axis
);
ctx
->
SetOutputDim
(
"Out"
,
framework
::
make_ddim
(
out_dims_array
));
// to do
ctx
->
ShareLoD
(
"X"
,
/*->*/
"Out"
);
}
}
framework
::
OpKernelType
GetExpectedKernelType
(
...
...
@@ -207,26 +217,14 @@ class ElementwiseOpGrad : public framework::OperatorWithKernel {
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
auto
out_grad_name
=
framework
::
GradVarName
(
"Out"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
PADDLE_ENFORCE
(
ctx
->
HasInput
(
out_grad_name
),
"Input(Out@GRAD) should not be null"
);
auto
x_dims
=
ctx
->
GetInputDim
(
out_grad_name
);
auto
y_dims
=
ctx
->
GetInputDim
(
"Y"
);
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
y_dims
.
size
(),
"ShapeError: the dimension of Out@GRAD must greater than or equal to "
"the one of input Y. But received: the shape of Out@GRAD = [%s], the "
"dimension of Out@GRAD = %d, the shape of input Y = [%s], the "
"dimension of of input Y = %d"
,
x_dims
,
x_dims
.
size
(),
y_dims
,
y_dims
.
size
());
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
"Y"
),
true
,
"Input(Y) should not be null."
);
PADDLE_ENFORCE_EQ
(
ctx
->
HasInput
(
out_grad_name
),
true
,
"Input(Out@GRAD) should not be null."
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
ShareDim
(
out_grad_name
,
/*->*/
x_grad_name
);
ctx
->
ShareLoD
(
out_grad_name
,
/*->*/
x_grad_name
);
ctx
->
ShareDim
(
"X"
,
/*->*/
x_grad_name
);
ctx
->
ShareLoD
(
"X"
,
/*->*/
x_grad_name
);
}
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
...
...
@@ -326,32 +324,6 @@ class ElementwiseOpDoubleGradWithoutDXDY
}
};
// For Add, Sub op, the X, Out is not needed.
class
ElementwiseOpExplicitGrad
:
public
ElementwiseOpGrad
{
public:
using
operators
::
ElementwiseOpGrad
::
ElementwiseOpGrad
;
using
operators
::
ElementwiseOpGrad
::
GetExpectedKernelType
;
using
Tensor
=
framework
::
Tensor
;
void
InferShape
(
framework
::
InferShapeContext
*
ctx
)
const
override
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
framework
::
GradVarName
(
"Out"
)),
"Input(Out@GRAD) should not be null"
);
auto
x_grad_name
=
framework
::
GradVarName
(
"X"
);
if
(
ctx
->
HasOutput
(
x_grad_name
))
{
ctx
->
ShareDim
(
framework
::
GradVarName
(
"Out"
),
/*->*/
x_grad_name
);
ctx
->
ShareLoD
(
framework
::
GradVarName
(
"Out"
),
/*->*/
x_grad_name
);
}
auto
y_grad_name
=
framework
::
GradVarName
(
"Y"
);
if
(
ctx
->
HasOutput
(
y_grad_name
))
{
PADDLE_ENFORCE
(
ctx
->
HasInput
(
"Y"
),
"Input(Y) should not be null"
);
ctx
->
ShareDim
(
"Y"
,
/*->*/
y_grad_name
);
ctx
->
ShareLoD
(
"Y"
,
/*->*/
y_grad_name
);
}
}
};
template
<
typename
T
>
class
ElemwiseGradKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
...
...
@@ -372,13 +344,13 @@ DECLARE_INPLACE_OP_INFERER(ElementwiseGradOpInplace,
framework
::
GradVarName
(
"X"
)});
DECLARE_INPLACE_OP_INFERER
(
ElementwiseDoubleGradOpInplace
,
{
"DDX"
,
"DDOut"
});
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE
(
ElementwiseGradNoBufVarsInference
,
"Y"
);
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE
(
ElementwiseGradNoBufVarsInference
,
"X"
,
"Y"
);
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE
(
ElementwiseDoubleGradNoBufVarsInference
,
"Y"
,
"DOut"
);
}
// namespace operators
}
// namespace paddle
#define REGISTER_ELEMWISE_GRAD_MAKER(kernel_type, op_name) \
template <typename T> \
class kernel_type##GradMaker \
...
...
@@ -390,6 +362,7 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
std::unique_ptr<T> Apply() const override { \
auto *op = new T(); \
op->SetType(#kernel_type "_grad"); \
op->SetInput("X", this->Input("X")); \
op->SetInput("Y", this->Input("Y")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
this->OutputGrad("Out")); \
...
...
@@ -402,41 +375,6 @@ DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ElementwiseDoubleGradNoBufVarsInference,
} \
}
#define REGISTER_ELEMWISE_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR( \
op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
::paddle::framework::DefaultGradOpMaker<::paddle::framework::OpDesc, \
true>, \
::paddle::framework::DefaultGradOpMaker<::paddle::imperative::OpBase, \
true>); \
REGISTER_OPERATOR(op_type##_grad, ::paddle::operators::ElementwiseOpGrad)
#define REGISTER_ELEMWISE_EXPLICIT_OP(op_type, op_name, equation) \
class __ElemwiseOp##op_type##Maker__ \
: public ::paddle::operators::ElementwiseOpMaker { \
protected: \
virtual std::string GetName() const { return op_name; } \
virtual std::string GetEquation() const { return equation; } \
}; \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
__ElemwiseOp##op_type##Maker__, \
::paddle::operators::ElementwiseOpInferVarType, \
op_type##GradMaker<::paddle::framework::OpDesc>, \
op_type##GradMaker<::paddle::imperative::OpBase>, \
::paddle::operators::ElementwiseOpInplace); \
REGISTER_OPERATOR(op_type##_grad, \
::paddle::operators::ElementwiseOpExplicitGrad, \
::paddle::operators::ElementwiseGradOpInplace, \
::paddle::operators::ElementwiseGradNoBufVarsInference)
#define REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD(op_type, op_name) \
REGISTER_OPERATOR(op_type, ::paddle::operators::ElementwiseOp, \
::paddle::operators::Elementwise##op_name##OpMaker, \
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.cu.h
浏览文件 @
0e7baabe
...
...
@@ -44,6 +44,12 @@ namespace operators {
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return a expr b; \
} \
}; \
template <typename T, class Enable = void> \
struct Inverse##Func##Functor { \
inline HOSTDEVICE T operator()(const T& a, const T& b) const { \
return b expr a; \
} \
};
DEFINE_SIMPLE_BINARY_FUNCTOR
(
Add
,
+
)
...
...
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
0e7baabe
...
...
@@ -16,11 +16,15 @@ limitations under the License. */
#include <glog/logging.h>
#include <algorithm>
#include <functional> // for multiplies
#include <iterator>
#include <vector>
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/memory/malloc.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/platform/gpu_info.h"
#include "paddle/fluid/platform/transform.h"
#ifdef __NVCC__
...
...
@@ -33,6 +37,12 @@ constexpr int ELEMWISE_MAX_BLOCK_DIM = 1024;
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/platform/for_range.h"
#define GetDivMod(dividend, divisor, div, mod) \
do { \
const auto dividend_copy = dividend; \
*div = dividend_copy / divisor; \
*mod = dividend_copy % divisor; \
} while (0)
namespace
paddle
{
namespace
operators
{
...
...
@@ -48,72 +58,453 @@ namespace operators {
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *mid_flag* is added to solve m*n*k & m*1*k
* broadcast cases.
* 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5)
* mid_flag should not be NULL.
* x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20)
* New parameter: *is_run_common_broadcast* is a flag to record whether to run
* common broadcast code.
*/
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
int
*
pre
,
int
*
n
,
int
*
post
,
int
*
mid_flag
=
NULL
)
{
int
*
pre
,
int
*
n
,
int
*
post
,
int
*
is_run_common_broadcast
)
{
*
pre
=
1
;
*
n
=
1
;
*
post
=
1
;
if
(
mid_flag
!=
NULL
)
{
*
mid_flag
=
0
;
int
mid
=
0
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
if
(
x_dims
[
i
+
axis
]
!=
y_dims
[
i
])
{
// only support single y_dims[i] = 1 now.
PADDLE_ENFORCE_EQ
(
*
mid_flag
,
0
,
"Broadcast support y_dims with single 1."
);
PADDLE_ENFORCE_EQ
(
y_dims
[
i
],
1
,
"ShapeError: broadcast dimension mismatch. Operands "
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] "
"in X is not equal to [%d] in Y"
,
x_dims
,
y_dims
,
x_dims
[
i
+
axis
],
y_dims
[
i
]);
// m*n*k m*1*k
for
(
int
j
=
0
;
j
<
i
;
++
j
)
{
(
*
pre
)
*=
y_dims
[
j
];
}
*
n
=
std
::
max
(
x_dims
[
i
+
axis
],
y_dims
[
i
]);
*
mid_flag
=
1
;
mid
=
i
;
break
;
}
(
*
n
)
*=
y_dims
[
i
];
*
is_run_common_broadcast
=
0
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
if
(
x_dims
[
i
+
axis
]
!=
y_dims
[
i
])
{
PADDLE_ENFORCE
(
y_dims
[
i
]
==
1
||
x_dims
[
i
+
axis
]
==
1
,
"ShapeError: broadcast dimension mismatch. Operands "
"could not be broadcast together with the shape of "
"X = [%s] and the shape of Y = [%s]. Received [%d] "
"in X is not equal to [%d] in Y"
,
x_dims
,
y_dims
,
x_dims
[
i
+
axis
],
y_dims
[
i
]);
*
is_run_common_broadcast
=
1
;
return
;
}
if
(
*
mid_flag
)
{
for
(
int
i
=
mid
+
1
;
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
(
*
n
)
*=
y_dims
[
i
];
}
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
}
inline
int
GetElementwiseIndex
(
const
int
*
x_dims_array
,
const
int
max_dim
,
const
int
*
index_array
)
{
int
index_
=
0
;
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
if
(
x_dims_array
[
i
]
>
1
)
{
index_
=
index_
*
x_dims_array
[
i
]
+
index_array
[
i
];
}
}
return
index_
;
}
inline
void
UpdateElementwiseIndexArray
(
const
int
*
out_dims_array
,
const
int
max_dim
,
int
*
index_array
)
{
for
(
int
i
=
max_dim
-
1
;
i
>=
0
;
--
i
)
{
++
index_array
[
i
];
if
(
index_array
[
i
]
>=
out_dims_array
[
i
])
{
index_array
[
i
]
-=
out_dims_array
[
i
];
}
else
{
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
break
;
}
}
}
inline
void
GetBroadcastDimsArrays
(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
const
int
max_dim
,
const
int
axis
)
{
PADDLE_ENFORCE_GE
(
axis
,
0
,
"Axis should be in range [0, %d)"
,
axis
);
PADDLE_ENFORCE_LT
(
axis
,
max_dim
,
"Axis should be in range [0, %d)"
,
axis
);
if
(
x_dims
.
size
()
>
y_dims
.
size
())
{
std
::
fill
(
y_dims_array
,
y_dims_array
+
axis
,
1
);
if
(
axis
+
y_dims
.
size
()
<
max_dim
)
{
std
::
fill
(
y_dims_array
+
axis
+
y_dims
.
size
(),
y_dims_array
+
max_dim
,
1
);
}
}
else
{
// for fused_elementwise_activation_op. keep the old version.
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
std
::
copy
(
x_dims
.
Get
(),
x_dims
.
Get
()
+
x_dims
.
size
(),
x_dims_array
);
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
+
axis
);
}
else
{
std
::
fill
(
x_dims_array
,
x_dims_array
+
axis
,
1
);
if
(
axis
+
x_dims
.
size
()
<
max_dim
)
{
std
::
fill
(
x_dims_array
+
axis
+
x_dims
.
size
(),
x_dims_array
+
max_dim
,
1
);
}
std
::
copy
(
x_dims
.
Get
(),
x_dims
.
Get
()
+
x_dims
.
size
(),
x_dims_array
+
axis
);
std
::
copy
(
y_dims
.
Get
(),
y_dims
.
Get
()
+
y_dims
.
size
(),
y_dims_array
);
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
],
"Broadcast dimension mismatch."
);
(
*
n
)
*=
y_dims
[
i
];
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
PADDLE_ENFORCE
(
x_dims_array
[
i
]
==
y_dims_array
[
i
]
||
x_dims_array
[
i
]
<=
1
||
y_dims_array
[
i
]
<=
1
,
"ShapeError: broadcast dimension mismatch. Operands could "
"not be broadcast together with the shape of X = [%s] and "
"the shape of Y = [%s]. Received [%d] in X is not equal to "
"[%d] in Y"
,
x_dims
,
y_dims
,
x_dims_array
[
i
],
y_dims_array
[
i
]);
if
(
x_dims_array
[
i
]
==
-
1
||
y_dims_array
[
i
]
==
-
1
)
{
out_dims_array
[
i
]
=
-
1
;
}
else
{
out_dims_array
[
i
]
=
std
::
max
(
x_dims_array
[
i
],
y_dims_array
[
i
]);
}
}
}
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
template
<
typename
Functor
,
typename
T
,
typename
OutType
=
T
>
void
CommonForwardBroadcastCPU
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
int
max_dim
,
const
platform
::
CPUDeviceContext
&
ctx
,
Functor
func
,
const
bool
is_xsize_larger
=
true
)
{
std
::
vector
<
int
>
index_array
(
max_dim
,
0
);
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
y_data
=
y
->
data
<
T
>
();
OutType
*
out_data
=
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
());
const
int
out_size
=
std
::
accumulate
(
out_dims_array
,
out_dims_array
+
max_dim
,
1
,
std
::
multiplies
<
int
>
());
int
x_index
,
y_index
;
for
(
int
out_index
=
0
;
out_index
<
out_size
;
++
out_index
)
{
x_index
=
GetElementwiseIndex
(
x_dims_array
,
max_dim
,
index_array
.
data
());
y_index
=
GetElementwiseIndex
(
y_dims_array
,
max_dim
,
index_array
.
data
());
if
(
is_xsize_larger
)
{
out_data
[
out_index
]
=
func
(
x_data
[
x_index
],
y_data
[
y_index
]);
}
else
{
out_data
[
out_index
]
=
func
(
y_data
[
y_index
],
x_data
[
x_index
]);
}
UpdateElementwiseIndexArray
(
out_dims_array
,
max_dim
,
index_array
.
data
());
}
}
#ifdef __NVCC__
template
<
typename
Functor
,
typename
T
>
__global__
void
CommonForwardBroadcastCUDAKernel
(
const
int
*
x_strides_array
,
const
int
*
y_strides_array
,
const
int
*
out_dims_array
,
const
T
*
x
,
const
T
*
y
,
T
*
out
,
int
out_size
,
int
max_dim
,
Functor
func
,
const
bool
is_xsize_larger
)
{
for
(
int
out_index
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
out_index
<
out_size
;
out_index
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
x_index
=
0
;
int
y_index
=
0
;
int
out_index_quotient
=
out_index
;
int
remainder
=
0
;
#pragma unroll
for
(
int
i
=
max_dim
-
1
;
i
>=
0
;
--
i
)
{
GetDivMod
(
out_index_quotient
,
out_dims_array
[
i
],
&
out_index_quotient
,
&
remainder
);
x_index
+=
remainder
*
x_strides_array
[
i
];
y_index
+=
remainder
*
y_strides_array
[
i
];
}
if
(
is_xsize_larger
)
{
out
[
out_index
]
=
func
(
x
[
x_index
],
y
[
y_index
]);
}
else
{
out
[
out_index
]
=
func
(
y
[
y_index
],
x
[
x_index
]);
}
}
}
template
<
typename
Functor
,
typename
T
>
void
CommonForwardBroadcastCUDA
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
int
max_dim
,
const
platform
::
CUDADeviceContext
&
ctx
,
Functor
func
,
const
bool
is_xsize_larger
=
true
)
{
const
auto
gplace
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
auto
cplace
=
platform
::
CPUPlace
();
const
T
*
x_data
=
x
->
data
<
T
>
();
const
T
*
y_data
=
y
->
data
<
T
>
();
T
*
out_data
=
z
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
vector
<
int
>
x_strides_array
(
max_dim
);
std
::
vector
<
int
>
y_strides_array
(
max_dim
);
int
x_stride
=
1
;
int
y_stride
=
1
;
for
(
int
i
=
max_dim
-
1
;
i
>=
0
;
i
--
)
{
x_strides_array
[
i
]
=
x_dims_array
[
i
]
==
1
?
0
:
x_stride
;
y_strides_array
[
i
]
=
y_dims_array
[
i
]
==
1
?
0
:
y_stride
;
x_stride
*=
x_dims_array
[
i
];
y_stride
*=
y_dims_array
[
i
];
}
int
bytes
=
max_dim
*
sizeof
(
int
);
auto
x_strides_array_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
x_strides_array_gpu
=
reinterpret_cast
<
int
*>
(
x_strides_array_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
x_strides_array_gpu
,
cplace
,
x_strides_array
.
data
(),
bytes
,
ctx
.
stream
());
auto
y_strides_array_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
y_strides_array_gpu
=
reinterpret_cast
<
int
*>
(
y_strides_array_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
y_strides_array_gpu
,
cplace
,
y_strides_array
.
data
(),
bytes
,
ctx
.
stream
());
auto
out_dims_array_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
out_dims_array_gpu
=
reinterpret_cast
<
int
*>
(
out_dims_array_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
out_dims_array_gpu
,
cplace
,
out_dims_array
,
bytes
,
ctx
.
stream
());
const
int
out_size
=
std
::
accumulate
(
out_dims_array
,
out_dims_array
+
max_dim
,
1
,
std
::
multiplies
<
int
>
());
dim3
gird_size
=
dim3
(
(
out_size
+
PADDLE_CUDA_THREAD_SIZE
-
1
)
/
PADDLE_CUDA_THREAD_SIZE
,
1
);
dim3
block_size
=
dim3
(
PADDLE_CUDA_THREAD_SIZE
,
1
);
CommonForwardBroadcastCUDAKernel
<
Functor
,
T
><<<
gird_size
,
block_size
,
0
,
ctx
.
stream
()
>>>
(
x_strides_array_gpu
,
y_strides_array_gpu
,
out_dims_array_gpu
,
x_data
,
y_data
,
out_data
,
out_size
,
max_dim
,
func
,
is_xsize_larger
);
}
#endif // __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
CommonGradBroadcastCPU
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
int
max_dim
,
const
platform
::
CPUDeviceContext
&
ctx
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
std
::
vector
<
int
>
index_array
(
max_dim
,
0
);
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
y_data
=
y
.
data
<
T
>
();
const
T
*
out_data
=
out
.
data
<
T
>
();
const
T
*
dout_data
=
dout
.
data
<
T
>
();
T
*
dx_data
=
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
dy_data
=
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
dx_data
!=
nullptr
)
{
memset
(
dx_data
,
0
,
dx
->
numel
()
*
sizeof
(
T
));
}
if
(
dy_data
!=
nullptr
)
{
memset
(
dy_data
,
0
,
dy
->
numel
()
*
sizeof
(
T
));
}
const
int
out_size
=
std
::
accumulate
(
out_dims_array
,
out_dims_array
+
max_dim
,
1
,
std
::
multiplies
<
int
>
());
int
x_index
,
y_index
;
for
(
int
out_index
=
0
;
out_index
<
out_size
;
++
out_index
)
{
x_index
=
GetElementwiseIndex
(
x_dims_array
,
max_dim
,
index_array
.
data
());
y_index
=
GetElementwiseIndex
(
y_dims_array
,
max_dim
,
index_array
.
data
());
if
(
dx_data
!=
nullptr
)
{
dx_data
[
x_index
]
+=
dx_op
(
x_data
[
x_index
],
y_data
[
y_index
],
out_data
[
out_index
],
dout_data
[
out_index
]);
}
if
(
dy_data
!=
nullptr
)
{
dy_data
[
y_index
]
+=
dy_op
(
x_data
[
x_index
],
y_data
[
y_index
],
out_data
[
out_index
],
dout_data
[
out_index
]);
}
UpdateElementwiseIndexArray
(
out_dims_array
,
max_dim
,
index_array
.
data
());
}
}
inline
void
ComputeBroadcastKernelSize
(
int
*
x_dims_array
,
int
*
out_dims_array
,
int
*
x_blocks
,
int
*
x_threads
,
int
max_dim
)
{
*
x_blocks
=
1
;
*
x_threads
=
1
;
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
if
(
x_dims_array
[
i
]
==
out_dims_array
[
i
])
{
*
x_blocks
*=
x_dims_array
[
i
];
}
else
{
*
x_threads
*=
out_dims_array
[
i
];
}
}
}
inline
void
ComputeBroadcastTranspositionArray
(
const
int
*
x_one_indexs
,
int
*
x_trans_indexs
,
const
int
max_dim
,
const
int
x_one_size
)
{
int
diff
=
max_dim
-
x_one_size
;
std
::
copy_n
(
x_one_indexs
,
x_one_size
,
x_trans_indexs
+
diff
);
int
p
=
0
;
int
q
=
diff
;
for
(
int
i
=
0
;
i
<
max_dim
;
++
i
)
{
if
(
q
<
max_dim
&&
i
==
x_trans_indexs
[
q
])
{
++
q
;
}
else
{
x_trans_indexs
[
p
++
]
=
i
;
}
}
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
>
__global__
void
CommonGradBroadcastCUDAKernel
(
const
int
*
x_strides_array
,
const
int
*
y_strides_array
,
const
int
*
out_dims_array
,
const
int
*
y_strides_order
,
const
int
*
y_dims_order
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
T
*
dx
,
int
out_size
,
int
max_dim
,
int
thread_num
,
DX_OP
dx_op
)
{
T
val
(
0
);
int
i
=
blockIdx
.
x
;
int
tid
=
threadIdx
.
x
;
for
(
int
j
=
tid
;
j
<
thread_num
;
j
+=
blockDim
.
x
)
{
const
int
X_index
=
i
*
thread_num
+
j
;
int
out_index
=
X_index
;
int
C_index
=
0
;
int
B_index
=
i
*
thread_num
+
j
;
int
remainder
=
0
;
#pragma unroll
for
(
int
d
=
max_dim
-
1
;
d
>=
0
;
--
d
)
{
GetDivMod
(
B_index
,
y_dims_order
[
d
],
&
B_index
,
&
remainder
);
C_index
+=
remainder
*
y_strides_order
[
d
];
}
int
x_index
=
0
;
int
y_index
=
0
;
int
C_index_val
=
C_index
;
#pragma unroll
for
(
int
d
=
max_dim
-
1
;
d
>=
0
;
--
d
)
{
GetDivMod
(
C_index_val
,
out_dims_array
[
d
],
&
C_index_val
,
&
remainder
);
x_index
+=
remainder
*
x_strides_array
[
d
];
y_index
+=
remainder
*
y_strides_array
[
d
];
}
out_index
=
C_index
;
val
+=
dx_op
(
x
[
x_index
],
y
[
y_index
],
out
[
out_index
],
dout
[
out_index
]);
}
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
thread_num
);
if
(
threadIdx
.
x
==
0
)
{
dx
[
i
]
=
val
;
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
CommonGradBroadcastCUDA
(
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
int
*
x_dims_array
,
int
*
y_dims_array
,
int
*
out_dims_array
,
int
max_dim
,
const
platform
::
CUDADeviceContext
&
ctx
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
const
auto
gplace
=
boost
::
get
<
platform
::
CUDAPlace
>
(
ctx
.
GetPlace
());
auto
cplace
=
platform
::
CPUPlace
();
const
T
*
x_data
=
x
.
data
<
T
>
();
const
T
*
y_data
=
y
.
data
<
T
>
();
const
T
*
out_data
=
out
.
data
<
T
>
();
const
T
*
dout_data
=
dout
.
data
<
T
>
();
T
*
dx_data
=
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
T
*
dy_data
=
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
std
::
vector
<
int
>
x_one_indexs
;
std
::
vector
<
int
>
y_one_indexs
;
for
(
int
i
=
0
;
i
<
max_dim
;
i
++
)
{
if
(
x_dims_array
[
i
]
!=
y_dims_array
[
i
])
{
if
(
x_dims_array
[
i
]
==
1
)
{
x_one_indexs
.
push_back
(
i
);
}
if
(
y_dims_array
[
i
]
==
1
)
{
y_one_indexs
.
push_back
(
i
);
}
}
}
std
::
vector
<
int
>
x_trans_indexs
(
max_dim
);
std
::
vector
<
int
>
y_trans_indexs
(
max_dim
);
ComputeBroadcastTranspositionArray
(
x_one_indexs
.
data
(),
x_trans_indexs
.
data
(),
max_dim
,
x_one_indexs
.
size
());
ComputeBroadcastTranspositionArray
(
y_one_indexs
.
data
(),
y_trans_indexs
.
data
(),
max_dim
,
y_one_indexs
.
size
());
// compute array stride for cuda kernel;
// e.g. x.dims=[2,3,4], x_stride=[12,4,1]
std
::
vector
<
int
>
x_strides_array
(
max_dim
);
std
::
vector
<
int
>
y_strides_array
(
max_dim
);
std
::
vector
<
int
>
out_strides_array
(
max_dim
);
int
x_stride
=
1
;
int
y_stride
=
1
;
int
z_stride
=
1
;
for
(
int
i
=
max_dim
-
1
;
i
>=
0
;
i
--
)
{
x_strides_array
[
i
]
=
x_dims_array
[
i
]
==
1
?
0
:
x_stride
;
y_strides_array
[
i
]
=
y_dims_array
[
i
]
==
1
?
0
:
y_stride
;
out_strides_array
[
i
]
=
z_stride
;
x_stride
*=
x_dims_array
[
i
];
y_stride
*=
y_dims_array
[
i
];
z_stride
*=
out_dims_array
[
i
];
}
std
::
vector
<
int
>
x_strides_order
(
max_dim
);
std
::
vector
<
int
>
y_strides_order
(
max_dim
);
std
::
vector
<
int
>
x_dims_order
(
max_dim
);
std
::
vector
<
int
>
y_dims_order
(
max_dim
);
for
(
int
i
=
0
;
i
<
max_dim
;
++
i
)
{
x_strides_order
[
i
]
=
out_strides_array
[
x_trans_indexs
[
i
]];
y_strides_order
[
i
]
=
out_strides_array
[
y_trans_indexs
[
i
]];
x_dims_order
[
i
]
=
out_dims_array
[
x_trans_indexs
[
i
]];
y_dims_order
[
i
]
=
out_dims_array
[
y_trans_indexs
[
i
]];
}
int
x_blocks
=
0
;
int
x_threads
=
0
;
ComputeBroadcastKernelSize
(
x_dims_array
,
out_dims_array
,
&
x_blocks
,
&
x_threads
,
max_dim
);
int
y_blocks
=
0
;
int
y_threads
=
0
;
ComputeBroadcastKernelSize
(
y_dims_array
,
out_dims_array
,
&
y_blocks
,
&
y_threads
,
max_dim
);
int
bytes
=
max_dim
*
sizeof
(
int
);
auto
x_strides_array_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
x_strides_array_gpu
=
reinterpret_cast
<
int
*>
(
x_strides_array_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
x_strides_array_gpu
,
cplace
,
x_strides_array
.
data
(),
bytes
,
ctx
.
stream
());
auto
y_strides_array_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
y_strides_array_gpu
=
reinterpret_cast
<
int
*>
(
y_strides_array_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
y_strides_array_gpu
,
cplace
,
y_strides_array
.
data
(),
bytes
,
ctx
.
stream
());
auto
out_dims_array_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
out_dims_array_gpu
=
reinterpret_cast
<
int
*>
(
out_dims_array_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
out_dims_array_gpu
,
cplace
,
out_dims_array
,
bytes
,
ctx
.
stream
());
const
int
out_size
=
std
::
accumulate
(
out_dims_array
,
out_dims_array
+
max_dim
,
1
,
std
::
multiplies
<
int
>
());
int
x_block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
x_threads
);
int
y_block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
y_threads
);
if
(
dx
)
{
auto
x_strides_order_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
x_strides_order_gpu
=
reinterpret_cast
<
int
*>
(
x_strides_order_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
x_strides_order_gpu
,
cplace
,
x_strides_order
.
data
(),
bytes
,
ctx
.
stream
());
auto
x_dims_order_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
x_dims_order_gpu
=
reinterpret_cast
<
int
*>
(
x_dims_order_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
x_dims_order_gpu
,
cplace
,
x_dims_order
.
data
(),
bytes
,
ctx
.
stream
());
CommonGradBroadcastCUDAKernel
<
T
,
DX_OP
><<<
x_blocks
,
x_block_size
,
0
,
ctx
.
stream
()
>>>
(
x_strides_array_gpu
,
y_strides_array_gpu
,
out_dims_array_gpu
,
x_strides_order_gpu
,
x_dims_order_gpu
,
x_data
,
y_data
,
out_data
,
dout_data
,
dx_data
,
out_size
,
max_dim
,
x_threads
,
dx_op
);
}
if
(
dy
)
{
auto
y_strides_order_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
y_strides_order_gpu
=
reinterpret_cast
<
int
*>
(
y_strides_order_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
y_strides_order_gpu
,
cplace
,
y_strides_order
.
data
(),
bytes
,
ctx
.
stream
());
auto
y_dims_order_tmp
=
memory
::
Alloc
(
ctx
,
bytes
);
int
*
y_dims_order_gpu
=
reinterpret_cast
<
int
*>
(
y_dims_order_tmp
->
ptr
());
memory
::
Copy
(
gplace
,
y_dims_order_gpu
,
cplace
,
y_dims_order
.
data
(),
bytes
,
ctx
.
stream
());
CommonGradBroadcastCUDAKernel
<
T
,
DY_OP
><<<
y_blocks
,
y_block_size
,
0
,
ctx
.
stream
()
>>>
(
x_strides_array_gpu
,
y_strides_array_gpu
,
out_dims_array_gpu
,
y_strides_order_gpu
,
y_dims_order_gpu
,
x_data
,
y_data
,
out_data
,
dout_data
,
dy_data
,
out_size
,
max_dim
,
y_threads
,
dy_op
);
}
}
#endif // __NVCC__
inline
framework
::
DDim
trim_trailing_singular_dims
(
const
framework
::
DDim
&
dims
)
{
// Remove trailing dimensions of size 1 for y
...
...
@@ -121,7 +512,7 @@ inline framework::DDim trim_trailing_singular_dims(
for
(;
actual_dims_size
!=
0
;
--
actual_dims_size
)
{
if
(
dims
[
actual_dims_size
-
1
]
!=
1
)
break
;
}
if
(
actual_dims_size
==
dims
.
size
())
return
dims
;
std
::
vector
<
int
>
trim_dims
;
trim_dims
.
resize
(
actual_dims_size
);
for
(
int
i
=
0
;
i
<
actual_dims_size
;
++
i
)
{
...
...
@@ -287,13 +678,19 @@ template <typename Functor, typename T, typename DeviceContext,
class
TransformFunctor
{
public:
TransformFunctor
(
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
,
const
DeviceContext
&
ctx
,
Functor
func
)
framework
::
Tensor
*
z
,
const
DeviceContext
&
ctx
,
Functor
func
,
const
bool
is_xsize_larger
=
true
)
:
x_
(
x
->
data
<
T
>
()),
y_
(
y
->
data
<
T
>
()),
z_
(
z
->
mutable_data
<
OutType
>
(
ctx
.
GetPlace
())),
nx_
(
x
->
numel
()),
ctx_
(
ctx
),
func_
(
func
)
{}
func_
(
func
),
is_xsize_larger_
(
is_xsize_larger
)
{
if
(
is_xsize_larger_
==
false
)
{
nx_
=
y
->
numel
();
}
}
inline
void
Run
()
const
{
platform
::
Transform
<
DeviceContext
>
trans
;
...
...
@@ -302,22 +699,23 @@ class TransformFunctor {
inline
void
RunRowWise
(
int
n
,
int
pre
)
const
{
platform
::
Transform
<
DeviceContext
>
trans
;
trans
(
ctx_
,
x_
,
x_
+
nx_
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
),
z_
,
func_
);
if
(
is_xsize_larger_
)
{
trans
(
ctx_
,
x_
,
x_
+
nx_
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
),
z_
,
func_
);
}
else
{
trans
(
ctx_
,
y_
,
y_
+
nx_
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
x_
,
n
),
z_
,
func_
);
}
}
inline
void
RunMidWise
(
int
n
,
int
pre
,
int
post
)
const
{
platform
::
Transform
<
DeviceContext
>
trans
;
trans
(
ctx_
,
x_
,
x_
+
nx_
,
MidWiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
,
post
),
z_
,
func_
);
}
inline
void
RunMidRowWise
(
int
n
,
int
pre
,
int
post
)
const
{
platform
::
Transform
<
DeviceContext
>
trans
;
for
(
int
i
=
0
;
i
<
pre
;
i
++
)
{
trans
(
ctx_
,
x_
+
i
*
n
*
post
,
x_
+
(
i
+
1
)
*
n
*
post
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
y_
+
i
*
post
,
post
),
z_
+
i
*
n
*
post
,
func_
);
if
(
is_xsize_larger_
)
{
trans
(
ctx_
,
x_
,
x_
+
nx_
,
MidWiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
,
post
),
z_
,
func_
);
}
else
{
trans
(
ctx_
,
y_
,
y_
+
nx_
,
MidWiseTransformIterator
<
T
,
DeviceContext
>
(
x_
,
n
,
post
),
z_
,
func_
);
}
}
...
...
@@ -328,6 +726,7 @@ class TransformFunctor {
int64_t
nx_
;
const
DeviceContext
&
ctx_
;
Functor
func_
;
bool
is_xsize_larger_
;
};
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
...
...
@@ -354,20 +753,42 @@ struct ElemwiseGradNoBroadcast {
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
const
T
*
dout
,
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
w
;
++
j
)
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
is_xsize_larger
)
{
for
(
int
i
=
0
;
i
<
h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
w
;
++
j
)
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
else
{
// x.dims < y.dims, broadcast for x.
for
(
int
i
=
0
;
i
<
h
;
++
i
)
{
for
(
int
j
=
0
;
j
<
w
;
++
j
)
{
int
y_offset
=
i
*
w
+
j
;
if
(
dy
!=
nullptr
)
{
dy
[
y_offset
]
=
dy_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
if
(
dx
!=
nullptr
)
{
T
tmp
=
dx_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
if
(
i
==
0
)
{
dx
[
j
]
=
tmp
;
}
else
{
dx
[
j
]
+=
tmp
;
}
}
}
}
...
...
@@ -378,28 +799,48 @@ static void ElemwiseGradBroadcast1CPU(const T *x, const T *y, const T *out,
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
j
=
blockIdx
.
x
;
int
i
=
threadIdx
.
x
;
int
tid
=
threadIdx
.
x
;
T
val
(
0
);
if
(
is_xsize_larger
)
{
do
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
)
{
val
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
while
(
i
<
h
);
do
{
int
x_offset
=
i
*
w
+
j
;
if
(
dx
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
)
{
val
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
while
(
i
<
h
);
}
else
{
// x.dims < y.dims, broadcast for x.
do
{
int
y_offset
=
i
*
w
+
j
;
if
(
dy
)
{
dy
[
y_offset
]
=
dy_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
if
(
dx
)
{
val
+=
dx_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
i
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
while
(
i
<
h
);
if
(
dy
)
{
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
if
(
dx
)
{
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dx
[
j
]
=
val
;
}
}
}
}
...
...
@@ -412,7 +853,7 @@ static __global__ void ElemwiseGradBroadcast1CUDAKernel(
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
FastElemwiseGradBroadcast1CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
__shared__
T
sdata
[
BLOCK_Y
][
BLOCK_X
+
1
];
T
val
(
0
);
...
...
@@ -422,33 +863,66 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
(
w
&
(
~
((
uint64_t
)(
BLOCK_X
-
1
))))
+
((
w
&
(
BLOCK_X
-
1
))
?
BLOCK_X
:
0
);
size_t
full_height
=
(
h
&
(
~
((
uint64_t
)(
BLOCK_Y
-
1
))))
+
((
h
&
(
BLOCK_Y
-
1
))
?
BLOCK_Y
:
0
);
for
(
int
m
=
idx
;
m
<
full_width
;
m
+=
width_stride
)
{
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
=
0
;
for
(
int
n
=
threadIdx
.
y
;
n
<
full_height
;
n
+=
BLOCK_Y
)
{
int
x_offset
=
n
*
w
+
m
;
if
(
dx
&&
m
<
w
&&
n
<
h
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
m
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
is_xsize_larger
)
{
for
(
int
m
=
idx
;
m
<
full_width
;
m
+=
width_stride
)
{
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
=
0
;
for
(
int
n
=
threadIdx
.
y
;
n
<
full_height
;
n
+=
BLOCK_Y
)
{
int
x_offset
=
n
*
w
+
m
;
if
(
dx
&&
m
<
w
&&
n
<
h
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
m
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
)
{
if
(
m
<
w
&&
n
<
h
)
{
T
val
=
dy_op
(
x
[
x_offset
],
y
[
m
],
out
[
x_offset
],
dout
[
x_offset
]);
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
+=
val
;
}
__syncthreads
();
}
}
if
(
dy
)
{
if
(
m
<
w
&&
n
<
h
)
{
T
val
=
dy_op
(
x
[
x_offset
],
y
[
m
],
out
[
x_offset
],
dout
[
x_offset
]);
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
+=
val
;
T
my_val
=
sdata
[
threadIdx
.
x
][
threadIdx
.
y
];
for
(
int
i
=
warpSize
>>
1
;
i
>
0
;
i
>>=
1
)
my_val
+=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
my_val
,
i
);
__syncthreads
();
if
((
threadIdx
.
x
==
0
))
{
sdata
[
0
][
threadIdx
.
y
]
=
my_val
;
}
__syncthreads
();
if
(
threadIdx
.
y
==
0
&&
m
<
w
)
{
dy
[
m
]
=
sdata
[
0
][
threadIdx
.
x
];
}
}
}
if
(
dy
)
{
T
my_val
=
sdata
[
threadIdx
.
x
][
threadIdx
.
y
];
for
(
int
i
=
warpSize
>>
1
;
i
>
0
;
i
>>=
1
)
my_val
+=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
my_val
,
i
);
__syncthreads
();
if
((
threadIdx
.
x
==
0
))
{
sdata
[
0
][
threadIdx
.
y
]
=
my_val
;
}
else
{
// x.dims < y.dims, broadcast for x.
for
(
int
m
=
idx
;
m
<
full_width
;
m
+=
width_stride
)
{
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
=
0
;
for
(
int
n
=
threadIdx
.
y
;
n
<
full_height
;
n
+=
BLOCK_Y
)
{
int
y_offset
=
n
*
w
+
m
;
if
(
dy
&&
m
<
w
&&
n
<
h
)
{
dy
[
y_offset
]
=
dy_op
(
x
[
m
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
if
(
dx
)
{
if
(
m
<
w
&&
n
<
h
)
{
T
val
=
dy_op
(
x
[
m
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
sdata
[
threadIdx
.
y
][
threadIdx
.
x
]
+=
val
;
}
__syncthreads
();
}
}
__syncthreads
();
if
(
threadIdx
.
y
==
0
&&
m
<
w
)
{
dy
[
m
]
=
sdata
[
0
][
threadIdx
.
x
];
if
(
dx
)
{
T
my_val
=
sdata
[
threadIdx
.
x
][
threadIdx
.
y
];
for
(
int
i
=
warpSize
>>
1
;
i
>
0
;
i
>>=
1
)
my_val
+=
platform
::
CudaShuffleXorSync
(
0xFFFFFFFF
,
my_val
,
i
);
__syncthreads
();
if
((
threadIdx
.
x
==
0
))
{
sdata
[
0
][
threadIdx
.
y
]
=
my_val
;
}
__syncthreads
();
if
(
threadIdx
.
y
==
0
&&
m
<
w
)
{
dx
[
m
]
=
sdata
[
0
][
threadIdx
.
x
];
}
}
}
}
...
...
@@ -457,21 +931,21 @@ static __global__ void FastElemwiseGradBroadcast1CUDAKernel(
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast1CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
h
,
int
w
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
h
,
int
w
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
// For small case use 1D block
constexpr
int
half_walf
=
16
;
if
(
w
<
half_walf
||
h
<
half_walf
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
h
);
int
gird_size
=
w
;
ElemwiseGradBroadcast1CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dx
,
dy
);
x
,
y
,
out
,
dout
,
h
,
w
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
,
dy
);
}
else
{
// suppose perfoemance improves with h increased.
dim3
block_size
=
dim3
(
BLOCK_X
,
BLOCK_Y
);
int
grid_size
=
(
w
+
BLOCK_X
-
1
)
/
BLOCK_X
;
FastElemwiseGradBroadcast1CUDAKernel
<<<
grid_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
h
,
w
,
dx_op
,
dy_op
,
dx
,
dy
);
x
,
y
,
out
,
dout
,
h
,
w
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
,
dy
);
}
}
...
...
@@ -480,21 +954,44 @@ static void ElemwiseGradBroadcast1CUDA(cudaStream_t stream, const T *x,
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
if
(
is_xsize_larger
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
&&
k
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
i
==
0
&&
k
==
0
)
{
dy
[
j
]
=
tmp
;
}
else
{
dy
[
j
]
+=
tmp
;
}
}
}
else
{
// x.dims < y.dims, broadcast for x.
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
y_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dy
!=
nullptr
)
{
dy
[
y_offset
]
=
dy_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
if
(
dx
!=
nullptr
)
{
T
tmp
=
dx_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
if
(
i
==
0
&&
k
==
0
)
{
dx
[
j
]
=
tmp
;
}
else
{
dx
[
j
]
+=
tmp
;
}
}
}
}
...
...
@@ -506,37 +1003,66 @@ static void ElemwiseGradBroadcast2CPU(const T *x, const T *y, const T *out,
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcast2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
post
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
tid
=
threadIdx
.
x
;
int
j
=
blockIdx
.
x
;
T
val
(
0
);
int
ttid
=
tid
;
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
if
(
is_xsize_larger
)
{
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
val
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dy
!=
nullptr
)
{
val
+=
dy_op
(
x
[
x_offset
],
y
[
j
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
dy
)
{
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
}
}
}
else
{
// x.dims < y.dims, broadcast for x.
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
int
y_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
if
(
dy
)
{
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
j
]
=
val
;
if
(
dy
!=
nullptr
)
{
dy
[
y_offset
]
=
dy_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
if
(
dx
!=
nullptr
)
{
val
+=
dx_op
(
x
[
j
],
y
[
y_offset
],
out
[
y_offset
],
dout
[
y_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dx
)
{
int
h
=
pre
*
post
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
tid
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dx
[
j
]
=
val
;
}
}
}
}
...
...
@@ -544,98 +1070,57 @@ static __global__ void ElemwiseGradBroadcast2CUDAKernel(
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcast2CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
int
pre
,
int
n
,
int
post
,
bool
is_xsize_larger
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
pre
*
post
);
int
gird_size
=
n
;
ElemwiseGradBroadcast2CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcastMid2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
int
y_offset
=
i
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
j
==
0
)
{
dy
[
y_offset
]
=
tmp
;
}
else
{
dy
[
y_offset
]
+=
tmp
;
}
}
}
}
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
CommonElementwiseBroadcastBackward
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
int
max_dim
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
std
::
vector
<
int
>
x_dims_array
(
max_dim
);
std
::
vector
<
int
>
y_dims_array
(
max_dim
);
std
::
vector
<
int
>
out_dims_array
(
max_dim
);
GetBroadcastDimsArrays
(
x_dims
,
y_dims
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
axis
);
// for inplace strategy. memset will make dx and dout clear and get wrong
// result.
if
(
dx
&&
dout
.
Holder
()
==
dx
->
Holder
())
{
dx
->
clear
();
dx
->
mutable_data
<
T
>
(
x_dims
,
ctx
.
GetPlace
());
}
}
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcastMid2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
j
=
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
;
T
val
(
0
);
int
ttid
=
tid
;
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
int
y_offset
=
i
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
val
+=
dy_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dy
)
{
int
h
=
n
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
j
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
tid
]
=
val
;
}
CommonGradBroadcastCUDA
<
T
,
DX_OP
,
DY_OP
>
(
x
,
y
,
out
,
dout
,
dx
,
dy
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
dx_op
,
dy_op
);
#endif
}
else
{
CommonGradBroadcastCPU
<
T
,
DX_OP
,
DY_OP
>
(
x
,
y
,
out
,
dout
,
dx
,
dy
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
dx_op
,
dy_op
);
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcastMid2CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
n
);
int
gird_size
=
pre
*
post
;
ElemwiseGradBroadcastMid2CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
ElemwiseGradComputeNoBroadcast
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dim
,
...
...
@@ -659,47 +1144,54 @@ void ElemwiseGradComputeNoBroadcast(
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
ElemwiseGradComputeWithBroadcast
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dim
,
const
framework
::
DDim
&
y_dim
_untrimed
,
const
framework
::
Tensor
&
x
,
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dim
s
,
const
framework
::
DDim
&
y_dim
s
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
const
framework
::
Tensor
&
out
,
const
framework
::
Tensor
&
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
axis
=
(
axis
==
-
1
?
x_dim
.
size
()
-
y_dim_untrimed
.
size
()
:
axis
);
auto
y_dim
=
trim_trailing_singular_dims
(
y_dim_untrimed
);
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
bool
is_xsize_larger
=
true
;
int
max_dim
=
x_dims
.
size
();
if
(
x_dims
.
size
()
<
y_dims
.
size
())
{
is_xsize_larger
=
false
;
max_dim
=
y_dims
.
size
();
}
int
pre
,
n
,
post
,
mid_flag
=
0
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
,
&
mid_flag
);
if
(
mid_flag
)
{
PADDLE_ENFORCE_EQ
(
mid_flag
,
1
,
"mid_flag should be no more than 1."
);
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcastMid2CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcastMid2CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
if
(
post
==
1
)
{
axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
PADDLE_ENFORCE_GE
(
axis
,
0
,
"Axis should be in range [0, %d)"
,
axis
);
PADDLE_ENFORCE_LT
(
axis
,
max_dim
,
"Axis should be in range [0, %d)"
,
axis
);
int
pre
,
n
,
post
,
is_run_common_broadcast
,
axis_trim
=
0
;
if
(
is_xsize_larger
)
{
auto
y_dims_trimed
=
trim_trailing_singular_dims
(
y_dims
);
axis_trim
=
(
y_dims_trimed
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
get_mid_dims
(
x_dims
,
y_dims_trimed
,
axis_trim
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
}
else
{
auto
x_dims_trimed
=
trim_trailing_singular_dims
(
x_dims
);
axis_trim
=
(
x_dims_trimed
.
size
()
==
0
)
?
y_dims
.
size
()
:
axis
;
get_mid_dims
(
y_dims
,
x_dims_trimed
,
axis_trim
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
}
// special case for common backward implementation.
if
(
is_run_common_broadcast
)
{
CommonElementwiseBroadcastBackward
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
x_dims
,
y_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
return
;
}
if
(
post
==
1
)
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
dx_op
,
dy_op
,
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast1CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
dx_op
,
dy_op
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
...
...
@@ -708,20 +1200,56 @@ void ElemwiseGradComputeWithBroadcast(
#ifdef __NVCC__
ElemwiseGradBroadcast2CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast2CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
is_xsize_larger
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
}
template
<
typename
Functor
,
typename
DeviceContext
,
typename
T
,
typename
OutType
=
T
>
void
CommonElementwiseBroadcastForward
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
,
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
Functor
func
,
int
axis
,
const
bool
is_xsize_larger
=
true
)
{
int
max_dim
=
std
::
max
(
x_dims
.
size
(),
y_dims
.
size
());
axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
PADDLE_ENFORCE_GE
(
axis
,
0
,
"Axis should be in range [0, %d)"
,
axis
);
PADDLE_ENFORCE_LT
(
axis
,
max_dim
,
"Axis should be in range [0, %d)"
,
axis
);
std
::
vector
<
int
>
x_dims_array
(
max_dim
);
std
::
vector
<
int
>
y_dims_array
(
max_dim
);
std
::
vector
<
int
>
out_dims_array
(
max_dim
);
GetBroadcastDimsArrays
(
x_dims
,
y_dims
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
axis
);
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
CommonForwardBroadcastCUDA
<
Functor
,
T
>
(
x
,
y
,
z
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
ctx
.
template
device_context
<
platform
::
CUDADeviceContext
>(),
func
,
is_xsize_larger
);
#endif
}
else
{
CommonForwardBroadcastCPU
<
Functor
,
T
,
OutType
>
(
x
,
y
,
z
,
x_dims_array
.
data
(),
y_dims_array
.
data
(),
out_dims_array
.
data
(),
max_dim
,
ctx
.
template
device_context
<
platform
::
CPUDeviceContext
>(),
func
,
is_xsize_larger
);
}
}
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
ElemwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
&
x
,
const
framework
::
Tensor
&
y
,
...
...
@@ -734,7 +1262,7 @@ void ElemwiseGradCompute(const framework::ExecutionContext &ctx,
if
(
x
.
dims
()
==
y
.
dims
())
{
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
x_dim
,
y_dim
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
// Y is a scalar
}
else
{
ElemwiseGradComputeWithBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
x_dim
,
y_dim
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
...
...
@@ -752,103 +1280,61 @@ void ElemwiseExplicitGradCompute(const framework::ExecutionContext &ctx,
const
framework
::
Tensor
&
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
,
DX_OP
dx_op
,
DY_OP
dy_op
)
{
if
(
dy
==
nullptr
)
{
const
framework
::
DDim
&
dx_dims
=
dout
.
dims
();
auto
dy_dims
=
dx_dims
;
const
framework
::
DDim
&
x_dim
=
x
.
dims
();
const
framework
::
DDim
&
y_dim
=
y
.
dims
();
if
(
x
.
dims
()
==
y
.
dims
())
{
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
dx_dims
,
dy_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
if
(
dout
.
dims
()
==
dy
->
dims
())
{
const
framework
::
DDim
&
dx_dims
=
dout
.
dims
();
const
framework
::
DDim
&
dy_dims
=
dy
->
dims
();
ElemwiseGradComputeNoBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
dx_dims
,
dy_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
// Y is a scalar
auto
dx_dims
=
dout
.
dims
();
const
framework
::
DDim
&
dy_dims
=
dy
->
dims
();
ElemwiseGradComputeWithBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
dx_dims
,
dy_dims
,
x
,
y
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
}
}
// Deprecated
template
<
typename
DeviceContext
,
typename
T
,
typename
functor
,
typename
broadcastfunctor
,
typename
broadcast2functor
>
void
ElementwiseGradCompute
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
const
framework
::
Tensor
*
out
,
const
framework
::
Tensor
*
dout
,
int
axis
,
framework
::
Tensor
*
dx
,
framework
::
Tensor
*
dy
)
{
auto
&
place
=
*
ctx
.
template
device_context
<
DeviceContext
>().
eigen_device
();
auto
x_dims
=
x
->
dims
();
auto
y_dims
=
y
->
dims
();
if
(
dx
)
{
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
if
(
dy
)
{
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
}
if
(
x_dims
==
y_dims
)
{
functor
f
;
f
(
place
,
x
,
y
,
out
,
dx
,
dy
,
dout
);
return
;
}
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims
.
size
()
:
axis
);
trim_trailing_singular_dims
(
y_dims
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
broadcastfunctor
f
;
f
(
place
,
x
,
y
,
out
,
dx
,
dy
,
dout
,
pre
,
n
);
return
;
ctx
,
x_dim
,
y_dim
,
dout
,
dout
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
else
{
broadcast2functor
f
;
f
(
place
,
x
,
y
,
out
,
dx
,
dy
,
dout
,
pre
,
n
,
post
);
return
;
ElemwiseGradComputeWithBroadcast
<
DeviceContext
,
T
,
DX_OP
,
DY_OP
>
(
ctx
,
x_dim
,
y_dim
,
dout
,
dout
,
out
,
dout
,
axis
,
dx
,
dy
,
dx_op
,
dy_op
);
}
}
template
<
typename
Functor
,
typename
DeviceContext
,
typename
T
,
typename
OutType
=
T
>
void
ElementwiseComputeEx
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
int
axis
,
Functor
func
,
framework
::
Tensor
*
z
)
{
TransformFunctor
<
Functor
,
T
,
DeviceContext
,
OutType
>
functor
(
x
,
y
,
z
,
ctx
.
template
device_context
<
DeviceContext
>(),
func
);
auto
x_dims
=
x
->
dims
();
auto
y_dims_untrimed
=
y
->
dims
();
PADDLE_ENFORCE_GE
(
x_dims
.
size
(),
y_dims_untrimed
.
size
(),
"ShapeError: the dimension of input X must greater than or equal to "
"the one of input Y. But received: the shape of input X = [%s], the "
"dimension of input X = %d, the shape of input Y = [%s], the dimension "
"of of input Y = %d"
,
x_dims
,
x_dims
.
size
(),
y_dims_untrimed
,
y_dims_untrimed
.
size
());
if
(
x_dims
==
y_dims_untrimed
)
{
auto
y_dims
=
y
->
dims
();
bool
is_xsize_larger
=
true
;
int
max_dim
=
x_dims
.
size
();
if
(
x_dims
.
size
()
<
y_dims
.
size
())
{
is_xsize_larger
=
false
;
max_dim
=
y_dims
.
size
();
}
TransformFunctor
<
Functor
,
T
,
DeviceContext
,
OutType
>
functor
(
x
,
y
,
z
,
ctx
.
template
device_context
<
DeviceContext
>(),
func
,
is_xsize_larger
);
if
(
x_dims
==
y_dims
)
{
functor
.
Run
();
return
;
}
axis
=
(
axis
==
-
1
?
x_dims
.
size
()
-
y_dims_untrimed
.
size
()
:
axis
);
PADDLE_ENFORCE
(
axis
>=
0
&&
axis
<
x_dims
.
size
(),
"Axis should be in range [0, x_dims)"
);
auto
y_dims
=
trim_trailing_singular_dims
(
y_dims_untrimed
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
,
mid_flag
=
0
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
,
&
mid_flag
);
if
(
mid_flag
)
{
functor
.
RunMidRowWise
(
n
,
pre
,
post
);
axis
=
(
axis
==
-
1
?
std
::
abs
(
x_dims
.
size
()
-
y_dims
.
size
())
:
axis
);
PADDLE_ENFORCE_GE
(
axis
,
0
,
"Axis should be in range [0, %d)"
,
axis
);
PADDLE_ENFORCE_LT
(
axis
,
max_dim
,
"Axis should be in range [0, %d)"
,
axis
);
int
pre
,
n
,
post
,
is_run_common_broadcast
,
axis_trim
=
0
;
if
(
is_xsize_larger
)
{
auto
y_dims_trimed
=
trim_trailing_singular_dims
(
y_dims
);
axis_trim
=
(
y_dims_trimed
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
get_mid_dims
(
x_dims
,
y_dims_trimed
,
axis_trim
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
}
else
{
auto
x_dims_trimed
=
trim_trailing_singular_dims
(
x_dims
);
axis_trim
=
(
x_dims_trimed
.
size
()
==
0
)
?
y_dims
.
size
()
:
axis
;
get_mid_dims
(
y_dims
,
x_dims_trimed
,
axis_trim
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
}
// special case for common implementation.
// case 1: x=[2,3,1,5], y=[2,1,4,1]
// case 2: x=[2,3,4], y=[1,1,4]
if
(
is_run_common_broadcast
==
1
)
{
CommonElementwiseBroadcastForward
<
Functor
,
DeviceContext
,
T
,
OutType
>
(
ctx
,
x
,
y
,
z
,
x_dims
,
y_dims
,
func
,
axis
,
is_xsize_larger
);
return
;
}
if
(
post
==
1
)
{
...
...
@@ -1114,9 +1600,8 @@ void FusedElemwiseAndActComputeWithBroadcast(
auto
y_dim
=
trim_trailing_singular_dims
(
y_dim_untrimed
);
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
);
int
pre
,
n
,
post
,
is_run_common_broadcast
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
if
(
post
==
1
)
{
int
h
=
pre
;
int
w
=
n
;
...
...
@@ -1628,8 +2113,8 @@ void FusedElemwiseAndActGradComputeWithBroadcast(
auto
y_dim
=
trim_trailing_singular_dims
(
y_dim_untrimed
);
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
);
int
pre
,
n
,
post
,
is_run_common_broadcast
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
if
(
post
==
1
)
{
int
h
=
pre
;
int
w
=
n
;
...
...
@@ -1763,16 +2248,7 @@ void FusedElemwiseAndActComputeEx(const framework::ExecutionContext &ctx,
}
else
{
// Whether the shape of Y is a continuous subsequence of X,
// For more information please refer to the op's introduction.
bool
bcast_y
=
x
.
dims
().
size
()
>=
y
.
dims
().
size
();
if
(
x
.
dims
().
size
()
==
y
.
dims
().
size
())
{
for
(
int
i
=
0
;
i
<
x
.
dims
().
size
();
++
i
)
{
if
(
x
.
dims
()[
i
]
<
y
.
dims
()[
i
])
{
bcast_y
=
false
;
break
;
}
}
}
bool
bcast_y
=
x
.
numel
()
>=
y
.
numel
();
// z = f1(x, f2(y))
// z = f1(f2(x, y))
if
(
bcast_y
)
{
// Y should be broadcast.
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.cc
浏览文件 @
0e7baabe
...
...
@@ -93,13 +93,14 @@ class ElementwiseSubDoubleGradMaker : public framework::SingleGradOpMaker<T> {
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_ELEMWISE_GRAD_MAKER
(
elementwise_sub
,
Sub
);
REGISTER_ELEMWISE_EXPLICIT_OP_WITHOUT_GRAD
(
elementwise_sub
,
Sub
);
namespace
ops
=
paddle
::
operators
;
REGISTER_OPERATOR
(
elementwise_sub_grad
,
ops
::
ElementwiseOp
ExplicitGrad
,
ops
::
ElementwiseGrad
OpInplace
,
ops
::
ElementwiseGrad
NoBufVarsInference
,
elementwise_sub_grad
,
ops
::
ElementwiseOp
Grad
,
ops
::
ElementwiseGradOpInplace
,
ops
::
ElementwiseGradNoBufVarsInference
,
ops
::
ElementwiseSubDoubleGradMaker
<
paddle
::
framework
::
OpDesc
>
,
ops
::
ElementwiseSubDoubleGradMaker
<
paddle
::
imperative
::
OpBase
>
);
REGISTER_OPERATOR
(
elementwise_sub_grad_grad
,
...
...
paddle/fluid/operators/elementwise/elementwise_sub_op.h
浏览文件 @
0e7baabe
...
...
@@ -26,8 +26,13 @@ void default_elementwise_sub(const framework::ExecutionContext& ctx,
const
framework
::
Tensor
*
x
,
const
framework
::
Tensor
*
y
,
framework
::
Tensor
*
z
)
{
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
SubFunctor
<
T
>
(),
z
);
if
(
x
->
numel
()
>=
y
->
numel
())
{
ElementwiseComputeEx
<
SubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
SubFunctor
<
T
>
(),
z
);
}
else
{
ElementwiseComputeEx
<
InverseSubFunctor
<
T
>
,
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
axis
,
InverseSubFunctor
<
T
>
(),
z
);
}
}
template
<
typename
DeviceContext
,
typename
T
,
class
Enable
=
void
>
...
...
@@ -98,13 +103,14 @@ class ElementwiseSubGradKernel : public ElemwiseGradKernel<T> {
ElemwiseGradKernel
<
T
>::
Compute
(
ctx
);
using
Tensor
=
framework
::
Tensor
;
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dy
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"Y"
));
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
// skip out
, x, y
// skip out
auto
*
out
=
dout
;
auto
*
x
=
dout
,
*
y
=
dout
;
if
(
dx
!=
nullptr
&&
dy
!=
nullptr
&&
(
dx
->
dims
()
==
dy
->
dims
()))
{
elementwise_sub_grad
<
DeviceContext
,
T
>
(
ctx
,
x
,
y
,
out
,
dout
,
dx
,
dy
);
}
else
{
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_add_mkldnn_op.cc
浏览文件 @
0e7baabe
...
...
@@ -108,8 +108,9 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
auto
y_dims
=
trim_trailing_singular_dims
(
y_dims_untrimed
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
int
pre
,
n
,
post
,
is_run_common_broadcast
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
if
(
post
==
1
)
{
functor
.
RunRowWise
(
n
,
pre
);
...
...
@@ -212,6 +213,8 @@ class EltwiseAddMKLDNNGradKernel : public ElemwiseGradKernel<T> {
}
}
else
{
// Execute default kernel when broadcast is needed
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
y
=
ctx
.
Input
<
Tensor
>
(
"Y"
);
ElemwiseExplicitGradCompute
<
paddle
::
platform
::
CPUDeviceContext
,
T
,
IdentityGrad
<
T
>
,
IdentityGrad
<
T
>>
(
ctx
,
*
x
,
*
y
,
*
out
,
*
dout
,
axis
,
dx
,
dy
,
IdentityGrad
<
T
>
(),
...
...
paddle/fluid/operators/elementwise/mkldnn/elementwise_mul_mkldnn_op.cc
浏览文件 @
0e7baabe
...
...
@@ -91,8 +91,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
const
bool
is_y_format_correct
=
y
->
format
()
==
MKLDNNMemoryFormat
::
nc
;
if
(
is_x_format_correct
&&
is_y_format_correct
&&
are_dims_divisable
&&
is_avx512_enabled
)
{
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims_untrimmed
,
axis
,
&
pre
,
&
n
,
&
post
);
int
pre
,
n
,
post
,
is_run_common_broadcast
;
get_mid_dims
(
x_dims
,
y_dims_untrimmed
,
axis
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
if
(
post
==
1
)
{
PADDLE_THROW
(
"Not implemented when post is 1"
);
...
...
@@ -168,8 +169,9 @@ class ElementwiseMulMKLDNNKernel : public framework::OpKernel<T> {
auto
y_dims
=
trim_trailing_singular_dims
(
y_dims_untrimmed
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
int
pre
,
n
,
post
,
is_run_common_broadcast
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
,
&
is_run_common_broadcast
);
if
(
post
==
1
)
{
functor
.
RunRowWise
(
n
,
pre
);
...
...
paddle/fluid/operators/layer_norm_op.h
浏览文件 @
0e7baabe
...
...
@@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.cu.h"
#include "paddle/fluid/operators/elementwise/elementwise_op_function.h"
#include "paddle/fluid/operators/math/blas.h"
#if !defined(PADDLE_WITH_CUDA) && !defined(_WIN32) && !defined(__APPLE__) && \
...
...
@@ -139,21 +140,6 @@ struct DivAndSqrtFunctor {
T
epsilon_
;
};
template
<
typename
T
>
struct
MulFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
*
b
;
}
};
template
<
typename
T
>
struct
AddFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
+
b
;
}
};
template
<
typename
T
>
struct
SubFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
return
a
-
b
;
}
};
template
<
typename
T
>
struct
MulInvVarFunctor
{
inline
HOSTDEVICE
T
operator
()(
T
a
,
T
b
)
const
{
...
...
paddle/fluid/operators/op_debug_string_test.cc
浏览文件 @
0e7baabe
...
...
@@ -32,6 +32,7 @@ TEST(op_debug_str, test_unknown_dtype) {
framework
::
Scope
scope
;
desc
.
SetType
(
"elementwise_add_grad"
);
desc
.
SetInput
(
"X"
,
{
"X"
});
desc
.
SetInput
(
"Y"
,
{
"Y"
});
desc
.
SetInput
(
framework
::
GradVarName
(
"Out"
),
{
framework
::
GradVarName
(
"Out"
)});
desc
.
SetOutput
(
framework
::
GradVarName
(
"X"
),
{
framework
::
GradVarName
(
"X"
)});
...
...
@@ -41,6 +42,10 @@ TEST(op_debug_str, test_unknown_dtype) {
desc
.
SetAttr
(
"x_data_format"
,
""
);
desc
.
SetAttr
(
"y_data_format"
,
""
);
auto
x_tensor
=
scope
.
Var
(
"X"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
x_tensor
->
Resize
(
dim
);
x_tensor
->
mutable_data
<
float
>
(
place
);
auto
y_tensor
=
scope
.
Var
(
"Y"
)
->
GetMutable
<
framework
::
LoDTensor
>
();
y_tensor
->
Resize
(
dim
);
y_tensor
->
mutable_data
<
float
>
(
place
);
...
...
paddle/fluid/train/demo/run.sh
浏览文件 @
0e7baabe
...
...
@@ -7,8 +7,8 @@ TURN_ON_MKL=$2 # use MKL or Openblas
# download models
function
download
()
{
wget
-q
http://paddle-tar.bj.bcebos.com/train_demo/LR/main_program
wget
-q
http://paddle-tar.bj.bcebos.com/train_demo/LR/startup_program
wget
-q
http://paddle-tar.bj.bcebos.com/train_demo/LR
-1-7
/main_program
wget
-q
http://paddle-tar.bj.bcebos.com/train_demo/LR
-1-7
/startup_program
}
download
...
...
python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
浏览文件 @
0e7baabe
...
...
@@ -308,6 +308,36 @@ class TestFP16ElementwiseAddOp_channelwise_add(TestFP16ElementwiseAddOp):
self
.
axis
=
-
1
class
TestElementwiseAddOp_commonuse_add1
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
1
,
1
,
4
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
def
init_axis
(
self
):
self
.
axis
=
-
1
class
TestElementwiseAddOp_commonuse_add2
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
2
,
1
,
4
,
1
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
def
init_axis
(
self
):
self
.
axis
=
-
1
class
TestElementwiseAddOp_xsize_lessthan_ysize_add
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
4
,
5
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
def
init_axis
(
self
):
self
.
axis
=
2
class
TestElementwiseAddOpError
(
OpTest
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
...
...
python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
浏览文件 @
0e7baabe
...
...
@@ -151,6 +151,39 @@ class TestElementwiseDivOp_broadcast_5(ElementwiseDivOp):
self
.
outputs
=
{
'Out'
:
np
.
divide
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseDivOp_commonuse_1
(
ElementwiseDivOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_div"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
]).
astype
(
"float32"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
1
,
1
,
4
]).
astype
(
"float32"
),
}
self
.
outputs
=
{
'Out'
:
np
.
divide
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseDivOp_commonuse_2
(
ElementwiseDivOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_div"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
1
,
5
]).
astype
(
"float32"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
1
,
4
,
1
]).
astype
(
"float32"
),
}
self
.
outputs
=
{
'Out'
:
np
.
divide
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseDivOp_xsize_lessthan_ysize
(
ElementwiseDivOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_div"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
4
,
5
]).
astype
(
"float32"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
,
5
]).
astype
(
"float32"
),
}
self
.
attrs
=
{
'axis'
:
2
}
self
.
outputs
=
{
'Out'
:
np
.
divide
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseDivOp_INT
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_div"
...
...
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
浏览文件 @
0e7baabe
...
...
@@ -162,6 +162,41 @@ class TestElementwiseMulOpFp16(ElementwiseMulOp):
self
.
dtype
=
np
.
float16
class
TestElementwiseMulOp_commonuse_1
(
ElementwiseMulOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_mul"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
np
.
float64
),
'Y'
:
np
.
random
.
rand
(
1
,
1
,
4
).
astype
(
np
.
float64
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
inputs
[
'Y'
]}
class
TestElementwiseMulOp_commonuse_2
(
ElementwiseMulOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_mul"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
np
.
float64
),
'Y'
:
np
.
random
.
rand
(
2
,
1
,
4
,
1
).
astype
(
np
.
float64
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
inputs
[
'Y'
]}
class
TestElementwiseMulOp_xsize_lessthan_ysize
(
ElementwiseMulOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_mul"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
4
,
5
).
astype
(
np
.
float64
),
'Y'
:
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
np
.
float64
)
}
self
.
attrs
=
{
'axis'
:
2
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
reshape
(
1
,
1
,
4
,
5
)
*
self
.
inputs
[
'Y'
]
}
class
TestElementwiseMulOpError
(
OpTest
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
...
...
python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
浏览文件 @
0e7baabe
...
...
@@ -127,5 +127,40 @@ class TestElementwiseSubOp_broadcast_4(TestElementwiseOp):
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
class
TestElementwiseSubOp_commonuse_1
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
np
.
float32
),
'Y'
:
np
.
random
.
rand
(
1
,
1
,
4
).
astype
(
np
.
float32
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
class
TestElementwiseSubOp_commonuse_2
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
np
.
float32
),
'Y'
:
np
.
random
.
rand
(
2
,
1
,
4
,
1
).
astype
(
np
.
float32
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
class
TestElementwiseSubOp_xsize_lessthan_ysize
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
4
,
5
).
astype
(
np
.
float32
),
'Y'
:
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
)
}
self
.
attrs
=
{
'axis'
:
2
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
].
reshape
(
1
,
1
,
4
,
5
)
-
self
.
inputs
[
'Y'
]
}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_executor_return_tensor_not_overwriting.py
浏览文件 @
0e7baabe
...
...
@@ -47,7 +47,7 @@ class TestExecutorReturnTensorNotOverwritingWithOptest(OpTest):
'Y'
:
OpTest
.
np_dtype_to_fluid_dtype
(
self
.
y
)
}
self
.
outputs
=
{
'Out'
:
self
.
out
}
self
.
op_type
=
"
elementwise_
mul"
self
.
op_type
=
"mul"
self
.
dtype
=
np
.
float32
outs
,
fetch_list
=
self
.
_calc_output
(
place
,
parallel
=
parallel
)
return
outs
...
...
python/paddle/fluid/tests/unittests/test_fl_listen_and_serv_op.py
浏览文件 @
0e7baabe
...
...
@@ -57,7 +57,7 @@ def run_trainer(use_cuda, sync_mode, ip, port, trainers, trainer_id):
exe
.
run
(
trainer_startup_program
)
for
i
in
range
(
5
):
exe
.
run
(
recv_program
)
exe
.
run
(
main_program
,
exe
.
run
(
fluid
.
default_main_program
()
,
feed
=
{
"x"
:
numpy
.
array
([
1
,
2
]).
astype
(
'float32'
).
reshape
(
2
,
1
),
"y"
:
numpy
.
array
([
2
,
3
]).
astype
(
'float32'
).
reshape
(
2
,
1
)
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录