Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
PaddlePaddle
Paddle
提交
60e1eccb
P
Paddle
项目概览
PaddlePaddle
/
Paddle
大约 1 年 前同步成功
通知
2298
Star
20931
Fork
5422
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1423
列表
看板
标记
里程碑
合并请求
543
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1,423
Issue
1,423
列表
看板
标记
里程碑
合并请求
543
合并请求
543
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
体验新版 GitCode,发现更多精彩内容 >>
未验证
提交
60e1eccb
编写于
8月 29, 2022
作者:
W
wanghuancoder
提交者:
GitHub
8月 29, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
[Phi] gather gather_grad gather_nd gaussian_random xpu to Phi (#45465)
* gather gather_grad gather_nd gaussian_random xpu to phi
上级
ca5567e1
变更
7
隐藏空白更改
内联
并排
Showing
7 changed file
with
335 addition
and
382 deletion
+335
-382
paddle/fluid/operators/gather_nd_op_xpu.cc
paddle/fluid/operators/gather_nd_op_xpu.cc
+0
-96
paddle/fluid/operators/gather_op_xpu.cc
paddle/fluid/operators/gather_op_xpu.cc
+0
-228
paddle/fluid/operators/gaussian_random_op_xpu.cc
paddle/fluid/operators/gaussian_random_op_xpu.cc
+0
-58
paddle/phi/kernels/xpu/gather_grad_kernel.cc
paddle/phi/kernels/xpu/gather_grad_kernel.cc
+111
-0
paddle/phi/kernels/xpu/gather_kernel.cc
paddle/phi/kernels/xpu/gather_kernel.cc
+86
-0
paddle/phi/kernels/xpu/gather_nd_kernel.cc
paddle/phi/kernels/xpu/gather_nd_kernel.cc
+83
-0
paddle/phi/kernels/xpu/gaussian_random_kernel.cc
paddle/phi/kernels/xpu/gaussian_random_kernel.cc
+55
-0
未找到文件。
paddle/fluid/operators/gather_nd_op_xpu.cc
已删除
100644 → 0
浏览文件 @
ca5567e1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
GatherNdXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
auto
*
x
=
ctx
.
Input
<
framework
::
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
framework
::
Tensor
>
(
"Index"
);
auto
*
out
=
ctx
.
Output
<
framework
::
Tensor
>
(
"Out"
);
out
->
template
mutable_data
<
T
>(
ctx
.
GetPlace
());
if
(
x
->
numel
()
==
0
)
return
;
if
(
index
->
numel
()
==
0
)
{
framework
::
TensorCopy
(
*
x
,
ctx
.
GetPlace
(),
ctx
.
device_context
(),
out
);
return
;
}
const
auto
&
index_type
=
framework
::
TransToProtoVarType
(
index
->
dtype
());
bool
index_type_match
=
index_type
==
framework
::
proto
::
VarType
::
INT32
||
index_type
==
framework
::
proto
::
VarType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
platform
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]"
,
paddle
::
framework
::
DataTypeToString
(
index_type
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT32
),
paddle
::
framework
::
DataTypeToString
(
framework
::
proto
::
VarType
::
INT64
)));
auto
x_shape
=
phi
::
vectorize
<
int
>
(
x
->
dims
());
auto
index_shape
=
phi
::
vectorize
<
int
>
(
index
->
dims
());
if
(
index_shape
.
size
()
==
1
)
{
index_shape
.
insert
(
index_shape
.
begin
(),
1
);
}
xpu
::
VectorParam
<
int
>
x_vec
=
{
x_shape
.
data
(),
static_cast
<
int
>
(
x_shape
.
size
()),
nullptr
};
auto
&
dev_ctx
=
ctx
.
template
device_context
<
paddle
::
platform
::
XPUDeviceContext
>();
int
ret
=
XPU_SUCCESS
;
if
(
index_type
==
framework
::
proto
::
VarType
::
INT32
)
{
ret
=
xpu
::
gather_nd
<
T
,
int
>
(
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
index
->
data
<
int
>
(),
out
->
data
<
T
>
(),
x_vec
,
index_shape
);
}
else
{
ret
=
xpu
::
gather_nd
<
T
,
int64_t
>
(
dev_ctx
.
x_context
(),
x
->
data
<
T
>
(),
index
->
data
<
int64_t
>
(),
out
->
data
<
T
>
(),
x_vec
,
index_shape
);
}
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU gather_nd kernel return wrong value[%d %s]"
,
ret
,
XPUAPIErrorMsg
[
ret
]));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
gather_nd
,
ops
::
GatherNdXPUKernel
<
int
>
,
ops
::
GatherNdXPUKernel
<
int64_t
>
,
ops
::
GatherNdXPUKernel
<
float
>
);
#endif
paddle/fluid/operators/gather_op_xpu.cc
已删除
100644 → 0
浏览文件 @
ca5567e1
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
T
>
class
GatherOpXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on XPU."
));
auto
*
x
=
ctx
.
Input
<
Tensor
>
(
"X"
);
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
output
=
ctx
.
Output
<
Tensor
>
(
"Out"
);
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
ctx
.
HasInput
(
"Axis"
))
{
Tensor
cpu_axis
;
const
Tensor
*
axis_tensor
=
ctx
.
Input
<
Tensor
>
(
"Axis"
);
framework
::
TensorCopy
(
*
axis_tensor
,
platform
::
CPUPlace
(),
&
cpu_axis
);
const
auto
&
axis_type
=
axis_tensor
->
dtype
();
if
(
framework
::
TransToProtoVarType
(
axis_type
)
==
framework
::
proto
::
VarType
::
INT32
)
{
axis
=
static_cast
<
int
>
(
cpu_axis
.
data
<
int32_t
>
()[
0
]);
}
else
if
(
framework
::
TransToProtoVarType
(
axis_type
)
==
framework
::
proto
::
VarType
::
INT64
)
{
axis
=
static_cast
<
int
>
(
cpu_axis
.
data
<
int64_t
>
()[
0
]);
}
}
output
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
if
(
x
->
numel
()
==
0
)
return
;
const
auto
index_dims
=
index
->
dims
();
if
(
index_dims
.
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index_dims
[
1
],
1
,
platform
::
errors
::
InvalidArgument
(
"The last dim of index should be 1 when it is 2D, but we get %d"
,
index_dims
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
index_dims
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The index should be 1D, when it is not 2D, but we get %d"
,
index_dims
.
size
()));
}
std
::
vector
<
int
>
xshape
(
x
->
dims
().
size
());
for
(
int
i
=
0
;
i
<
x
->
dims
().
size
();
++
i
)
{
xshape
[
i
]
=
x
->
dims
()[
i
];
}
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
XPUDeviceContext
>();
int
r
=
XPU_SUCCESS
;
if
(
framework
::
TransToProtoVarType
(
index
->
dtype
())
==
framework
::
proto
::
VarType
::
INT32
)
{
r
=
xpu
::
gather
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
index
->
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
output
->
data
<
T
>
()),
xshape
,
index
->
dims
()[
0
],
axis
);
}
else
{
r
=
xpu
::
gather
<
XPUType
,
int64_t
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
->
data
<
T
>
()),
index
->
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
output
->
data
<
T
>
()),
xshape
,
index
->
dims
()[
0
],
axis
);
}
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
"XPU gather kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
};
template
<
typename
T
>
class
GatherGradOpXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
ctx
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
ctx
.
GetPlace
()),
true
,
platform
::
errors
::
PreconditionNotMet
(
"This kernel only runs on XPU."
));
auto
*
index
=
ctx
.
Input
<
Tensor
>
(
"Index"
);
auto
*
dx
=
ctx
.
Output
<
Tensor
>
(
framework
::
GradVarName
(
"X"
));
auto
*
dout
=
ctx
.
Input
<
Tensor
>
(
framework
::
GradVarName
(
"Out"
));
auto
&
dev_ctx
=
ctx
.
template
device_context
<
platform
::
XPUDeviceContext
>();
int
axis
=
ctx
.
Attr
<
int
>
(
"axis"
);
if
(
ctx
.
HasInput
(
"Axis"
))
{
Tensor
cpu_axis
;
const
Tensor
*
axis_tensor
=
ctx
.
Input
<
Tensor
>
(
"Axis"
);
framework
::
TensorCopy
(
*
axis_tensor
,
platform
::
CPUPlace
(),
&
cpu_axis
);
const
auto
&
axis_type
=
axis_tensor
->
dtype
();
if
(
framework
::
TransToProtoVarType
(
axis_type
)
==
framework
::
proto
::
VarType
::
INT32
)
{
axis
=
static_cast
<
int
>
(
cpu_axis
.
data
<
int32_t
>
()[
0
]);
}
else
if
(
framework
::
TransToProtoVarType
(
axis_type
)
==
framework
::
proto
::
VarType
::
INT64
)
{
axis
=
static_cast
<
int
>
(
cpu_axis
.
data
<
int64_t
>
()[
0
]);
}
}
if
(
dout
->
numel
()
==
0
)
{
return
;
}
bool
overwrite
=
ctx
.
Attr
<
bool
>
(
"overwrite"
);
const
auto
index_dims
=
index
->
dims
();
if
(
index_dims
.
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index_dims
[
1
],
1
,
platform
::
errors
::
InvalidArgument
(
"The last dim of index should be 1 when it is 2D, but we get %d"
,
index_dims
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
index_dims
.
size
(),
1
,
platform
::
errors
::
InvalidArgument
(
"The index should be 1D, when it is not 2D, but we get %d"
,
index_dims
.
size
()));
}
std
::
vector
<
int
>
xshape
(
dx
->
dims
().
size
());
for
(
int
i
=
0
;
i
<
dx
->
dims
().
size
();
++
i
)
{
xshape
[
i
]
=
dx
->
dims
()[
i
];
}
dx
->
mutable_data
<
T
>
(
ctx
.
GetPlace
());
int
r
=
XPU_SUCCESS
;
if
(
framework
::
TransToProtoVarType
(
index
->
dtype
())
==
framework
::
proto
::
VarType
::
INT32
)
{
r
=
xpu
::
gather_grad
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
dout
->
data
<
T
>
()),
index
->
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
dx
->
data
<
T
>
()),
xshape
,
index
->
dims
()[
0
],
axis
,
overwrite
);
}
else
{
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
int
*
index_int_ptr_l3
=
RAII_GUARD
.
alloc_l3_or_gm
<
int32_t
>
(
index
->
numel
());
r
=
xpu
::
cast_v2
<
int64_t
,
int32_t
>
(
dev_ctx
.
x_context
(),
index
->
data
<
int64_t
>
(),
index_int_ptr_l3
,
index
->
numel
());
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
platform
::
errors
::
External
(
"XPU API(cast_v2) return wrong "
"value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
r
=
xpu
::
gather_grad
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
dout
->
data
<
T
>
()),
index_int_ptr_l3
,
reinterpret_cast
<
XPUType
*>
(
dx
->
data
<
T
>
()),
xshape
,
index
->
dims
()[
0
],
axis
,
overwrite
);
}
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
platform
::
errors
::
External
(
"XPU gather grad kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
gather
,
ops
::
GatherOpXPUKernel
<
float
>
,
ops
::
GatherOpXPUKernel
<
paddle
::
platform
::
float16
>
);
REGISTER_OP_XPU_KERNEL
(
gather_grad
,
ops
::
GatherGradOpXPUKernel
<
float
>
,
ops
::
GatherGradOpXPUKernel
<
paddle
::
platform
::
float16
>
);
#endif
paddle/fluid/operators/gaussian_random_op_xpu.cc
已删除
100644 → 0
浏览文件 @
ca5567e1
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <random>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace
paddle
{
namespace
operators
{
template
<
typename
T
>
class
XPUGaussianRandomKernel
:
public
framework
::
OpKernel
<
T
>
{
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
float
mean
=
context
.
Attr
<
float
>
(
"mean"
);
float
std
=
context
.
Attr
<
float
>
(
"std"
);
auto
*
tensor
=
context
.
Output
<
framework
::
Tensor
>
(
"Out"
);
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
int64_t
size
=
tensor
->
numel
();
T
*
data
=
tensor
->
mutable_data
<
T
>
(
context
.
GetPlace
());
unsigned
int
seed
=
static_cast
<
unsigned
int
>
(
context
.
Attr
<
int
>
(
"seed"
));
// TODO(pangyoki): implement GetXPURandomEngine to set different seeds on
// corresponding XPU device.
auto
engine
=
framework
::
GetCPURandomEngine
(
seed
);
std
::
unique_ptr
<
T
[]
>
data_cpu
(
new
T
[
size
]);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data_cpu
[
i
]
=
dist
(
*
engine
);
}
memory
::
Copy
(
context
.
GetPlace
(),
data
,
platform
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
data_cpu
.
get
()),
size
*
sizeof
(
T
));
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
gaussian_random
,
ops
::
XPUGaussianRandomKernel
<
float
>
);
#endif
paddle/phi/kernels/xpu/gather_grad_kernel.cc
0 → 100644
浏览文件 @
60e1eccb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
GatherGradKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
index
,
const
DenseTensor
&
out_grad
,
const
Scalar
&
axis
,
bool
overwrite
,
DenseTensor
*
x_grad
)
{
auto
axis_v
=
axis
.
to
<
int
>
();
const
auto
&
index_type
=
index
.
dtype
();
if
(
out_grad
.
numel
()
==
0
)
{
return
;
}
const
auto
index_dims
=
index
.
dims
();
if
(
index_dims
.
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index_dims
[
1
],
1
,
phi
::
errors
::
InvalidArgument
(
"The last dim of index should be 1 when it is 2D, but we get %d"
,
index_dims
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
index_dims
.
size
(),
1
,
phi
::
errors
::
InvalidArgument
(
"The index should be 1D, when it is not 2D, but we get %d"
,
index_dims
.
size
()));
}
std
::
vector
<
int
>
xshape
(
x_grad
->
dims
().
size
());
for
(
int
i
=
0
;
i
<
x_grad
->
dims
().
size
();
++
i
)
{
xshape
[
i
]
=
x_grad
->
dims
()[
i
];
}
dev_ctx
.
template
Alloc
<
T
>(
x_grad
);
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
int
r
=
XPU_SUCCESS
;
if
(
index_type
==
DataType
::
INT32
)
{
r
=
xpu
::
gather_grad
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
out_grad
.
data
<
T
>
()),
index
.
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
x_grad
->
data
<
T
>
()),
xshape
,
index
.
dims
()[
0
],
axis_v
,
overwrite
);
}
else
{
xpu
::
ctx_guard
RAII_GUARD
(
dev_ctx
.
x_context
());
int
*
index_int_ptr_l3
=
RAII_GUARD
.
alloc_l3_or_gm
<
int32_t
>
(
index
.
numel
());
r
=
xpu
::
cast_v2
<
int64_t
,
int32_t
>
(
dev_ctx
.
x_context
(),
index
.
data
<
int64_t
>
(),
index_int_ptr_l3
,
index
.
numel
());
PADDLE_ENFORCE_EQ
(
r
,
XPU_SUCCESS
,
phi
::
errors
::
External
(
"XPU API(cast_v2) return wrong "
"value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
r
=
xpu
::
gather_grad
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
out_grad
.
data
<
T
>
()),
index_int_ptr_l3
,
reinterpret_cast
<
XPUType
*>
(
x_grad
->
data
<
T
>
()),
xshape
,
index
.
dims
()[
0
],
axis_v
,
overwrite
);
}
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
phi
::
errors
::
External
(
"XPU gather grad kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
gather_grad
,
XPU
,
ALL_LAYOUT
,
phi
::
GatherGradKernel
,
float
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/xpu/gather_kernel.cc
0 → 100644
浏览文件 @
60e1eccb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
GatherKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
index
,
const
Scalar
&
axis
,
DenseTensor
*
out
)
{
auto
axis_v
=
axis
.
to
<
int
>
();
const
auto
&
index_type
=
index
.
dtype
();
dev_ctx
.
template
Alloc
<
T
>(
out
);
if
(
x
.
numel
()
==
0
)
return
;
const
auto
index_dims
=
index
.
dims
();
if
(
index_dims
.
size
()
==
2
)
{
PADDLE_ENFORCE_EQ
(
index_dims
[
1
],
1
,
phi
::
errors
::
InvalidArgument
(
"The last dim of index should be 1 when it is 2D, but we get %d"
,
index_dims
[
1
]));
}
else
{
PADDLE_ENFORCE_EQ
(
index_dims
.
size
(),
1
,
phi
::
errors
::
InvalidArgument
(
"The index should be 1D, when it is not 2D, but we get %d"
,
index_dims
.
size
()));
}
std
::
vector
<
int
>
xshape
(
x
.
dims
().
size
());
for
(
int
i
=
0
;
i
<
x
.
dims
().
size
();
++
i
)
{
xshape
[
i
]
=
x
.
dims
()[
i
];
}
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
int
r
=
XPU_SUCCESS
;
if
(
index_type
==
DataType
::
INT32
)
{
r
=
xpu
::
gather
<
XPUType
,
int
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
xshape
,
index
.
dims
()[
0
],
axis_v
);
}
else
{
r
=
xpu
::
gather
<
XPUType
,
int64_t
>
(
dev_ctx
.
x_context
(),
reinterpret_cast
<
const
XPUType
*>
(
x
.
data
<
T
>
()),
index
.
data
<
int64_t
>
(),
reinterpret_cast
<
XPUType
*>
(
out
->
data
<
T
>
()),
xshape
,
index
.
dims
()[
0
],
axis_v
);
}
PADDLE_ENFORCE_EQ
(
r
,
xpu
::
Error_t
::
SUCCESS
,
phi
::
errors
::
External
(
"XPU gather kernel return wrong value[%d %s]"
,
r
,
XPUAPIErrorMsg
[
r
]));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
gather
,
XPU
,
ALL_LAYOUT
,
phi
::
GatherKernel
,
float
,
phi
::
dtype
::
float16
)
{}
paddle/phi/kernels/xpu/gather_nd_kernel.cc
0 → 100644
浏览文件 @
60e1eccb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_nd_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
GatherNdKernel
(
const
Context
&
ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
index
,
DenseTensor
*
out
)
{
ctx
.
template
Alloc
<
T
>(
out
);
const
auto
&
index_type
=
index
.
dtype
();
if
(
x
.
numel
()
==
0
)
return
;
if
(
index
.
numel
()
==
0
)
{
phi
::
Copy
(
ctx
,
x
,
phi
::
XPUPlace
(),
true
,
out
);
return
;
}
bool
index_type_match
=
index_type
==
DataType
::
INT32
||
index_type
==
DataType
::
INT64
;
PADDLE_ENFORCE_EQ
(
index_type_match
,
true
,
phi
::
errors
::
InvalidArgument
(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]"
,
index_type
,
DataType
::
INT32
,
DataType
::
INT64
));
auto
x_shape
=
phi
::
vectorize
<
int
>
(
x
.
dims
());
auto
index_shape
=
phi
::
vectorize
<
int
>
(
index
.
dims
());
if
(
index_shape
.
size
()
==
1
)
{
index_shape
.
insert
(
index_shape
.
begin
(),
1
);
}
xpu
::
VectorParam
<
int
>
x_vec
=
{
x_shape
.
data
(),
static_cast
<
int
>
(
x_shape
.
size
()),
nullptr
};
int
ret
=
XPU_SUCCESS
;
if
(
index_type
==
DataType
::
INT32
)
{
ret
=
xpu
::
gather_nd
<
T
,
int
>
(
ctx
.
x_context
(),
x
.
data
<
T
>
(),
index
.
data
<
int
>
(),
out
->
data
<
T
>
(),
x_vec
,
index_shape
);
}
else
{
ret
=
xpu
::
gather_nd
<
T
,
int64_t
>
(
ctx
.
x_context
(),
x
.
data
<
T
>
(),
index
.
data
<
int64_t
>
(),
out
->
data
<
T
>
(),
x_vec
,
index_shape
);
}
PADDLE_ENFORCE_EQ
(
ret
,
XPU_SUCCESS
,
phi
::
errors
::
External
(
"XPU gather_nd kernel return wrong value[%d %s]"
,
ret
,
XPUAPIErrorMsg
[
ret
]));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
gather_nd
,
XPU
,
ALL_LAYOUT
,
phi
::
GatherNdKernel
,
float
,
int64_t
,
int
)
{}
paddle/phi/kernels/xpu/gaussian_random_kernel.cc
0 → 100644
浏览文件 @
60e1eccb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
GaussianRandomKernel
(
const
Context
&
ctx
,
const
IntArray
&
shape
,
float
mean
,
float
std
,
int
seed
,
DataType
dtype
,
DenseTensor
*
out
)
{
std
::
normal_distribution
<
T
>
dist
(
mean
,
std
);
int64_t
size
=
out
->
numel
();
ctx
.
template
Alloc
<
T
>(
out
);
auto
*
data
=
out
->
data
();
uint64_t
seed_v
=
static_cast
<
uint64_t
>
(
seed
);
// TODO(pangyoki): implement GetXPURandomEngine to set different seeds on
// corresponding XPU device.
auto
engine
=
paddle
::
framework
::
GetCPURandomEngine
(
seed_v
);
std
::
unique_ptr
<
T
[]
>
data_cpu
(
new
T
[
size
]);
for
(
int64_t
i
=
0
;
i
<
size
;
++
i
)
{
data_cpu
[
i
]
=
dist
(
*
engine
);
}
paddle
::
memory
::
Copy
(
phi
::
XPUPlace
(),
data
,
phi
::
CPUPlace
(),
reinterpret_cast
<
void
*>
(
data_cpu
.
get
()),
size
*
sizeof
(
T
));
}
}
// namespace phi
PD_REGISTER_KERNEL
(
gaussian_random
,
XPU
,
ALL_LAYOUT
,
phi
::
GaussianRandomKernel
,
float
)
{}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录