提交 bbfc82cc 编写于 作者: P phlrain

softmax corss entropy support high rank

test=develop
上级 e2897ba1
......@@ -106,24 +106,40 @@ class SoftmaxWithCrossEntropyOp : public framework::OperatorWithKernel {
auto logits_dims = ctx->GetInputDim("Logits");
auto labels_dims = ctx->GetInputDim("Label");
int rank = logits_dims.size();
PADDLE_ENFORCE_EQ(
logits_dims.size(), 2UL,
"The input of softmax_with_cross_entropy should be a 2-D tensor.");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(logits_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(logits_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(logits_dims[1], labels_dims[1],
"If Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
if (check) {
PADDLE_ENFORCE_EQ(logits_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"If Attr(soft_label) == false, the 2nd dimension of "
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim("Softmax", logits_dims);
ctx->SetOutputDim("Loss", {logits_dims[0], 1});
auto loss_dims = logits_dims;
loss_dims[rank - 1] = 1;
ctx->SetOutputDim("Loss", loss_dims);
// ctx->SetOutputDim("Loss", {logits_dims[0], 1});
ctx->ShareLoD("Logits", /*->*/ "Softmax");
ctx->ShareLoD("Logits", /*->*/ "Loss");
......@@ -152,16 +168,33 @@ class SoftmaxWithCrossEntropyOpGrad : public framework::OperatorWithKernel {
auto softmax_dims = ctx->GetInputDim("Softmax");
auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2UL,
"The labels should be a 2-D tensor.");
int rank = softmax_dims.size();
PADDLE_ENFORCE_EQ(
rank, labels_dims.size(),
"Input(logits) and Input(Label) shall have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(softmax_dims) <= 0 ||
framework::product(labels_dims) <= 0)) {
check = false;
}
if (check) {
PADDLE_ENFORCE_EQ(
framework::slice_ddim(softmax_dims, 0, rank - 1),
framework::slice_ddim(labels_dims, 0, rank - 1),
"Input(Softmax) and Input(Label) shall have the same shape "
"except the last dimension.");
}
if (ctx->Attrs().Get<bool>("soft_label")) {
PADDLE_ENFORCE_EQ(softmax_dims[1], labels_dims[1],
"When Attr(soft_label) == true, the 2nd dimension of "
"Input(X) and Input(Label) should be equal.");
if (check) {
PADDLE_ENFORCE_EQ(softmax_dims[rank - 1], labels_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input( Softmax) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(labels_dims[1], 1UL,
"When Attr(soft_label) == false, the 2nd dimension of "
PADDLE_ENFORCE_EQ(labels_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
}
......
......@@ -400,9 +400,15 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
auto soft_label = context.Attr<bool>("soft_label");
auto ignore_index = context.Attr<int>("ignore_index");
int rank = logits->dims().size();
if (soft_label) {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logits->dims()[i];
}
int feature_size = logits->dims()[rank - 1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<T>();
SoftmaxWithCrossEntropyFusedKernel(
......@@ -410,14 +416,23 @@ class SoftmaxWithCrossEntropyCUDAKernel : public framework::OpKernel<T> {
feature_size, context.cuda_device_context().stream());
} else {
if (!context.Attr<bool>("numeric_stable_mode")) {
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(), logits,
softmax);
// reshape to 2d
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
math::SoftmaxCUDNNFunctor<T>()(context.cuda_device_context(),
&logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CUDADeviceContext, T>()(
context.cuda_device_context(), loss, softmax, labels, false,
ignore_index);
context.cuda_device_context(), &loss_2d, &softmax_2d, &labels_2d,
false, ignore_index);
} else {
int batch_size = logits->dims()[0];
int feature_size = logits->dims()[1];
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logits->dims()[i];
}
int feature_size = logits->dims()[rank - 1];
auto* logits_data = logits->data<T>();
auto* labels_data = labels->data<int64_t>();
HardLabelSoftmaxWithCrossEntropy<T>(
......@@ -443,8 +458,13 @@ class SoftmaxWithCrossEntropyGradCUDAKernel : public framework::OpKernel<T> {
context.device_context(), logit_grad);
T* logit_grad_data = logit_grad->data<T>();
const int batch_size = logit_grad->dims()[0];
const int class_num = logit_grad->dims()[1];
int rank = logit_grad->dims().size();
int batch_size = 1;
for (int i = 0; i < rank - 1; ++i) {
batch_size *= logit_grad->dims()[i];
}
const int class_num = logit_grad->dims()[rank - 1];
int block = 512;
auto stream = context.cuda_device_context().stream();
auto ignore_index = context.Attr<int>("ignore_index");
......
......@@ -40,15 +40,22 @@ class SoftmaxWithCrossEntropyKernel : public framework::OpKernel<T> {
softmax->mutable_data<T>(context.GetPlace());
loss->mutable_data<T>(context.GetPlace());
int axis_dim = logits->dims()[logits->dims().size() - 1];
// reshape to 2D tensor
int rank = logits->dims().size();
Tensor logits_2d = framework::ReshapeToMatrix(*logits, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor loss_2d = framework::ReshapeToMatrix(*loss, rank - 1);
Tensor softmax_2d = framework::ReshapeToMatrix(*softmax, rank - 1);
int axis_dim = logits->dims()[rank - 1];
auto& dev_ctx =
context.template device_context<platform::CPUDeviceContext>();
math::SoftmaxFunctor<platform::CPUDeviceContext, T, false>()(
dev_ctx, axis_dim, logits, softmax);
dev_ctx, axis_dim, &logits_2d, &softmax_2d);
math::CrossEntropyFunctor<platform::CPUDeviceContext, T>()(
dev_ctx, loss, softmax, labels, context.Attr<bool>("soft_label"),
context.Attr<int>("ignore_index"));
dev_ctx, &loss_2d, &softmax_2d, &labels_2d,
context.Attr<bool>("soft_label"), context.Attr<int>("ignore_index"));
}
};
......@@ -63,13 +70,19 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
context.Output<Tensor>(framework::GradVarName("Logits"));
logit_grad->ShareDataWith(*context.Input<Tensor>("Softmax"));
const int class_num = logit_grad->dims()[1];
auto out_grad_mat = EigenMatrix<T>::From(*out_grad);
auto logit_grad_mat = EigenMatrix<T>::From(*logit_grad);
int rank = logit_grad->dims().size();
const int class_num = logit_grad->dims()[rank - 1];
// reshape to 2d
Tensor logit_grad_2d = framework::ReshapeToMatrix(*logit_grad, rank - 1);
Tensor out_grad_2d = framework::ReshapeToMatrix(*out_grad, rank - 1);
auto out_grad_mat = EigenMatrix<T>::From(out_grad_2d);
auto logit_grad_mat = EigenMatrix<T>::From(logit_grad_2d);
auto& place = *context.template device_context<platform::CPUDeviceContext>()
.eigen_device();
if (context.Attr<bool>("soft_label")) {
auto lbl_mat = EigenMatrix<T>::From(*labels);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
auto lbl_mat = EigenMatrix<T>::From(labels_2d);
logit_grad_mat.device(place) =
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num)) *
(logit_grad_mat - lbl_mat);
......@@ -78,7 +91,8 @@ class SoftmaxWithCrossEntropyGradKernel : public framework::OpKernel<T> {
logit_grad_mat *
out_grad_mat.broadcast(Eigen::DSizes<int, 2>(1, class_num));
const int batch_size = logit_grad->dims()[0];
const int batch_size = logit_grad_2d.dims()[0];
const int64_t* label_data = labels->data<int64_t>();
T* logit_grad_data = logit_grad->data<T>();
const T* out_grad_data = out_grad->data<T>();
......
......@@ -149,5 +149,98 @@ class TestSigmoidCrossEntropyWithNorm(OpTest):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithLogitsOp5(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")),
'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithNorm2(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = [10, 10]
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")),
'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes]))
.astype("float32")
}
self.attrs = {'ignore_index': ignore_index, 'normalize': True}
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
if self.attrs['normalize']:
out = out / float(
np.where(self.inputs['Label'] != ignore_index)[0].size)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithLogitsOp6(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with binary label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")),
'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes]))
.astype("float32")
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册