Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
c3ae0d40
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2299
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
c3ae0d40
编写于
5月 13, 2021
作者:
B
Baibaifan
提交者:
GitHub
5月 13, 2021
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
solved some npu bugs (#32793)
上级
3e47eee9
变更
12
显示空白变更内容
内联
并排
Showing
12 changed file
with
166 addition
and
31 deletion
+166
-31
paddle/fluid/framework/operator.cc
paddle/fluid/framework/operator.cc
+7
-1
paddle/fluid/framework/operator.h
paddle/fluid/framework/operator.h
+8
-0
paddle/fluid/framework/section_worker.cc
paddle/fluid/framework/section_worker.cc
+15
-1
paddle/fluid/operators/collective/recv_v2_op_npu.cc
paddle/fluid/operators/collective/recv_v2_op_npu.cc
+9
-6
paddle/fluid/operators/lookup_table_v2_op_npu.cc
paddle/fluid/operators/lookup_table_v2_op_npu.cc
+5
-0
python/paddle/distributed/collective.py
python/paddle/distributed/collective.py
+93
-6
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
...addle/distributed/fleet/meta_optimizers/sharding/utils.py
+7
-2
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
...e/distributed/fleet/meta_optimizers/sharding_optimizer.py
+10
-9
python/paddle/fluid/dataset.py
python/paddle/fluid/dataset.py
+3
-1
python/paddle/fluid/layers/nn.py
python/paddle/fluid/layers/nn.py
+2
-2
python/paddle/fluid/optimizer.py
python/paddle/fluid/optimizer.py
+6
-2
python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py
.../fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py
+1
-1
未找到文件。
paddle/fluid/framework/operator.cc
浏览文件 @
c3ae0d40
...
@@ -1228,6 +1228,8 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
...
@@ -1228,6 +1228,8 @@ void OperatorWithKernel::ChooseKernel(const RuntimeContext& ctx,
// will be executed and a warning will be given at the same time.
// will be executed and a warning will be given at the same time.
if
(
SupportGPU
())
{
if
(
SupportGPU
())
{
expected_kernel_key
.
place_
=
dev_ctx
->
GetPlace
();
expected_kernel_key
.
place_
=
dev_ctx
->
GetPlace
();
}
else
if
(
SupportNPU
())
{
expected_kernel_key
.
place_
=
dev_ctx
->
GetPlace
();
}
else
{
}
else
{
expected_kernel_key
.
place_
=
platform
::
CPUPlace
();
expected_kernel_key
.
place_
=
platform
::
CPUPlace
();
LOG_FIRST_N
(
WARNING
,
1
)
LOG_FIRST_N
(
WARNING
,
1
)
...
@@ -1299,8 +1301,12 @@ void OperatorWithKernel::TransferInplaceVarsBack(
...
@@ -1299,8 +1301,12 @@ void OperatorWithKernel::TransferInplaceVarsBack(
auto
*
transformed_tensor
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
*
transformed_tensor
=
GetLoDTensorOrSelectedRowsValueFromVar
(
*
var
);
auto
original_dims
=
original_tensor
->
dims
();
auto
original_dims
=
original_tensor
->
dims
();
original_tensor
->
ShareDataWith
(
*
transformed_tensor
);
original_tensor
->
ShareDataWith
(
*
transformed_tensor
);
// In order to solve the problem that the output latitude of NPU reshape
// operator is not changed when inplace.
if
(
type_
!=
"reshape2"
&&
type_
!=
"reshape2_grad"
)
{
original_tensor
->
Resize
(
original_dims
);
original_tensor
->
Resize
(
original_dims
);
}
}
}
}
}
void
OperatorWithKernel
::
HandleComplexGradToRealGrad
(
void
OperatorWithKernel
::
HandleComplexGradToRealGrad
(
...
...
paddle/fluid/framework/operator.h
浏览文件 @
c3ae0d40
...
@@ -154,6 +154,7 @@ class OperatorBase {
...
@@ -154,6 +154,7 @@ class OperatorBase {
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
std
::
string
DebugString
()
const
{
return
DebugStringEx
(
nullptr
);
}
virtual
bool
SupportGPU
()
const
{
return
false
;
}
virtual
bool
SupportGPU
()
const
{
return
false
;
}
virtual
bool
SupportNPU
()
const
{
return
false
;
}
const
std
::
string
&
Type
()
const
{
return
type_
;
}
const
std
::
string
&
Type
()
const
{
return
type_
;
}
...
@@ -490,6 +491,13 @@ class OperatorWithKernel : public OperatorBase {
...
@@ -490,6 +491,13 @@ class OperatorWithKernel : public OperatorBase {
return
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
);
return
platform
::
is_gpu_place
(
kern_pair
.
first
.
place_
);
});
});
}
}
bool
SupportNPU
()
const
override
{
auto
&
op_kernels
=
OperatorWithKernel
::
AllOpKernels
().
at
(
type_
);
return
std
::
any_of
(
op_kernels
.
begin
(),
op_kernels
.
end
(),
[](
OpKernelMap
::
const_reference
kern_pair
)
{
return
platform
::
is_npu_place
(
kern_pair
.
first
.
place_
);
});
}
bool
SupportsMKLDNN
(
proto
::
VarType
::
Type
data_type
)
const
;
bool
SupportsMKLDNN
(
proto
::
VarType
::
Type
data_type
)
const
;
bool
CanMKLDNNBeUsed
(
const
framework
::
ExecutionContext
&
ctx
,
bool
CanMKLDNNBeUsed
(
const
framework
::
ExecutionContext
&
ctx
,
...
...
paddle/fluid/framework/section_worker.cc
浏览文件 @
c3ae0d40
...
@@ -110,8 +110,22 @@ void SectionWorker::TrainFiles() {
...
@@ -110,8 +110,22 @@ void SectionWorker::TrainFiles() {
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
),
max_memory_size
));
BOOST_GET_CONST
(
platform
::
CUDAPlace
,
place_
),
max_memory_size
));
}
}
}
}
#endif
#elif defined(PADDLE_WITH_ASCEND_CL)
if
(
IsFastEagerDeletionModeEnabled
())
{
VLOG
(
4
)
<<
"Use unsafe fast gc for NPU."
;
gc
.
reset
(
new
NPUUnsafeFastGarbageCollector
(
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place_
),
max_memory_size
));
}
else
{
PADDLE_THROW
(
platform
::
errors
::
Unimplemented
(
"Please set FLAGS_fast_eager_deletion_mode=true to use "
"GarbageCollector on NPU."
));
// TODO(zhiqiu): fix bugs and enable NPUDefaultStreamGarbageCollector.
VLOG
(
4
)
<<
"Use default stream gc for NPU."
;
gc
.
reset
(
new
NPUDefaultStreamGarbageCollector
(
BOOST_GET_CONST
(
platform
::
NPUPlace
,
place_
),
max_memory_size
));
}
}
#endif
}
// max_memory_size >= 0
if
(
schedule_mode_
==
0
)
{
if
(
schedule_mode_
==
0
)
{
// F-then-B scheduler which runs Forward phase for all microbatches,
// F-then-B scheduler which runs Forward phase for all microbatches,
...
...
paddle/fluid/operators/collective/recv_v2_op_npu.cc
浏览文件 @
c3ae0d40
...
@@ -27,10 +27,11 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
...
@@ -27,10 +27,11 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
public:
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
#if defined(PADDLE_WITH_ASCEND_CL)
#if defined(PADDLE_WITH_ASCEND_CL)
auto
x
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
auto
out
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
void
*
ptr
=
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
x
->
data
<
T
>
()));
out
->
mutable_data
<
T
>
(
out
->
dims
(),
ctx
.
GetPlace
());
int
numel
=
x
->
numel
();
void
*
ptr
=
reinterpret_cast
<
void
*>
(
const_cast
<
T
*>
(
out
->
data
<
T
>
()));
HcclDataType
dtype
=
platform
::
ToHCCLDataType
(
x
->
type
());
int
numel
=
out
->
numel
();
HcclDataType
dtype
=
platform
::
ToHCCLDataType
(
out
->
type
());
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
int
ring_id
=
ctx
.
Attr
<
int
>
(
"ring_id"
);
auto
place
=
ctx
.
GetPlace
();
auto
place
=
ctx
.
GetPlace
();
...
@@ -54,8 +55,10 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
...
@@ -54,8 +55,10 @@ class CRecvOpASCENDKernel : public framework::OpKernel<T> {
int
root
=
peer
;
int
root
=
peer
;
VLOG
(
3
)
<<
"begin hccl recv, parameter is: "
VLOG
(
3
)
<<
"begin hccl recv, parameter is: "
<<
"root "
<<
root
<<
", comm: "
<<
comm
->
comm
()
<<
"ring_id:"
<<
ring_id
<<
", nranks:"
<<
nranks
<<
", stream: "
<<
stream
;
<<
", peer:"
<<
peer
<<
", numel:"
<<
numel
<<
", ptr:"
<<
ptr
<<
", dtype:"
<<
dtype
<<
", root:"
<<
root
<<
", comm: "
<<
comm
->
comm
()
<<
", stream: "
<<
stream
;
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
HcclBroadcast
(
PADDLE_ENFORCE_NPU_SUCCESS
(
platform
::
dynload
::
HcclBroadcast
(
ptr
,
numel
,
dtype
,
(
uint32_t
)
root
,
comm
->
comm
(),
stream
));
ptr
,
numel
,
dtype
,
(
uint32_t
)
root
,
comm
->
comm
(),
stream
));
...
...
paddle/fluid/operators/lookup_table_v2_op_npu.cc
浏览文件 @
c3ae0d40
...
@@ -29,6 +29,11 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
...
@@ -29,6 +29,11 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
auto
*
output_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
output_t
=
ctx
.
Output
<
framework
::
LoDTensor
>
(
"Out"
);
// float tensor
auto
*
table_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"W"
);
auto
*
table_t
=
ctx
.
Input
<
framework
::
LoDTensor
>
(
"W"
);
// It seems cann 20.1 accepts int64, but cann 20.2+ not.
PADDLE_ENFORCE_EQ
(
ids_t
->
type
(),
framework
::
proto
::
VarType
::
INT32
,
platform
::
errors
::
Unimplemented
(
"The index of LookupTableV2 should be int32."
));
auto
*
table_var
=
ctx
.
InputVar
(
"W"
);
auto
*
table_var
=
ctx
.
InputVar
(
"W"
);
PADDLE_ENFORCE_EQ
(
PADDLE_ENFORCE_EQ
(
table_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
table_var
->
IsType
<
framework
::
LoDTensor
>
(),
true
,
...
...
python/paddle/distributed/collective.py
浏览文件 @
c3ae0d40
...
@@ -25,6 +25,7 @@ from ..fluid.data_feeder import check_type
...
@@ -25,6 +25,7 @@ from ..fluid.data_feeder import check_type
from
..fluid.data_feeder
import
check_dtype
from
..fluid.data_feeder
import
check_dtype
from
..fluid.layers.tensor
import
fill_constant
from
..fluid.layers.tensor
import
fill_constant
from
..fluid.layers
import
utils
from
..fluid.layers
import
utils
from
..fluid.dygraph
import
layers
from
..fluid.dygraph.parallel
import
prepare_context
from
..fluid.dygraph.parallel
import
prepare_context
import
paddle
import
paddle
from
.fleet
import
fleet
from
.fleet
import
fleet
...
@@ -875,6 +876,84 @@ def _mp_allreduce(tensor,
...
@@ -875,6 +876,84 @@ def _mp_allreduce(tensor,
raise
NotImplementedError
(
"No support _mp_allreduce in dygraph mode."
)
raise
NotImplementedError
(
"No support _mp_allreduce in dygraph mode."
)
class
_Linear
(
layers
.
Layer
):
"""
Linear
"""
def
__init__
(
self
,
in_features
,
out_features
,
weight_attr
=
None
,
bias_attr
=
None
,
name
=
None
):
super
(
_Linear
,
self
).
__init__
()
self
.
_dtype
=
self
.
_helper
.
get_default_dtype
()
self
.
_weight_attr
=
weight_attr
self
.
_bias_attr
=
bias_attr
self
.
weight
=
self
.
create_parameter
(
shape
=
[
in_features
,
out_features
],
attr
=
self
.
_weight_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
False
)
self
.
bias
=
self
.
create_parameter
(
shape
=
[
out_features
],
attr
=
self
.
_bias_attr
,
dtype
=
self
.
_dtype
,
is_bias
=
True
)
self
.
name
=
name
def
forward
(
self
,
input
):
out
=
_linear
(
x
=
input
,
weight
=
self
.
weight
,
bias
=
self
.
bias
,
name
=
self
.
name
)
return
out
def
extra_repr
(
self
):
name_str
=
', name={}'
.
format
(
self
.
name
)
if
self
.
name
else
''
return
'in_features={}, out_features={}, dtype={}{}'
.
format
(
self
.
weight
.
shape
[
0
],
self
.
weight
.
shape
[
1
],
self
.
_dtype
,
name_str
)
def
_linear
(
x
,
weight
,
bias
=
None
,
name
=
None
):
"""
Fuction Linear
"""
if
in_dygraph_mode
():
pre_bias
=
_varbase_creator
(
dtype
=
x
.
dtype
)
core
.
ops
.
matmul
(
x
,
weight
,
pre_bias
,
'transpose_X'
,
False
,
'transpose_Y'
,
False
,
"alpha"
,
1
)
return
dygraph_utils
.
_append_bias_in_dygraph
(
pre_bias
,
bias
,
axis
=
len
(
x
.
shape
)
-
1
)
else
:
helper
=
LayerHelper
(
'linear'
,
**
locals
())
dtype
=
x
.
dtype
check_variable_and_dtype
(
x
,
'x'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
check_dtype
(
dtype
,
'dtype'
,
[
'float16'
,
'float32'
,
'float64'
],
'linear'
)
inputs
=
{
'X'
:
[
x
],
'Y'
:
[
weight
]}
attrs
=
{
'transpose_X'
:
False
,
'transpose_Y'
:
False
,
'alpha'
:
1
,
}
tmp
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
type
=
'matmul_v2'
,
inputs
=
inputs
,
outputs
=
{
'Out'
:
tmp
},
attrs
=
attrs
)
if
bias
is
not
None
:
res
=
helper
.
create_variable_for_type_inference
(
dtype
)
helper
.
append_op
(
type
=
'elementwise_add'
,
inputs
=
{
'X'
:
[
tmp
],
'Y'
:
[
bias
]},
outputs
=
{
'Out'
:
[
res
]},
attrs
=
{
'axis'
:
len
(
x
.
shape
)
-
1
})
else
:
res
=
tmp
return
res
def
_parallel_linear
(
x
,
def
_parallel_linear
(
x
,
num_rows
,
num_rows
,
num_cols
,
num_cols
,
...
@@ -900,6 +979,14 @@ def _parallel_linear(x,
...
@@ -900,6 +979,14 @@ def _parallel_linear(x,
else
:
else
:
x
=
_c_identity
(
x
,
group
=
group
)
x
=
_c_identity
(
x
,
group
=
group
)
if
core
.
is_compiled_with_npu
():
linear
=
_Linear
(
num_rows
,
num_cols
,
weight_attr
=
param_attr
,
bias_attr
=
bias_attr
,
name
=
name
)
else
:
linear
=
paddle
.
nn
.
Linear
(
linear
=
paddle
.
nn
.
Linear
(
num_rows
,
num_rows
,
num_cols
,
num_cols
,
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding/utils.py
浏览文件 @
c3ae0d40
...
@@ -402,13 +402,18 @@ def get_grad_device(grad_name, shard):
...
@@ -402,13 +402,18 @@ def get_grad_device(grad_name, shard):
return
shard
.
global_param2device
[
base_name
]
return
shard
.
global_param2device
[
base_name
]
def
get_first_check_finite_and_unscale_op_idx
(
block
):
def
get_first_check_finite_and_unscale_op_idx
(
block
,
raise_error
=
True
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
for
idx
,
op
in
enumerate
(
block
.
ops
):
if
op
.
type
==
"check_finite_and_unscale"
:
if
op
.
type
==
"check_finite_and_unscale"
:
return
idx
return
idx
raise
ValueError
(
"check_finite_and_unscale does not exist in block"
)
if
raise_error
:
raise
ValueError
(
"amp is turned on but check_finite_and_unscale op does not exist in main block"
)
return
-
1
def
insert_broadcast_ops
(
block
,
insert_idx
,
ring_id
,
broadcast2root
):
def
insert_broadcast_ops
(
block
,
insert_idx
,
ring_id
,
broadcast2root
):
...
...
python/paddle/distributed/fleet/meta_optimizers/sharding_optimizer.py
浏览文件 @
c3ae0d40
...
@@ -298,7 +298,7 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -298,7 +298,7 @@ class ShardingOptimizer(MetaOptimizerBase):
print
(
"persistable FP32 grad: "
)
print
(
"persistable FP32 grad: "
)
print
(
accumulated_grad_names
)
print
(
accumulated_grad_names
)
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
main_block
)
main_block
,
raise_error
=
self
.
user_defined_strategy
.
amp
)
insert_reduce_ops
(
insert_reduce_ops
(
main_block
,
main_block
,
first_optimize_op_index
,
first_optimize_op_index
,
...
@@ -309,7 +309,8 @@ class ShardingOptimizer(MetaOptimizerBase):
...
@@ -309,7 +309,8 @@ class ShardingOptimizer(MetaOptimizerBase):
use_calc_stream
=
True
)
use_calc_stream
=
True
)
if
self
.
hybrid_dp
and
self
.
hybrid_dp_mode
==
"pp_hybrid_dp"
:
if
self
.
hybrid_dp
and
self
.
hybrid_dp_mode
==
"pp_hybrid_dp"
:
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
first_optimize_op_index
=
get_first_check_finite_and_unscale_op_idx
(
main_block
)
main_block
,
raise_error
=
self
.
user_defined_strategy
.
amp
)
if
first_optimize_op_index
>=
0
:
insert_allreduce_ops
(
insert_allreduce_ops
(
main_block
,
main_block
,
first_optimize_op_index
,
first_optimize_op_index
,
...
...
python/paddle/fluid/dataset.py
浏览文件 @
c3ae0d40
...
@@ -252,9 +252,11 @@ class DatasetBase(object):
...
@@ -252,9 +252,11 @@ class DatasetBase(object):
slot_var
.
type
=
"float"
slot_var
.
type
=
"float"
elif
var
.
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
elif
var
.
dtype
==
core
.
VarDesc
.
VarType
.
INT64
:
slot_var
.
type
=
"uint64"
slot_var
.
type
=
"uint64"
elif
var
.
dtype
==
core
.
VarDesc
.
VarType
.
INT32
:
slot_var
.
type
=
"uint32"
else
:
else
:
raise
ValueError
(
raise
ValueError
(
"Currently, fluid.dataset only supports dtype=float32 and dtype=int64"
"Currently, fluid.dataset only supports dtype=float32
, dtype=int32
and dtype=int64"
)
)
def
set_hdfs_config
(
self
,
fs_name
,
fs_ugi
):
def
set_hdfs_config
(
self
,
fs_name
,
fs_ugi
):
...
...
python/paddle/fluid/layers/nn.py
浏览文件 @
c3ae0d40
...
@@ -14772,7 +14772,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
...
@@ -14772,7 +14772,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
the size of the last shard will be less than the calculated `shard_size`
the size of the last shard will be less than the calculated `shard_size`
Args:
Args:
input (Tensor): Input indices with data type int64. It's last dimension must be 1.
input (Tensor): Input indices with data type int64
or int32
. It's last dimension must be 1.
index_num (int): An integer defining the range of the index.
index_num (int): An integer defining the range of the index.
nshards (int): The number of shards.
nshards (int): The number of shards.
shard_id (int): The index of the current shard.
shard_id (int): The index of the current shard.
...
@@ -14793,7 +14793,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
...
@@ -14793,7 +14793,7 @@ def shard_index(input, index_num, nshards, shard_id, ignore_value=-1):
print(shard_label)
print(shard_label)
# [[-1], [1]]
# [[-1], [1]]
"""
"""
check_variable_and_dtype(input, 'input', ['int64'], 'shard_index')
check_variable_and_dtype(input, 'input', ['int64'
, 'int32'
], 'shard_index')
op_type = 'shard_index'
op_type = 'shard_index'
helper = LayerHelper(op_type, **locals())
helper = LayerHelper(op_type, **locals())
if shard_id < 0 or shard_id >= nshards:
if shard_id < 0 or shard_id >= nshards:
...
...
python/paddle/fluid/optimizer.py
浏览文件 @
c3ae0d40
...
@@ -4200,6 +4200,8 @@ class PipelineOptimizer(object):
...
@@ -4200,6 +4200,8 @@ class PipelineOptimizer(object):
op
.
type
==
'elementwise_div'
):
op
.
type
==
'elementwise_div'
):
device
=
"gpu:all"
device
=
"gpu:all"
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
op
.
_set_attr
(
self
.
_op_device_key
,
device
)
elif
op
.
type
==
"alloc_float_status"
:
op
.
_set_attr
(
self
.
_op_device_key
,
"gpu:all"
)
else
:
else
:
other_known_ops
=
[
other_known_ops
=
[
'update_loss_scaling'
,
'update_loss_scaling'
,
...
@@ -4207,6 +4209,7 @@ class PipelineOptimizer(object):
...
@@ -4207,6 +4209,7 @@ class PipelineOptimizer(object):
'concat'
,
'concat'
,
'sum'
,
'sum'
,
'check_finite_and_unscale'
,
'check_finite_and_unscale'
,
'alloc_float_status'
,
]
]
assert
op
.
type
in
other_known_ops
,
"For other ops without "
\
assert
op
.
type
in
other_known_ops
,
"For other ops without "
\
"op_device set, they must be one of {}, but it "
\
"op_device set, they must be one of {}, but it "
\
...
@@ -4272,7 +4275,8 @@ class PipelineOptimizer(object):
...
@@ -4272,7 +4275,8 @@ class PipelineOptimizer(object):
"{} has not been set."
.
format
(
op
.
type
))
"{} has not been set."
.
format
(
op
.
type
))
if
device
==
"gpu:all"
:
continue
if
device
==
"gpu:all"
:
continue
dev_type
=
device
.
split
(
':'
)[
0
]
dev_type
=
device
.
split
(
':'
)[
0
]
assert
dev_type
==
"gpu"
,
(
"Now only gpu devices are supported "
assert
dev_type
==
"gpu"
or
dev_type
==
'npu'
,
(
"Now only gpu and npu devices are supported "
"for pipeline parallelism."
)
"for pipeline parallelism."
)
if
not
device
in
device_list
:
if
not
device
in
device_list
:
device_list
.
append
(
device
)
device_list
.
append
(
device
)
...
...
python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py
浏览文件 @
c3ae0d40
...
@@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest):
...
@@ -41,7 +41,7 @@ class TestLookupTableV2(OpTest):
vocab
=
10
vocab
=
10
dim
=
20
dim
=
20
w
=
np
.
ones
([
vocab
,
dim
]).
astype
(
self
.
dtype
)
w
=
np
.
ones
([
vocab
,
dim
]).
astype
(
self
.
dtype
)
x
=
np
.
random
.
randint
(
0
,
vocab
,
size
=
(
bsz
,
seqlen
)).
astype
(
np
.
int
64
)
x
=
np
.
random
.
randint
(
0
,
vocab
,
size
=
(
bsz
,
seqlen
)).
astype
(
np
.
int
32
)
out
=
np
.
ones
([
bsz
,
seqlen
,
dim
]).
astype
(
self
.
dtype
)
out
=
np
.
ones
([
bsz
,
seqlen
,
dim
]).
astype
(
self
.
dtype
)
self
.
inputs
=
{
self
.
inputs
=
{
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录