提交 d82cc200 编写于 作者: V Vijay Vasudevan 提交者: TensorFlower Gardener

TensorFlow: select op: add support for empty tensors: select

can propagate empty tensors.
Change: 116588842
上级 cc62d992
......@@ -77,11 +77,12 @@ class SelectOp : public OpKernel {
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
functor::BatchSelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
cond->vec<bool>(), then->flat_outer_dims<T>(),
else_->flat_outer_dims<T>());
if (output->NumElements() > 0) {
functor::BatchSelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat_outer_dims<T>(),
cond->vec<bool>(), then->flat_outer_dims<T>(),
else_->flat_outer_dims<T>());
}
}
void ComputeElementwise(OpKernelContext* ctx, const Tensor* cond,
......@@ -89,9 +90,11 @@ class SelectOp : public OpKernel {
if (!ctx->ValidateInputsAreSameShape(this)) return;
Tensor* output = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, then->shape(), &output));
functor::SelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
then->flat<T>(), else_->flat<T>());
if (output->NumElements() > 0) {
functor::SelectFunctor<Device, T> func;
func(ctx->eigen_device<Device>(), output->flat<T>(), cond->flat<bool>(),
then->flat<T>(), else_->flat<T>());
}
}
private:
......
......@@ -967,6 +967,17 @@ class SelectOpTest(tf.test.TestCase):
with self.assertRaises(ValueError):
tf.select(c, xt, yt)
def testEmptyTensor(self):
c = np.random.randint(0, 3, 0).astype(np.bool).reshape(1, 3, 0)
x = np.random.rand(1, 3, 0) * 100
y = np.random.rand(1, 3, 0) * 100
z_expected = np.zeros((1, 3, 0), dtype=np.float32)
with self.test_session():
xt = x.astype(np.float32)
yt = y.astype(np.float32)
z = tf.select(c, xt, yt).eval()
self.assertAllEqual(z_expected, z)
class BatchSelectOpTest(tf.test.TestCase):
"""Test broadcasting of Select when 'c' is a vec and 't' &'e' are rank2+."""
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册