未验证 提交 b125e327 编写于 作者: A Aurelius84 提交者: GitHub

Remove constraint that last dimension is forced to be 1 in cross_entropy (#19606)

* Remove constraint that last dimension is forced to be 1 in cross_entropy
test=develop

* modify labels last dims test=develop
上级 e8d3745c
...@@ -25,19 +25,21 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { ...@@ -25,19 +25,21 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should be not null.");
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
"Input(Label) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("Y"), "Output(Y) should be not null."); PADDLE_ENFORCE_EQ(ctx->HasOutput("Y"), true,
"Output(Y) should be not null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_EQ(rank, label_dims.size(),
"Input(X) and Input(Label) shall have the same rank.");
bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) || bool contain_unknown_dim = framework::contain_unknown_dim(x_dims) ||
framework::contain_unknown_dim(label_dims); framework::contain_unknown_dim(label_dims);
bool check = ctx->IsRuntime() || !contain_unknown_dim; bool check = ctx->IsRuntime() || !contain_unknown_dim;
if (check) { if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1), framework::slice_ddim(label_dims, 0, rank - 1),
...@@ -46,19 +48,30 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel { ...@@ -46,19 +48,30 @@ class CrossEntropyOpBase : public framework::OperatorWithKernel {
} }
if (IsSoftLabel(ctx)) { if (IsSoftLabel(ctx)) {
PADDLE_ENFORCE_EQ(
rank, label_dims.size(),
"If Attr(soft_label) == true, Input(X) and Input(Label) "
"shall have the same rank.");
if (check) { if (check) {
PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1], PADDLE_ENFORCE_EQ(x_dims[rank - 1], label_dims[rank - 1],
"If Attr(soft_label) == true, the last dimension of " "If Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal."); "Input(X) and Input(Label) should be equal.");
} }
} else { } else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL, if (rank == label_dims.size()) {
"If Attr(softLabel) == false, the last dimension of " PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1UL,
"Input(Label) should be 1."); "the last dimension of Input(Label) should be 1.");
} else {
PADDLE_ENFORCE_EQ(
rank, label_dims.size() + 1,
"The rank of Input(X) should be equal to Input(Label) plus 1.");
}
} }
auto y_dims = x_dims; auto y_dims = label_dims;
y_dims[rank - 1] = 1; if (rank == label_dims.size()) {
y_dims[rank - 1] = 1;
}
ctx->SetOutputDim("Y", y_dims); ctx->SetOutputDim("Y", y_dims);
ctx->ShareLoD("X", /*->*/ "Y"); ctx->ShareLoD("X", /*->*/ "Y");
} }
...@@ -82,20 +95,19 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { ...@@ -82,20 +95,19 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const { void InferShape(framework::InferShapeContext* ctx) const {
PADDLE_ENFORCE(ctx->HasInput("Label"), "Input(Label) should be not null."); PADDLE_ENFORCE_EQ(ctx->HasInput("Label"), true,
PADDLE_ENFORCE(ctx->HasInput(framework::GradVarName("Y")), "Input(Label) should be not null.");
"Input(Y@GRAD) shoudl be not null."); PADDLE_ENFORCE_EQ(ctx->HasInput(framework::GradVarName("Y")), true,
PADDLE_ENFORCE(ctx->HasOutput(framework::GradVarName("X")), "Input(Y@GRAD) shoudl be not null.");
"Output(X@GRAD) should be not null."); PADDLE_ENFORCE_EQ(ctx->HasOutput(framework::GradVarName("X")), true,
"Output(X@GRAD) should be not null.");
auto x_dims = GetXDim(ctx); auto x_dims = GetXDim(ctx);
auto label_dims = ctx->GetInputDim("Label"); auto label_dims = ctx->GetInputDim("Label");
auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y")); auto dy_dims = ctx->GetInputDim(framework::GradVarName("Y"));
int rank = x_dims.size(); int rank = x_dims.size();
PADDLE_ENFORCE_EQ(dy_dims.size(), rank, PADDLE_ENFORCE_EQ(dy_dims.size(), label_dims.size(),
"Input(Y@Grad) and Input(X) should have the same rank."); "Input(Y@Grad) and Input(Y) should have the same rank.");
PADDLE_ENFORCE_EQ(label_dims.size(), rank,
"Input(Label) and Input(X) should have the same rank.");
bool check = true; bool check = true;
if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 || if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
...@@ -104,30 +116,12 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel { ...@@ -104,30 +116,12 @@ class CrossEntropyGradientOpBase : public framework::OperatorWithKernel {
} }
if (check) { if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(label_dims, 0, rank - 1),
"The Input(X) and Input(Label) should have the same "
"shape except the last dimension.");
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1), PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank - 1),
framework::slice_ddim(dy_dims, 0, rank - 1), framework::slice_ddim(dy_dims, 0, rank - 1),
"The Input(X) and Input(Y@Grad) should have the same " "The Input(X) and Input(Y@Grad) should have the same "
"shape except the last dimension."); "shape except the last dimension.");
} }
if (IsSoftLabel(ctx)) {
if (check) {
PADDLE_ENFORCE_EQ(
x_dims[rank - 1], label_dims[rank - 1],
"When Attr(soft_label) == true, the last dimension of "
"Input(X) and Input(Label) should be equal.");
}
} else {
PADDLE_ENFORCE_EQ(label_dims[rank - 1], 1,
"When Attr(soft_label) == false, the last dimension of "
"Input(Label) should be 1.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
PADDLE_ENFORCE_EQ(dy_dims[rank - 1], 1,
"The last dimension of Input(Y@Grad) should be 1.");
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X")); ctx->ShareLoD(VarNameWithXLoD(), framework::GradVarName("X"));
} }
...@@ -231,7 +225,7 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase { ...@@ -231,7 +225,7 @@ class CrossEntropyGradientOp : public CrossEntropyGradientOpBase {
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("X"), "Input(X) should be not null."); PADDLE_ENFORCE_EQ(ctx->HasInput("X"), true, "Input(X) should be not null.");
CrossEntropyGradientOpBase::InferShape(ctx); CrossEntropyGradientOpBase::InferShape(ctx);
} }
}; };
...@@ -260,11 +254,11 @@ class CrossEntropyOp2 : public CrossEntropyOpBase { ...@@ -260,11 +254,11 @@ class CrossEntropyOp2 : public CrossEntropyOpBase {
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
CrossEntropyOpBase::InferShape(ctx); CrossEntropyOpBase::InferShape(ctx);
PADDLE_ENFORCE(ctx->HasOutput("XShape"), PADDLE_ENFORCE_EQ(ctx->HasOutput("XShape"), true,
"Output(XShape) should be not null."); "Output(XShape) should be not null.");
PADDLE_ENFORCE(ctx->HasOutput("MatchX"), PADDLE_ENFORCE_EQ(ctx->HasOutput("MatchX"), true,
"Output(MatchX) should be not null."); "Output(MatchX) should be not null.");
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto x_dims_vec = framework::vectorize(x_dims); auto x_dims_vec = framework::vectorize(x_dims);
x_dims_vec.push_back(0); x_dims_vec.push_back(0);
...@@ -284,7 +278,8 @@ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase { ...@@ -284,7 +278,8 @@ class CrossEntropyGradientOp2 : public CrossEntropyGradientOpBase {
public: public:
using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase; using CrossEntropyGradientOpBase::CrossEntropyGradientOpBase;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
PADDLE_ENFORCE(ctx->HasInput("MatchX"), "Input(MatchX) must exist"); PADDLE_ENFORCE_EQ(ctx->HasInput("MatchX"), true,
"Input(MatchX) must exist");
CrossEntropyGradientOpBase::InferShape(ctx); CrossEntropyGradientOpBase::InferShape(ctx);
} }
......
...@@ -35,9 +35,20 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> { ...@@ -35,9 +35,20 @@ class CrossEntropyOpKernel : public framework::OpKernel<T> {
y->mutable_data<T>(ctx.GetPlace()); y->mutable_data<T>(ctx.GetPlace());
int rank = x->dims().size(); int rank = x->dims().size();
auto label_dims = labels->dims();
Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1); Tensor x_2d = framework::ReshapeToMatrix(*x, rank - 1);
Tensor labels_2d = framework::ReshapeToMatrix(*labels, rank - 1); Tensor labels_2d, y_2d;
Tensor y_2d = framework::ReshapeToMatrix(*y, rank - 1); if (label_dims.size() < rank) {
labels_2d.ShareDataWith(*labels);
labels_2d.Resize({framework::product(label_dims), 1});
y_2d.ShareDataWith(*y);
y_2d.Resize({framework::product(y->dims()), 1});
} else {
labels_2d = framework::ReshapeToMatrix(*labels, rank - 1);
y_2d = framework::ReshapeToMatrix(*y, rank - 1);
}
int axis_dim = x->dims()[rank - 1]; int axis_dim = x->dims()[rank - 1];
math::CrossEntropyFunctor<DeviceContext, T>()( math::CrossEntropyFunctor<DeviceContext, T>()(
......
...@@ -20,7 +20,7 @@ import six ...@@ -20,7 +20,7 @@ import six
class CrossEntropy2OpTestBase(OpTest): class CrossEntropy2OpTestBase(OpTest):
def initParameters(self): def initParameters(self):
return [32, 64], 'float32', -100 return [32, 64], 'float32', -100, False
def calc_output(self, logits, label, ignore_index): def calc_output(self, logits, label, ignore_index):
ret = np.zeros(shape=label.shape, dtype=logits.dtype) ret = np.zeros(shape=label.shape, dtype=logits.dtype)
...@@ -33,21 +33,24 @@ class CrossEntropy2OpTestBase(OpTest): ...@@ -33,21 +33,24 @@ class CrossEntropy2OpTestBase(OpTest):
return ret, match_x return ret, match_x
def setUp(self): def setUp(self):
self.shape, self.dtype, self.ignore_index = self.initParameters() self.shape, self.dtype, self.ignore_index, self.drop_last_dim = self.initParameters(
)
self.op_type = 'cross_entropy2' self.op_type = 'cross_entropy2'
feature_size = int(self.shape[-1]) feature_size = int(self.shape[-1])
batch_size = int(np.prod(self.shape) / feature_size) batch_size = int(np.prod(self.shape) / feature_size)
logits = (np.random.random(size=self.shape) + 1).astype(self.dtype) logits = (np.random.random(size=self.shape) + 1).astype(self.dtype)
label_shape = self.shape[0:-1] if self.drop_last_dim else self.shape[
0:-1] + [1]
label = np.random.random_integers( label = np.random.random_integers(
low=0, high=feature_size - 1, low=0, high=feature_size - 1, size=label_shape).astype('int64')
size=self.shape[0:-1] + [1]).astype('int64')
outputs, match_x = self.calc_output( outputs, match_x = self.calc_output(
np.reshape(logits, [batch_size, feature_size]), np.reshape(logits, [batch_size, feature_size]),
np.reshape(label, [batch_size, 1]), self.ignore_index) np.reshape(label, [batch_size, 1]), self.ignore_index)
self.inputs = {'X': logits, 'Label': label} self.inputs = {'X': logits, 'Label': label}
out_shape = label_shape
self.outputs = { self.outputs = {
'Y': np.reshape(outputs, label.shape), 'Y': np.reshape(outputs, out_shape),
'MatchX': np.reshape(match_x, label.shape), 'MatchX': np.reshape(match_x, self.shape[:-1] + [1]),
'XShape': np.zeros( 'XShape': np.zeros(
shape=logits.shape, dtype=logits.dtype) shape=logits.shape, dtype=logits.dtype)
} }
...@@ -65,17 +68,27 @@ class CrossEntropy2OpTestBase(OpTest): ...@@ -65,17 +68,27 @@ class CrossEntropy2OpTestBase(OpTest):
class CrossEntropy2OpTest2(CrossEntropy2OpTestBase): class CrossEntropy2OpTest2(CrossEntropy2OpTestBase):
def initParameters(self): def initParameters(self):
return [32, 64], 'float64', 3 return [32, 64], 'float64', 3, False
class CrossEntropy2OpTest2RemoveLastDim(CrossEntropy2OpTestBase):
def initParameters(self):
return [32, 64], 'float64', 3, True
class CrossEntropy2OpTest3(CrossEntropy2OpTestBase): class CrossEntropy2OpTest3(CrossEntropy2OpTestBase):
def initParameters(self): def initParameters(self):
return [4, 8, 16, 32], 'float32', -100 return [4, 8, 16, 32], 'float32', -100, False
class CrossEntropy2OpTest3RemoveLastDim(CrossEntropy2OpTestBase):
def initParameters(self):
return [4, 8, 16, 32], 'float32', -100, True
class CrossEntropy2OpTest4(CrossEntropy2OpTestBase): class CrossEntropy2OpTest4(CrossEntropy2OpTestBase):
def initParameters(self): def initParameters(self):
return [4, 8, 16, 32], 'float32', 3 return [4, 8, 16, 32], 'float32', 3, False
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -76,6 +76,23 @@ class TestCrossEntropyOp(OpTest): ...@@ -76,6 +76,23 @@ class TestCrossEntropyOp(OpTest):
self.check_grad(["X"], "Y", numeric_grad_delta=0.001) self.check_grad(["X"], "Y", numeric_grad_delta=0.001)
class TestCrossEntropyOpRemoveLastDim(TestCrossEntropyOp):
"""Test cross-entropy with discrete one-hot labels with shape [batch_size]
"""
def init_label(self):
self.label = np.random.randint(
0, self.class_num, (self.batch_size), dtype="int64")
def get_cross_entropy(self):
self.cross_entropy = np.asmatrix(
[
-np.log(self.x[i][self.label[i]])
for i in range(self.x.shape[0])
],
dtype="float64")
class TestCrossEntropyOp2(TestCrossEntropyOp): class TestCrossEntropyOp2(TestCrossEntropyOp):
"""Test cross-entropy with vectorized soft labels. """Test cross-entropy with vectorized soft labels.
""" """
...@@ -167,6 +184,22 @@ class TestCrossEntropyOp4(TestCrossEntropyOp): ...@@ -167,6 +184,22 @@ class TestCrossEntropyOp4(TestCrossEntropyOp):
self.class_num = 10 self.class_num = 10
class TestCrossEntropyOp4RemoveLastDim(TestCrossEntropyOp4):
"""Test high rank tensor cross-entropy with discrete one-hot labels with shape [batch_size]
"""
def init_label(self):
self.label_2d = np.random.randint(
0, self.class_num, (self.ins_num, 1), dtype="int64")
self.label = self.label_2d.reshape(self.shape)
def get_cross_entropy(self):
cross_entropy_2d = np.asmatrix(
[[-np.log(self.X_2d[i][self.label_2d[i][0]])]
for i in range(self.X_2d.shape[0])]).astype(self.dtype)
self.cross_entropy = np.array(cross_entropy_2d).reshape(self.shape)
class TestCrossEntropyOp5(TestCrossEntropyOp): class TestCrossEntropyOp5(TestCrossEntropyOp):
"""Test high rank tensor cross-entropy with vectorized soft labels. """Test high rank tensor cross-entropy with vectorized soft labels.
""" """
...@@ -270,6 +303,23 @@ class TestCrossEntropyOp7(TestCrossEntropyOp): ...@@ -270,6 +303,23 @@ class TestCrossEntropyOp7(TestCrossEntropyOp):
self.class_num = 10 self.class_num = 10
class TestCrossEntropyOp7RemoveLastDim(TestCrossEntropyOp7):
"""Test cross-entropy with ignore index with shape [batch_size].
"""
def init_label(self):
self.label = np.random.randint(
0, self.class_num, (self.batch_size), dtype="int64")
def get_cross_entropy(self):
self.cross_entropy = np.asmatrix(
[[-np.log(self.x[i][self.label[i]])]
if self.label[i] != self.ignore_index else [0]
for i in range(self.x.shape[0])]).astype(self.dtype)
self.cross_entropy = np.array(self.cross_entropy).reshape(
[self.batch_size]).astype(self.dtype)
# Add Fp16 test # Add Fp16 test
def create_test_class(parent, cls_name): def create_test_class(parent, cls_name):
@unittest.skipIf(not core.is_compiled_with_cuda(), @unittest.skipIf(not core.is_compiled_with_cuda(),
...@@ -298,9 +348,13 @@ create_test_class(TestCrossEntropyOp, "TestCrossEntropyF16Op") ...@@ -298,9 +348,13 @@ create_test_class(TestCrossEntropyOp, "TestCrossEntropyF16Op")
#create_test_class(TestCrossEntropyOp2, "TestCrossEntropyF16Op2") #create_test_class(TestCrossEntropyOp2, "TestCrossEntropyF16Op2")
create_test_class(TestCrossEntropyOp3, "TestCrossEntropyF16Op3") create_test_class(TestCrossEntropyOp3, "TestCrossEntropyF16Op3")
create_test_class(TestCrossEntropyOp4, "TestCrossEntropyF16Op4") create_test_class(TestCrossEntropyOp4, "TestCrossEntropyF16Op4")
create_test_class(TestCrossEntropyOp4RemoveLastDim,
"TestCrossEntropyF16Op4RemoveLastDim")
#create_test_class(TestCrossEntropyOp5, "TestCrossEntropyF16Op5") #create_test_class(TestCrossEntropyOp5, "TestCrossEntropyF16Op5")
create_test_class(TestCrossEntropyOp6, "TestCrossEntropyF16Op6") create_test_class(TestCrossEntropyOp6, "TestCrossEntropyF16Op6")
create_test_class(TestCrossEntropyOp7, "TestCrossEntropyF16Op7") create_test_class(TestCrossEntropyOp7, "TestCrossEntropyF16Op7")
create_test_class(TestCrossEntropyOp7RemoveLastDim,
"TestCrossEntropyF16Op7RemoveLastDim")
if __name__ == "__main__": if __name__ == "__main__":
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册