diff --git a/paddle/fluid/operators/p_norm_op_xpu.cc b/paddle/fluid/operators/p_norm_op_xpu.cc index b37a65e794d08c222e5b9c4d0c5b6b69e6b7d826..0d2bb42790381a5f6e7bd376b47b16cfd1f313db 100644 --- a/paddle/fluid/operators/p_norm_op_xpu.cc +++ b/paddle/fluid/operators/p_norm_op_xpu.cc @@ -102,6 +102,7 @@ class P_NormXPUKernel : public framework::OpKernel { XPUType* zeros = RAII_GUARD.alloc_l3_or_gm(1); PADDLE_ENFORCE_XDNN_NOT_NULL(zeros); r = xpu::constant(dev_ctx.x_context(), zeros, 1, 0.0f); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "constant"); std::vector zeros_dim(1, 1); bool* tmp2_x = RAII_GUARD.alloc_l3_or_gm(m * t * n); diff --git a/paddle/fluid/operators/pool_op_xpu.cc b/paddle/fluid/operators/pool_op_xpu.cc index 7208b195b460062d96a1ea96fdcd02007eed4ee6..591559001309aa5fbf7932c53aed15e3bc062407 100644 --- a/paddle/fluid/operators/pool_op_xpu.cc +++ b/paddle/fluid/operators/pool_op_xpu.cc @@ -60,11 +60,7 @@ class PoolXPUKernel : public framework::OpKernel { 2, platform::errors::InvalidArgument( "The Pool2d XPU OP only support 2 dimension pooling!")); - PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), - true, - platform::errors::InvalidArgument( - "The Pool2d XPU OP does not support (adaptive == " - "true && output_size != 1)")); + int* index_data = nullptr; bool global_pooling = context.Attr("global_pooling") || (adaptive && (ksize[0] * ksize[1] == 1)); @@ -80,6 +76,9 @@ class PoolXPUKernel : public framework::OpKernel { const int in_h = in_x->dims()[2]; const int in_w = in_x->dims()[3]; + const int out_h = out->dims()[2]; + const int out_w = out->dims()[3]; + framework::DDim data_dims; data_dims = phi::slice_ddim(in_x->dims(), 2, in_x->dims().size()); @@ -90,9 +89,13 @@ class PoolXPUKernel : public framework::OpKernel { data_dims, strides, ksize); + if (ceil_mode) { - paddings[1] += (strides[0] - 1); - paddings[3] += (strides[1] - 1); + int in_h_ceil = (out_h - 1) * strides[0] + ksize[0] - 2 * paddings[0]; + int in_w_ceil = (out_w - 1) * strides[1] + ksize[1] - 2 * paddings[2]; + + paddings[1] += (in_h_ceil - in_h); + paddings[3] += (in_w_ceil - in_w); } auto input = reinterpret_cast(in_x->data()); @@ -100,35 +103,65 @@ class PoolXPUKernel : public framework::OpKernel { auto output = reinterpret_cast(out->data()); auto& dev_ctx = context.template device_context(); int r = xpu::Error_t::SUCCESS; - if (pooling_type == "max") { - r = xpu::max_pool2d(dev_ctx.x_context(), - input, - output, - index_data, - n, - c, - in_h, - in_w, - ksize, - strides, - paddings, - true); - } else if (pooling_type == "avg") { - r = xpu::avg_pool2d(dev_ctx.x_context(), - input, - output, - n, - c, - in_h, - in_w, - ksize, - strides, - paddings, - !exclusive, - true); + if (!adaptive) { + if (pooling_type == "max") { + r = xpu::max_pool2d(dev_ctx.x_context(), + input, + output, + index_data, + n, + c, + in_h, + in_w, + ksize, + strides, + paddings, + true); + } else if (pooling_type == "avg") { + r = xpu::avg_pool2d(dev_ctx.x_context(), + input, + output, + n, + c, + in_h, + in_w, + ksize, + strides, + paddings, + !exclusive, + true); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported pooling type for kunlun ", pooling_type)); + } } else { - PADDLE_THROW(platform::errors::InvalidArgument( - "Unsupported pooling type for kunlun ", pooling_type)); + if (pooling_type == "max") { + r = xpu::adaptive_max_pool2d(dev_ctx.x_context(), + input, + output, + index_data, + n, + c, + in_h, + in_w, + out_h, + out_w, + true); + } else if (pooling_type == "avg") { + r = xpu::adaptive_avg_pool2d(dev_ctx.x_context(), + input, + output, + n, + c, + in_h, + in_w, + out_h, + out_w, + true); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Unsupported pooling type for kunlun ", pooling_type)); + } } PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, @@ -167,11 +200,6 @@ class PoolGradXPUKernel : public framework::OpKernel { "dimension pooling!, but received " "%d-dimension pool kernel size", ksize.size())); - PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1), - true, - platform::errors::InvalidArgument( - "The Pool2d XPU OP does not support (adaptive == " - "true && output_size != 1)")); bool global_pooling = context.Attr("global_pooling") || (adaptive && (ksize[0] * ksize[1] == 1)); if (global_pooling) { @@ -188,6 +216,16 @@ class PoolGradXPUKernel : public framework::OpKernel { const int in_h = in_x->dims()[2]; const int in_w = in_x->dims()[3]; + const int out_h = out->dims()[2]; + const int out_w = out->dims()[3]; + + PADDLE_ENFORCE_EQ(!adaptive || (ksize[0] * ksize[1] == 1) || + (in_h % out_h == 0 && in_w % out_w == 0), + true, + platform::errors::InvalidArgument( + "The Pool2d XPU OP does not support (adaptive == " + "true && output_size != 1)")); + framework::DDim data_dims; data_dims = phi::slice_ddim(in_x->dims(), 2, in_x->dims().size()); @@ -199,8 +237,11 @@ class PoolGradXPUKernel : public framework::OpKernel { strides, ksize); if (ceil_mode) { - paddings[1] += (strides[0] - 1); - paddings[3] += (strides[1] - 1); + int in_h_ceil = (out_h - 1) * strides[0] + ksize[0] - 2 * paddings[0]; + int in_w_ceil = (out_w - 1) * strides[1] + ksize[1] - 2 * paddings[2]; + + paddings[1] += (in_h_ceil - in_h); + paddings[3] += (in_w_ceil - in_w); } auto input = reinterpret_cast(in_x->data()); @@ -210,6 +251,14 @@ class PoolGradXPUKernel : public framework::OpKernel { auto input_grad = reinterpret_cast(in_x_grad->data()); auto& dev_ctx = context.template device_context(); int r = xpu::Error_t::SUCCESS; + if (adaptive && in_h % out_h == 0 && in_w % out_w == 0) { + strides = {in_h / out_h, in_w / out_w}; + int kh = in_h - (out_h - 1) * strides[0]; + int kw = in_w - (out_w - 1) * strides[1]; + ksize = {kh, kw}; + paddings = {0, 0, 0, 0}; + } + if (pooling_type == "max") { r = xpu::max_pool2d_grad(dev_ctx.x_context(), input, diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc index 46195a8023c5ec60597a5f776ead073662da92b3..8251fe21ea4d754cc1dc85dc2aa72217f04b2a50 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_xpu.cc @@ -45,10 +45,6 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { Tensor* loss = context.Output("Loss"); const int rank = logits->dims().size(); const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - PADDLE_ENFORCE_EQ( - axis, - rank - 1, - platform::errors::InvalidArgument("axis should == rank - 1")); softmax->mutable_data(context.GetPlace()); loss->mutable_data(context.GetPlace()); const int n = phi::funcs::SizeToAxis(axis, logits->dims()); @@ -56,17 +52,20 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { std::vector logits_dims = phi::vectorize(logits->dims()); const bool soft_label = context.Attr("soft_label"); + int t = logits_dims[axis]; + auto logits_data = reinterpret_cast(logits->data()); auto softmax_data = reinterpret_cast(softmax->data()); auto loss_data = reinterpret_cast(loss->data()); - // softmax auto& dev_ctx = context.template device_context(); int r = XPU_SUCCESS; + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) == phi::backends::xpu::XPUVersion::XPU2 && - soft_label) { + soft_label && axis == rank - 1) { auto labels_data = reinterpret_cast(labels->data()); r = xpu::soft_softmax_with_cross_entropy(dev_ctx.x_context(), logits_data, @@ -79,7 +78,6 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { return; } - xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); int len = logits->numel(); T* clip_logits = RAII_GUARD.alloc_l3_or_gm(len); PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits); @@ -105,10 +103,38 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax"); // cross_entropy + if (axis != rank - 1) { + XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm(n * d); + PADDLE_ENFORCE_XDNN_NOT_NULL(trans_softmax); + + r = xpu::transpose(dev_ctx.x_context(), + softmax_data, + trans_softmax, + {n, t, d / t}, + {0, 2, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + softmax_data = trans_softmax; + } + if (soft_label) { auto labels_data = reinterpret_cast(labels->data()); - r = xpu::soft_cross_entropy( - dev_ctx.x_context(), softmax_data, labels_data, loss_data, n, d); + if (axis != rank - 1) { + XPUType* trans_label = RAII_GUARD.alloc_l3_or_gm(n * d); + PADDLE_ENFORCE_XDNN_NOT_NULL(trans_label); + r = xpu::transpose(dev_ctx.x_context(), + labels_data, + trans_label, + {n, t, d / t}, + {0, 2, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + labels_data = trans_label; + } + r = xpu::soft_cross_entropy(dev_ctx.x_context(), + softmax_data, + labels_data, + loss_data, + axis == rank - 1 ? n : n * d / t, + axis == rank - 1 ? d : t); PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy"); } else { auto ignore_index = context.Attr("ignore_index"); @@ -127,8 +153,8 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel { labels_int32.data(), loss_data, nullptr, - n, - d, + axis == rank - 1 ? n : n * d / t, + axis == rank - 1 ? d : t, ignore_index); PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy"); } @@ -157,10 +183,6 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel { const int rank = logit_grad->dims().size(); const int axis = phi::funcs::CanonicalAxis(context.Attr("axis"), rank); - PADDLE_ENFORCE_EQ( - axis, - rank - 1, - platform::errors::InvalidArgument("axis should == rank - 1")); const int n = phi::funcs::SizeToAxis(axis, logit_grad->dims()); const int d = phi::funcs::SizeFromAxis(axis, logit_grad->dims()); @@ -168,40 +190,107 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel { context.template device_context(); int r = XPU_SUCCESS; - if (soft_label) { - r = xpu::soft_softmax_with_cross_entropy_grad( - dev_ctx.x_context(), - reinterpret_cast(out_grad->data()), - reinterpret_cast(labels->data()), - reinterpret_cast(softmax->data()), - reinterpret_cast(logit_grad->data()), - use_softmax, - n, - d); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad"); + if (axis == rank - 1) { + if (soft_label) { + r = xpu::soft_softmax_with_cross_entropy_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad->data()), + reinterpret_cast(labels->data()), + reinterpret_cast(softmax->data()), + reinterpret_cast(logit_grad->data()), + use_softmax, + n, + d); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad"); + } else { + xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); + int* labels_int_ptr_l3 = + RAII_GUARD.alloc_l3_or_gm(labels->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3); + + r = xpu::cast_v2(dev_ctx.x_context(), + labels->data(), + labels_int_ptr_l3, + labels->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast_v2"); + + r = xpu::hard_softmax_with_cross_entropy_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad->data()), + labels_int_ptr_l3, + reinterpret_cast(softmax->data()), + reinterpret_cast(logit_grad->data()), + ignore_index, + use_softmax, + n, + d); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad"); + } } else { + int t = logit_grad->dims()[axis]; xpu::ctx_guard RAII_GUARD(dev_ctx.x_context()); - int* labels_int_ptr_l3 = - RAII_GUARD.alloc_l3_or_gm(labels->numel()); - PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3); + int len = softmax->numel(); + XPUType* trans_logit = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(trans_logit); - r = xpu::cast_v2(dev_ctx.x_context(), - labels->data(), - labels_int_ptr_l3, - labels->numel()); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast_v2"); + XPUType* trans_softmax = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(trans_softmax); + r = xpu::transpose(dev_ctx.x_context(), + reinterpret_cast(softmax->data()), + trans_softmax, + {n, t, d / t}, + {0, 2, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + + if (soft_label) { + XPUType* trans_labels = RAII_GUARD.alloc_l3_or_gm(len); + PADDLE_ENFORCE_XDNN_NOT_NULL(trans_labels); + r = xpu::transpose(dev_ctx.x_context(), + reinterpret_cast(labels->data()), + trans_labels, + {n, t, d / t}, + {0, 2, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); + r = xpu::soft_softmax_with_cross_entropy_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad->data()), + trans_labels, + trans_softmax, + trans_logit, + use_softmax, + n * d / t, + t); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad"); + } else { + int* labels_int_ptr_l3 = + RAII_GUARD.alloc_l3_or_gm(labels->numel()); + PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3); + + r = xpu::cast_v2(dev_ctx.x_context(), + labels->data(), + labels_int_ptr_l3, + labels->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2"); + r = xpu::hard_softmax_with_cross_entropy_grad( + dev_ctx.x_context(), + reinterpret_cast(out_grad->data()), + labels_int_ptr_l3, + trans_softmax, + trans_logit, + ignore_index, + use_softmax, + n * d / t, + t); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad"); + } - r = xpu::hard_softmax_with_cross_entropy_grad( + r = xpu::transpose( dev_ctx.x_context(), - reinterpret_cast(out_grad->data()), - labels_int_ptr_l3, - reinterpret_cast(softmax->data()), + trans_logit, reinterpret_cast(logit_grad->data()), - ignore_index, - use_softmax, - n, - d); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad"); + {n, d / t, t}, + {0, 2, 1}); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose"); } } }; diff --git a/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py index 0d7121144adabeb84b7170be311a65aceb885b29..370d7645a81fa5c01b93302a9f8fceff50e273a8 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_pool2d_op_xpu.py @@ -541,6 +541,50 @@ class XPUTestPool2D_Op(XPUOpTestWrapper): def init_ceil_mode(self): self.ceil_mode = True + class TestCaseAdaptiveAvg(TestPool2D_Op): + + def init_test_case(self): + self.ksize = [2, 2] + self.strides = [2, 2] + + def init_paddings(self): + self.paddings = [0, 0] + + def init_pool_type(self): + self.pool_type = "avg" + self.pool2D_forward_naive = avg_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 4, 8, 8] + + def init_adaptive_mode(self): + self.adaptive = True + + class TestCaseAdaptiveMax(TestPool2D_Op): + + def init_test_case(self): + self.ksize = [2, 2] + self.strides = [2, 2] + + def init_paddings(self): + self.paddings = [0, 0] + + def init_pool_type(self): + self.pool_type = "max" + self.pool2D_forward_naive = max_pool2D_forward_naive + + def init_global_pool(self): + self.global_pool = False + + def init_shape(self): + self.shape = [2, 4, 8, 8] + + def init_adaptive_mode(self): + self.adaptive = True + support_types = get_xpu_op_support_types('pool2d') for stype in support_types: diff --git a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py index 661f11704187dd61022fe594710528a6a13d7029..ea60dd1fb908ef57c44b9d9369c006c5a0991051 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_softmax_with_cross_entropy_op_xpu.py @@ -62,16 +62,21 @@ class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper): for numeric_stable_mode in [True, False]: for shape in shapes: for logits_type in [0, 1, 2]: - class_name = 'XPUTestSoftmaxWithCrossEntropy_' + \ - str(soft_label) + "_" + \ - str(numeric_stable_mode) + "_" + \ - str(shape) + "_" + \ - str(logits_type) - attr_dict = {'soft_label': soft_label, \ - 'numeric_stable_mode': numeric_stable_mode, \ - 'shape': shape, \ - 'logits_type': logits_type} - classes.append([class_name, attr_dict]) + for axis in range(len(shape)): + if (not numeric_stable_mode): + axis = -1 + class_name = 'XPUTestSoftmaxWithCrossEntropy_' + \ + str(soft_label) + "_" + \ + str(numeric_stable_mode) + "_" + \ + str(shape) + "_" + \ + str(logits_type) + "_" + \ + str(axis) + attr_dict = {'soft_label': soft_label, \ + 'numeric_stable_mode': numeric_stable_mode, \ + 'shape': shape, \ + 'logits_type': logits_type, + 'axis': axis} + classes.append([class_name, attr_dict]) return base_class, classes class TestSoftmaxWithCrossEntropyOp(XPUOpTest): @@ -83,7 +88,6 @@ class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper): self.op_type = "softmax_with_cross_entropy" self.use_xpu = True self.dtype = np.float32 - self.axis = -1 self.ignore_index = -1 if not hasattr(self, 'shape'): @@ -91,6 +95,7 @@ class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper): self.numeric_stable_mode = True self.logits_type = 0 self.soft_label = True + self.axis = -1 logits = getattr( self, "logits", np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))