未验证 提交 d5a7c098 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16798 from phlrain/softmax_cross_support_high_rank

softmax cross entropy support high rank
......@@ -106,24 +106,36 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
int rank = logits_dims.size();
PADDLE_ENFORCE_EQ(
logits_dims.size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = ctx->IsRuntime() || (framework::product(logits_dims) > 0 &&
framework::product(labels_dims) > 0);
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
if (check) {
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(soft_label) == false, the 2nd dimension of "
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim("Softmax", logits_dims);
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
auto loss_dims = logits_dims;
loss_dims[rank - 1] = 1;
ctx->SetOutputDim("Loss", loss_dims);
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
......@@ -152,16 +164,33 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
int rank = softmax_dims.size();
PADDLE_ENFORCE_EQ(
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
framework::slice_ddim(softmax_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(Softmax) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
if (check) {
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input( Softmax) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(soft_label) == false, the 2nd dimension of "
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
......
......@@ -400,9 +400,15 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index");
int rank = logits->dims().size();
if (soft_label) {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logits->dims()[i];
}
int feature_size = logits->dims()[rank - 1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<T>();
SoftmaxWithCrossEntropyFusedKernel(
......@@ -410,14 +416,23 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
feature_size, context.cuda_device_context().stream());
} else {
if (!context.Attr<bool>("numeric_stable_mode")) {
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
softmax);
// reshape to 2d
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
&logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), loss, softmax, labels, false,
ignore_index);
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
false, ignore_index);
} else {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logits->dims()[i];
}
int feature_size = logits->dims()[rank - 1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<int64_t>();
HardLabelSoftmaxWithCrossEntropy<T>(
......@@ -443,8 +458,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.device_context(), logit_grad);
T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
int rank = logit_grad->dims().size();
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logit_grad->dims()[i];
}
const int class_num = logit_grad->dims()[rank - 1];
int block = 512;
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
......
......@@ -40,15 +40,22 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size() - 1];
// reshape to 2D tensor
int rank = logits->dims().size();
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
int axis_dim = logits->dims()[rank - 1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, logits, softmax);
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
dev_ctx, &loss_2d, &softmax_2d, &labels_2d,
context.Attr<bool>("soft_label"), context.Attr<int>("ignore_index"));
}
};
......@@ -63,13 +70,19 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
const int class_num = logit_grad->dims()[1];
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
int rank = logit_grad->dims().size();
const int class_num = logit_grad->dims()[rank - 1];
// reshape to 2d
Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1);
Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1);
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (context.Attr<bool>("soft_label")) {
auto lbl_mat = EigenMatrix<T>::From(*labels);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) *
(logit_grad_mat - lbl_mat);
......@@ -78,7 +91,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num));
const int batch_size = logit_grad->dims()[0];
const int batch_size = logit_grad_2d.dims()[0];
const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
......
......@@ -195,5 +195,144 @@ class TestSoftmaxWithCrossEntropyOp3NoCudnn(TestSoftmaxWithCrossEntropyOp3):
self.numeric_stable_mode = True
class TestSoftmaxWithCrossEntropyOp5(OpTest):
"""
Test softmax with cross entropy operator with ignore_index.
"""
def initParams(self):
self.numeric_stable_mode = False
def setUp(self):
self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = [6, 10]
class_num = 47
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.randint(
0, class_num, tuple(batch_size + [1]), dtype="int64")
ignore_index = 7
softmax_2d = np.reshape(softmax, [-1, class_num])
labels_2d = np.reshape(labels, [-1, 1])
cross_entropy = np.asmatrix(
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
if labels_2d[i] != ignore_index else [0]
for i in range(softmax_2d.shape[0])],
dtype="float64")
cross_entropy = np.reshape(cross_entropy, batch_size)
output_shape = tuple(batch_size + [1])
output_res = cross_entropy.astype("float64")
output_res = np.expand_dims(output_res, axis=2)
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": output_res,
}
self.attrs = {
"ignore_index": ignore_index,
"numeric_stable_mode": self.numeric_stable_mode
}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOp5NoCudnn(TestSoftmaxWithCrossEntropyOp5):
def initParams(self):
self.numeric_stable_mode = True
class TestSoftmaxWithCrossEntropyOp6(OpTest):
"""
Test softmax with cross entropy operator with soft labels.
"""
def setUp(self):
self.op_type = "softmax_with_cross_entropy"
batch_size = [6, 10]
class_num = 37
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype("float64")
labels /= np.sum(labels, axis=2, keepdims=True)
cross_entropy = (-labels * np.log(softmax)).sum(
axis=2, keepdims=True).astype("float64")
self.inputs = {"Logits": logits, "Label": labels}
self.outputs = {
"Softmax": softmax.astype("float64"),
"Loss": cross_entropy.astype("float64")
}
self.attrs = {"soft_label": True}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(["Logits"], "Loss")
class TestSoftmaxWithCrossEntropyOpFp16_2(TestSoftmaxWithCrossEntropyOp):
def initParams(self):
self.numeric_stable_mode = False
self.dtype = np.float16
def setUp(self):
self.initParams()
self.op_type = "softmax_with_cross_entropy"
batch_size = [64, 10]
class_num = 37
# NOTE: numpy float16 have very low accuracy, use float32 for numpy check.
logits = np.random.uniform(
0.1, 1.0, tuple(batch_size + [class_num])).astype(np.float32)
softmax = np.apply_along_axis(stable_softmax, 2, logits)
labels = np.random.randint(
0, class_num, tuple(batch_size + [1]), dtype="int64")
softmax_2d = np.reshape(softmax, [-1, class_num])
labels_2d = np.reshape(labels, [-1, 1])
cross_entropy = np.asmatrix(
[[-np.log(softmax_2d[i][labels_2d[i][0]])]
for i in range(softmax_2d.shape[0])],
dtype=np.float32)
cross_entropy = np.reshape(cross_entropy, batch_size)
output_shape = tuple(batch_size + [1])
output_res = cross_entropy.astype(self.dtype)
output_res = np.expand_dims(output_res, axis=2)
self.inputs = {"Logits": logits, "Label": labels}
self.inputs = {
"Logits": logits.astype(self.dtype).view(np.uint16),
"Label": labels
}
self.outputs = {
"Softmax": softmax.astype(self.dtype),
"Loss": output_res,
}
self.attrs = {"numeric_stable_mode": self.numeric_stable_mode}
def test_check_output(self):
self.check_output(atol=1e-2)
def test_check_grad(self):
self.check_grad(["Logits"], "Loss", max_relative_error=0.1)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册