提交 7aafb97f 编写于 作者: A A. Unique TensorFlower 提交者: TensorFlower Gardener

Swaps order of input/label arguments for weighted_cross_entropy_with_logits

Change: 141569831
上级 0cfa416b
......@@ -164,7 +164,7 @@ def sigmoid_cross_entropy_with_logits(logits, targets, name=None):
name=name)
def weighted_cross_entropy_with_logits(logits, targets, pos_weight, name=None):
def weighted_cross_entropy_with_logits(targets, logits, pos_weight, name=None):
"""Computes a weighted cross entropy.
This is like `sigmoid_cross_entropy_with_logits()` except that `pos_weight`,
......@@ -198,8 +198,8 @@ def weighted_cross_entropy_with_logits(logits, targets, pos_weight, name=None):
`logits` and `targets` must have the same type and shape.
Args:
logits: A `Tensor` of type `float32` or `float64`.
targets: A `Tensor` of the same type and shape as `logits`.
logits: A `Tensor` of type `float32` or `float64`.
pos_weight: A coefficient to use on the positive examples.
name: A name for the operation (optional).
......
......@@ -119,7 +119,7 @@ class WeightedCrossEntropyTest(tf.test.TestCase):
def testConstructionNamed(self):
with self.test_session():
logits, targets, pos_weight, _ = self._Inputs()
loss = tf.nn.weighted_cross_entropy_with_logits(logits, targets,
loss = tf.nn.weighted_cross_entropy_with_logits(targets, logits,
pos_weight, name="mybce")
self.assertEqual("mybce", loss.op.name)
......@@ -127,7 +127,7 @@ class WeightedCrossEntropyTest(tf.test.TestCase):
for use_gpu in [True, False]:
with self.test_session(use_gpu=use_gpu):
logits, targets, pos_weight, losses = self._Inputs(dtype=tf.float32)
loss = tf.nn.weighted_cross_entropy_with_logits(logits, targets,
loss = tf.nn.weighted_cross_entropy_with_logits(targets, logits,
pos_weight)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
......@@ -138,7 +138,7 @@ class WeightedCrossEntropyTest(tf.test.TestCase):
with self.test_session(use_gpu=use_gpu):
logits, targets, pos_weight, losses = self._Inputs(dtype=tf.float32,
sizes=[2, 2, 2])
loss = tf.nn.weighted_cross_entropy_with_logits(logits, targets,
loss = tf.nn.weighted_cross_entropy_with_logits(targets, logits,
pos_weight)
np_loss = np.array(losses).astype(np.float32)
tf_loss = loss.eval()
......@@ -148,7 +148,7 @@ class WeightedCrossEntropyTest(tf.test.TestCase):
sizes = [4, 2]
with self.test_session():
logits, targets, pos_weight, _ = self._Inputs(sizes=sizes)
loss = tf.nn.weighted_cross_entropy_with_logits(logits, targets,
loss = tf.nn.weighted_cross_entropy_with_logits(targets, logits,
pos_weight)
err = tf.test.compute_gradient_error(logits, sizes, loss, sizes)
print("logistic loss gradient err = ", err)
......@@ -156,7 +156,7 @@ class WeightedCrossEntropyTest(tf.test.TestCase):
def testShapeError(self):
with self.assertRaisesRegexp(ValueError, "must have the same shape"):
tf.nn.weighted_cross_entropy_with_logits([[2, 1]], [1, 2, 3], 2.0)
tf.nn.weighted_cross_entropy_with_logits([1, 2, 3], [[2, 1]], 2.0)
if __name__ == "__main__":
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册