未验证 提交 1281b612 编写于 作者: Y ykkk2333 提交者: GitHub

add xpu pool3d kernels (#50233)

* add xpu adagrad and where_grad kernels, test=kunlun

* add xpu pool3d kernels, test=kunlun
上级 9c24a4ac
......@@ -628,10 +628,11 @@ class ReduceOp : public framework::OperatorWithKernel {
platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_mlu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()) ||
platform::is_custom_place(ctx.GetPlace()),
true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU or NPU or MLU place"));
"float16 can only be used on GPU or NPU or MLU or XPU place"));
}
return phi::KernelKey(input_data_type, ctx.GetPlace());
}
......
......@@ -436,6 +436,10 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pool2d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pool3d_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pool3d",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"pow", XPUKernelSet({phi::DataType::FLOAT32})},
{"pow_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"pow2_decay_with_linear_warmup", XPUKernelSet({phi::DataType::FLOAT32})},
......@@ -456,9 +460,10 @@ XPUOpMap& get_kl2_ops() {
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_sum",
XPUKernelSet({phi::DataType::FLOAT32,
XPUKernelSet({phi::DataType::FLOAT16,
phi::DataType::INT64,
phi::DataType::INT8,
phi::DataType::INT64})},
phi::DataType::FLOAT32})},
{"relu6", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu6_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"relu_grad",
......
......@@ -83,8 +83,14 @@ PD_REGISTER_KERNEL(
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(
sum, XPU, ALL_LAYOUT, phi::SumKernel, float, int8_t, int64_t) {
PD_REGISTER_KERNEL(sum,
XPU,
ALL_LAYOUT,
phi::SumKernel,
float,
phi::dtype::float16,
int8_t,
int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
#endif
......@@ -189,8 +189,10 @@ struct XPULogGradFunctor : public funcs::BaseActivationFunctor<T> {
if (dOut != nullptr) dout_data = dOut->data<T>();
T* dx_data = dev_ctx.template Alloc<T>(dX);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
float* tmp = RAII_GUARD.alloc_l3_or_gm<T>(x->numel());
int r = xpu::constant<T>(
dev_ctx.x_context(), dx_data, x->numel(), static_cast<T>(1.0));
dev_ctx.x_context(), tmp, x->numel(), static_cast<T>(1.0));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
auto x_dims = vectorize<int>(x->dims());
......@@ -199,18 +201,16 @@ struct XPULogGradFunctor : public funcs::BaseActivationFunctor<T> {
if (x_dims.size() == 0) {
x_dims = std::vector<int>({1});
}
// dx.device(d) = dout * (static_cast<T>(1) / x);
r = xpu::broadcast_div(dev_ctx.x_context(),
reinterpret_cast<const float*>(dx_data),
reinterpret_cast<const float*>(tmp),
reinterpret_cast<const float*>(x_data),
reinterpret_cast<float*>(dx_data),
reinterpret_cast<float*>(tmp),
x_dims,
x_dims);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_div");
r = xpu::broadcast_mul(dev_ctx.x_context(),
reinterpret_cast<const float*>(dx_data),
reinterpret_cast<const float*>(tmp),
reinterpret_cast<const float*>(dout_data),
reinterpret_cast<float*>(dx_data),
x_dims,
......
......@@ -180,7 +180,6 @@ void AdamDenseKernel(const Context& dev_ctx,
epsilon_,
param.numel());
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
funcs::FreeData<float>(grad, grad_c);
......@@ -213,7 +212,6 @@ void AdamDenseKernel(const Context& dev_ctx,
false,
beta1_,
0.0f);
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
......@@ -231,7 +229,6 @@ void AdamDenseKernel(const Context& dev_ctx,
false,
beta2_,
0.0f);
xpu_wait(dev_ctx.x_context()->xpu_stream);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adam");
}
}
......
......@@ -95,51 +95,236 @@ void Pool2dGradKernel(const Context& ctx,
const int* index_data = nullptr;
int r = xpu::Error_t::SUCCESS;
if (adaptive) {
// floor for stride
strides = {in_h / out_h, in_w / out_w};
int kh = in_h - (out_h - 1) * strides[0];
int kw = in_w - (out_w - 1) * strides[1];
kernel_size = {kh, kw};
paddings = {0, 0, 0, 0};
if (pooling_type == "max") {
r = xpu::adaptive_max_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_h,
in_w,
out_h,
out_w,
true);
} else if (pooling_type == "avg") {
r = xpu::adaptive_avg_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_h,
in_w,
out_h,
out_w,
true);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool2d_grad");
} else {
if (pooling_type == "max") {
// TODO(zhanghuan05) to bind max_pool2d_grad_indices xpu api
r = xpu::max_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_h,
in_w,
kernel_size,
strides,
paddings,
true);
} else if (pooling_type == "avg") {
r = xpu::avg_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_h,
in_w,
kernel_size,
strides,
paddings,
!exclusive,
true);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad");
}
}
template <typename T, typename Context>
void Pool3dGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& out,
const DenseTensor& dout,
const std::vector<int>& kernel_size_t,
const std::vector<int>& strides_t,
const std::vector<int>& paddings_t,
bool ceil_mode,
bool exclusive,
const std::string& data_format,
const std::string& pooling_type,
bool global_pooling,
bool adaptive,
const std::string& padding_algorithm,
DenseTensor* dx) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto x_dims = x.dims();
const bool channel_last = data_format == "NDHWC";
std::vector<int> paddings(paddings_t);
std::vector<int> kernel_size(kernel_size_t);
std::vector<int> strides(strides_t);
PADDLE_ENFORCE_EQ(
data_format,
"NCDHW",
phi::errors::InvalidArgument("The Pool3d_grad XPU OP only support"
"data_format is 'NCDHW', but received %s",
data_format));
if (!dx) {
return;
}
int n = x.dims()[0];
int c = x.dims()[1];
int in_d = x.dims()[2];
int in_h = x.dims()[3];
int in_w = x.dims()[4];
if (pooling_type == "max") {
r = xpu::max_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_h,
in_w,
kernel_size,
strides,
paddings,
true);
} else if (pooling_type == "avg") {
r = xpu::avg_pool2d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_h,
in_w,
kernel_size,
strides,
paddings,
!exclusive,
true);
int out_d = out.dims()[2];
int out_h = out.dims()[3];
int out_w = out.dims()[4];
if (channel_last) {
c = x.dims()[4];
in_d = x.dims()[1];
in_h = x.dims()[2];
in_w = x.dims()[3];
out_d = out.dims()[1];
out_h = out.dims()[2];
out_w = out.dims()[3];
}
DDim data_dims;
if (channel_last) {
data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
data_dims = slice_ddim(x_dims, 2, x_dims.size());
}
funcs::UpdatePadding(&paddings,
global_pooling,
adaptive,
padding_algorithm,
data_dims,
strides,
kernel_size);
if (global_pooling) {
funcs::UpdateKernelSize(&kernel_size, data_dims);
}
ctx.template Alloc<T>(dx);
const int* index_data = nullptr;
int r = xpu::Error_t::SUCCESS;
if (adaptive) {
if (pooling_type == "max") {
r = xpu::adaptive_max_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
out_d,
out_h,
out_w,
!channel_last);
} else if (pooling_type == "avg") {
r = xpu::adaptive_avg_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
out_d,
out_h,
out_w,
!channel_last);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "adaptive_pool3d_grad");
} else {
if (pooling_type == "max") {
r = xpu::max_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<const XPUType*>(out.data<T>()),
index_data,
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
!channel_last);
} else if (pooling_type == "avg") {
r = xpu::avg_pool3d_grad<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(dout.data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()),
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
!exclusive,
!channel_last);
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool3dgrad");
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2dgrad");
}
template <typename T, typename Context>
......@@ -210,6 +395,12 @@ PD_REGISTER_KERNEL(pool2d_grad,
phi::Pool2dGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(pool3d_grad,
XPU,
ALL_LAYOUT,
phi::Pool3dGradKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index_grad,
XPU,
ALL_LAYOUT,
......
......@@ -155,6 +155,144 @@ void Pool2dKernel(const Context& ctx,
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool2d");
}
template <typename T, typename Context>
void Pool3dKernel(const Context& ctx,
const DenseTensor& x,
const std::vector<int>& kernel_size_t,
const std::vector<int>& strides,
const std::vector<int>& paddings_t,
bool ceil_mode,
bool exclusive,
const std::string& data_format,
const std::string& pooling_type,
bool global_pooling,
bool adaptive,
const std::string& padding_algorithm,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
const bool channel_last = data_format == "NDHWC";
std::vector<int> paddings(paddings_t);
std::vector<int> kernel_size(kernel_size_t);
auto x_dims = x.dims();
int n = x.dims()[0];
int c = x.dims()[1];
int in_d = x.dims()[2];
int in_h = x.dims()[3];
int in_w = x.dims()[4];
int out_d = out->dims()[2];
int out_h = out->dims()[3];
int out_w = out->dims()[4];
if (data_format == "NDHWC") {
c = x.dims()[4];
in_d = x.dims()[1];
in_h = x.dims()[2];
in_w = x.dims()[3];
out_d = out->dims()[1];
out_h = out->dims()[2];
out_w = out->dims()[3];
}
DDim data_dims;
if (channel_last) {
data_dims = slice_ddim(x_dims, 1, x_dims.size() - 1);
} else {
data_dims = slice_ddim(x_dims, 2, x_dims.size());
}
funcs::UpdatePadding(&paddings,
global_pooling,
adaptive,
padding_algorithm,
data_dims,
strides,
kernel_size);
if (global_pooling) {
funcs::UpdateKernelSize(&kernel_size, data_dims);
}
ctx.template Alloc<T>(out);
int* index_data = nullptr;
int r = xpu::Error_t::SUCCESS;
if (!adaptive) {
if (pooling_type == "max") {
r = xpu::max_pool3d<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
index_data,
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
data_format == "NCDHW");
} else if (pooling_type == "avg") {
r = xpu::avg_pool3d<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
n,
c,
in_d,
in_h,
in_w,
kernel_size,
strides,
paddings,
!exclusive,
data_format == "NCDHW");
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
} else {
if (pooling_type == "max") {
r = xpu::adaptive_max_pool3d<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
index_data,
n,
c,
in_d,
in_h,
in_w,
out_d,
out_h,
out_w,
data_format == "NCDHW");
} else if (pooling_type == "avg") {
r = xpu::adaptive_avg_pool3d<XPUType>(
ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
n,
c,
in_d,
in_h,
in_w,
out_d,
out_h,
out_w,
data_format == "NCDHW");
} else {
PADDLE_THROW(phi::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
}
PADDLE_ENFORCE_XDNN_SUCCESS(r, "pool3d");
}
template <typename T, typename Context>
void MaxPool2dWithIndexKernel(const Context& ctx,
const DenseTensor& x,
......@@ -216,6 +354,8 @@ void MaxPool2dWithIndexKernel(const Context& ctx,
PD_REGISTER_KERNEL(
pool2d, XPU, ALL_LAYOUT, phi::Pool2dKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
pool3d, XPU, ALL_LAYOUT, phi::Pool3dKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(max_pool2d_with_index,
XPU,
......
......@@ -32,11 +32,15 @@ void ProdKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_prod<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_prod<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
int r = XPUReduce<Context, T>(
......
......@@ -33,6 +33,7 @@ int XPUReduce(const Context& dev_ctx,
T*,
const std::vector<int>&,
const std::vector<int>&)> func) {
using XPUType = typename XPUTypeTrait<T>::Type;
reduce_all = recompute_reduce_all(x, dims, reduce_all);
dev_ctx.template Alloc<T>(out);
......@@ -70,8 +71,10 @@ int XPUReduce(const Context& dev_ctx,
int r = xpu::SUCCESS;
if (reduce_dims.size() == 0) {
r = xpu::copy<T>(
dev_ctx.x_context(), x_data, y_data, x.numel() * sizeof(T));
r = xpu::copy<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
x.numel() * sizeof(T));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "copy");
} else {
r = func(dev_ctx.x_context(), x_data, y_data, xdims, reduce_dims);
......
......@@ -31,14 +31,18 @@ void MaxRawKernel(const Context& dev_ctx,
reduce_all = recompute_reduce_all(x, dims, reduce_all);
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_max<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_max<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
int r = XPUReduce<Context, XPUType>(
int r = XPUReduce<Context, T>(
dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out, f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_max");
}
......
......@@ -31,14 +31,18 @@ void MeanRawKernel(const Context& dev_ctx,
reduce_all = recompute_reduce_all(x, dims, reduce_all);
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_mean<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_mean<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
int r = XPUReduce<Context, XPUType>(
int r = XPUReduce<Context, T>(
dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out, f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_mean");
......
......@@ -32,14 +32,18 @@ void MinRawKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_min<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_min<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
int r = XPUReduce<Context, XPUType>(
int r = XPUReduce<Context, T>(
dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out, f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_min");
}
......
......@@ -33,18 +33,28 @@ void SumRawKernel(const Context& dev_ctx,
using XPUType = typename XPUTypeTrait<T>::Type;
auto f = [](xpu::Context* ctx,
const XPUType* x,
XPUType* y,
const T* x,
T* y,
const std::vector<int>& xdims,
const std::vector<int>& reduce_dims) {
return xpu::reduce_sum<XPUType>(ctx, x, y, xdims, reduce_dims);
return xpu::reduce_sum<XPUType>(ctx,
reinterpret_cast<const XPUType*>(x),
reinterpret_cast<XPUType*>(y),
xdims,
reduce_dims);
};
int r = XPUReduce<Context, XPUType>(
int r = XPUReduce<Context, T>(
dev_ctx, x, dims.GetData(), keep_dim, reduce_all, out, f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reduce_sum");
}
} // namespace phi
PD_REGISTER_KERNEL(
sum_raw, XPU, ALL_LAYOUT, phi::SumRawKernel, float, int8_t, int64_t) {}
PD_REGISTER_KERNEL(sum_raw,
XPU,
ALL_LAYOUT,
phi::SumRawKernel,
float,
phi::dtype::float16,
int8_t,
int64_t) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
import unittest
import numpy as np
import paddle
sys.path.append("..")
from op_test_xpu import XPUOpTest
from xpu.get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
paddle.enable_static()
def adaptive_start_index(index, input_size, output_size):
return int(np.floor(index * input_size / output_size))
def adaptive_end_index(index, input_size, output_size):
return int(np.ceil((index + 1) * input_size / output_size))
def pool3D_forward_naive(
x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
data_format='NCDHW',
pool_type='max',
padding_algorithm="EXPLICIT",
):
# update paddings
def _get_padding_with_SAME(input_shape, pool_size, pool_stride):
padding = []
for input_size, filter_size, stride_size in zip(
input_shape, pool_size, pool_stride
):
out_size = int((input_size + stride_size - 1) / stride_size)
pad_sum = np.max(
((out_size - 1) * stride_size + filter_size - input_size, 0)
)
pad_0 = int(pad_sum / 2)
pad_1 = int(pad_sum - pad_0)
padding.append(pad_0)
padding.append(pad_1)
return padding
if isinstance(padding_algorithm, str):
padding_algorithm = padding_algorithm.upper()
if padding_algorithm not in ["SAME", "VALID", "EXPLICIT"]:
raise ValueError(
"Unknown Attr(padding_algorithm): '%s'. "
"It can only be 'SAME' or 'VALID'." % str(padding_algorithm)
)
if padding_algorithm == "VALID":
paddings = [0, 0, 0, 0, 0, 0]
if ceil_mode is not False:
raise ValueError(
"When Attr(pool_padding) is \"VALID\", Attr(ceil_mode)"
" must be False. "
"Received ceil_mode: True."
)
elif padding_algorithm == "SAME":
input_data_shape = []
if data_format == "NCDHW":
input_data_shape = x.shape[2:5]
elif data_format == "NDHWC":
input_data_shape = x.shape[1:4]
paddings = _get_padding_with_SAME(input_data_shape, ksize, strides)
assert len(paddings) == 3 or len(paddings) == 6
is_sys = True if len(paddings) == 3 else False
N = x.shape[0]
C, D, H, W = (
[x.shape[1], x.shape[2], x.shape[3], x.shape[4]]
if data_format == 'NCDHW'
else [x.shape[4], x.shape[1], x.shape[2], x.shape[3]]
)
if global_pool == 1:
ksize = [D, H, W]
paddings = [0 for _ in range(len(paddings))]
pad_d_forth = paddings[0] if is_sys else paddings[0]
pad_d_back = paddings[0] if is_sys else paddings[1]
pad_h_up = paddings[1] if is_sys else paddings[2]
pad_h_down = paddings[1] if is_sys else paddings[3]
pad_w_left = paddings[2] if is_sys else paddings[4]
pad_w_right = paddings[2] if is_sys else paddings[5]
if adaptive:
D_out, H_out, W_out = ksize
else:
D_out = (
(D - ksize[0] + pad_d_forth + pad_d_back + strides[0] - 1)
// strides[0]
+ 1
if ceil_mode
else (D - ksize[0] + pad_d_forth + pad_d_back) // strides[0] + 1
)
H_out = (
(H - ksize[1] + pad_h_up + pad_h_down + strides[1] - 1)
// strides[1]
+ 1
if ceil_mode
else (H - ksize[1] + pad_h_up + pad_h_down) // strides[1] + 1
)
W_out = (
(W - ksize[2] + pad_w_left + pad_w_right + strides[2] - 1)
// strides[2]
+ 1
if ceil_mode
else (W - ksize[2] + pad_w_left + pad_w_right) // strides[2] + 1
)
out = (
np.zeros((N, C, D_out, H_out, W_out))
if data_format == 'NCDHW'
else np.zeros((N, D_out, H_out, W_out, C))
)
for k in range(D_out):
if adaptive:
d_start = adaptive_start_index(k, D, ksize[0])
d_end = adaptive_end_index(k, D, ksize[0])
for i in range(H_out):
if adaptive:
h_start = adaptive_start_index(i, H, ksize[1])
h_end = adaptive_end_index(i, H, ksize[1])
for j in range(W_out):
if adaptive:
w_start = adaptive_start_index(j, W, ksize[2])
w_end = adaptive_end_index(j, W, ksize[2])
else:
d_start = k * strides[0] - pad_d_forth
d_end = np.min(
(
k * strides[0] + ksize[0] - pad_d_forth,
D + pad_d_back,
)
)
h_start = i * strides[1] - pad_h_up
h_end = np.min(
(i * strides[1] + ksize[1] - pad_h_up, H + pad_h_down)
)
w_start = j * strides[2] - pad_w_left
w_end = np.min(
(
j * strides[2] + ksize[2] - pad_w_left,
W + pad_w_right,
)
)
field_size = (
(d_end - d_start)
* (h_end - h_start)
* (w_end - w_start)
)
w_start = np.max((w_start, 0))
d_start = np.max((d_start, 0))
h_start = np.max((h_start, 0))
w_end = np.min((w_end, W))
d_end = np.min((d_end, D))
h_end = np.min((h_end, H))
if data_format == 'NCDHW':
x_masked = x[
:, :, d_start:d_end, h_start:h_end, w_start:w_end
]
if pool_type == 'avg':
if exclusive or adaptive:
field_size = (
(d_end - d_start)
* (h_end - h_start)
* (w_end - w_start)
)
out[:, :, k, i, j] = (
np.sum(x_masked, axis=(2, 3, 4)) / field_size
)
elif pool_type == 'max':
out[:, :, k, i, j] = np.max(x_masked, axis=(2, 3, 4))
elif data_format == 'NDHWC':
x_masked = x[
:, d_start:d_end, h_start:h_end, w_start:w_end, :
]
if pool_type == 'avg':
if exclusive or adaptive:
field_size = (
(d_end - d_start)
* (h_end - h_start)
* (w_end - w_start)
)
out[:, k, i, j, :] = (
np.sum(x_masked, axis=(1, 2, 3)) / field_size
)
elif pool_type == 'max':
out[:, k, i, j, :] = np.max(x_masked, axis=(1, 2, 3))
return out
def max_pool3D_forward_naive(
x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
):
out = pool3D_forward_naive(
x=x,
ksize=ksize,
strides=strides,
paddings=paddings,
global_pool=global_pool,
ceil_mode=ceil_mode,
exclusive=exclusive,
adaptive=adaptive,
data_format='NCDHW',
pool_type="max",
)
return out
def avg_pool3D_forward_naive(
x,
ksize,
strides,
paddings,
global_pool=0,
ceil_mode=False,
exclusive=True,
adaptive=False,
):
out = pool3D_forward_naive(
x=x,
ksize=ksize,
strides=strides,
paddings=paddings,
global_pool=global_pool,
ceil_mode=ceil_mode,
exclusive=exclusive,
adaptive=adaptive,
data_format='NCDHW',
pool_type="avg",
)
return out
class XPUTestPool3DOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'pool3d'
self.use_dynamic_create_class = False
class TestPool3D_Op(XPUOpTest):
def setUp(self):
self.op_type = "pool3d"
self.init_kernel_type()
self.dtype = self.in_type
self.init_test_case()
self.padding_algorithm = "EXPLICIT"
self.init_paddings()
self.init_global_pool()
self.init_kernel_type()
self.init_pool_type()
self.init_ceil_mode()
self.init_exclusive()
self.init_adaptive()
self.init_data_format()
self.init_shape()
paddle.enable_static()
input = np.random.random(self.shape).astype(self.dtype)
output = pool3D_forward_naive(
input,
self.ksize,
self.strides,
self.paddings,
self.global_pool,
self.ceil_mode,
self.exclusive,
self.adaptive,
self.data_format,
self.pool_type,
self.padding_algorithm,
).astype(self.dtype)
self.inputs = {'X': XPUOpTest.np_dtype_to_fluid_dtype(input)}
self.attrs = {
'strides': self.strides,
'paddings': self.paddings,
'ksize': self.ksize,
'pooling_type': self.pool_type,
'global_pooling': self.global_pool,
'ceil_mode': self.ceil_mode,
'data_format': self.data_format,
'exclusive': self.exclusive,
'adaptive': self.adaptive,
"padding_algorithm": self.padding_algorithm,
}
self.outputs = {'Out': output}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place)
def test_check_grad(self):
if self.dtype == np.float16:
return
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, set(['X']), 'Out')
def init_data_format(self):
self.data_format = "NCDHW"
def init_shape(self):
self.shape = [1, 3, 5, 6, 5]
def init_test_case(self):
self.ksize = [2, 3, 1]
self.strides = [2, 2, 3]
def init_paddings(self):
self.paddings = [0, 0, 0]
self.padding_algorithm = "EXPLICIT"
def init_kernel_type(self):
self.use_cudnn = False
def init_pool_type(self):
self.pool_type = "avg"
def init_global_pool(self):
self.global_pool = True
def init_ceil_mode(self):
self.ceil_mode = False
def init_exclusive(self):
self.exclusive = True
def init_adaptive(self):
self.adaptive = False
class TestCase1(TestPool3D_Op):
def init_shape(self):
self.shape = [1, 3, 7, 7, 7]
def init_test_case(self):
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
def init_paddings(self):
self.paddings = [0, 0, 0]
def init_pool_type(self):
self.pool_type = "avg"
def init_global_pool(self):
self.global_pool = False
class TestCase2(TestPool3D_Op):
def init_shape(self):
self.shape = [1, 3, 6, 7, 7]
def init_test_case(self):
self.ksize = [3, 3, 4]
self.strides = [1, 3, 2]
def init_paddings(self):
self.paddings = [1, 1, 1]
def init_pool_type(self):
self.pool_type = "avg"
def init_global_pool(self):
self.global_pool = False
class TestCase3(TestPool3D_Op):
def init_pool_type(self):
self.pool_type = "max"
class TestCase4(TestCase1):
def init_pool_type(self):
self.pool_type = "max"
class TestCase5(TestCase2):
def init_pool_type(self):
self.pool_type = "max"
class TestAvgInclude(TestCase2):
def init_exclusive(self):
self.exclusive = False
class TestAvgPoolAdaptive(TestCase1):
def init_adaptive(self):
self.adaptive = True
class TestAvgPoolAdaptiveAsyOutSize(TestCase1):
def init_adaptive(self):
self.adaptive = True
def init_shape(self):
self.shape = [1, 3, 3, 4, 4]
def init_test_case(self):
self.ksize = [2, 2, 3]
self.strides = [1, 1, 1]
# -------test pool3d with asymmetric padding------
class TestPool3D_Op_AsyPadding(TestPool3D_Op):
def init_test_case(self):
self.ksize = [3, 4, 3]
self.strides = [1, 1, 2]
def init_paddings(self):
self.paddings = [0, 0, 0, 2, 3, 0]
def init_shape(self):
self.shape = [1, 3, 5, 5, 6]
class TestCase1_AsyPadding(TestCase1):
def init_test_case(self):
self.ksize = [3, 3, 4]
self.strides = [1, 1, 2]
def init_paddings(self):
self.paddings = [1, 0, 2, 1, 2, 1]
def init_shape(self):
self.shape = [1, 3, 7, 7, 6]
class TestCase2_AsyPadding(TestCase2):
def init_test_case(self):
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
def init_paddings(self):
self.paddings = [1, 2, 1, 1, 1, 0]
def init_shape(self):
self.shape = [1, 3, 7, 7, 7]
class TestCase3_AsyPadding(TestCase3):
def init_test_case(self):
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
def init_paddings(self):
self.paddings = [1, 0, 0, 0, 1, 0]
def init_shape(self):
self.shape = [1, 3, 5, 5, 5]
class TestCase4_AsyPadding(TestCase4):
def init_test_case(self):
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
def init_paddings(self):
self.paddings = [1, 0, 2, 1, 2, 1]
def init_shape(self):
self.shape = [1, 3, 7, 7, 7]
class TestCase5_AsyPadding(TestCase5):
def init_test_case(self):
self.ksize = [3, 3, 3]
self.strides = [1, 1, 1]
def init_paddings(self):
self.paddings = [1, 2, 1, 1, 1, 0]
def init_shape(self):
self.shape = [1, 3, 7, 7, 7]
class TestAvgInclude_AsyPadding(TestCase2):
def init_exclusive(self):
self.exclusive = False
def init_paddings(self):
self.paddings = [2, 2, 1, 1, 0, 0]
class TestAvgPoolAdaptive_AsyPadding(TestCase1):
def init_adaptive(self):
self.adaptive = True
def init_paddings(self):
self.paddings = [1, 0, 2, 1, 2, 1]
class TestCase5_Max(TestCase2):
def init_pool_type(self):
self.pool_type = "max"
def test_check_grad(self):
if self.dtype == np.float16:
return
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, set(['X']), 'Out')
support_types = get_xpu_op_support_types('pool3d')
for stype in ["float32"]:
create_test_class(globals(), XPUTestPool3DOp, stype)
if __name__ == '__main__':
unittest.main()
......@@ -47,6 +47,7 @@ class XPUTestReduceSumOp(XPUOpTestWrapper):
'use_xpu': True,
'reduce_all': self.reduce_all,
'keep_dim': self.keep_dim,
'dim': self.axis,
}
self.inputs = {'X': np.random.random(self.shape).astype(self.dtype)}
if self.attrs['reduce_all']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册