Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
5061d3db
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
5061d3db
编写于
7月 12, 2022
作者:
F
fwenguang
提交者:
GitHub
7月 12, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[MLU] fix sync copy bugs (#44127)
上级
d55ee95f
变更
4
隐藏空白更改
内联
并排
Showing
4 changed file
with
52 addition
and
32 deletion
+52
-32
paddle/fluid/operators/mlu/mlu_baseop.cc
paddle/fluid/operators/mlu/mlu_baseop.cc
+3
-12
paddle/fluid/operators/mlu/mlu_baseop.h
paddle/fluid/operators/mlu/mlu_baseop.h
+2
-2
paddle/fluid/operators/randperm_op_mlu.cc
paddle/fluid/operators/randperm_op_mlu.cc
+25
-2
paddle/fluid/operators/where_index_op_mlu.cc
paddle/fluid/operators/where_index_op_mlu.cc
+22
-16
未找到文件。
paddle/fluid/operators/mlu/mlu_baseop.cc
浏览文件 @
5061d3db
...
...
@@ -4274,21 +4274,12 @@ MLURNNDesc::~MLURNNDesc() {
/* static */
void
MLUCnnl
::
NumTrue
(
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
x_desc
,
const
void
*
x
,
Tensor
index
,
uint32_t
*
num_true
)
{
const
cnnlTensorDescriptor_t
num_true_desc
,
void
*
num_true
)
{
cnnlHandle_t
handle
=
GetHandleFromCTX
(
ctx
);
size_t
workspace_size
=
0
;
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlGetNumTrueWorkspaceSize
(
handle
,
x_desc
,
&
workspace_size
));
auto
&
dev_ctx
=
GetDevCtxFromCTX
(
ctx
);
index
=
ctx
.
AllocateTmpTensor
<
int8_t
,
MLUDeviceContext
>
(
{
static_cast
<
int64_t
>
(
workspace_size
)},
dev_ctx
);
void
*
index_ptr
=
index
.
mutable_data
(
ctx
.
GetPlace
());
PADDLE_ENFORCE_MLU_SUCCESS
(
cnnlNumTrue
(
handle
,
x_desc
,
x
,
static_cast
<
uint32_t
*>
(
index_ptr
),
num_true
));
cnnlNumTrue_v2
(
handle
,
x_desc
,
x
,
num_true_desc
,
num_true
));
}
/* static */
void
MLUCnnl
::
Where
(
const
ExecutionContext
&
ctx
,
...
...
paddle/fluid/operators/mlu/mlu_baseop.h
浏览文件 @
5061d3db
...
...
@@ -1703,8 +1703,8 @@ class MLUCnnl {
static
void
NumTrue
(
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
x_desc
,
const
void
*
x
,
Tensor
index
,
uint32_t
*
num_true
);
const
cnnlTensorDescriptor_t
num_true_desc
,
void
*
num_true
);
static
void
Where
(
const
ExecutionContext
&
ctx
,
const
cnnlTensorDescriptor_t
x_desc
,
...
...
paddle/fluid/operators/randperm_op_mlu.cc
浏览文件 @
5061d3db
...
...
@@ -15,9 +15,32 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/randperm_op.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
RandpermMLUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
int
n
=
ctx
.
Attr
<
int
>
(
"n"
);
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
ctx
.
Attr
<
int
>
(
"seed"
));
framework
::
Variable
*
out_var
=
ctx
.
OutputVar
(
"Out"
);
framework
::
Tensor
*
out_tensor
=
framework
::
GetMutableLoDTensorOrSelectedRowsValueFromVar
(
out_var
);
framework
::
Tensor
tmp_tensor
;
tmp_tensor
.
Resize
(
phi
::
make_ddim
({
n
}));
T
*
tmp_data
=
tmp_tensor
.
mutable_data
<
T
>
(
platform
::
CPUPlace
());
random_permate
<
T
>
(
tmp_data
,
n
,
seed
);
framework
::
TensorCopySync
(
tmp_tensor
,
ctx
.
GetPlace
(),
out_tensor
);
}
};
}
// namespace operators
}
// namespace paddle
template
<
typename
T
>
using
kernel
=
paddle
::
operators
::
RandpermKernel
<
paddle
::
platform
::
MLUDeviceContext
,
T
>
;
using
kernel
=
paddle
::
operators
::
RandpermMLUKernel
<
T
>
;
REGISTER_OP_MLU_KERNEL
(
randperm
,
kernel
<
int64_t
>
,
kernel
<
int
>
,
kernel
<
float
>
,
kernel
<
double
>
);
paddle/fluid/operators/where_index_op_mlu.cc
浏览文件 @
5061d3db
...
...
@@ -30,30 +30,36 @@ class MLUWhereIndexKernel : public framework::OpKernel<T> {
auto
*
out
=
context
.
Output
<
Tensor
>
(
"Out"
);
auto
dims
=
condition
->
dims
();
const
int
rank
=
dims
.
size
();
std
::
vector
<
int
>
true_num
=
{
0
};
std
::
vector
<
T
>
vec_condition
;
paddle
::
framework
::
TensorToVector
(
*
condition
,
context
.
device_context
(),
&
vec_condition
);
int
vec_con_size
=
vec_condition
.
size
();
for
(
int
i
=
0
;
i
<
vec_con_size
;
++
i
)
{
if
(
vec_condition
[
i
]
>
0
)
true_num
[
0
]
++
;
}
out
->
Resize
(
phi
::
make_ddim
({
true_num
[
0
],
rank
}));
Tensor
num_true
;
num_true
.
mutable_data
<
int
>
({
1
},
context
.
GetPlace
());
MLUCnnlTensorDesc
con_desc
(
*
condition
);
MLUCnnlTensorDesc
num_true_desc
(
num_true
);
MLUCnnl
::
NumTrue
(
context
,
con_desc
.
get
(),
GetBasePtr
(
condition
),
num_true_desc
.
get
(),
GetBasePtr
(
&
num_true
));
Tensor
local_true_num
;
paddle
::
framework
::
TensorCopySync
(
num_true
,
platform
::
CPUPlace
(),
&
local_true_num
);
auto
true_num
=
*
local_true_num
.
data
<
int
>
();
out
->
Resize
(
phi
::
make_ddim
({
true_num
,
rank
}));
out
->
mutable_data
<
int64_t
>
(
context
.
GetPlace
());
if
(
true_num
==
0
)
{
return
;
}
auto
&
dev_ctx
=
context
.
template
device_context
<
MLUDeviceContext
>();
framework
::
Tensor
out_int32
=
context
.
AllocateTmpTensor
<
int32_t
,
MLUDeviceContext
>
(
out
->
dims
(),
dev_ctx
);
Tensor
num_true
;
paddle
::
framework
::
TensorFromVector
(
true_num
,
context
.
device_context
(),
&
num_true
);
num_true
.
mutable_data
<
int
>
(
context
.
GetPlace
());
bool
as_tuple
=
false
;
MLUCnnlTensorDesc
con_desc
(
*
condition
);
MLUCnnlTensorDesc
num_true_desc
(
num_true
);
MLUCnnlTensorDesc
out_int32_desc
(
out_int32
);
MLUCnnlTensorDesc
out_desc
(
*
out
);
bool
as_tuple
=
false
;
MLUCnnl
::
Where
(
context
,
con_desc
.
get
(),
GetBasePtr
(
condition
),
...
...
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录