未验证 提交 b125e327 编写于 作者: A Aurelius84 提交者: GitHub

Remove constraint that last dimension is forced to be 1 in cross_entropy (#19606)

* Remove constraint that last dimension is forced to be 1 in cross_entropy
test=develop

* modify labels last dims test=develop
上级 e8d3745c
......@@ -25,19 +25,21 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
"Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
"Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label");
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
......@@ -46,19 +48,30 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
}
if (IsSoftLabel(ctx)) {
PADDLE_ENFORCE_EQ(
rank, label_dims.size(),
"If Attr(soft_label) == true, Input(X) and Input(Label) "
"shall have the same rank.");
if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"If Attr(softLabel) == false, the last dimension of "
"Input(Label) should be 1.");
if (rank == label_dims.size()) {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"the last dimension of Input(Label) should be 1.");
} else {
PADDLE_ENFORCE_EQ(
rank, label_dims.size() + 1,
"The rank of Input(X) should be equal to Input(Label) plus 1.");
}
}
auto y_dims = x_dims;
y_dims[rank - 1] = 1;
auto y_dims = label_dims;
if (rank == label_dims.size()) {
y_dims[rank - 1] = 1;
}
ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y");
}
......@@ -82,20 +95,19 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")),
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")),
"Output(X@GRAD) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
"Input(Label) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Y")), true,
"Input(Y@GRAD) shoudl be not null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"Output(X@GRAD) should be not null.");
auto x_dims = GetXDim(ctx);
auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size();
PADDLE_ENFORCE_EQ(dy_dims.size(), rank,
"Input(Y@Grad) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank.");
PADDLE_ENFORCE_EQ(dy_dims.size(), label_dims.size(),
"Input(Y@Grad) and Input(Y) should have the same rank.");
bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
......@@ -104,30 +116,12 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
}
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension.");
}
if (IsSoftLabel(ctx)) {
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"When Attr(soft_label) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
}
......@@ -231,7 +225,7 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should be not null.");
CrossEntropyGradientOpBase::InferShape(ctx);
}
};
......@@ -260,11 +254,11 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
void InferShape(framework::InferShapeContext* ctx) const override {
CrossEntropyOpBase::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"),
"Output(XShape) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
"Output(XShape) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("MatchX"),
"Output(MatchX) should be not null.");
PADDLE_ENFORCE_EQ(ctx->HasOutput("MatchX"), true,
"Output(MatchX) should be not null.");
auto x_dims = ctx->GetInputDim("X");
auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0);
......@@ -284,7 +278,8 @@ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist");
PADDLE_ENFORCE_EQ(ctx->HasInput("MatchX"), true,
"Input(MatchX) must exist");
CrossEntropyGradientOpBase::InferShape(ctx);
}
......
......@@ -35,9 +35,20 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size();
auto label_dims = labels->dims();
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1);
Tensor labels_2d, y_2d;
if (label_dims.size() < rank) {
labels_2d.ShareDataWith(*labels);
labels_2d.Resize({framework::product(label_dims), 1});
y_2d.ShareDataWith(*y);
y_2d.Resize({framework::product(y->dims()), 1});
} else {
labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
y_2d = framework::ReshapeToMatrix(*y, rank - 1);
}
int axis_dim = x->dims()[rank - 1];
math::CrossEntropyFunctor<DeviceContext, T>()(
......
......@@ -20,7 +20,7 @@ import six
class CrossEntropy2OpTestBase(OpTest):
def initParameters(self):
return [32, 64], 'float32', -100
return [32, 64], 'float32', -100, False
def calc_output(self, logits, label, ignore_index):
ret = np.zeros(shape=label.shape, dtype=logits.dtype)
......@@ -33,21 +33,24 @@ class CrossEntropy2OpTestBase(OpTest):
return ret, match_x
def setUp(self):
self.shape, self.dtype, self.ignore_index = self.initParameters()
self.shape, self.dtype, self.ignore_index, self.drop_last_dim = self.initParameters(
)
self.op_type = 'cross_entropy2'
feature_size = int(self.shape[-1])
batch_size = int(np.prod(self.shape) / feature_size)
logits = (np.random.random(size=self.shape) + 1).astype(self.dtype)
label_shape = self.shape[0:-1] if self.drop_last_dim else self.shape[
0:-1] + [1]
label = np.random.random_integers(
low=0, high=feature_size - 1,
size=self.shape[0:-1] + [1]).astype('int64')
low=0, high=feature_size - 1, size=label_shape).astype('int64')
outputs, match_x = self.calc_output(
np.reshape(logits, [batch_size, feature_size]),
np.reshape(label, [batch_size, 1]), self.ignore_index)
self.inputs = {'X': logits, 'Label': label}
out_shape = label_shape
self.outputs = {
'Y': np.reshape(outputs, label.shape),
'MatchX': np.reshape(match_x, label.shape),
'Y': np.reshape(outputs, out_shape),
'MatchX': np.reshape(match_x, self.shape[:-1] + [1]),
'XShape': np.zeros(
shape=logits.shape, dtype=logits.dtype)
}
......@@ -65,17 +68,27 @@ class CrossEntropy2OpTestBase(OpTest):
class CrossEntropy2OpTest2(CrossEntropy2OpTestBase):
def initParameters(self):
return [32, 64], 'float64', 3
return [32, 64], 'float64', 3, False
class CrossEntropy2OpTest2RemoveLastDim(CrossEntropy2OpTestBase):
def initParameters(self):
return [32, 64], 'float64', 3, True
class CrossEntropy2OpTest3(CrossEntropy2OpTestBase):
def initParameters(self):
return [4, 8, 16, 32], 'float32', -100
return [4, 8, 16, 32], 'float32', -100, False
class CrossEntropy2OpTest3RemoveLastDim(CrossEntropy2OpTestBase):
def initParameters(self):
return [4, 8, 16, 32], 'float32', -100, True
class CrossEntropy2OpTest4(CrossEntropy2OpTestBase):
def initParameters(self):
return [4, 8, 16, 32], 'float32', 3
return [4, 8, 16, 32], 'float32', 3, False
if __name__ == '__main__':
......
......@@ -76,6 +76,23 @@ class TestCrossEntropyOp(OpTest):
self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
class TestCrossEntropyOpRemoveLastDim(TestCrossEntropyOp):
"""Test cross-entropy with discrete one-hot labels with shape [batch_size]
"""
def init_label(self):
self.label = np.random.randint(
0, self.class_num, (self.batch_size), dtype="int64")
def get_cross_entropy(self):
self.cross_entropy = np.asmatrix(
[
-np.log(self.x[i][self.label[i]])
for i in range(self.x.shape[0])
],
dtype="float64")
class TestCrossEntropyOp2(TestCrossEntropyOp):
"""Test cross-entropy with vectorized soft labels.
"""
......@@ -167,6 +184,22 @@ class TestCrossEntropyOp4(TestCrossEntropyOp):
self.class_num = 10
class TestCrossEntropyOp4RemoveLastDim(TestCrossEntropyOp4):
"""Test high rank tensor cross-entropy with discrete one-hot labels with shape [batch_size]
"""
def init_label(self):
self.label_2d = np.random.randint(
0, self.class_num, (self.ins_num, 1), dtype="int64")
self.label = self.label_2d.reshape(self.shape)
def get_cross_entropy(self):
cross_entropy_2d = np.asmatrix(
[[-np.log(self.X_2d[i][self.label_2d[i][0]])]
for i in range(self.X_2d.shape[0])]).astype(self.dtype)
self.cross_entropy = np.array(cross_entropy_2d).reshape(self.shape)
class TestCrossEntropyOp5(TestCrossEntropyOp):
"""Test high rank tensor cross-entropy with vectorized soft labels.
"""
......@@ -270,6 +303,23 @@ class TestCrossEntropyOp7(TestCrossEntropyOp):
self.class_num = 10
class TestCrossEntropyOp7RemoveLastDim(TestCrossEntropyOp7):
"""Test cross-entropy with ignore index with shape [batch_size].
"""
def init_label(self):
self.label = np.random.randint(
0, self.class_num, (self.batch_size), dtype="int64")
def get_cross_entropy(self):
self.cross_entropy = np.asmatrix(
[[-np.log(self.x[i][self.label[i]])]
if self.label[i] != self.ignore_index else [0]
for i in range(self.x.shape[0])]).astype(self.dtype)
self.cross_entropy = np.array(self.cross_entropy).reshape(
[self.batch_size]).astype(self.dtype)
# Add Fp16 test
def create_test_class(parent, cls_name):
@unittest.skipIf(not core.is_compiled_with_cuda(),
......@@ -298,9 +348,13 @@ create_test_class(TestCrossEntropyOp, "TestCrossEntropyF16Op")
#create_test_class(TestCrossEntropyOp2, "TestCrossEntropyF16Op2")
create_test_class(TestCrossEntropyOp3, "TestCrossEntropyF16Op3")
create_test_class(TestCrossEntropyOp4, "TestCrossEntropyF16Op4")
create_test_class(TestCrossEntropyOp4RemoveLastDim,
"TestCrossEntropyF16Op4RemoveLastDim")
#create_test_class(TestCrossEntropyOp5, "TestCrossEntropyF16Op5")
create_test_class(TestCrossEntropyOp6, "TestCrossEntropyF16Op6")
create_test_class(TestCrossEntropyOp7, "TestCrossEntropyF16Op7")
create_test_class(TestCrossEntropyOp7RemoveLastDim,
"TestCrossEntropyF16Op7RemoveLastDim")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册