未验证 提交 2de7f3cf 编写于 作者: H Hongyu Liu 提交者: GitHub

Merge pull request #16799 from phlrain/sigmoid_corss_entropy_support_high_rank

supprt high rank
...@@ -34,15 +34,22 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel { ...@@ -34,15 +34,22 @@ class SigmoidCrossEntropyWithLogitsOp : public framework::OperatorWithKernel {
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label"); auto labels_dims = ctx->GetInputDim("Label");
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2, int rank = x_dims.size();
"Input(Label)'s rank should be 2."); PADDLE_ENFORCE_EQ(rank, labels_dims.size(),
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], "Input(X) and Input(Label) shall have the same rank.");
"The 1st dimension of Input(X) and Input(Label) should " bool check = true;
"be equal."); if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], framework::product(labels_dims) <= 0)) {
"The 2nd dimension of Input(X) and Input(Label) should " check = false;
"be equal."); }
if (check) {
PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
framework::slice_ddim(labels_dims, 0, rank),
"Input(X) and Input(Label) shall have the same shape "
"except the last dimension.");
}
ctx->ShareDim("X", /*->*/ "Out"); ctx->ShareDim("X", /*->*/ "Out");
ctx->ShareLoD("X", /*->*/ "Out"); ctx->ShareLoD("X", /*->*/ "Out");
...@@ -65,23 +72,24 @@ class SigmoidCrossEntropyWithLogitsGradOp ...@@ -65,23 +72,24 @@ class SigmoidCrossEntropyWithLogitsGradOp
auto x_dims = ctx->GetInputDim("X"); auto x_dims = ctx->GetInputDim("X");
auto labels_dims = ctx->GetInputDim("Label"); auto labels_dims = ctx->GetInputDim("Label");
auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out")); auto dout_dims = ctx->GetInputDim(framework::GradVarName("Out"));
PADDLE_ENFORCE_EQ(x_dims.size(), 2, "Input(X)'s rank should be 2.");
PADDLE_ENFORCE_EQ(labels_dims.size(), 2, int rank = x_dims.size();
"Input(Label)'s rank should be 2."); bool check = true;
PADDLE_ENFORCE_EQ(dout_dims.size(), 2, if ((!ctx->IsRuntime()) && (framework::product(x_dims) <= 0 ||
"Input(Out@Grad)'s rank should be 2."); framework::product(labels_dims) <= 0)) {
PADDLE_ENFORCE_EQ(x_dims[0], labels_dims[0], check = false;
"The 1st dimension of Input(X) and Input(Label) should " }
"be equal.");
PADDLE_ENFORCE_EQ(x_dims[1], labels_dims[1], if (check) {
"The 2nd dimension of Input(X) and Input(Label) should " PADDLE_ENFORCE_EQ(framework::slice_ddim(x_dims, 0, rank),
"be equal."); framework::slice_ddim(labels_dims, 0, rank),
PADDLE_ENFORCE_EQ(x_dims[0], dout_dims[0], "Input(X) and Input(Label) shall have the same shape.");
"The 1st dimension of Input(X) and Input(Out@Grad) "
"should be equal."); PADDLE_ENFORCE_EQ(
PADDLE_ENFORCE_EQ(x_dims[1], dout_dims[1], framework::slice_ddim(x_dims, 0, rank),
"The 2nd dimension of Input(X) and Input(Out@Grad) " framework::slice_ddim(dout_dims, 0, rank),
"should be equal."); "Input(X) and Input(Out@Grad) shall have the same shape.");
}
ctx->SetOutputDim(framework::GradVarName("X"), x_dims); ctx->SetOutputDim(framework::GradVarName("X"), x_dims);
} }
......
...@@ -149,5 +149,98 @@ class TestSigmoidCrossEntropyWithNorm(OpTest): ...@@ -149,5 +149,98 @@ class TestSigmoidCrossEntropyWithNorm(OpTest):
self.check_grad(['X'], 'Out') self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithLogitsOp5(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with probabalistic label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")),
'Label': np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithNorm2(OpTest):
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = [10, 10]
num_classes = 20
ignore_index = -1
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")),
'Label': np.random.randint(-1, 2, tuple(batch_size + [num_classes]))
.astype("float32")
}
self.attrs = {'ignore_index': ignore_index, 'normalize': True}
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
out = -term1 - term2
out[np.where(self.inputs['Label'] == ignore_index)] = 0
if self.attrs['normalize']:
out = out / float(
np.where(self.inputs['Label'] != ignore_index)[0].size)
self.outputs = {'Out': out}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
class TestSigmoidCrossEntropyWithLogitsOp6(OpTest):
"""Test sigmoid_cross_entropy_with_logit_op with binary label
"""
def setUp(self):
self.op_type = "sigmoid_cross_entropy_with_logits"
batch_size = [10, 10]
num_classes = 20
self.inputs = {
'X': logit(
np.random.uniform(0, 1, tuple(batch_size + [num_classes]))
.astype("float32")),
'Label': np.random.randint(0, 2, tuple(batch_size + [num_classes]))
.astype("float32")
}
# Fw Pass is implemented as elementwise sigmoid followed by
# elementwise logistic loss
# Label * -log(sigmoid(X)) + (1 - label) * -log(1 - sigmoid(X))
sigmoid_X = expit(self.inputs['X'])
term1 = self.inputs['Label'] * np.log(sigmoid_X)
term2 = (1 - self.inputs['Label']) * np.log(1 - sigmoid_X)
self.outputs = {'Out': -term1 - term2}
def test_check_output(self):
self.check_output()
def test_check_grad(self):
self.check_grad(['X'], 'Out')
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册