提交 127d17a7 编写于 作者: G Geoffrey Irving 提交者: TensorFlower Gardener

Avoid race conditions in TensorShape and PartialTensorShape

Also fix MakePartialShape to correctly set is_unknown_ = false.
Change: 117285169
上级 477312ea
......@@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/kernels/bounds_check.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
......@@ -73,9 +74,10 @@ PartialTensorShape::PartialTensorShape(const TensorShapeProto& proto)
PartialTensorShape::PartialTensorShape(gtl::ArraySlice<int64> dim_sizes)
: is_unknown_(false) {
dim_sizes_.reserve(dim_sizes.size());
for (auto s : dim_sizes) {
CHECK_GE(s, -1);
dim_sizes_.push_back(s);
for (const int64& s : dim_sizes) {
const int64 dim = internal::SubtleMustCopy(s);
CHECK_GE(dim, -1);
dim_sizes_.push_back(dim);
}
}
......@@ -209,4 +211,56 @@ bool PartialTensorShape::IsCompatibleWith(const TensorShape& shape) const {
return true;
}
template <typename T>
static Status CheckAndCopyDims(const T* dims, int n,
gtl::InlinedVector<int64, 4>* out_dims) {
out_dims->reserve(n);
for (int i = 0; i < n; ++i) {
const int64 dim = internal::SubtleMustCopy(dims[i]);
if (dim >= -1) {
out_dims->push_back(dim);
} else {
return errors::InvalidArgument("Dimension ", dim, " must be >= -1");
}
}
return Status::OK();
}
#define MAKE_PARTIAL_SHAPE(T) \
Status PartialTensorShape::MakePartialShape(const T* dims, int n, \
PartialTensorShape* out) { \
out->is_unknown_ = false; \
return CheckAndCopyDims(dims, n, &out->dim_sizes_); \
}
MAKE_PARTIAL_SHAPE(int32)
MAKE_PARTIAL_SHAPE(int64)
#undef MAKE_PARTIAL_SHAPE
string PartialTensorShapeUtils::PartialShapeListString(
const gtl::ArraySlice<PartialTensorShape>& shapes) {
string result = "[";
bool first = true;
for (const PartialTensorShape& shape : shapes) {
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
first = false;
}
strings::StrAppend(&result, "]");
return result;
}
bool PartialTensorShapeUtils::AreCompatible(
const gtl::ArraySlice<PartialTensorShape>& shapes0,
const gtl::ArraySlice<PartialTensorShape>& shapes1) {
if (shapes0.size() == shapes1.size()) {
for (size_t i = 0; i < shapes0.size(); ++i) {
if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
return false;
}
}
return true;
} else {
return false;
}
}
} // namespace tensorflow
......@@ -116,59 +116,25 @@ class PartialTensorShape {
/// \brief Returns a `PartialTensorShape` whose dimensions are
/// `dims[0]`, `dims[1]`, ..., `dims[n-1]`. Values of -1 are
/// considered "unknown".
template <typename T>
static Status MakePartialShape(const T* dims, int n, PartialTensorShape* out);
static Status MakePartialShape(const int32* dims, int n,
PartialTensorShape* out);
static Status MakePartialShape(const int64* dims, int n,
PartialTensorShape* out);
private:
bool is_unknown_;
gtl::InlinedVector<int64, 4> dim_sizes_;
};
template <typename T>
Status PartialTensorShape::MakePartialShape(const T* dims, int n,
PartialTensorShape* out) {
*out = PartialTensorShape();
out->dim_sizes_.reserve(n);
for (int i = 0; i < n; ++i) {
if (dims[i] >= -1) {
out->dim_sizes_.push_back(dims[i]);
} else {
return errors::InvalidArgument("Dimension ", dims[i], " must be >= -1");
}
}
return Status::OK();
}
/// \brief Static helper routines for `PartialTensorShape`. Includes a few
/// common predicates on a partially known tensor shape.
class PartialTensorShapeUtils {
public:
static string PartialShapeListString(
const gtl::ArraySlice<PartialTensorShape>& shapes) {
string result = "[";
bool first = true;
for (const PartialTensorShape& shape : shapes) {
strings::StrAppend(&result, (first ? "" : ", "), shape.DebugString());
first = false;
}
strings::StrAppend(&result, "]");
return result;
}
const gtl::ArraySlice<PartialTensorShape>& shapes);
static bool AreCompatible(
const gtl::ArraySlice<PartialTensorShape>& shapes0,
const gtl::ArraySlice<PartialTensorShape>& shapes1) {
if (shapes0.size() == shapes1.size()) {
for (size_t i = 0; i < shapes0.size(); ++i) {
if (!shapes0[i].IsCompatibleWith(shapes1[i])) {
return false;
}
}
return true;
} else {
return false;
}
}
static bool AreCompatible(const gtl::ArraySlice<PartialTensorShape>& shapes0,
const gtl::ArraySlice<PartialTensorShape>& shapes1);
};
} // namespace tensorflow
......
......@@ -16,6 +16,7 @@ limitations under the License.
#include "tensorflow/core/framework/partial_tensor_shape.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
......@@ -217,5 +218,25 @@ TEST(PartialTensorShapeTest, PartialShapeMergeWith) {
EXPECT_EQ(test.dim_size(2), 1);
}
TEST(PartialTensorShapeTest, MakePartialShapeEmpty) {
// Empty made partial shapes should still be fully defined
const int64 dims[0] = {};
PartialTensorShape shape;
EXPECT_FALSE(shape.IsFullyDefined());
TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 0, &shape));
EXPECT_TRUE(shape.IsFullyDefined());
}
TEST(PartialTensorShapeTest, MakePartialShapeFull) {
// Check that arrays are copied through correctly
const int64 dims[3] = {7, -1, 2};
PartialTensorShape shape;
TF_ASSERT_OK(PartialTensorShape::MakePartialShape(dims, 3, &shape));
ASSERT_EQ(shape.dims(), 3);
for (int i = 0; i < 3; i++) {
EXPECT_EQ(shape.dim_size(i), dims[i]);
}
}
} // namespace
} // namespace tensorflow
......@@ -74,8 +74,8 @@ TensorShape::TensorShape(gtl::ArraySlice<int64> dim_sizes) {
set_ndims_byte(0);
set_data_type(DT_INVALID);
num_elements_ = 1;
for (auto s : dim_sizes) {
AddDim(s);
for (const int64& s : dim_sizes) {
AddDim(internal::SubtleMustCopy(s));
}
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册