提交 083c4dc9 编写于 作者: A A. Unique TensorFlower 提交者: Gunhan Gulsoy

Change StridedSlice to error on scalar input, in both

the shape inference function and the kernel.
Change: 134434589
上级 425f49e1
......@@ -72,7 +72,10 @@ REGISTER_OP("CudnnRNNParamsSize")
.Attr(kRNNInputModeAttrs)
.Attr(kRNNDirectionAttrs)
.Output("params_size: S")
.SetShapeFn(shape_inference::ScalarShape)
.SetShapeFn([](InferenceContext* c) {
c->set_output(0, c->Vector(1));
return Status::OK();
})
.Doc(strings::StrCat(R"doc(
Return the params size that can be used by the Cudnn RNN model. Subsequent
weight allocation and initialization should use this size.
......
......@@ -26,7 +26,7 @@ namespace tensorflow {
TEST(CudnnRNNOpsTest, ParamsSize_ShapeFn) {
ShapeInferenceTestOp op("CudnnRNNParamsSize");
INFER_OK(op, "[1];[1];[1]", "[]");
INFER_OK(op, "[1];[1];[1]", "[1]");
}
TEST(CudnnRNNOpsTest, ForwardLstm_ShapeFn) {
......
......@@ -94,8 +94,8 @@ struct StridedSliceDenseSpec {
} // namespace
template <class T>
static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
StridedSliceDenseSpec* dense) {
static Status TF_MUST_USE_RESULT BuildDenseSpec(
const StridedSliceSparseSpec& sparse, StridedSliceDenseSpec* dense) {
// Build expanded begin, end, strides, begin_mask, end_mask
// to remove any ellipsis
dense->begin.resize(dense->dims);
......@@ -130,6 +130,12 @@ static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
} else if ((1 << i) & sparse.new_axis_mask) {
dense->final_shape_gather_indices.push_back(kNewAxis);
} else {
if (full_index == dense->begin.size()) {
return errors::InvalidArgument("Index out of range using input dim ",
full_index, "; input has only ",
dense->dims, " dims");
}
// Gather slicing spec into appropriate index
dense->begin[full_index] = internal::SubtleMustCopy<T>(begin_flat(i));
dense->end[full_index] = internal::SubtleMustCopy<T>(end_flat(i));
......@@ -154,6 +160,7 @@ static void BuildDenseSpec(const StridedSliceSparseSpec& sparse,
}
}
}
return Status::OK();
}
Status ValidateStridedSliceOp(
......@@ -233,9 +240,9 @@ Status ValidateStridedSliceOp(
input_shape.dims(), 0, 0, *begin, *end, *strides};
if (begin_tensor.dtype() == DT_INT32) {
BuildDenseSpec<int32>(sparse_spec, &dense_spec);
TF_RETURN_IF_ERROR(BuildDenseSpec<int32>(sparse_spec, &dense_spec));
} else if (begin_tensor.dtype() == DT_INT64) {
BuildDenseSpec<int64>(sparse_spec, &dense_spec);
TF_RETURN_IF_ERROR(BuildDenseSpec<int64>(sparse_spec, &dense_spec));
} else {
LOG(FATAL) << "begin must be either int32 or int64";
}
......
......@@ -77,6 +77,34 @@ class SliceTest(tf.test.TestCase):
slice_val = slice_t.eval()
self.assertAllEqual(slice_val, inp[lo:hi])
def testScalarInput(self):
input_val = 0
with self.test_session() as sess:
# Test with constant input; shape inference fails.
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
tf.constant(input_val)[:].get_shape()
# Test evaluating with non-constant input; kernel execution fails.
input_t = tf.placeholder(tf.int32)
slice_t = input_t[:]
with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
"out of range"):
sess.run([slice_t], feed_dict={input_t: input_val})
def testInvalidIndex(self):
input_val = [1, 2]
with self.test_session() as sess:
# Test with constant input; shape inference fails.
with self.assertRaisesWithPredicateMatch(ValueError, "out of range"):
tf.constant(input_val)[1:, 1:].get_shape()
# Test evaluating with non-constant input; kernel execution fails.
input_t = tf.placeholder(tf.int32)
slice_t = input_t[1:, 1:]
with self.assertRaisesWithPredicateMatch(tf.errors.InvalidArgumentError,
"out of range"):
sess.run([slice_t], feed_dict={input_t: input_val})
def _testSliceMatrixDim0(self, x, begin, size):
with self.test_session(use_gpu=True):
tf_ans = tf.slice(x, [begin, 0], [size, x.shape[1]]).eval()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册