未验证 提交 e5021ee9 编写于 作者: W wangshengxiang 提交者: GitHub

[XPU] bind 3D grid sample, fix edge cases in slice & reshape (#53981)

* bind xpu op: 3D grid sample

* fix edge cases in xpu op: reshape & slice
上级 83a12b11
......@@ -8,7 +8,7 @@ set(XPU_API_LIB_NAME "libxpuapi.so")
set(XPU_RT_LIB_NAME "libxpurt.so")
set(XPU_XFT_LIB_NAME "libxft.so")
set(XPU_BASE_DATE "20230510")
set(XPU_BASE_DATE "20230519")
set(XPU_XCCL_BASE_VERSION "1.0.49.2")
set(XPU_XFT_BASE_VERSION "latest")
......
......@@ -57,6 +57,9 @@ void ReshapeInferKernel<phi::XPUContext>(const XPUContext& dev_ctx,
DenseTensor* out) {
MetaTensor meta_out(out);
InferMetaFromVecValue(x, shape.GetData(), &meta_out);
if (x.numel() == 0) {
return;
}
if (x.initialized() && x.Holder() == out->Holder()) {
dev_ctx.Alloc(out, x.dtype());
return;
......
......@@ -28,13 +28,6 @@ void GridSampleKernel(const Context& dev_ctx,
const std::string& padding_mode,
bool align_corners,
DenseTensor* out) {
int n = x.dims()[0];
int c = x.dims()[1];
int h = x.dims()[2];
int w = x.dims()[3];
int out_h = grid.dims()[1];
int out_w = grid.dims()[2];
// attrs
// paddle.nn.functional.grid_sample(x, grid, mode='bilinear',
// padding_mode='zeros', align_corners=True, name=None)
......@@ -68,42 +61,77 @@ void GridSampleKernel(const Context& dev_ctx,
padding_mode));
}
bool is_nchw_bool;
if (data_format == "NCHW") {
is_nchw_bool = true;
} else if (data_format == "NHWC") {
is_nchw_bool = false;
} else {
PADDLE_THROW(errors::InvalidArgument(
"should not reach here: data_format should be either 'NCHW' or "
"'NHWC', bot got %s.",
data_format));
}
// data pointers
const T* input_data = x.data<T>();
const T* grid_data = grid.data<T>();
out->Resize(make_ddim({n, c, out_h, out_w}));
T* output_data = dev_ctx.template Alloc<T>(out);
// int grid_sample(Context* ctx, const T* x, const T* grid, T* y, int n, int
// c, int xh, int xw, int yh, int yw, bool is_nearest, bool align_corners,
// int padding_mode, bool is_nchw);
int r = xpu::grid_sample(dev_ctx.x_context(),
input_data,
grid_data,
output_data,
n,
c,
h,
w,
out_h,
out_w,
is_nearest_bool,
align_corners,
padding_mode_int,
is_nchw_bool);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "grid_sampler");
int n = x.dims()[0];
int c = x.dims()[1];
if (x.dims().size() == 4) { // 2D grid sample
int h = x.dims()[2];
int w = x.dims()[3];
int out_h = grid.dims()[1];
int out_w = grid.dims()[2];
bool is_nchw_bool;
if (data_format == "NCHW") {
is_nchw_bool = true;
} else if (data_format == "NHWC") {
is_nchw_bool = false;
} else {
PADDLE_THROW(errors::InvalidArgument(
"should not reach here: data_format should be either 'NCHW' or "
"'NHWC', bot got %s.",
data_format));
}
out->Resize(make_ddim({n, c, out_h, out_w}));
T* output_data = dev_ctx.template Alloc<T>(out);
int r = xpu::grid_sample(dev_ctx.x_context(),
input_data,
grid_data,
output_data,
n,
c,
h,
w,
out_h,
out_w,
is_nearest_bool,
align_corners,
padding_mode_int,
is_nchw_bool);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "grid_sampler");
} else { // 3D grid sample
int d = x.dims()[2];
int h = x.dims()[3];
int w = x.dims()[4];
int out_d = grid.dims()[1];
int out_h = grid.dims()[2];
int out_w = grid.dims()[3];
out->Resize(make_ddim({n, c, out_d, out_h, out_w}));
T* output_data = dev_ctx.template Alloc<T>(out);
int r = xpu::grid_sample3d(dev_ctx.x_context(),
input_data,
grid_data,
output_data,
n,
c,
d,
h,
w,
out_d,
out_h,
out_w,
is_nearest_bool,
align_corners,
padding_mode_int,
true);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "grid_sampler3d");
}
}
} // namespace phi
......
......@@ -97,6 +97,12 @@ void SliceKernel(const Context& ctx,
}
ctx.template Alloc<T>(out);
for (size_t i = 0; i < shape_size; ++i) {
if (starts_extension[i] == ends_extension[i] || shape[i] == 0) {
return;
}
}
int r = xpu::slice<XPUType>(ctx.x_context(),
reinterpret_cast<const XPUType*>(input.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
......
......@@ -77,6 +77,71 @@ def getGridPointValue(data, x, y):
return out
def AffineGrid3D(theta, grid_shape):
n = grid_shape[0]
d = grid_shape[1]
h = grid_shape[2]
w = grid_shape[3]
d_idx = np.repeat(
np.repeat(np.linspace(-1, 1, d)[:, np.newaxis, np.newaxis], h, axis=1),
w,
axis=2,
)[:, :, :, np.newaxis]
h_idx = np.repeat(
np.repeat(np.linspace(-1, 1, h)[np.newaxis, :, np.newaxis], w, axis=2),
d,
axis=0,
)[:, :, :, np.newaxis]
w_idx = np.repeat(
np.repeat(np.linspace(-1, 1, w)[np.newaxis, np.newaxis, :], h, axis=1),
d,
axis=0,
)[:, :, :, np.newaxis]
grid = np.concatenate(
[w_idx, h_idx, d_idx, np.ones([d, h, w, 1])], axis=3
) # d * h * w * 4
grid = np.repeat(grid[np.newaxis, :], n, axis=0) # n * d * h * w *4
ret = np.zeros([n, d * h * w, 3])
theta = theta.transpose([0, 2, 1])
for i in range(len(theta)):
ret[i] = np.dot(grid[i].reshape([d * h * w, 4]), theta[i])
return ret.reshape([n, d, h, w, 3]).astype("float64")
def getGridPointValue3D(data, x, y, z):
data_shape = data.shape
N = data_shape[0]
C = data_shape[1]
in_D = data_shape[2]
in_H = data_shape[3]
in_W = data_shape[4]
out_D = x.shape[1]
out_H = x.shape[2]
out_W = x.shape[3]
out = np.zeros([N, C, out_D, out_H, out_W], dtype='float64')
for i in range(N):
for j in range(out_D):
for k in range(out_H):
for l in range(out_W):
if (
y[i, j, k, l] < 0
or y[i, j, k, l] > in_H - 1
or x[i, j, k, l] < 0
or x[i, j, k, l] > in_W - 1
or z[i, j, k, l] < 0
or z[i, j, k, l] > in_D - 1
):
out[i, :, j, k, l] = 0
else:
out[i, :, j, k, l] = data[
i, :, z[i, j, k, l], y[i, j, k, l], x[i, j, k, l]
]
return out
def clip(x, min_n, max_n):
return np.maximum(np.minimum(x, max_n), min_n)
......@@ -156,6 +221,117 @@ def GridSampler(
return out
def GridSampler3D(
data, grid, align_corners=True, mode="bilinear", padding_mode="zeros"
):
dims = data.shape
N = dims[0]
in_C = dims[1]
in_D = dims[2]
in_H = dims[3]
in_W = dims[4]
out_D = grid.shape[1]
out_H = grid.shape[2]
out_W = grid.shape[3]
x = grid[:, :, :, :, 0]
y = grid[:, :, :, :, 1]
z = grid[:, :, :, :, 2]
z_max = in_D - 1
y_max = in_H - 1
x_max = in_W - 1
x = unnormalizeAndClip(x, x_max, align_corners, padding_mode)
y = unnormalizeAndClip(y, y_max, align_corners, padding_mode)
z = unnormalizeAndClip(z, z_max, align_corners, padding_mode)
if mode == "bilinear":
x0 = np.floor(x).astype('int32')
x1 = x0 + 1
y0 = np.floor(y).astype('int32')
y1 = y0 + 1
z0 = np.floor(z).astype('int32')
z1 = z0 + 1
w_tnw = np.tile(
((x1 - x) * (y1 - y) * (z1 - z)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_tne = np.tile(
((x - x0) * (y1 - y) * (z1 - z)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_tsw = np.tile(
((x1 - x) * (y - y0) * (z1 - z)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_tse = np.tile(
((x - x0) * (y - y0) * (z1 - z)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_bnw = np.tile(
((x1 - x) * (y1 - y) * (z - z0)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_bne = np.tile(
((x - x0) * (y1 - y) * (z - z0)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_bsw = np.tile(
((x1 - x) * (y - y0) * (z - z0)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
w_bse = np.tile(
((x - x0) * (y - y0) * (z - z0)).reshape(
(N, 1, out_D, out_H, out_W)
),
(1, in_C, 1, 1, 1),
)
v_tnw = getGridPointValue3D(data, x0, y0, z0)
v_tne = getGridPointValue3D(data, x1, y0, z0)
v_tsw = getGridPointValue3D(data, x0, y1, z0)
v_tse = getGridPointValue3D(data, x1, y1, z0)
v_bnw = getGridPointValue3D(data, x0, y0, z1)
v_bne = getGridPointValue3D(data, x1, y0, z1)
v_bsw = getGridPointValue3D(data, x0, y1, z1)
v_bse = getGridPointValue3D(data, x1, y1, z1)
out = (
w_tnw * v_tnw
+ w_tne * v_tne
+ w_tsw * v_tsw
+ w_tse * v_tse
+ w_bnw * v_bnw
+ w_bne * v_bne
+ w_bsw * v_bsw
+ w_bse * v_bse
).astype('float64')
elif mode == "nearest":
x = np.round(x).astype('int32')
y = np.round(y).astype('int32')
z = np.round(z).astype('int32')
out = getGridPointValue3D(data, x, y, z)
return out
class XPUTestGridSamplerOP(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'grid_sampler'
......@@ -178,24 +354,52 @@ class XPUTestGridSamplerOP(XPUOpTestWrapper):
x = np.random.uniform(-10, 10, self.x_shape).astype(self.dtype)
theta = np.zeros(self.theta_shape).astype(self.dtype)
for i in range(self.theta_shape[0]):
for j in range(2):
for k in range(3):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid(theta, self.grid_shape).astype(self.dtype)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {
'use_cudnn': self.use_cudnn,
"align_corners": self.align_corners,
"padding_mode": self.padding_mode,
"mode": self.mode,
}
self.outputs = {
'Output': GridSampler(
x, grid, self.align_corners, self.mode, self.padding_mode
)
}
if len(self.grid_shape) == 4:
for i in range(self.theta_shape[0]):
for j in range(2):
for k in range(3):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid(theta, self.grid_shape).astype(self.dtype)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {
'use_cudnn': self.use_cudnn,
"align_corners": self.align_corners,
"padding_mode": self.padding_mode,
"mode": self.mode,
}
self.outputs = {
'Output': GridSampler(
x,
grid,
self.align_corners,
self.mode,
self.padding_mode,
)
}
else:
for i in range(self.theta_shape[0]):
for j in range(3):
for k in range(4):
theta[i, j, k] = np.random.rand(1)[0]
grid = AffineGrid3D(theta, self.grid_shape)
self.inputs = {'X': x, 'Grid': grid}
self.attrs = {
'use_cudnn': self.use_cudnn,
"align_corners": self.align_corners,
"padding_mode": self.padding_mode,
"mode": self.mode,
}
self.outputs = {
'Output': GridSampler3D(
x,
grid,
self.align_corners,
self.mode,
self.padding_mode,
)
}
def initTestCase(self):
self.x_shape = (2, 3, 8, 8)
......@@ -212,6 +416,8 @@ class XPUTestGridSamplerOP(XPUOpTestWrapper):
self.check_output_with_place(self.place)
def test_check_grad(self):
if hasattr(self, "no_need_check_grad") and self.no_need_check_grad:
return
self.check_grad_with_place(self.place, ['X', 'Grid'], 'Output')
class TestGridSample1(TestXPUGridSamplerOp):
......@@ -277,6 +483,62 @@ class XPUTestGridSamplerOP(XPUOpTestWrapper):
self.padding_mode = "zeros"
self.mode = "bilinear"
# 3d grid_sample_grad is not supported yet
class TestGridSample3DBilinear(TestXPUGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7)
self.grid_shape = (2, 8, 9, 10, 3)
self.theta_shape = (2, 3, 4)
self.align_corners = True
self.padding_mode = "zeros"
self.mode = "bilinear"
self.no_need_check_grad = True
class TestGridSample3DNearest(TestXPUGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7)
self.grid_shape = (2, 8, 9, 10, 3)
self.theta_shape = (2, 3, 4)
self.align_corners = True
self.padding_mode = "zeros"
self.mode = "nearest"
self.no_need_check_grad = True
class TestGridSample3DBorder(TestXPUGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7)
self.grid_shape = (2, 8, 9, 10, 3)
self.theta_shape = (2, 3, 4)
self.align_corners = True
self.padding_mode = "border"
self.mode = "nearest"
self.no_need_check_grad = True
class TestGridSample3DReflection(TestXPUGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7)
self.grid_shape = (2, 8, 9, 10, 3)
self.theta_shape = (2, 3, 4)
self.align_corners = True
self.padding_mode = "reflection"
self.mode = "bilinear"
self.no_need_check_grad = True
class TestGridSample3DAlignCornersFalse(TestXPUGridSamplerOp):
def initTestCase(self):
self.x_shape = (2, 3, 5, 6, 7)
self.grid_shape = (2, 8, 9, 10, 3)
self.theta_shape = (2, 3, 4)
self.align_corners = False
self.padding_mode = "reflection"
self.mode = "bilinear"
self.no_need_check_grad = True
support_types = get_xpu_op_support_types('grid_sampler')
for stype in support_types:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册