Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
8672e153
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
提交
8672e153
编写于
9月 04, 2019
作者:
D
danleifeng
提交者:
gongweibao
9月 04, 2019
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
elementwise broadcast function enhancement (#19536)
elementwise broadcast function enhancement
上级
a50785b0
变更
12
隐藏空白更改
内联
并排
Showing
12 changed file
with
293 addition
and
32 deletion
+293
-32
paddle/fluid/operators/elementwise/elementwise_op_function.h
paddle/fluid/operators/elementwise/elementwise_op_function.h
+172
-23
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_max_ngraph_op.py
.../tests/unittests/ngraph/test_elementwise_max_ngraph_op.py
+1
-1
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_min_ngraph_op.py
.../tests/unittests/ngraph/test_elementwise_min_ngraph_op.py
+1
-1
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py
.../tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py
+1
-1
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py
.../tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py
+1
-1
python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_add_op.py
+28
-0
python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_div_op.py
+20
-0
python/paddle/fluid/tests/unittests/test_elementwise_max_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_max_op.py
+12
-0
python/paddle/fluid/tests/unittests/test_elementwise_min_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_min_op.py
+17
-5
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
+20
-0
python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
+10
-0
python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
...n/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
+10
-0
未找到文件。
paddle/fluid/operators/elementwise/elementwise_op_function.h
浏览文件 @
8672e153
...
...
@@ -47,25 +47,65 @@ namespace operators {
* 2. shape(X) = (2, 3, 4, 5), shape(Y) = (4,5)
* pre=2*3, n=4*5, post=1
* x.shape(6, 20, 1) * y.shape(1, 20, 1).broadcast(6, 20, 1)
*
* New parameter: *mid_flag* is added to solve m*n*k & m*1*k
* broadcast cases.
* 3. shape(X) = (2, 3, 4, 5), shape(Y) = (2, 1, 4, 5)
* mid_flag should not be NULL.
* x.shape(2, 3, 20) * y.shape(2, 1, 20).broadcast(2, 3, 20)
*/
inline
void
get_mid_dims
(
const
framework
::
DDim
&
x_dims
,
const
framework
::
DDim
&
y_dims
,
const
int
axis
,
int
*
pre
,
int
*
n
,
int
*
post
)
{
int
*
pre
,
int
*
n
,
int
*
post
,
int
*
mid_flag
=
NULL
)
{
*
pre
=
1
;
*
n
=
1
;
*
post
=
1
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
if
(
mid_flag
!=
NULL
)
{
*
mid_flag
=
0
;
int
mid
=
0
;
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
if
(
x_dims
[
i
+
axis
]
!=
y_dims
[
i
])
{
// only support single y_dims[i] = 1 now.
PADDLE_ENFORCE_EQ
(
*
mid_flag
,
0
,
"Broadcast support y_dims with single 1."
);
PADDLE_ENFORCE_EQ
(
y_dims
[
i
],
1
,
"Broadcast dimension mismatch."
);
// m*n*k m*1*k
for
(
int
j
=
0
;
j
<
i
;
++
j
)
{
(
*
pre
)
*=
y_dims
[
j
];
}
*
n
=
std
::
max
(
x_dims
[
i
+
axis
],
y_dims
[
i
]);
*
mid_flag
=
1
;
mid
=
i
;
break
;
}
(
*
n
)
*=
y_dims
[
i
];
}
if
(
*
mid_flag
)
{
for
(
int
i
=
mid
+
1
;
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
}
else
{
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
}
}
else
{
// for fused_elementwise_activation_op. keep the old version.
for
(
int
i
=
0
;
i
<
axis
;
++
i
)
{
(
*
pre
)
*=
x_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
],
"Broadcast dimension mismatch."
);
(
*
n
)
*=
y_dims
[
i
];
}
for
(
int
i
=
0
;
i
<
y_dims
.
size
();
++
i
)
{
PADDLE_ENFORCE_EQ
(
x_dims
[
i
+
axis
],
y_dims
[
i
],
"Broadcast dimension mismatch."
);
(
*
n
)
*=
y_dims
[
i
];
}
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
for
(
int
i
=
axis
+
y_dims
.
size
();
i
<
x_dims
.
size
();
++
i
)
{
(
*
post
)
*=
x_dims
[
i
];
}
}
}
...
...
@@ -171,7 +211,6 @@ class MidWiseTransformIterator<T, platform::CPUDeviceContext>
}
}
}
return
*
this
;
}
...
...
@@ -268,6 +307,15 @@ class TransformFunctor {
MidWiseTransformIterator
<
T
,
DeviceContext
>
(
y_
,
n
,
post
),
z_
,
func_
);
}
inline
void
RunMidRowWise
(
int
n
,
int
pre
,
int
post
)
const
{
platform
::
Transform
<
DeviceContext
>
trans
;
for
(
int
i
=
0
;
i
<
pre
;
i
++
)
{
trans
(
ctx_
,
x_
+
i
*
n
*
post
,
x_
+
(
i
+
1
)
*
n
*
post
,
RowwiseTransformIterator
<
T
,
DeviceContext
>
(
y_
+
i
*
post
,
post
),
z_
+
i
*
n
*
post
,
func_
);
}
}
private:
const
T
*
x_
;
const
T
*
y_
;
...
...
@@ -501,6 +549,88 @@ static void ElemwiseGradBroadcast2CUDA(cudaStream_t stream, const T *x,
#endif
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcastMid2CPU
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
for
(
int
i
=
0
;
i
<
pre
;
++
i
)
{
for
(
int
j
=
0
;
j
<
n
;
++
j
)
{
for
(
int
k
=
0
;
k
<
post
;
++
k
)
{
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
int
y_offset
=
i
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
T
tmp
=
dy_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
if
(
j
==
0
)
{
dy
[
y_offset
]
=
tmp
;
}
else
{
dy
[
y_offset
]
+=
tmp
;
}
}
}
}
}
}
#ifdef __NVCC__
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
__global__
void
ElemwiseGradBroadcastMid2CUDAKernel
(
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
j
=
threadIdx
.
x
;
int
tid
=
blockIdx
.
x
;
T
val
(
0
);
int
ttid
=
tid
;
while
(
true
)
{
int
i
=
ttid
/
post
;
int
k
=
ttid
%
post
;
if
(
i
>=
pre
)
break
;
int
x_offset
=
i
*
n
*
post
+
j
*
post
+
k
;
int
y_offset
=
i
*
post
+
k
;
if
(
dx
!=
nullptr
)
{
dx
[
x_offset
]
=
dx_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
}
if
(
dy
!=
nullptr
)
{
val
+=
dy_op
(
x
[
x_offset
],
y
[
y_offset
],
out
[
x_offset
],
dout
[
x_offset
]);
}
ttid
+=
ELEMWISE_MAX_BLOCK_DIM
;
}
if
(
dy
)
{
int
h
=
n
;
h
=
h
>
ELEMWISE_MAX_BLOCK_DIM
?
ELEMWISE_MAX_BLOCK_DIM
:
h
;
val
=
paddle
::
platform
::
reduceSum
(
val
,
j
,
h
);
if
(
threadIdx
.
x
==
0
)
{
dy
[
tid
]
=
val
;
}
}
}
template
<
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
static
void
ElemwiseGradBroadcastMid2CUDA
(
cudaStream_t
stream
,
const
T
*
x
,
const
T
*
y
,
const
T
*
out
,
const
T
*
dout
,
int
pre
,
int
n
,
int
post
,
DX_OP
dx_op
,
DY_OP
dy_op
,
T
*
dx
,
T
*
dy
)
{
int
block_size
=
std
::
min
(
ELEMWISE_MAX_BLOCK_DIM
,
n
);
int
gird_size
=
pre
*
post
;
ElemwiseGradBroadcastMid2CUDAKernel
<<<
gird_size
,
block_size
,
0
,
stream
>>>
(
x
,
y
,
out
,
dout
,
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
,
dy
);
}
#endif
template
<
typename
DeviceContext
,
typename
T
,
typename
DX_OP
,
typename
DY_OP
>
void
ElemwiseGradComputeNoBroadcast
(
const
framework
::
ExecutionContext
&
ctx
,
const
framework
::
DDim
&
x_dim
,
...
...
@@ -533,23 +663,39 @@ void ElemwiseGradComputeWithBroadcast(
auto
y_dim
=
trim_trailing_singular_dims
(
y_dim_untrimed
);
axis
=
(
y_dim
.
size
()
==
0
)
?
x_dim
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
);
if
(
post
==
1
)
{
int
h
=
pre
;
int
w
=
n
;
int
pre
,
n
,
post
,
mid_flag
=
0
;
get_mid_dims
(
x_dim
,
y_dim
,
axis
,
&
pre
,
&
n
,
&
post
,
&
mid_flag
);
if
(
mid_flag
)
{
PADDLE_ENFORCE_EQ
(
mid_flag
,
1
,
"mid_flag should be no more than 1."
);
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcastMid2CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcastMid2CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
post
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
if
(
post
==
1
)
{
if
(
platform
::
is_gpu_place
(
ctx
.
GetPlace
()))
{
#ifdef __NVCC__
ElemwiseGradBroadcast1CUDA
(
ctx
.
template
device_context
<
DeviceContext
>().
stream
(),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
#endif
}
else
{
ElemwiseGradBroadcast1CPU
(
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
h
,
w
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
x
.
data
<
T
>
(),
y
.
data
<
T
>
(),
out
.
data
<
T
>
(),
dout
.
data
<
T
>
(),
pre
,
n
,
dx_op
,
dy_op
,
dx
==
nullptr
?
nullptr
:
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()),
dy
==
nullptr
?
nullptr
:
dy
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
}
}
else
{
...
...
@@ -689,9 +835,12 @@ void ElementwiseComputeEx(const framework::ExecutionContext &ctx,
"Axis should be in range [0, x_dims)"
);
auto
y_dims
=
trim_trailing_singular_dims
(
y_dims_untrimed
);
axis
=
(
y_dims
.
size
()
==
0
)
?
x_dims
.
size
()
:
axis
;
int
pre
,
n
,
post
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
);
int
pre
,
n
,
post
,
mid_flag
=
0
;
get_mid_dims
(
x_dims
,
y_dims
,
axis
,
&
pre
,
&
n
,
&
post
,
&
mid_flag
);
if
(
mid_flag
)
{
functor
.
RunMidRowWise
(
n
,
pre
,
post
);
return
;
}
if
(
post
==
1
)
{
functor
.
RunRowWise
(
n
,
pre
);
return
;
...
...
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_max_ngraph_op.py
浏览文件 @
8672e153
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
,
sys
sys
.
path
.
append
(
"../"
)
from
test_elementwise_max_op
import
*
from
test_elementwise_max_op
import
TestElementwiseMaxOp_scalar
,
TestElementwiseMaxOp_Vector
,
TestElementwiseMaxOp_broadcast_0
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_min_ngraph_op.py
浏览文件 @
8672e153
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
,
sys
sys
.
path
.
append
(
"../"
)
from
test_elementwise_min_op
import
*
from
test_elementwise_min_op
import
TestElementwiseMinOp_scalar
,
TestElementwiseMinOp_Vector
,
TestElementwiseMinOp_broadcast_0
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_pow_ngraph_op.py
浏览文件 @
8672e153
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
,
sys
sys
.
path
.
append
(
"../"
)
from
test_elementwise_pow_op
import
*
from
test_elementwise_pow_op
import
TestElementwisePowOp_scalar
,
TestElementwisePowOp_tensor
,
TestElementwisePowOp_broadcast_0
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/ngraph/test_elementwise_sub_ngraph_op.py
浏览文件 @
8672e153
...
...
@@ -16,7 +16,7 @@ from __future__ import print_function
import
unittest
,
sys
sys
.
path
.
append
(
"../"
)
from
test_elementwise_sub_op
import
*
from
test_elementwise_sub_op
import
TestElementwiseSubOp_scalar
,
TestElementwiseSubOp_Vector
,
TestElementwiseSubOp_broadcast_0
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_elementwise_add_op.py
浏览文件 @
8672e153
...
...
@@ -218,6 +218,34 @@ class TestFP16ElementwiseAddOp_broadcast_4(TestFP16ElementwiseAddOp):
self
.
axis
=
0
class
TestElementwiseAddOp_broadcast_5
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
2
,
1
,
4
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
class
TestFP16ElementwiseAddOp_broadcast_5
(
TestFP16ElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
2
,
1
,
4
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
class
TestElementwiseAddOp_broadcast_6
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
class
TestFP16ElementwiseAddOp_broadcast_6
(
TestFP16ElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
self
.
dtype
)
self
.
y
=
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
self
.
dtype
)
self
.
out
=
self
.
x
+
self
.
y
class
TestElementwiseAddOp_rowwise_add_0
(
TestElementwiseAddOp
):
def
init_input_output
(
self
):
self
.
x
=
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
self
.
dtype
)
...
...
python/paddle/fluid/tests/unittests/test_elementwise_div_op.py
浏览文件 @
8672e153
...
...
@@ -131,6 +131,26 @@ class TestElementwiseDivOp_broadcast_3(ElementwiseDivOp):
}
class
TestElementwiseDivOp_broadcast_4
(
ElementwiseDivOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_div"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
]).
astype
(
"float32"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
1
,
4
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'Out'
:
np
.
divide
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseDivOp_broadcast_5
(
ElementwiseDivOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_div"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
,
5
]).
astype
(
"float32"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
1
,
5
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'Out'
:
np
.
divide
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseDivOpFp16
(
ElementwiseDivOp
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
...
...
python/paddle/fluid/tests/unittests/test_elementwise_max_op.py
浏览文件 @
8672e153
...
...
@@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp):
}
class
TestElementwiseMaxOp_broadcast_4
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_max"
x
=
np
.
random
.
uniform
(
0.5
,
1
,
(
2
,
3
,
4
,
5
)).
astype
(
np
.
float32
)
sgn
=
np
.
random
.
choice
([
-
1
,
1
],
(
2
,
3
,
1
,
5
)).
astype
(
np
.
float32
)
y
=
x
+
sgn
*
\
np
.
random
.
uniform
(
1
,
2
,
(
2
,
3
,
1
,
5
)).
astype
(
np
.
float32
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
maximum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_elementwise_min_op.py
浏览文件 @
8672e153
...
...
@@ -55,7 +55,7 @@ class TestElementwiseMinOp_scalar(TestElementwiseOp):
self
.
outputs
=
{
'Out'
:
np
.
minimum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseM
ax
Op_Vector
(
TestElementwiseOp
):
class
TestElementwiseM
in
Op_Vector
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
x
=
np
.
random
.
random
((
32
,
)).
astype
(
"float32"
)
...
...
@@ -65,7 +65,7 @@ class TestElementwiseMaxOp_Vector(TestElementwiseOp):
self
.
outputs
=
{
'Out'
:
np
.
minimum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
class
TestElementwiseM
ax
Op_broadcast_0
(
TestElementwiseOp
):
class
TestElementwiseM
in
Op_broadcast_0
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
x
=
np
.
random
.
uniform
(
0.5
,
1
,
(
2
,
3
,
4
)).
astype
(
np
.
float32
)
...
...
@@ -81,7 +81,7 @@ class TestElementwiseMaxOp_broadcast_0(TestElementwiseOp):
}
class
TestElementwiseM
ax
Op_broadcast_1
(
TestElementwiseOp
):
class
TestElementwiseM
in
Op_broadcast_1
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
x
=
np
.
random
.
uniform
(
0.5
,
1
,
(
2
,
3
,
4
)).
astype
(
np
.
float32
)
...
...
@@ -97,7 +97,7 @@ class TestElementwiseMaxOp_broadcast_1(TestElementwiseOp):
}
class
TestElementwiseM
ax
Op_broadcast_2
(
TestElementwiseOp
):
class
TestElementwiseM
in
Op_broadcast_2
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
x
=
np
.
random
.
uniform
(
0.5
,
1
,
(
2
,
3
,
4
)).
astype
(
np
.
float32
)
...
...
@@ -112,7 +112,7 @@ class TestElementwiseMaxOp_broadcast_2(TestElementwiseOp):
}
class
TestElementwiseM
ax
Op_broadcast_3
(
TestElementwiseOp
):
class
TestElementwiseM
in
Op_broadcast_3
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
x
=
np
.
random
.
uniform
(
0.5
,
1
,
(
2
,
3
,
4
,
5
)).
astype
(
np
.
float32
)
...
...
@@ -128,5 +128,17 @@ class TestElementwiseMaxOp_broadcast_3(TestElementwiseOp):
}
class
TestElementwiseMinOp_broadcast_4
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_min"
x
=
np
.
random
.
uniform
(
0.5
,
1
,
(
2
,
3
,
4
,
5
)).
astype
(
np
.
float32
)
sgn
=
np
.
random
.
choice
([
-
1
,
1
],
(
2
,
3
,
1
,
5
)).
astype
(
np
.
float32
)
y
=
x
+
sgn
*
\
np
.
random
.
uniform
(
1
,
2
,
(
2
,
3
,
1
,
5
)).
astype
(
np
.
float32
)
self
.
inputs
=
{
'X'
:
x
,
'Y'
:
y
}
self
.
outputs
=
{
'Out'
:
np
.
minimum
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_elementwise_mul_op.py
浏览文件 @
8672e153
...
...
@@ -135,6 +135,26 @@ class TestElementwiseMulOp_broadcast_3(ElementwiseMulOp):
}
class
TestElementwiseMulOp_broadcast_4
(
ElementwiseMulOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_mul"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
4
).
astype
(
np
.
float64
),
'Y'
:
np
.
random
.
rand
(
2
,
1
,
4
).
astype
(
np
.
float64
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
inputs
[
'Y'
]}
class
TestElementwiseMulOp_broadcast_5
(
ElementwiseMulOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_mul"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
np
.
float64
),
'Y'
:
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
np
.
float64
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
*
self
.
inputs
[
'Y'
]}
class
TestElementwiseMulOpFp16
(
ElementwiseMulOp
):
def
init_dtype
(
self
):
self
.
dtype
=
np
.
float16
...
...
python/paddle/fluid/tests/unittests/test_elementwise_pow_op.py
浏览文件 @
8672e153
...
...
@@ -104,5 +104,15 @@ class TestElementwisePowOp_broadcast_3(TestElementwisePowOp):
}
class
TestElementwisePowOp_broadcast_4
(
TestElementwisePowOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_pow"
self
.
inputs
=
{
'X'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
4
,
5
]).
astype
(
"float32"
),
'Y'
:
np
.
random
.
uniform
(
0.1
,
1
,
[
2
,
3
,
1
,
5
]).
astype
(
"float32"
)
}
self
.
outputs
=
{
'Out'
:
np
.
power
(
self
.
inputs
[
'X'
],
self
.
inputs
[
'Y'
])}
if
__name__
==
'__main__'
:
unittest
.
main
()
python/paddle/fluid/tests/unittests/test_elementwise_sub_op.py
浏览文件 @
8672e153
...
...
@@ -117,5 +117,15 @@ class TestElementwiseSubOp_broadcast_3(TestElementwiseOp):
}
class
TestElementwiseSubOp_broadcast_4
(
TestElementwiseOp
):
def
setUp
(
self
):
self
.
op_type
=
"elementwise_sub"
self
.
inputs
=
{
'X'
:
np
.
random
.
rand
(
2
,
3
,
4
,
5
).
astype
(
np
.
float32
),
'Y'
:
np
.
random
.
rand
(
2
,
3
,
1
,
5
).
astype
(
np
.
float32
)
}
self
.
outputs
=
{
'Out'
:
self
.
inputs
[
'X'
]
-
self
.
inputs
[
'Y'
]}
if
__name__
==
'__main__'
:
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录