Skip to content
体验新版
项目
组织
正在加载...
登录
切换导航
打开侧边栏
Crayon鑫
Paddle
提交
2c89bccb
P
Paddle
项目概览
Crayon鑫
/
Paddle
与 Fork 源项目一致
Fork自
PaddlePaddle / Paddle
通知
1
Star
1
Fork
0
代码
文件
提交
分支
Tags
贡献者
分支图
Diff
Issue
1
列表
看板
标记
里程碑
合并请求
0
Wiki
0
Wiki
分析
仓库
DevOps
项目成员
Pages
P
Paddle
项目概览
项目概览
详情
发布
仓库
仓库
文件
提交
分支
标签
贡献者
分支图
比较
Issue
1
Issue
1
列表
看板
标记
里程碑
合并请求
0
合并请求
0
Pages
分析
分析
仓库分析
DevOps
Wiki
0
Wiki
成员
成员
收起侧边栏
关闭侧边栏
动态
分支图
创建新Issue
提交
Issue看板
未验证
提交
2c89bccb
编写于
8月 26, 2022
作者:
R
Ruibiao Chen
提交者:
GitHub
8月 26, 2022
浏览文件
操作
浏览文件
下载
电子邮件补丁
差异文件
Move grid_sample XPU kernel to PHI, test=kunlun (#45425)
上级
1f1a7835
变更
2
显示空白变更内容
内联
并排
Showing
2 changed file
with
113 addition
and
138 deletion
+113
-138
paddle/fluid/operators/grid_sampler_op_xpu.cc
paddle/fluid/operators/grid_sampler_op_xpu.cc
+0
-138
paddle/phi/kernels/xpu/grid_sample_kernel.cc
paddle/phi/kernels/xpu/grid_sample_kernel.cc
+113
-0
未找到文件。
paddle/fluid/operators/grid_sampler_op_xpu.cc
已删除
100644 → 0
浏览文件 @
1f1a7835
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace
paddle
{
namespace
operators
{
using
Tensor
=
framework
::
Tensor
;
template
<
typename
DeviceContext
,
typename
T
>
class
GridSamplerXPUKernel
:
public
framework
::
OpKernel
<
T
>
{
using
XPUType
=
typename
XPUTypeTrait
<
T
>::
Type
;
public:
void
Compute
(
const
framework
::
ExecutionContext
&
context
)
const
override
{
PADDLE_ENFORCE_EQ
(
platform
::
is_xpu_place
(
context
.
GetPlace
()),
true
,
platform
::
errors
::
Unavailable
(
"This kernel only runs on XPU."
));
// input and output data
const
Tensor
*
input
=
context
.
Input
<
Tensor
>
(
"X"
);
const
Tensor
*
grid
=
context
.
Input
<
Tensor
>
(
"Grid"
);
Tensor
*
output
=
context
.
Output
<
Tensor
>
(
"Output"
);
int
n
=
input
->
dims
()[
0
];
int
c
=
input
->
dims
()[
1
];
int
h
=
input
->
dims
()[
2
];
int
w
=
input
->
dims
()[
3
];
int
out_h
=
grid
->
dims
()[
1
];
int
out_w
=
grid
->
dims
()[
2
];
// attrs
// paddle.nn.functional.grid_sample(x, grid, mode='bilinear',
// padding_mode='zeros', align_corners=True, name=None)
const
std
::
string
mode
=
context
.
Attr
<
std
::
string
>
(
"mode"
);
const
std
::
string
padding_mode
=
context
.
Attr
<
std
::
string
>
(
"padding_mode"
);
bool
align_corners_bool
=
context
.
Attr
<
bool
>
(
"align_corners"
);
const
std
::
string
data_format
=
paddle
::
framework
::
DataLayoutToString
(
input
->
layout
());
// attr to real param
bool
is_nearest_bool
;
if
(
mode
==
"bilinear"
)
{
is_nearest_bool
=
false
;
}
else
if
(
mode
==
"nearest"
)
{
is_nearest_bool
=
true
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"should not reach here: mode should be either 'bilinear' or "
"'nearest', bot got %s."
,
mode
));
}
// attention: 0: zeros, 2: reflection, 1: border according to XDNN api.
int
padding_mode_int
;
if
(
padding_mode
==
"zeros"
)
{
padding_mode_int
=
0
;
}
else
if
(
padding_mode
==
"reflection"
)
{
padding_mode_int
=
2
;
}
else
if
(
padding_mode
==
"border"
)
{
padding_mode_int
=
1
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"should not reach here: padding_mode should be either 'zeros' or "
"'reflection' or 'border', bot got %s."
,
padding_mode
));
}
bool
is_nchw_bool
;
if
(
data_format
==
"NCHW"
)
{
is_nchw_bool
=
true
;
}
else
if
(
data_format
==
"NHWC"
)
{
is_nchw_bool
=
false
;
}
else
{
PADDLE_THROW
(
platform
::
errors
::
InvalidArgument
(
"should not reach here: data_format should be either 'NCHW' or "
"'NHWC', bot got %s."
,
data_format
));
}
// data pointers
const
T
*
input_data
=
input
->
data
<
T
>
();
const
T
*
grid_data
=
grid
->
data
<
T
>
();
T
*
output_data
=
output
->
mutable_data
<
T
>
({
n
,
c
,
out_h
,
out_w
},
context
.
GetPlace
());
auto
&
dev_ctx
=
context
.
template
device_context
<
DeviceContext
>();
// int grid_sample(Context* ctx, const T* x, const T* grid, T* y, int n, int
// c, int xh, int xw, int yh, int yw, bool is_nearest, bool align_corners,
// int padding_mode, bool is_nchw);
int
r
=
xpu
::
grid_sample
(
dev_ctx
.
x_context
(),
input_data
,
grid_data
,
output_data
,
n
,
c
,
h
,
w
,
out_h
,
out_w
,
is_nearest_bool
,
align_corners_bool
,
padding_mode_int
,
is_nchw_bool
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"grid_sampler"
);
}
};
}
// namespace operators
}
// namespace paddle
namespace
ops
=
paddle
::
operators
;
REGISTER_OP_XPU_KERNEL
(
grid_sampler
,
ops
::
GridSamplerXPUKernel
<
paddle
::
platform
::
XPUDeviceContext
,
float
>
);
#endif
paddle/phi/kernels/xpu/grid_sample_kernel.cc
0 → 100644
浏览文件 @
2c89bccb
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/grid_sample_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/common/layout.h"
#include "paddle/phi/core/kernel_registry.h"
namespace
phi
{
template
<
typename
T
,
typename
Context
>
void
GridSampleKernel
(
const
Context
&
dev_ctx
,
const
DenseTensor
&
x
,
const
DenseTensor
&
grid
,
const
std
::
string
&
mode
,
const
std
::
string
&
padding_mode
,
bool
align_corners
,
DenseTensor
*
out
)
{
int
n
=
x
.
dims
()[
0
];
int
c
=
x
.
dims
()[
1
];
int
h
=
x
.
dims
()[
2
];
int
w
=
x
.
dims
()[
3
];
int
out_h
=
grid
.
dims
()[
1
];
int
out_w
=
grid
.
dims
()[
2
];
// attrs
// paddle.nn.functional.grid_sample(x, grid, mode='bilinear',
// padding_mode='zeros', align_corners=True, name=None)
const
std
::
string
data_format
=
paddle
::
framework
::
DataLayoutToString
(
x
.
layout
());
// attr to real param
bool
is_nearest_bool
;
if
(
mode
==
"bilinear"
)
{
is_nearest_bool
=
false
;
}
else
if
(
mode
==
"nearest"
)
{
is_nearest_bool
=
true
;
}
else
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"should not reach here: mode should be either 'bilinear' or "
"'nearest', bot got %s."
,
mode
));
}
// attention: 0: zeros, 2: reflection, 1: border according to XDNN api.
int
padding_mode_int
;
if
(
padding_mode
==
"zeros"
)
{
padding_mode_int
=
0
;
}
else
if
(
padding_mode
==
"reflection"
)
{
padding_mode_int
=
2
;
}
else
if
(
padding_mode
==
"border"
)
{
padding_mode_int
=
1
;
}
else
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"should not reach here: padding_mode should be either 'zeros' or "
"'reflection' or 'border', bot got %s."
,
padding_mode
));
}
bool
is_nchw_bool
;
if
(
data_format
==
"NCHW"
)
{
is_nchw_bool
=
true
;
}
else
if
(
data_format
==
"NHWC"
)
{
is_nchw_bool
=
false
;
}
else
{
PADDLE_THROW
(
errors
::
InvalidArgument
(
"should not reach here: data_format should be either 'NCHW' or "
"'NHWC', bot got %s."
,
data_format
));
}
// data pointers
const
T
*
input_data
=
x
.
data
<
T
>
();
const
T
*
grid_data
=
grid
.
data
<
T
>
();
out
->
Resize
(
make_ddim
({
n
,
c
,
out_h
,
out_w
}));
T
*
output_data
=
dev_ctx
.
template
Alloc
<
T
>(
out
);
// int grid_sample(Context* ctx, const T* x, const T* grid, T* y, int n, int
// c, int xh, int xw, int yh, int yw, bool is_nearest, bool align_corners,
// int padding_mode, bool is_nchw);
int
r
=
xpu
::
grid_sample
(
dev_ctx
.
x_context
(),
input_data
,
grid_data
,
output_data
,
n
,
c
,
h
,
w
,
out_h
,
out_w
,
is_nearest_bool
,
align_corners
,
padding_mode_int
,
is_nchw_bool
);
PADDLE_ENFORCE_XDNN_SUCCESS
(
r
,
"grid_sampler"
);
}
}
// namespace phi
PD_REGISTER_KERNEL
(
grid_sample
,
XPU
,
ALL_LAYOUT
,
phi
::
GridSampleKernel
,
float
)
{
}
编辑
预览
Markdown
is supported
0%
请重试
或
添加新附件
.
添加附件
取消
You are about to add
0
people
to the discussion. Proceed with caution.
先完成此消息的编辑!
取消
想要评论请
注册
或
登录