未验证 提交 2b9bb8bb 编写于 作者: Q QingshuChen 提交者: GitHub

optimize kunlun/xpu softmax_with_cross_entropy add add unitest (#39180)

* optimize kunlun/xpu softmax_with_cross_entropy add add unitest
*test=kunlun

* minor
*test=kunlun

* minor
*test=kunlun

* minor
*test=kunlun

* minor
*test=kunlun
上级 9b79988c
......@@ -94,7 +94,7 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
inverse_scale = 0.0;
}
auto version = dev_ctx.xpu_version();
auto version = platform::get_xpu_version(ctx.GetPlace().GetDeviceId());
framework::Tensor float_x;
framework::Tensor float_out;
if (std::is_same<T, paddle::platform::float16>::value &&
......
......@@ -107,7 +107,7 @@ class DropoutGradXPUKernel : public framework::OpKernel<T> {
return;
}
auto version = dev_ctx.xpu_version();
auto version = platform::get_xpu_version(context.GetPlace().GetDeviceId());
if (version == pten::backends::xpu::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* mask_new = RAII_GUARD.alloc_l3_or_gm<XPUType>(mask->numel());
......
......@@ -45,7 +45,7 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = XPU_SUCCESS;
auto version = dev_ctx.xpu_version();
auto version = platform::get_xpu_version(context.GetPlace().GetDeviceId());
if (version == pten::backends::xpu::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* clip_x_data_l3 = RAII_GUARD.alloc_l3_or_gm<XPUType>(x->numel());
......
......@@ -19,6 +19,7 @@ limitations under the License. */
#include <unordered_map>
#include <vector>
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "xpu/refactor/math.h"
#include "xpu/refactor/nn.h"
......@@ -45,68 +46,57 @@ class SoftmaxWithCrossEntropyXPUKernel : public framework::OpKernel<T> {
const int n = SizeToAxis(axis, logits->dims());
const int d = SizeFromAxis(axis, logits->dims());
std::vector<int> logits_dims = framework::vectorize<int>(logits->dims());
const bool soft_label = context.Attr<bool>("soft_label");
// softmax
auto& dev_ctx =
context.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
Tensor clip_logits;
if (platform::get_xpu_version(context.GetPlace().GetDeviceId()) ==
pten::backends::xpu::XPUVersion::XPU2 &&
soft_label) {
r = xpu::soft_softmax_with_cross_entropy(
dev_ctx.x_context(), logits->data<float>(), labels->data<T>(),
softmax->data<T>(), loss->data<T>(), n, d);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy");
return;
}
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int len = logits->numel();
T* clip_logits_data =
clip_logits.mutable_data<T>(context.GetPlace(), len * sizeof(T));
T* clip_logits_data = RAII_GUARD.alloc_l3_or_gm<T>(len);
PADDLE_ENFORCE_XDNN_NOT_NULL(clip_logits_data);
r = xpu::clip_v2(dev_ctx.x_context(), logits->data<float>(),
clip_logits_data, len, static_cast<float>(-1e20),
static_cast<float>(1e20));
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. clip "
"execution not succeed, error code=%d",
r));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::softmax(dev_ctx.x_context(), clip_logits_data,
softmax->data<float>(), logits_dims, axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax");
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. Softmax2d_forward "
"execution not succeed, error code=%d",
r));
// cross_entropy
auto ignore_index = context.Attr<int>("ignore_index");
const bool soft_label = context.Attr<bool>("soft_label");
if (soft_label) {
r = xpu::soft_cross_entropy<float>(
dev_ctx.x_context(), softmax->data<float>(), labels->data<float>(),
loss->data<float>(), n, d);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. soft_cross_entropy "
"execution not succeed, error code=%d",
r));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_cross_entropy");
} else {
auto ignore_index = context.Attr<int>("ignore_index");
Tensor labels_int32;
labels_int32.mutable_data<int32_t>(context.GetPlace(),
labels->numel() * sizeof(int32_t));
r = xpu::cast_v2<int64_t, int32_t>(
dev_ctx.x_context(), labels->data<int64_t>(),
labels_int32.data<int32_t>(), labels->numel());
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. cast_v2 "
"execution not succeed, error code=%d",
r));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
r = xpu::hard_cross_entropy<float, int32_t>(
dev_ctx.x_context(), softmax->data<float>(),
labels_int32.data<int32_t>(), loss->data<float>(), nullptr, n, d,
ignore_index);
PADDLE_ENFORCE_EQ(
r, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error. hard_cross_entropy "
"execution not succeed, error code=%d",
r));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_cross_entropy");
}
}
};
......@@ -149,23 +139,17 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(labels->data<T>()),
reinterpret_cast<const XPUType*>(softmax->data<T>()),
reinterpret_cast<XPUType*>(logit_grad->data<T>()), use_softmax, n, d);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API(soft_softmax_with_cross_entropy_grad) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "soft_softmax_with_cross_entropy_grad");
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* labels_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(labels->numel());
PADDLE_ENFORCE_XDNN_NOT_NULL(labels_int_ptr_l3);
r = xpu::cast_v2<int64_t, int32_t>(dev_ctx.x_context(),
labels->data<int64_t>(),
labels_int_ptr_l3, labels->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(cast_v2) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "cast_v2");
r = xpu::hard_softmax_with_cross_entropy_grad<XPUType, int>(
dev_ctx.x_context(),
......@@ -174,12 +158,7 @@ class SoftmaxWithCrossEntropyGradXPUKernel : public framework::OpKernel<T> {
reinterpret_cast<const XPUType*>(softmax->data<T>()),
reinterpret_cast<XPUType*>(logit_grad->data<T>()), ignore_index,
use_softmax, n, d);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
"XPU API(hard_softmax_with_cross_entropy_grad) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_softmax_with_cross_entropy_grad");
}
}
};
......
......@@ -189,6 +189,17 @@ DEFINE_EXTERNAL_API_TYPE(BKCLResult_t, BKCL_SUCCESS);
} \
} while (0)
#define PADDLE_ENFORCE_XDNN_NOT_NULL(ptr) \
do { \
if (UNLIKELY(ptr == nullptr)) { \
auto __summary__ = paddle::platform::errors::External( \
::pten::backends::xpu::build_xpu_xdnn_error_msg( \
baidu::xpu::api::Error_t::NO_ENOUGH_WORKSPACE, \
"XPU memory is not enough")); \
__THROW_ERROR_INTERNAL__(__summary__); \
} \
} while (0)
} // namespace xpu
} // namespace backends
} // namespace pten
......@@ -193,11 +193,6 @@ class XPUOpTest(OpTest):
for input_to_check in inputs_to_check:
set_input(self.scope, self.op, self.inputs, place)
tensor_to_check = self.scope.find_var(input_to_check).get_tensor()
tensor_size = six.moves.reduce(lambda a, b: a * b,
tensor_to_check.shape(), 1)
if tensor_size < 100:
self.__class__.input_shape_is_large = False
if not type(output_names) is list:
output_names = [output_names]
......
......@@ -23,6 +23,7 @@ import paddle
import unittest
import numpy as np
from xpu.get_test_cover_info import create_test_class, get_xpu_op_support_types, XPUOpTestWrapper
def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
......@@ -45,353 +46,103 @@ def cross_entropy(softmax, label, soft_label, axis, ignore_index=-1):
return result.reshape(label.shape)
class TestSoftmaxWithCrossEntropyOp(XPUOpTest):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = False
self.soft_label = False
self.dtype = np.float32
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
self.use_xpu = True
def setUp(self):
self.initParams()
logits = getattr(
self, "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
if self.soft_label:
labels = np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(0, axis_dim, self.shape, dtype="int64")
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
}
if self.ignore_index >= 0:
self.attrs['ignore_index'] = self.ignore_index
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-2)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ["Logits"], "Loss", max_relative_error=0.2)
class TestXPUSoftmaxWithCrossEntropyOp(TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float32
self.use_xpu = True
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-2)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ["Logits"], "Loss", max_relative_error=0.2)
class TestXPUSoftmaxWithCrossEntropyOp2(TestXPUSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with soft labels.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.dtype = np.float32
self.axis = -1
self.ignore_index = -1
self.shape = [41, 37]
self.use_xpu = True
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-2)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ["Logits"], "Loss", max_relative_error=0.2)
class TestXPUSoftmaxWithCrossEntropyOp3(TestXPUSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with ignore_index.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [41, 37]
self.ignore_index = 5
self.axis = -1
self.dtype = np.float32
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpAxis1(TestXPUSoftmaxWithCrossEntropyOp):
# """
# Test softmax with cross entropy operator with discreate one-hot labels.
# Given axis != -1
# """
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = False
# self.dtype = np.float32
# self.axis = 0
# self.ignore_index = -1
# self.shape = [3, 5, 7, 11]
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpAxis2(TestXPUSoftmaxWithCrossEntropyOp):
# """
# Test softmax with cross entropy operator with discreate one-hot labels.
# Given axis != -1
# """
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = False
# self.dtype = np.float32
# self.axis = 1
# self.ignore_index = -1
# self.shape = [3, 5, 7, 11]
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpAxis3(TestXPUSoftmaxWithCrossEntropyOp):
# """
# Test softmax with cross entropy operator with discreate one-hot labels.
# Given axis != -1
# """
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = False
# self.dtype = np.float32
# self.axis = 2
# self.ignore_index = -1
# self.shape = [3, 5, 7, 11]
class TestXPUSoftmaxWithCrossEntropyOpAxis4(TestXPUSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float32
self.axis = 3
self.ignore_index = -1
self.shape = [3, 5, 7, 11]
class TestXPUSoftmaxWithCrossEntropyOpAxisDimEqualOne(
TestXPUSoftmaxWithCrossEntropyOp):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
Given axis != -1
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.dtype = np.float32
self.axis = -1
self.ignore_index = -1
self.shape = [3, 5, 7, 1]
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpSoftLabelAxis1(
# TestXPUSoftmaxWithCrossEntropyOp):
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = True
# self.shape = [3, 5, 7, 11]
# self.axis = 0
# self.ignore_index = -1
# self.dtype = np.float32
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpSoftLabelAxis2(
# TestXPUSoftmaxWithCrossEntropyOp2):
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = True
# self.shape = [3, 5, 7, 11]
# self.axis = 1
# self.ignore_index = -1
# self.dtype = np.float32
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpSoftLabelAxis3(
# TestXPUSoftmaxWithCrossEntropyOp2):
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = True
# self.shape = [3, 5, 7, 11]
# self.axis = 2
# self.ignore_index = -1
# self.dtype = np.float32
class TestXPUSoftmaxWithCrossEntropyOpSoftLabelAxis4(
TestXPUSoftmaxWithCrossEntropyOp2):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = True
self.shape = [3, 5, 7, 11]
self.axis = 3
self.ignore_index = -1
self.dtype = np.float32
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis1(
# TestXPUSoftmaxWithCrossEntropyOp3):
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = False
# self.shape = [3, 5, 7, 11]
# self.ignore_index = 1
# self.axis = 0
# self.dtype = np.float32
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis2(
# TestXPUSoftmaxWithCrossEntropyOp3):
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = False
# self.shape = [3, 5, 7, 11]
# self.ignore_index = 0
# self.axis = 1
# self.dtype = np.float32
# xpu only support axis = rank -1
# class TestXPUSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis3(
# TestXPUSoftmaxWithCrossEntropyOp3):
# def initParams(self):
# self.op_type = "softmax_with_cross_entropy"
# self.numeric_stable_mode = True
# self.soft_label = False
# self.shape = [3, 5, 7, 11]
# self.ignore_index = 3
# self.axis = 2
# self.dtype = np.float32
class TestXPUSoftmaxWithCrossEntropyOpIgnoreIndexNoCudnnAxis4(
TestXPUSoftmaxWithCrossEntropyOp3):
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.ignore_index = 3
self.axis = 3
self.dtype = np.float32
class TestXPUSoftmaxWithCrossEntropyOpBoundary0(
TestXPUSoftmaxWithCrossEntropyOp):
"""
Test stable softmax with cross entropy operator will not product INF
with small logits value.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float32
self.logits = np.full(self.shape, -500.0).astype(self.dtype)
class TestXPUSoftmaxWithCrossEntropyOpBoundary1(
TestXPUSoftmaxWithCrossEntropyOp):
"""
Test stable softmax with cross entropy operator will not product INF
with small logits value.
"""
def initParams(self):
self.op_type = "softmax_with_cross_entropy"
self.numeric_stable_mode = True
self.soft_label = False
self.shape = [3, 5, 7, 11]
self.axis = -1
self.ignore_index = -1
self.dtype = np.float32
self.logits = np.full(self.shape, 1000.0).astype(self.dtype)
self.logits[:, :, 0, :] = -1000.0
class XPUTestSoftmaxWithCrossEntropyOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = 'softmax_with_cross_entropy'
self.use_dynamic_create_class = True
def dynamic_create_class(self):
base_class = self.TestSoftmaxWithCrossEntropyOp
classes = []
shapes = [[41, 37], [3, 5, 7, 11], [3, 5, 7, 1], [1023, 38512],
[1, 511]]
for soft_label in [True, False]:
for numeric_stable_mode in [True, False]:
for shape in shapes:
for logits_type in [0, 1, 2]:
class_name = 'XPUTestSoftmaxWithCrossEntropy_' + \
str(soft_label) + "_" + \
str(numeric_stable_mode) + "_" + \
str(shape) + "_" + \
str(logits_type)
attr_dict = {'soft_label': soft_label, \
'numeric_stable_mode': numeric_stable_mode, \
'shape': shape, \
'logits_type': logits_type}
classes.append([class_name, attr_dict])
return base_class, classes
class TestSoftmaxWithCrossEntropyOp(XPUOpTest):
"""
Test softmax with cross entropy operator with discreate one-hot labels.
"""
def setUp(self):
self.op_type = "softmax_with_cross_entropy"
self.use_xpu = True
self.dtype = np.float32
self.axis = -1
self.ignore_index = -1
if not hasattr(self, 'shape'):
self.shape = [43, 6]
self.numeric_stable_mode = True
self.logits_type = 0
self.soft_label = True
logits = getattr(
self, "logits",
np.random.uniform(0.1, 1.0, self.shape).astype(self.dtype))
if self.logits_type == 1:
self.logits = np.full(self.shape, -500.0).astype(self.dtype)
elif self.logits_type == 2 and len(self.shape) == 4:
self.logits = np.full(self.shape, 1000.0).astype(self.dtype)
self.logits[:, :, 0, :] = -1000.0
softmax = np.apply_along_axis(stable_softmax, self.axis, logits)
if self.soft_label:
labels = np.random.uniform(0.1, 1.0,
self.shape).astype(self.dtype)
labels /= np.sum(labels, axis=self.axis, keepdims=True)
else:
axis_dim = self.shape[self.axis]
self.shape[self.axis] = 1
labels = np.random.randint(
0, axis_dim, self.shape, dtype="int64")
loss = cross_entropy(softmax, labels, self.soft_label, self.axis,
self.ignore_index)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": loss.astype(self.dtype)
}
self.attrs = {
"numeric_stable_mode": self.numeric_stable_mode,
"soft_label": self.soft_label,
}
if self.ignore_index >= 0:
self.attrs['ignore_index'] = self.ignore_index
if self.axis != -1:
self.attrs['axis'] = self.axis
def test_check_output(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=1e-2)
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
paddle.enable_static()
place = paddle.XPUPlace(0)
self.check_grad_with_place(
place, ["Logits"], "Loss", max_relative_error=0.2)
support_types = get_xpu_op_support_types('softmax_with_cross_entropy')
for stype in support_types:
create_test_class(globals(), XPUTestSoftmaxWithCrossEntropyOp, stype)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册