未验证 提交 1a7f2de3 编写于 作者: Y ykkk2333 提交者: GitHub

add adaptive pool and softmax with cross entropy supports different axis, * test = kunlun (#44428)

* add xpu pnorm op and fix pool op, *test=kunlun

* add adaptive pool, and softmax with cross entropy supports different axis, *test=kunlun
上级 5414694b
......@@ -102,6 +102,7 @@ class P_NormXPUKernel : public framework::OpKernel<T> {
XPUType* zeros = RAII_GUARD.alloc_l3_or_gm<XPUType>(1);
PADDLE_ENFORCE_XDNN_NOT_NULL(zeros);
r = xpu::constant(dev_ctx.x_context(), zeros, 1, 0.0f);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant");
std::vector<int> zeros_dim(1, 1);
bool* tmp2_x = RAII_GUARD.alloc_l3_or_gm<bool>(m * t * n);
......
......@@ -60,11 +60,7 @@ class PoolXPUKernel : public framework::OpKernel<T> {
2,
platform::errors::InvalidArgument(
"The Pool2d XPU OP only support 2 dimension pooling!"));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1),
true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
int* index_data = nullptr;
bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1));
......@@ -80,6 +76,9 @@ class PoolXPUKernel : public framework::OpKernel<T> {
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const int out_h = out->dims()[2];
const int out_w = out->dims()[3];
framework::DDim data_dims;
data_dims = phi::slice_ddim(in_x->dims(), 2, in_x->dims().size());
......@@ -90,9 +89,13 @@ class PoolXPUKernel : public framework::OpKernel<T> {
data_dims,
strides,
ksize);
if (ceil_mode) {
paddings[1] += (strides[0] - 1);
paddings[3] += (strides[1] - 1);
int in_h_ceil = (out_h - 1) * strides[0] + ksize[0] - 2 * paddings[0];
int in_w_ceil = (out_w - 1) * strides[1] + ksize[1] - 2 * paddings[2];
paddings[1] += (in_h_ceil - in_h);
paddings[3] += (in_w_ceil - in_w);
}
auto input = reinterpret_cast<const XPUType*>(in_x->data<T>());
......@@ -100,6 +103,7 @@ class PoolXPUKernel : public framework::OpKernel<T> {
auto output = reinterpret_cast<XPUType*>(out->data<T>());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::Error_t::SUCCESS;
if (!adaptive) {
if (pooling_type == "max") {
r = xpu::max_pool2d<XPUType>(dev_ctx.x_context(),
input,
......@@ -130,6 +134,35 @@ class PoolXPUKernel : public framework::OpKernel<T> {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
} else {
if (pooling_type == "max") {
r = xpu::adaptive_max_pool2d<XPUType>(dev_ctx.x_context(),
input,
output,
index_data,
n,
c,
in_h,
in_w,
out_h,
out_w,
true);
} else if (pooling_type == "avg") {
r = xpu::adaptive_avg_pool2d<XPUType>(dev_ctx.x_context(),
input,
output,
n,
c,
in_h,
in_w,
out_h,
out_w,
true);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Unsupported pooling type for kunlun ", pooling_type));
}
}
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
platform::errors::External(
......@@ -167,11 +200,6 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
"dimension pooling!, but received "
"%d-dimension pool kernel size",
ksize.size()));
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1),
true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
bool global_pooling = context.Attr<bool>("global_pooling") ||
(adaptive && (ksize[0] * ksize[1] == 1));
if (global_pooling) {
......@@ -188,6 +216,16 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
const int in_h = in_x->dims()[2];
const int in_w = in_x->dims()[3];
const int out_h = out->dims()[2];
const int out_w = out->dims()[3];
PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1) ||
(in_h % out_h == 0 && in_w % out_w == 0),
true,
platform::errors::InvalidArgument(
"The Pool2d XPU OP does not support (adaptive == "
"true && output_size != 1)"));
framework::DDim data_dims;
data_dims = phi::slice_ddim(in_x->dims(), 2, in_x->dims().size());
......@@ -199,8 +237,11 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
strides,
ksize);
if (ceil_mode) {
paddings[1] += (strides[0] - 1);
paddings[3] += (strides[1] - 1);
int in_h_ceil = (out_h - 1) * strides[0] + ksize[0] - 2 * paddings[0];
int in_w_ceil = (out_w - 1) * strides[1] + ksize[1] - 2 * paddings[2];
paddings[1] += (in_h_ceil - in_h);
paddings[3] += (in_w_ceil - in_w);
}
auto input = reinterpret_cast<const XPUType*>(in_x->data<T>());
......@@ -210,6 +251,14 @@ class PoolGradXPUKernel : public framework::OpKernel<T> {
auto input_grad = reinterpret_cast<XPUType*>(in_x_grad->data<T>());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::Error_t::SUCCESS;
if (adaptive && in_h % out_h == 0 && in_w % out_w == 0) {
strides = {in_h / out_h, in_w / out_w};
int kh = in_h - (out_h - 1) * strides[0];
int kw = in_w - (out_w - 1) * strides[1];
ksize = {kh, kw};
paddings = {0, 0, 0, 0};
}
if (pooling_type == "max") {
r = xpu::max_pool2d_grad<XPUType>(dev_ctx.x_context(),
input,
......
......@@ -45,10 +45,6 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
Tensor* loss = context.Output<Tensor>("Loss");
const int rank = logits->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
PADDLE_ENFORCE_EQ(
axis,
rank - 1,
platform::errors::InvalidArgument("axis should == rank - 1"));
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
const int n = phi::funcs::SizeToAxis(axis, logits->dims());
......@@ -56,17 +52,20 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
std::vector<int> logits_dims = phi::vectorize<int>(logits->dims());
const bool soft_label = context.Attr<bool>("soft_label");
int t = logits_dims[axis];
auto logits_data = reinterpret_cast<const XPUType*>(logits->data<T>());
auto softmax_data = reinterpret_cast<XPUType*>(softmax->data<T>());
auto loss_data = reinterpret_cast<XPUType*>(loss->data<T>());
// softmax
auto& dev_ctx =
context.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) ==
phi::backends::xpu::XPUVersion::XPU2 &&
soft_label) {
soft_label && axis == rank - 1) {
auto labels_data = reinterpret_cast<const XPUType*>(labels->data<T>());
r = xpu::soft_softmax_with_cross_entropy<XPUType>(dev_ctx.x_context(),
logits_data,
......@@ -79,7 +78,6 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
return;
}
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int len = logits->numel();
T* clip_logits = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits);
......@@ -105,10 +103,38 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
// cross_entropy
if (axis != rank - 1) {
XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm<XPUType>(n * d);
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_softmax);
r = xpu::transpose(dev_ctx.x_context(),
softmax_data,
trans_softmax,
{n, t, d / t},
{0, 2, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
softmax_data = trans_softmax;
}
if (soft_label) {
auto labels_data = reinterpret_cast<const XPUType*>(labels->data<T>());
r = xpu::soft_cross_entropy<XPUType>(
dev_ctx.x_context(), softmax_data, labels_data, loss_data, n, d);
if (axis != rank - 1) {
XPUType* trans_label = RAII_GUARD.alloc_l3_or_gm<XPUType>(n * d);
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_label);
r = xpu::transpose(dev_ctx.x_context(),
labels_data,
trans_label,
{n, t, d / t},
{0, 2, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
labels_data = trans_label;
}
r = xpu::soft_cross_entropy<XPUType>(dev_ctx.x_context(),
softmax_data,
labels_data,
loss_data,
axis == rank - 1 ? n : n * d / t,
axis == rank - 1 ? d : t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy");
} else {
auto ignore_index = context.Attr<int>("ignore_index");
......@@ -127,8 +153,8 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
labels_int32.data<int32_t>(),
loss_data,
nullptr,
n,
d,
axis == rank - 1 ? n : n * d / t,
axis == rank - 1 ? d : t,
ignore_index);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy");
}
......@@ -157,10 +183,6 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
const int rank = logit_grad->dims().size();
const int axis = phi::funcs::CanonicalAxis(context.Attr<int>("axis"), rank);
PADDLE_ENFORCE_EQ(
axis,
rank - 1,
platform::errors::InvalidArgument("axis should == rank - 1"));
const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims());
const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims());
......@@ -168,6 +190,7 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
context.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
if (axis == rank - 1) {
if (soft_label) {
r = xpu::soft_softmax_with_cross_entropy_grad<XPUType>(
dev_ctx.x_context(),
......@@ -203,6 +226,72 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad");
}
} else {
int t = logit_grad->dims()[axis];
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int len = softmax->numel();
XPUType* trans_logit = RAII_GUARD.alloc_l3_or_gm<XPUType>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_logit);
XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm<XPUType>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_softmax);
r = xpu::transpose(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(softmax->data<T>()),
trans_softmax,
{n, t, d / t},
{0, 2, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
if (soft_label) {
XPUType* trans_labels = RAII_GUARD.alloc_l3_or_gm<XPUType>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(trans_labels);
r = xpu::transpose(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(labels->data<T>()),
trans_labels,
{n, t, d / t},
{0, 2, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
r = xpu::soft_softmax_with_cross_entropy_grad<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()),
trans_labels,
trans_softmax,
trans_logit,
use_softmax,
n * d / t,
t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad");
} else {
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast_v2<int64_t, int32_t>(dev_ctx.x_context(),
labels->data<int64_t>(),
labels_int_ptr_l3,
labels->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()),
labels_int_ptr_l3,
trans_softmax,
trans_logit,
ignore_index,
use_softmax,
n * d / t,
t);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad");
}
r = xpu::transpose<XPUType>(
dev_ctx.x_context(),
trans_logit,
reinterpret_cast<XPUType*>(logit_grad->data<T>()),
{n, d / t, t},
{0, 2, 1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
}
};
......
......@@ -541,6 +541,50 @@ class XPUTestPool2D_Op(XPUOpTestWrapper):
def init_ceil_mode(self):
self.ceil_mode = True
class TestCaseAdaptiveAvg(TestPool2D_Op):
def init_test_case(self):
self.ksize = [2, 2]
self.strides = [2, 2]
def init_paddings(self):
self.paddings = [0, 0]
def init_pool_type(self):
self.pool_type = "avg"
self.pool2D_forward_naive = avg_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
def init_shape(self):
self.shape = [2, 4, 8, 8]
def init_adaptive_mode(self):
self.adaptive = True
class TestCaseAdaptiveMax(TestPool2D_Op):
def init_test_case(self):
self.ksize = [2, 2]
self.strides = [2, 2]
def init_paddings(self):
self.paddings = [0, 0]
def init_pool_type(self):
self.pool_type = "max"
self.pool2D_forward_naive = max_pool2D_forward_naive
def init_global_pool(self):
self.global_pool = False
def init_shape(self):
self.shape = [2, 4, 8, 8]
def init_adaptive_mode(self):
self.adaptive = True
support_types = get_xpu_op_support_types('pool2d')
for stype in support_types:
......
......@@ -62,15 +62,20 @@ class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper):
for numeric_stable_mode in [True, False]:
for shape in shapes:
for logits_type in [0, 1, 2]:
for axis in range(len(shape)):
if (not numeric_stable_mode):
axis = -1
class_name = 'XPUTestSoftmaxWithCrossEntropy_' + \
str(soft_label) + "_" + \
str(numeric_stable_mode) + "_" + \
str(shape) + "_" + \
str(logits_type)
str(logits_type) + "_" + \
str(axis)
attr_dict = {'soft_label': soft_label, \
'numeric_stable_mode': numeric_stable_mode, \
'shape': shape, \
'logits_type': logits_type}
'logits_type': logits_type,
'axis': axis}
classes.append([class_name, attr_dict])
return base_class, classes
......@@ -83,7 +88,6 @@ class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper):
self.op_type = "softmax_with_cross_entropy"
self.use_xpu = True
self.dtype = np.float32
self.axis = -1
self.ignore_index = -1
if not hasattr(self, 'shape'):
......@@ -91,6 +95,7 @@ class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper):
self.numeric_stable_mode = True
self.logits_type = 0
self.soft_label = True
self.axis = -1
logits = getattr(
self, "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册