Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
BaiXuePrincess
Paddle
提交
3b9b4c34
P
Paddle
项目概览
BaiXuePrincess
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
0
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
0
Issue
0
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
3b9b4c34
编写于
9月 02, 2022
作者:
Y
ykkk2333
提交者:
GitHub
9月 02, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
migrate shaple sgd, split,sign xpu kernels to phi, test=kunlun (#45607)
上级
445fce62
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
200 addition
and
229 deletion
+200
-229
paddle/fluid/operators/optimizers/sgd_op_xpu.cc
paddle/fluid/operators/optimizers/sgd_op_xpu.cc
+0
-102
paddle/fluid/operators/shape_op_xpu.cc
paddle/fluid/operators/shape_op_xpu.cc
+0
-54
paddle/fluid/operators/split_op_xpu.cc
paddle/fluid/operators/split_op_xpu.cc
+0
-73
paddle/phi/kernels/shape_kernel.cc
paddle/phi/kernels/shape_kernel.cc
+14
-0
paddle/phi/kernels/xpu/sgd_kernel.cc
paddle/phi/kernels/xpu/sgd_kernel.cc
+86
-0
paddle/phi/kernels/xpu/sign_kernel.cc
paddle/phi/kernels/xpu/sign_kernel.cc
+33
-0
paddle/phi/kernels/xpu/split_kernel.cc
paddle/phi/kernels/xpu/split_kernel.cc
+67
-0
未找到文件。
paddle/fluid/operators/optimizers/sgd_op_xpu.cc
已删除
100644 → 0
浏览文件 @
445fce62
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <string>
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
SGDOpXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
const
auto
*
learning_rate
=
ctx
.
Input
<
framework
::
Tensor
>
(
"LearningRate"
);
const
auto
*
param_var
=
ctx
.
InputVar
(
"Param"
);
const
auto
*
grad_var
=
ctx
.
InputVar
(
"Grad"
);
if
(
param_var
->
IsType
<
framework
::
LoDTensor
>
()
&&
grad_var
->
IsType
<
framework
::
LoDTensor
>
())
{
const
auto
*
param
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Param"
);
auto
*
param_out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"ParamOut"
);
// Actually, all tensors are LoDTensor except SelectedRows.
const
auto
*
grad
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Grad"
);
auto
sz
=
param_out
->
numel
();
PADDLE_ENFORCE_EQ
(
param
->
numel
(),
sz
,
platform
::
errors
::
InvalidArgument
(
"The input tensor Param's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Param's "
"numel = [%s], ParamOut's numel = [%s]"
,
param
->
numel
(),
sz
));
PADDLE_ENFORCE_EQ
(
grad
->
numel
(),
sz
,
platform
::
errors
::
InvalidArgument
(
"The input tensor Grad's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Grad's "
"numel = [%s], ParamOut's numel = [%s]"
,
grad
->
numel
(),
sz
));
const
T
*
lr_t
=
learning_rate
->
data
<
T
>
();
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
const
float
*
lr
=
nullptr
;
if
(
std
::
is_same
<
T
,
paddle
::
platform
::
float16
>::
value
)
{
float
*
lr_float
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
learning_rate
->
numel
());
int
r
=
xpu
::
cast_v2
<
XPUType
,
float
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
lr_t
),
lr_float
,
learning_rate
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"clip_v2"
);
lr
=
lr_float
;
}
else
{
lr
=
reinterpret_cast
<
const
float
*>
(
lr_t
);
}
const
T
*
param_data
=
param
->
data
<
T
>
();
const
T
*
grad_data
=
grad
->
data
<
T
>
();
T
*
out_data
=
param_out
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
r
=
xpu
::
sgd
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
grad_data
),
reinterpret_cast
<
const
XPUType
*>
(
param_data
),
lr
,
reinterpret_cast
<
XPUType
*>
(
out_data
),
sz
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"sgd"
);
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
namespace
plat
=
paddle
::
platform
;
REGISTER_OP_XPU_KERNEL
(
sgd
,
ops
::
SGDOpXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
,
ops
::
SGDOpXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
plat
::
float16
>
);
#endif
paddle/fluid/operators/shape_op_xpu.cc
已删除
100644 → 0
浏览文件 @
445fce62
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
using
LoDTensor
=
framework
::
LoDTensor
;
using
SelectedRows
=
phi
::
SelectedRows
;
template
<
typename
T
>
class
ShapeXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
in_var
=
ctx
.
InputVar
(
"Input"
);
framework
::
DDim
in_dims
;
if
(
in_var
->
IsType
<
phi
::
SelectedRows
>
())
{
in_dims
=
in_var
->
Get
<
phi
::
SelectedRows
>
().
value
().
dims
();
}
else
{
in_dims
=
in_var
->
Get
<
LoDTensor
>
().
dims
();
}
auto
*
out_t
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
out_t
->
Resize
({
in_dims
.
size
()});
auto
out_data
=
out_t
->
mutable_data
<
int32_t
>
(
platform
::
CPUPlace
());
for
(
int
i
=
0
;
i
<
in_dims
.
size
();
++
i
)
{
out_data
[
i
]
=
in_dims
[
i
];
}
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
shape
,
ops
::
ShapeXPUKernel
<
bool
>
,
ops
::
ShapeXPUKernel
<
int
>
,
ops
::
ShapeXPUKernel
<
int64_t
>
,
ops
::
ShapeXPUKernel
<
float
>
,
ops
::
ShapeXPUKernel
<
double
>
);
#endif
paddle/fluid/operators/split_op_xpu.cc
已删除
100644 → 0
浏览文件 @
445fce62
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <string>
#include <vector>
#include "paddle/fluid/operators/split_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
SplitXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
input
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
output
=
ctx
.
MultiOutput
<
framework
::
Tensor
>
(
"Out"
);
int
num
=
ctx
.
Attr
<
int
>
(
"num"
);
std
::
vector
<
int
>
sections
=
ctx
.
Attr
<
std
::
vector
<
int
>>
(
"sections"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
auto
&
dev_ctx
=
ctx
.
template
device_context
<
DeviceContext
>();
auto
in_dims
=
input
->
dims
();
auto
input_shape
=
phi
::
vectorize
<
int
>
(
in_dims
);
std
::
vector
<
int
>
split_lists
;
std
::
vector
<
T
*>
out_ptrs
;
auto
outs_number
=
output
.
size
();
std
::
vector
<
framework
::
DDim
>
outs_dims
=
UpdateOutsDims
(
true
,
true
,
in_dims
,
num
,
sections
,
axis
,
outs_number
);
for
(
size_t
i
=
0
;
i
<
output
.
size
();
++
i
)
{
output
[
i
]
->
Resize
(
outs_dims
[
i
]);
out_ptrs
.
push_back
(
output
[
i
]
->
mutable_data
<
T
>
(
ctx
.
GetPlace
()));
split_lists
.
push_back
(
output
[
i
]
->
dims
()[
axis
]);
}
int
r
=
xpu
::
split
<
T
>
(
dev_ctx
.
x_context
(),
input
->
data
<
T
>
(),
out_ptrs
,
input_shape
,
split_lists
,
axis
);
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU split kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
split
,
ops
::
SplitXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
,
ops
::
SplitXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
int
>
);
#endif
paddle/phi/kernels/shape_kernel.cc
浏览文件 @
3b9b4c34
...
...
@@ -67,3 +67,17 @@ PD_REGISTER_KERNEL(shape,
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL
(
shape
,
XPU
,
ALL_LAYOUT
,
phi
::
ShapeKernel
,
bool
,
int
,
int64_t
,
float
,
double
)
{
kernel
->
InputAt
(
0
).
SetBackend
(
phi
::
Backend
::
ALL_BACKEND
);
}
#endif
paddle/phi/kernels/xpu/sgd_kernel.cc
0 → 100644
浏览文件 @
3b9b4c34
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SGDDenseKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
param
,
const
DenseTensor
&
learning_rate
,
const
DenseTensor
&
grad
,
const
paddle
::
optional
<
DenseTensor
>
&
master_param
,
bool
multi_precision
,
DenseTensor
*
param_out
,
DenseTensor
*
master_param_out
)
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
auto
sz
=
param_out
->
numel
();
PADDLE_ENFORCE_EQ
(
param
.
numel
(),
sz
,
errors
::
InvalidArgument
(
"The input tensor Param's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Param's "
"numel = [%s], ParamOut's numel = [%s]"
,
param
.
numel
(),
sz
));
PADDLE_ENFORCE_EQ
(
grad
.
numel
(),
sz
,
errors
::
InvalidArgument
(
"The input tensor Grad's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Grad's "
"numel = [%s], ParamOut's numel = [%s]"
,
grad
.
numel
(),
sz
));
const
T
*
lr_t
=
learning_rate
.
data
<
T
>
();
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
const
float
*
lr
=
nullptr
;
if
(
std
::
is_same
<
T
,
dtype
::
float16
>::
value
)
{
float
*
lr_float
=
RAII_GUARD
.
alloc_l3_or_gm
<
float
>
(
learning_rate
.
numel
());
int
r
=
xpu
::
cast_v2
<
XPUType
,
float
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
lr_t
),
lr_float
,
learning_rate
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"clip_v2"
);
lr
=
lr_float
;
}
else
{
lr
=
reinterpret_cast
<
const
float
*>
(
lr_t
);
}
const
T
*
param_data
=
param
.
data
<
T
>
();
const
T
*
grad_data
=
grad
.
data
<
T
>
();
dev_ctx
.
template
Alloc
<
T
>(
param_out
);
T
*
out_data
=
param_out
->
data
<
T
>
();
int
r
=
xpu
::
sgd
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
grad_data
),
reinterpret_cast
<
const
XPUType
*>
(
param_data
),
lr
,
reinterpret_cast
<
XPUType
*>
(
out_data
),
sz
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"sgd"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
sgd
,
XPU
,
ALL_LAYOUT
,
phi
::
SGDDenseKernel
,
phi
::
dtype
::
float16
,
float
)
{}
paddle/
fluid/operators/sign_op_xpu
.cc
→
paddle/
phi/kernels/xpu/sign_kernel
.cc
浏览文件 @
3b9b4c34
/* Copyright (c) 202
0
PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 202
2
PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
...
...
@@ -12,32 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
DeviceContext
,
typename
T
>
class
SignXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
virtual
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
{
auto
*
out
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
auto
*
in
=
context
.
Input
<
framework
::
Tensor
>
(
"X"
);
out
->
mutable_data
<
T
>
(
in
->
place
());
auto
xpu_context
=
context
.
device_context
<
DeviceContext
>
().
x_context
();
// int sign(Context* ctx, const T* x , T* y, int len);
int
r
=
xpu
::
sign
(
xpu_context
,
in
->
data
<
T
>
(),
out
->
data
<
T
>
(),
in
->
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"sign"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
sign
,
ops
::
SignXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
);
#endif
#include "paddle/phi/kernels/sign_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SignKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
DenseTensor
*
out
)
{
dev_ctx
.
template
Alloc
<
T
>(
out
);
auto
xpu_context
=
dev_ctx
.
x_context
();
int
r
=
xpu
::
sign
(
xpu_context
,
x
.
data
<
T
>
(),
out
->
data
<
T
>
(),
x
.
numel
());
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"sign"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
sign
,
XPU
,
ALL_LAYOUT
,
phi
::
SignKernel
,
float
)
{}
paddle/phi/kernels/xpu/split_kernel.cc
0 → 100644
浏览文件 @
3b9b4c34
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/split_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
SplitKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
IntArray
&
sections
,
const
Scalar
&
axis_scalar
,
std
::
vector
<
DenseTensor
*>
outs
)
{
int
axis
=
axis_scalar
.
to
<
int
>
();
auto
in_dims
=
x
.
dims
();
auto
input_shape
=
vectorize
<
int
>
(
in_dims
);
std
::
vector
<
T
*>
out_ptrs
;
std
::
vector
<
int
>
split_lists
;
for
(
size_t
j
=
0
;
j
<
outs
.
size
();
++
j
)
{
dev_ctx
.
template
Alloc
<
T
>(
outs
[
j
]);
out_ptrs
.
push_back
(
outs
[
j
]
->
data
<
T
>
());
split_lists
.
push_back
(
outs
[
j
]
->
dims
()[
axis
]);
}
int
r
=
xpu
::
split
<
T
>
(
dev_ctx
.
x_context
(),
x
.
data
<
T
>
(),
out_ptrs
,
input_shape
,
split_lists
,
axis
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"split"
);
}
template
<
typename
T
,
typename
Context
>
void
SplitWithNumKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
int
num
,
const
Scalar
&
axis_scalar
,
std
::
vector
<
DenseTensor
*>
outs
)
{
int
axis_value
=
axis_scalar
.
to
<
int
>
();
auto
input_axis_dim
=
x
.
dims
().
at
(
axis_value
);
std
::
vector
<
int64_t
>
sections_vec
;
for
(
int
i
=
0
;
i
<
num
;
++
i
)
{
sections_vec
.
push_back
(
input_axis_dim
/
num
);
}
IntArray
sections
(
sections_vec
);
SplitKernel
<
T
,
Context
>
(
dev_ctx
,
x
,
sections
,
axis_scalar
,
outs
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
split
,
XPU
,
ALL_LAYOUT
,
phi
::
SplitKernel
,
float
,
int
)
{}
PD_REGISTER_KERNEL
(
split_with_num
,
XPU
,
ALL_LAYOUT
,
phi
::
SplitWithNumKernel
,
float
,
int
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录