Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2540b023
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2540b023
编写于
6月 17, 2022
作者:
F
fuyou765
提交者:
GitHub
6月 17, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU]add mlu kernel for where op (#43441)
上级
539a9e60
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
474 addition
and
12 deletion
+474
-12
paddle/fluid/operators/mlu/mlu_baseop.cc
paddle/fluid/operators/mlu/mlu_baseop.cc
+17
-7
paddle/fluid/operators/mlu/mlu_baseop.h
paddle/fluid/operators/mlu/mlu_baseop.h
+6
-5
paddle/fluid/operators/where_op_mlu.cc
paddle/fluid/operators/where_op_mlu.cc
+51
-0
python/paddle/fluid/tests/unittests/mlu/test_where_op_mlu.py
python/paddle/fluid/tests/unittests/mlu/test_where_op_mlu.py
+400
-0
未找到文件。
paddle/fluid/operators/mlu/mlu_baseop.cc
浏览文件 @
2540b023
...
...
@@ -1160,15 +1160,25 @@ MLUCnnlTrigonDesc::~MLUCnnlTrigonDesc() {
}
/* static */
void
MLUCnnl
::
Select
(
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
then_desc
,
const
void
*
p_then
,
const
cnnlTensorDescriptor_t
else_desc
,
const
void
*
p_else
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output
,
const
bool
*
condition
,
const
int
condition_size
)
{
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
condition_desc
,
const
void
*
condition_ptr
,
const
cnnlTensorDescriptor_t
then_desc
,
const
void
*
then_ptr
,
const
cnnlTensorDescriptor_t
else_desc
,
const
void
*
else_ptr
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output_ptr
)
{
cnnlHandle_t
handle
=
GetHandleFromCTX
(
ctx
);
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlSelect
(
handle
,
then_desc
,
p_then
,
else_desc
,
p_else
,
output_desc
,
output
,
condition
,
condition_size
));
size_t
workspace_size
=
0
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlGetSelectV2WorkspaceSize
(
handle
,
condition_desc
,
then_desc
,
else_desc
,
&
workspace_size
));
auto
&
dev_ctx
=
GetDevCtxFromCTX
(
ctx
);
Tensor
workspace
=
ctx
.
AllocateTmpTensor
<
int8_t
,
MLUDeviceContext
>
(
{
static_cast
<
int64_t
>
(
workspace_size
)},
dev_ctx
);
void
*
workspace_ptr
=
workspace
.
mutable_data
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlSelectV2
(
handle
,
condition_desc
,
condition_ptr
,
then_desc
,
then_ptr
,
else_desc
,
else_ptr
,
workspace_ptr
,
workspace_size
,
output_desc
,
output_ptr
));
}
/*static */
void
MLUCnnl
::
GatherNd
(
const
ExecutionContext
&
ctx
,
...
...
paddle/fluid/operators/mlu/mlu_baseop.h
浏览文件 @
2540b023
...
...
@@ -684,11 +684,12 @@ class MLUCnnl {
const
void
*
input2
,
const
cnnlTensorDescriptor_t
ouput_desc
,
void
*
output
);
static
void
Select
(
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
then_desc
,
const
void
*
p_then
,
const
cnnlTensorDescriptor_t
else_desc
,
const
void
*
p_else
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output
,
const
bool
*
condition
,
const
int
condition_size
);
static
void
Select
(
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
condition_desc
,
const
void
*
condition_ptr
,
const
cnnlTensorDescriptor_t
then_desc
,
const
void
*
then_ptr
,
const
cnnlTensorDescriptor_t
else_desc
,
const
void
*
else_ptr
,
const
cnnlTensorDescriptor_t
output_desc
,
void
*
output_ptr
);
static
void
AssignAdd
(
const
ExecutionContext
&
ctx
,
const
void
*
alpha
,
const
void
*
beta
,
...
...
paddle/fluid/operators/where_op_mlu.cc
0 → 100644
浏览文件 @
2540b023
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
WhereMLUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
auto
*
condition
=
context
.
Input
<
framework
::
Tensor
>
(
"Condition"
);
auto
*
X
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
Y
=
context
.
Input
<
framework
::
Tensor
>
(
"Y"
);
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
place
=
context
.
GetPlace
();
out
->
mutable_data
<
T
>
(
place
);
MLUCnnlTensorDesc
x_desc
(
*
X
);
MLUCnnlTensorDesc
y_desc
(
*
Y
);
MLUCnnlTensorDesc
condition_desc
(
*
condition
);
MLUCnnlTensorDesc
out_desc
(
*
out
);
MLUCnnl
::
Select
(
context
,
condition_desc
.
get
(),
GetBasePtr
(
condition
),
x_desc
.
get
(),
GetBasePtr
(
X
),
y_desc
.
get
(),
GetBasePtr
(
Y
),
out_desc
.
get
(),
GetBasePtr
(
out
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_MLU_KERNEL
(
where
,
ops
::
WhereMLUKernel
<
paddle
::
platform
::
MLUDeviceContext
,
float
>
,
ops
::
WhereMLUKernel
<
paddle
::
platform
::
MLUDeviceContext
,
int
>
);
#endif
python/paddle/fluid/tests/unittests/mlu/test_where_op_mlu.py
0 → 100644
浏览文件 @
2540b023
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from
__future__
import
print_function
import
sys
sys
.
path
.
append
(
".."
)
import
unittest
import
numpy
as
np
import
paddle
import
paddle.fluid
as
fluid
import
paddle.fluid.layers
as
layers
import
paddle.fluid.core
as
core
from
op_test
import
OpTest
from
paddle.fluid
import
compiler
,
Program
,
program_guard
from
paddle.fluid.op
import
Operator
from
paddle.fluid.backward
import
append_backward
from
paddle.fluid.framework
import
_test_eager_guard
class
TestWhereOp
(
OpTest
):
def
setUp
(
self
):
self
.
op_type
=
'where'
self
.
place
=
paddle
.
device
.
MLUPlace
(
0
)
self
.
__class__
.
use_mlu
=
True
self
.
__class__
.
no_need_check_grad
=
True
self
.
python_api
=
paddle
.
where
self
.
init_config
()
self
.
inputs
=
{
'Condition'
:
self
.
cond
,
'X'
:
self
.
x
,
'Y'
:
self
.
y
}
self
.
outputs
=
{
'Out'
:
np
.
where
(
self
.
cond
,
self
.
x
,
self
.
y
)}
def
test_check_output
(
self
):
self
.
check_output_with_place
(
self
.
place
,
check_eager
=
False
)
def
test_check_grad
(
self
):
self
.
check_grad
([
'X'
,
'Y'
],
'Out'
,
check_eager
=
False
)
def
init_config
(
self
):
self
.
x
=
np
.
random
.
uniform
((
-
3
),
5
,
100
).
astype
(
'float32'
)
self
.
y
=
np
.
random
.
uniform
((
-
3
),
5
,
100
).
astype
(
'float32'
)
self
.
cond
=
np
.
zeros
(
100
).
astype
(
'bool'
)
class
TestWhereOp2
(
TestWhereOp
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
uniform
((
-
5
),
5
,
(
60
,
2
)).
astype
(
'float32'
)
self
.
y
=
np
.
random
.
uniform
((
-
5
),
5
,
(
60
,
2
)).
astype
(
'float32'
)
self
.
cond
=
np
.
ones
((
60
,
2
)).
astype
(
'bool'
)
class
TestWhereOp3
(
TestWhereOp
):
def
init_config
(
self
):
self
.
x
=
np
.
random
.
uniform
((
-
3
),
5
,
(
20
,
2
,
4
)).
astype
(
'float32'
)
self
.
y
=
np
.
random
.
uniform
((
-
3
),
5
,
(
20
,
2
,
4
)).
astype
(
'float32'
)
self
.
cond
=
np
.
array
(
np
.
random
.
randint
(
2
,
size
=
(
20
,
2
,
4
)),
dtype
=
bool
)
class
TestWhereAPI
(
unittest
.
TestCase
):
def
setUp
(
self
):
self
.
place
=
paddle
.
device
.
MLUPlace
(
0
)
self
.
__class__
.
use_mlu
=
True
self
.
__class__
.
no_need_check_grad
=
True
self
.
init_data
()
def
init_data
(
self
):
self
.
shape
=
[
10
,
15
]
self
.
cond
=
np
.
array
(
np
.
random
.
randint
(
2
,
size
=
self
.
shape
),
dtype
=
bool
)
self
.
x
=
np
.
random
.
uniform
((
-
2
),
3
,
self
.
shape
).
astype
(
np
.
float32
)
self
.
y
=
np
.
random
.
uniform
((
-
2
),
3
,
self
.
shape
).
astype
(
np
.
float32
)
self
.
out
=
np
.
where
(
self
.
cond
,
self
.
x
,
self
.
y
)
def
ref_x_backward
(
self
,
dout
):
return
np
.
where
((
self
.
cond
==
True
),
dout
,
0
)
def
ref_y_backward
(
self
,
dout
):
return
np
.
where
((
self
.
cond
==
False
),
dout
,
0
)
def
test_api
(
self
,
use_mlu
=
False
):
for
x_stop_gradient
in
[
False
,
True
]:
for
y_stop_gradient
in
[
False
,
True
]:
with
fluid
.
program_guard
(
Program
(),
Program
()):
cond
=
fluid
.
layers
.
data
(
name
=
'cond'
,
shape
=
self
.
shape
,
dtype
=
'bool'
)
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
self
.
shape
,
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
self
.
shape
,
dtype
=
'float32'
)
x
.
stop_gradient
=
x_stop_gradient
y
.
stop_gradient
=
y_stop_gradient
result
=
paddle
.
where
(
cond
,
x
,
y
)
append_backward
(
layers
.
mean
(
result
))
for
use_mlu
in
[
False
,
True
]:
place
=
(
paddle
.
device
.
MLUPlace
(
0
)
if
use_mlu
else
fluid
.
CPUPlace
())
exe
=
fluid
.
Executor
(
place
)
fetch_list
=
[
result
,
result
.
grad_name
]
if
(
x_stop_gradient
is
False
):
fetch_list
.
append
(
x
.
grad_name
)
if
(
y_stop_gradient
is
False
):
fetch_list
.
append
(
y
.
grad_name
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'cond'
:
self
.
cond
,
'x'
:
self
.
x
,
'y'
:
self
.
y
},
fetch_list
=
fetch_list
)
assert
np
.
array_equal
(
out
[
0
],
self
.
out
)
if
(
x_stop_gradient
is
False
):
assert
np
.
array_equal
(
out
[
2
],
self
.
ref_x_backward
(
out
[
1
]))
if
(
y
.
stop_gradient
is
False
):
assert
np
.
array_equal
(
out
[
3
],
self
.
ref_y_backward
(
out
[
1
]))
elif
(
y
.
stop_gradient
is
False
):
assert
np
.
array_equal
(
out
[
2
],
self
.
ref_y_backward
(
out
[
1
]))
def
test_api_broadcast
(
self
,
use_mlu
=
False
):
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
4
,
1
],
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
4
,
2
],
dtype
=
'float32'
)
x_i
=
np
.
array
([[
0.9383
,
0.1983
,
3.2
,
1.2
]]).
astype
(
'float32'
)
y_i
=
np
.
array
([[
1.0
,
1.0
,
1.0
,
1.0
],
[
1.0
,
1.0
,
1.0
,
1.0
]]).
astype
(
'float32'
)
result
=
paddle
.
where
((
x
>
1
),
x
=
x
,
y
=
y
)
for
use_mlu
in
[
False
,
True
]:
place
=
(
paddle
.
device
.
MLUPlace
(
0
)
if
use_mlu
else
fluid
.
CPUPlace
())
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'x'
:
x_i
,
'y'
:
y_i
},
fetch_list
=
[
result
])
assert
np
.
array_equal
(
out
[
0
],
np
.
where
((
x_i
>
1
),
x_i
,
y_i
))
def
test_scalar
(
self
):
paddle
.
enable_static
()
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
):
cond_shape
=
[
2
,
4
]
cond
=
fluid
.
layers
.
data
(
name
=
'cond'
,
shape
=
cond_shape
,
dtype
=
'bool'
)
x_data
=
1.0
y_data
=
2.0
cond_data
=
np
.
array
([
False
,
False
,
True
,
True
]).
astype
(
'bool'
)
result
=
paddle
.
where
(
condition
=
cond
,
x
=
x_data
,
y
=
y_data
)
for
use_mlu
in
[
False
,
True
]:
place
=
(
paddle
.
device
.
MLUPlace
(
0
)
if
use_mlu
else
fluid
.
CPUPlace
())
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'cond'
:
cond_data
},
fetch_list
=
[
result
])
expect
=
np
.
where
(
cond_data
,
x_data
,
y_data
)
assert
np
.
array_equal
(
out
[
0
],
expect
)
def
__test_where_with_broadcast_static
(
self
,
cond_shape
,
x_shape
,
y_shape
):
paddle
.
enable_static
()
main_program
=
Program
()
with
fluid
.
program_guard
(
main_program
):
cond
=
fluid
.
layers
.
data
(
name
=
'cond'
,
shape
=
cond_shape
,
dtype
=
'bool'
)
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
x_shape
,
dtype
=
'float32'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
y_shape
,
dtype
=
'float32'
)
cond_data_tmp
=
np
.
random
.
random
(
size
=
cond_shape
).
astype
(
'float32'
)
cond_data
=
(
cond_data_tmp
<
0.3
)
x_data
=
np
.
random
.
random
(
size
=
x_shape
).
astype
(
'float32'
)
y_data
=
np
.
random
.
random
(
size
=
y_shape
).
astype
(
'float32'
)
result
=
paddle
.
where
(
condition
=
cond
,
x
=
x
,
y
=
y
)
for
use_mlu
in
[
False
,
True
]:
place
=
(
paddle
.
device
.
MLUPlace
(
0
)
if
use_mlu
else
fluid
.
CPUPlace
())
exe
=
fluid
.
Executor
(
place
)
out
=
exe
.
run
(
fluid
.
default_main_program
(),
feed
=
{
'cond'
:
cond_data
,
'x'
:
x_data
,
'y'
:
y_data
},
fetch_list
=
[
result
])
expect
=
np
.
where
(
cond_data
,
x_data
,
y_data
)
assert
np
.
array_equal
(
out
[
0
],
expect
)
def
test_static_api_broadcast_1
(
self
):
cond_shape
=
[
2
,
4
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_2
(
self
):
cond_shape
=
[
2
,
1
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_3
(
self
):
cond_shape
=
[
2
,
2
,
1
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_4
(
self
):
cond_shape
=
[
2
,
1
,
4
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_5
(
self
):
cond_shape
=
[
3
,
2
,
2
,
4
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_6
(
self
):
cond_shape
=
[
2
,
2
,
4
]
a_shape
=
[
2
,
2
,
1
]
b_shape
=
[
2
,
2
,
1
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_7
(
self
):
cond_shape
=
[
2
,
2
,
4
]
a_shape
=
[
2
,
1
,
4
]
b_shape
=
[
2
,
1
,
4
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
def
test_static_api_broadcast_8
(
self
):
cond_shape
=
[
3
,
2
,
2
,
4
]
a_shape
=
[
2
,
2
,
1
]
b_shape
=
[
2
,
2
,
1
]
self
.
__test_where_with_broadcast_static
(
cond_shape
,
a_shape
,
b_shape
)
class
TestWhereDygraphAPI
(
unittest
.
TestCase
):
def
test_api
(
self
):
with
fluid
.
dygraph
.
guard
():
x_i
=
np
.
array
([
0.9383
,
0.1983
,
3.2
,
1.2
]).
astype
(
'float32'
)
y_i
=
np
.
array
([
1.0
,
1.0
,
1.0
,
1.0
]).
astype
(
'float32'
)
cond_i
=
np
.
array
([
False
,
False
,
True
,
True
]).
astype
(
'bool'
)
x
=
fluid
.
dygraph
.
to_variable
(
x_i
)
y
=
fluid
.
dygraph
.
to_variable
(
y_i
)
cond
=
fluid
.
dygraph
.
to_variable
(
cond_i
)
out
=
paddle
.
where
(
cond
,
x
,
y
)
assert
np
.
array_equal
(
out
.
numpy
(),
np
.
where
(
cond_i
,
x_i
,
y_i
))
def
test_scalar
(
self
):
with
fluid
.
dygraph
.
guard
():
cond_i
=
np
.
array
([
False
,
False
,
True
,
True
]).
astype
(
'bool'
)
x
=
1.0
y
=
2.0
cond
=
fluid
.
dygraph
.
to_variable
(
cond_i
)
out
=
paddle
.
where
(
cond
,
x
,
y
)
assert
np
.
array_equal
(
out
.
numpy
(),
np
.
where
(
cond_i
,
x
,
y
))
def
__test_where_with_broadcast_dygraph
(
self
,
cond_shape
,
a_shape
,
b_shape
):
with
fluid
.
dygraph
.
guard
():
cond_tmp
=
paddle
.
rand
(
cond_shape
)
cond
=
(
cond_tmp
<
0.3
)
a
=
paddle
.
rand
(
a_shape
)
b
=
paddle
.
rand
(
b_shape
)
result
=
paddle
.
where
(
cond
,
a
,
b
)
result
=
result
.
numpy
()
expect
=
np
.
where
(
cond
,
a
,
b
)
self
.
assertTrue
(
np
.
array_equal
(
expect
,
result
))
def
test_dygraph_api_broadcast_1
(
self
):
cond_shape
=
[
2
,
4
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_2
(
self
):
cond_shape
=
[
2
,
1
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_3
(
self
):
cond_shape
=
[
2
,
2
,
1
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_4
(
self
):
cond_shape
=
[
2
,
1
,
4
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_5
(
self
):
cond_shape
=
[
3
,
2
,
2
,
4
]
a_shape
=
[
2
,
2
,
4
]
b_shape
=
[
2
,
2
,
4
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_6
(
self
):
cond_shape
=
[
2
,
2
,
4
]
a_shape
=
[
2
,
2
,
1
]
b_shape
=
[
2
,
2
,
1
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_7
(
self
):
cond_shape
=
[
2
,
2
,
4
]
a_shape
=
[
2
,
1
,
4
]
b_shape
=
[
2
,
1
,
4
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_dygraph_api_broadcast_8
(
self
):
cond_shape
=
[
3
,
2
,
2
,
4
]
a_shape
=
[
2
,
2
,
1
]
b_shape
=
[
2
,
2
,
1
]
self
.
__test_where_with_broadcast_dygraph
(
cond_shape
,
a_shape
,
b_shape
)
def
test_where_condition
(
self
):
data
=
np
.
array
([[
True
,
False
],
[
False
,
True
]])
with
program_guard
(
Program
(),
Program
()):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[(
-
1
),
2
])
y
=
paddle
.
where
(
x
)
self
.
assertEqual
(
type
(
y
),
tuple
)
self
.
assertEqual
(
len
(
y
),
2
)
z
=
fluid
.
layers
.
concat
(
list
(
y
),
axis
=
1
)
exe
=
fluid
.
Executor
(
paddle
.
device
.
MLUPlace
(
0
))
(
res
,
)
=
exe
.
run
(
feed
=
{
'x'
:
data
},
fetch_list
=
[
z
.
name
],
return_numpy
=
False
)
expect_out
=
np
.
array
([[
0
,
0
],
[
1
,
1
]])
self
.
assertTrue
(
np
.
allclose
(
expect_out
,
np
.
array
(
res
)))
data
=
np
.
array
([
True
,
True
,
False
])
with
program_guard
(
Program
(),
Program
()):
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[(
-
1
)])
y
=
paddle
.
where
(
x
)
self
.
assertEqual
(
type
(
y
),
tuple
)
self
.
assertEqual
(
len
(
y
),
1
)
z
=
fluid
.
layers
.
concat
(
list
(
y
),
axis
=
1
)
exe
=
fluid
.
Executor
(
paddle
.
device
.
MLUPlace
(
0
))
(
res
,
)
=
exe
.
run
(
feed
=
{
'x'
:
data
},
fetch_list
=
[
z
.
name
],
return_numpy
=
False
)
expect_out
=
np
.
array
([[
0
],
[
1
]])
self
.
assertTrue
(
np
.
allclose
(
expect_out
,
np
.
array
(
res
)))
class
TestWhereOpError
(
unittest
.
TestCase
):
def
test_errors
(
self
):
with
program_guard
(
Program
(),
Program
()):
x_i
=
np
.
array
([
0.9383
,
0.1983
,
3.2
,
1.2
]).
astype
(
'float32'
)
y_i
=
np
.
array
([
1.0
,
1.0
,
1.0
,
1.0
]).
astype
(
'float32'
)
cond_i
=
np
.
array
([
False
,
False
,
True
,
True
]).
astype
(
'bool'
)
def
test_Variable
():
paddle
.
where
(
cond_i
,
x_i
,
y_i
)
self
.
assertRaises
(
TypeError
,
test_Variable
)
def
test_type
():
x
=
fluid
.
layers
.
data
(
name
=
'x'
,
shape
=
[
4
],
dtype
=
'bool'
)
y
=
fluid
.
layers
.
data
(
name
=
'y'
,
shape
=
[
4
],
dtype
=
'float16'
)
cond
=
fluid
.
layers
.
data
(
name
=
'cond'
,
shape
=
[
4
],
dtype
=
'int32'
)
paddle
.
where
(
cond
,
x
,
y
)
self
.
assertRaises
(
TypeError
,
test_type
)
def
test_value_error
(
self
):
with
fluid
.
dygraph
.
guard
():
cond_shape
=
[
2
,
2
,
4
]
cond_tmp
=
paddle
.
rand
(
cond_shape
)
cond
=
(
cond_tmp
<
0.3
)
a
=
paddle
.
rand
(
cond_shape
)
self
.
assertRaises
(
ValueError
,
paddle
.
where
,
cond
,
a
)
if
__name__
==
"__main__"
:
paddle
.
enable_static
()
unittest
.
main
()
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录