未验证 提交 4eefdf9d 编写于 作者: M Mihai Maruseac 提交者: GitHub

Merge pull request #43357 from tensorflow/mm-patch-r2.2

Patch for TF 2.2.1
......@@ -250,21 +250,36 @@ void TFE_CallDLManagedTensorDeleter(void* dlm_ptr) {
}
void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
auto tf_dlm_context = GetDlContext(h, status);
if (!status->status.ok()) {
return nullptr;
}
auto* tf_dlm_data = TFE_TensorHandleDevicePointer(h, status);
if (!status->status.ok()) {
return nullptr;
}
const Tensor* tensor = GetTensorFromHandle(h, status);
TF_DataType data_type = static_cast<TF_DataType>(tensor->dtype());
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto tf_dlm_type = GetDlDataType(data_type, status);
if (!status->status.ok()) {
return nullptr;
}
TensorReference tensor_ref(*tensor); // This will call buf_->Ref()
auto* tf_dlm_tensor_ctx = new TfDlManagedTensorCtx(tensor_ref);
tf_dlm_tensor_ctx->reference = tensor_ref;
DLManagedTensor* dlm_tensor = &tf_dlm_tensor_ctx->tensor;
dlm_tensor->manager_ctx = tf_dlm_tensor_ctx;
dlm_tensor->deleter = &DLManagedTensorDeleter;
dlm_tensor->dl_tensor.ctx = GetDlContext(h, status);
dlm_tensor->dl_tensor.ctx = tf_dlm_context;
int ndim = tensor->dims();
dlm_tensor->dl_tensor.ndim = ndim;
dlm_tensor->dl_tensor.data = TFE_TensorHandleDevicePointer(h, status);
dlm_tensor->dl_tensor.dtype = GetDlDataType(data_type, status);
dlm_tensor->dl_tensor.data = tf_dlm_data;
dlm_tensor->dl_tensor.dtype = tf_dlm_type;
std::vector<int64_t>* shape_arr = &tf_dlm_tensor_ctx->shape;
std::vector<int64_t>* stride_arr = &tf_dlm_tensor_ctx->strides;
......@@ -277,13 +292,14 @@ void* TFE_HandleToDLPack(TFE_TensorHandle* h, TF_Status* status) {
(*stride_arr)[i] = (*shape_arr)[i + 1] * (*stride_arr)[i + 1];
}
dlm_tensor->dl_tensor.shape = &(*shape_arr)[0];
dlm_tensor->dl_tensor.shape = shape_arr->data();
// There are two ways to represent compact row-major data
// 1) nullptr indicates tensor is compact and row-majored.
// 2) fill in the strides array as the real case for compact row-major data.
// Here we choose option 2, since some frameworks didn't handle the strides
// argument properly.
dlm_tensor->dl_tensor.strides = &(*stride_arr)[0];
dlm_tensor->dl_tensor.strides = stride_arr->data();
dlm_tensor->dl_tensor.byte_offset =
0; // TF doesn't handle the strides and byte_offsets here
return static_cast<void*>(dlm_tensor);
......
......@@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/cc/saved_model/constants.h"
#include "tensorflow/cc/saved_model/reader.h"
#include "tensorflow/core/framework/attr_value.pb.h"
#include "tensorflow/core/framework/function.proto.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/tensor.pb.h"
#include "tensorflow/core/lib/io/path.h"
......@@ -72,26 +73,41 @@ uint64 GetLatencyMicroseconds(const uint64 start_microseconds) {
// Ensure that constant tensors loaded from the saved model have valid shape.
// Also ensure that constant nodes have a value assigned to them.
// TODO(b/154763635): this is temporary and will be replaced with a better audit
static Status ValidateNode(const NodeDef& node) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"", node.op(),
"\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
return Status::OK();
}
static Status ValidateSavedTensors(const GraphDef& graph_def) {
for (const auto& node : graph_def.node()) {
const auto node_iterator = node.attr().find("value");
if (node_iterator != node.attr().end()) {
AttrValue node_value = node_iterator->second;
if (node_value.has_tensor()) {
const PartialTensorShape node_shape(node_value.tensor().tensor_shape());
if (node_shape.num_elements() < 0) {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(), "\" (op \"",
node.op(), "\") which initializes from a tensor with ",
node_shape.num_elements(), " elements");
}
TF_RETURN_IF_ERROR(ValidateNode(node));
}
if (graph_def.has_library()) {
const FunctionDefLibrary& library = graph_def.library();
for (const auto& function : library.function()) {
for (const auto& node : function.node_def()) {
TF_RETURN_IF_ERROR(ValidateNode(node));
}
} else if (node.op() == "Const") {
return errors::FailedPrecondition(
"Saved model contains node \"", node.name(),
"\" which is a constant tensor but no value has been provided");
}
}
return Status::OK();
}
......
......@@ -304,7 +304,12 @@ Status KernelAndDeviceOp::Run(
if (outputs != nullptr) {
outputs->clear();
for (int i = 0; i < context.num_outputs(); ++i) {
outputs->push_back(Tensor(*context.mutable_output(i)));
const auto* output_tensor = context.mutable_output(i);
if (output_tensor != nullptr) {
outputs->push_back(Tensor(*output_tensor));
} else {
outputs->push_back(Tensor());
}
}
}
return Status::OK();
......
......@@ -5897,6 +5897,24 @@ tf_kernel_library(
deps = STRING_DEPS,
)
tf_cc_test(
name = "as_string_op_test",
size = "small",
srcs = ["as_string_op_test.cc"],
deps = [
":as_string_op",
":ops_testutil",
":ops_util",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
],
)
tf_kernel_library(
name = "unicode_ops",
prefix = "unicode_ops",
......
......@@ -65,9 +65,26 @@ class AsStringOp : public OpKernel {
OP_REQUIRES(ctx, !(scientific && shortest),
errors::InvalidArgument(
"Cannot select both scientific and shortest notation"));
format_ = "%";
if (!fill_string.empty()) {
switch (fill_string[0]) {
case ' ':
case '+':
case '-':
case '0':
case '#':
strings::Appendf(&format_, "%s", fill_string.c_str());
break;
default:
bool fill_not_supported = true;
OP_REQUIRES(ctx, !fill_not_supported,
errors::InvalidArgument("Fill argument not supported: \"",
fill_string, "\""));
}
}
if (width > -1) {
strings::Appendf(&format_, "%s%d", fill_string.c_str(), width);
strings::Appendf(&format_, "%d", width);
}
if (precision > -1) {
strings::Appendf(&format_, ".%d", precision);
......
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
namespace tensorflow {
namespace {
class AsStringGraphTest : public OpsTestBase {
protected:
Status Init(DataType input_type, const string& fill = "", int width = -1,
int precision = -1, bool scientific = false,
bool shortest = false) {
TF_CHECK_OK(NodeDefBuilder("op", "AsString")
.Input(FakeInput(input_type))
.Attr("fill", fill)
.Attr("precision", precision)
.Attr("scientific", scientific)
.Attr("shortest", shortest)
.Attr("width", width)
.Finalize(node_def()));
return InitOp();
}
};
TEST_F(AsStringGraphTest, Int8) {
TF_ASSERT_OK(Init(DT_INT8));
AddInputFromArray<int8>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Int64) {
TF_ASSERT_OK(Init(DT_INT64));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-42", "0", "42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatDefault) {
TF_ASSERT_OK(Init(DT_FLOAT));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatScientific) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/true));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-4.200000e+01", "0.000000e+00",
"3.141590e+00", "4.200000e+01"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatShortest) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/false, /*shortest=*/true));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-42", "0", "3.14159", "42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatPrecisionOnly) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/2));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-42.00", "0.00", "3.14", "42.00"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FloatWidthOnly) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(
&expected, {"-42.000000", "0.000000", "3.141590", "42.000000"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Float_5_2_Format) {
TF_ASSERT_OK(Init(DT_FLOAT, /*fill=*/"", /*width=*/5, /*precision=*/2));
AddInputFromArray<float>(TensorShape({4}), {-42, 0, 3.14159, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({4}));
test::FillValues<tstring>(&expected, {"-42.00", " 0.00", " 3.14", "42.00"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Complex) {
TF_ASSERT_OK(Init(DT_COMPLEX64, /*fill=*/"", /*width=*/5, /*precision=*/2));
AddInputFromArray<complex64>(TensorShape({3}), {{-4, 2}, {0}, {3.14159, -1}});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(
&expected, {"(-4.00, 2.00)", "( 0.00, 0.00)", "( 3.14,-1.00)"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, Bool) {
TF_ASSERT_OK(Init(DT_BOOL));
AddInputFromArray<bool>(TensorShape({2}), {true, false});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({2}));
test::FillValues<tstring>(&expected, {"true", "false"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, String) {
Status s = Init(DT_STRING);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(
s.error_message(),
"Value for attr 'T' of string is not in the list of allowed values"));
}
TEST_F(AsStringGraphTest, OnlyOneOfScientificAndShortest) {
Status s = Init(DT_FLOAT, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/true, /*shortest=*/true);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(
absl::StrContains(s.error_message(),
"Cannot select both scientific and shortest notation"));
}
TEST_F(AsStringGraphTest, NoShortestForNonFloat) {
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/false, /*shortest=*/true);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(
s.error_message(),
"scientific and shortest format not supported for datatype"));
}
TEST_F(AsStringGraphTest, NoScientificForNonFloat) {
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/-1,
/*scientific=*/true);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(
s.error_message(),
"scientific and shortest format not supported for datatype"));
}
TEST_F(AsStringGraphTest, NoPrecisionForNonFloat) {
Status s = Init(DT_INT32, /*fill=*/"", /*width=*/-1, /*precision=*/5);
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(s.error_message(),
"precision not supported for datatype"));
}
TEST_F(AsStringGraphTest, LongFill) {
Status s = Init(DT_INT32, /*fill=*/"asdf");
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(absl::StrContains(s.error_message(),
"Fill string must be one or fewer characters"));
}
TEST_F(AsStringGraphTest, FillWithZero) {
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"0", /*width=*/4));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-042", "0000", "0042"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FillWithSpace) {
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/" ", /*width=*/4));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {" -42", " 0", " 42"});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FillWithChar1) {
TF_ASSERT_OK(Init(DT_INT64, /*fill=*/"-", /*width=*/4));
AddInputFromArray<int64>(TensorShape({3}), {-42, 0, 42});
TF_ASSERT_OK(RunOpKernel());
Tensor expected(allocator(), DT_STRING, TensorShape({3}));
test::FillValues<tstring>(&expected, {"-42 ", "0 ", "42 "});
test::ExpectTensorEqual<tstring>(expected, *GetOutput(0));
}
TEST_F(AsStringGraphTest, FillWithChar3) {
Status s = Init(DT_INT32, /*fill=*/"s");
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(
absl::StrContains(s.error_message(), "Fill argument not supported"));
}
TEST_F(AsStringGraphTest, FillWithChar4) {
Status s = Init(DT_INT32, /*fill=*/"n");
ASSERT_EQ(error::INVALID_ARGUMENT, s.code());
ASSERT_TRUE(
absl::StrContains(s.error_message(), "Fill argument not supported"));
}
} // end namespace
} // end namespace tensorflow
......@@ -121,7 +121,7 @@ class BoostedTreesTrainingPredictOp : public OpKernel {
auto do_work = [&resource, &bucketized_features, &cached_tree_ids,
&cached_node_ids, &output_partial_logits,
&output_node_ids, latest_tree,
this](int32 start, int32 end) {
this](int64 start, int64 end) {
for (int32 i = start; i < end; ++i) {
int32 tree_id = cached_tree_ids(i);
int32 node_id = cached_node_ids(i);
......@@ -237,7 +237,7 @@ class BoostedTreesPredictOp : public OpKernel {
const int32 last_tree = resource->num_trees() - 1;
auto do_work = [&resource, &bucketized_features, &output_logits, last_tree,
this](int32 start, int32 end) {
this](int64 start, int64 end) {
for (int32 i = start; i < end; ++i) {
std::vector<float> tree_logits(logits_dimension_, 0.0);
int32 tree_id = 0;
......@@ -340,7 +340,7 @@ class BoostedTreesExampleDebugOutputsOp : public OpKernel {
// path. Note: feature_ids has one less value than logits_path because the
// first value of each logit path will be the bias.
auto do_work = [&resource, &bucketized_features, &output_debug_info,
last_tree](int32 start, int32 end) {
last_tree](int64 start, int64 end) {
for (int32 i = start; i < end; ++i) {
// Proto to store debug outputs, per example.
boosted_trees::DebugOutput example_debug_info;
......
......@@ -95,7 +95,8 @@ struct NthElementFunctor<CPUDevice, T> {
const int last_dim = input_tensor.dim_size(input_tensor.dims() - 1);
// Allocate each row to different shard.
auto SubNthElement = [&, input, output, last_dim, n](int start, int limit) {
auto SubNthElement = [&, input, output, last_dim, n](int64 start,
int64 limit) {
// std::nth_element would rearrange the array, so we need a new buffer.
std::vector<T> buf(last_dim);
......
......@@ -69,8 +69,8 @@ struct TruncatedNormalFunctor<CPUDevice, T> {
auto DoWork = [samples_per_batch, num_elements, &ctx, &means, &stddevs,
&minvals, &maxvals, &gen, &output,
kStdDevsInsideBoundsToUseRandnSampler](int start_batch,
int limit_batch) {
kStdDevsInsideBoundsToUseRandnSampler](int64 start_batch,
int64 limit_batch) {
// Capturing "gen" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "gen" by reference and explicitly do a copy assignment here.
......
......@@ -182,7 +182,7 @@ struct RandomBinomialFunctor<CPUDevice, T, U> {
// the sample shape and [H1, ... Hm] for the batch shape of the samples.
// We have B1 * ... * Bk samples per batch member we need.
auto DoWork = [num_batches, samples_per_batch, &bcast, &counts, &probs,
&gen, &output](int start_output, int limit_output) {
&gen, &output](int64 start_output, int64 limit_output) {
// Vectorized intermediate calculations for uniform rejection sampling.
// We always generate at most 4 samples.
Eigen::array<T, 4> z;
......
......@@ -205,7 +205,7 @@ class RandomGammaOp : public OpKernel {
// avoid a couple flops which can be done on a per-alpha basis.
auto DoWork = [samples_per_alpha, num_alphas, &rng, samples_flat,
alpha_flat](int start_output, int limit_output) {
alpha_flat](int64 start_output, int64 limit_output) {
using Eigen::numext::exp;
using Eigen::numext::log;
using Eigen::numext::pow;
......
......@@ -97,7 +97,7 @@ struct PoissonFunctor<CPUDevice, T, U> {
typedef random::UniformDistribution<random::PhiloxRandom, CT> Uniform;
auto DoWork = [num_samples, num_rate, &rng, samples_flat, rate_flat](
int start_output, int limit_output) {
int64 start_output, int64 limit_output) {
// Capturing "rng" by value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "rng" by reference and explicitly do a copy assignment.
......
......@@ -16,6 +16,7 @@ limitations under the License.
// See docs in ../ops/data_flow_ops.cc.
#include <limits.h>
#include <vector>
#include "tensorflow/core/common_runtime/device.h"
......@@ -27,6 +28,7 @@ limitations under the License.
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
......@@ -42,7 +44,11 @@ class GetSessionHandleOp : public OpKernel {
void Compute(OpKernelContext* ctx) override {
const Tensor& val = ctx->input(0);
int64 id = ctx->session_state()->GetNewId();
auto session_state = ctx->session_state();
OP_REQUIRES(ctx, session_state != nullptr,
errors::FailedPrecondition(
"GetSessionHandle called on null session state"));
int64 id = session_state->GetNewId();
TensorStore::TensorAndKey tk{val, id, requested_device()};
OP_REQUIRES_OK(ctx, ctx->tensor_store()->AddTensor(name(), tk));
......
......@@ -232,6 +232,9 @@ class SparseFillEmptyRowsGradOp : public OpKernel {
context, TensorShapeUtils::IsVector(reverse_index_map_t->shape()),
errors::InvalidArgument("reverse_index_map must be a vector, saw: ",
reverse_index_map_t->shape().DebugString()));
OP_REQUIRES(context, TensorShapeUtils::IsVector(grad_values_t->shape()),
errors::InvalidArgument("grad_values must be a vector, saw: ",
grad_values_t->shape().DebugString()));
const auto reverse_index_map = reverse_index_map_t->vec<int64>();
const auto grad_values = grad_values_t->vec<T>();
......@@ -260,8 +263,13 @@ class SparseFillEmptyRowsGradOp : public OpKernel {
// Locate the index of the output of the forward prop associated
// with this location in the input of the forward prop. Copy
// the gradient into it. Mark it as visited.
d_values(i) = grad_values(reverse_index_map(i));
visited(reverse_index_map(i)) = true;
int64 reverse_index = reverse_index_map(i);
OP_REQUIRES(
context, 0 <= reverse_index && reverse_index < N_full,
errors::InvalidArgument("Elements in reverse index must be in [0, ",
N_full, ") but got ", reverse_index));
d_values(i) = grad_values(reverse_index);
visited(reverse_index) = true;
}
for (int j = 0; j < N_full; ++j) {
// The default value gradient gets the accumulated remainder of
......
......@@ -252,7 +252,7 @@ class StatelessRandomGammaOp : public StatelessRandomOpBase {
// avoid a couple flops which can be done on a per-alpha basis.
auto DoWork = [samples_per_alpha, num_alphas, &random, samples_flat,
alpha_flat](int start_output, int limit_output) {
alpha_flat](int64 start_output, int64 limit_output) {
// Capturing "random" by-value would only make a copy for the _shared_
// lambda. Since we want to let each worker have its own copy, we pass
// "random" by reference and explicitly do a copy assignment.
......
......@@ -19,6 +19,7 @@ limitations under the License.
#include "absl/strings/ascii.h"
#include "absl/strings/str_cat.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace text {
......@@ -60,6 +61,18 @@ class StringNGramsOp : public tensorflow::OpKernel {
OP_REQUIRES_OK(context, context->input("data_splits", &splits));
const auto& splits_vec = splits->flat<SPLITS_TYPE>();
// Validate that the splits are valid indices into data
const int input_data_size = data->flat<tstring>().size();
const int splits_vec_size = splits_vec.size();
for (int i = 0; i < splits_vec_size; ++i) {
bool valid_splits = splits_vec(i) >= 0;
valid_splits = valid_splits && (splits_vec(i) <= input_data_size);
OP_REQUIRES(
context, valid_splits,
errors::InvalidArgument("Invalid split value ", splits_vec(i),
", must be in [0,", input_data_size, "]"));
}
int num_batch_items = splits_vec.size() - 1;
tensorflow::Tensor* ngrams_splits;
OP_REQUIRES_OK(
......
......@@ -136,7 +136,7 @@ struct TopKFunctor<CPUDevice, T> {
return Status::OK();
}
auto SortIndices = [&](int start_batch, int limit_batch) {
auto SortIndices = [&](int64 start_batch, int64 limit_batch) {
for (int32 b = start_batch; b < limit_batch; ++b) {
const T* input_data = &input(b, 0);
const auto stable_comp = [input_data](const int32 a, const int32 b) {
......
......@@ -18,6 +18,7 @@ limitations under the License.
#include <algorithm>
#include "tensorflow/lite/arena_planner.h"
#include "tensorflow/lite/builtin_ops.h"
#include "tensorflow/lite/c/common.h"
#include "tensorflow/lite/context_util.h"
#include "tensorflow/lite/core/api/tensor_utils.h"
......@@ -560,6 +561,33 @@ TfLiteStatus Subgraph::CheckTensorIndices(const char* label, const int* indices,
return kTfLiteOk;
}
// We have two arrays and we need to check that elements from one array don't
// show up in the other. We could sort both arrays and then iterate with two
// pointers from start to finish always increasing the smaller one but since
// these arrays are usually short (<25 elements for inputs, usually <3 for
// outputs), this might be slower than the naive approach (if arrays have size n
// and m, with n >> m ~ O(1), first approach is O(nlogn) whereas the other is
// O(n)). Plus, sorting the input and output arrays might not be something we
// want as it destroys ordering of elements.
//
// If it turns out that this is an issue, we can switch to the other algorithm.
TfLiteStatus Subgraph::CheckInputAndOutputForOverlap(const int* input_indices,
int num_inputs,
const int* output_indices,
int num_outputs) {
for (int i = 0; i < num_inputs; i++) {
for (int j = 0; j < num_outputs; j++) {
if (input_indices[i] == output_indices[j]) {
ReportError("Tensor %d is both input %d and output %d\n",
input_indices[i], i, j);
consistent_ = false;
return kTfLiteError;
}
}
}
return kTfLiteOk;
}
namespace {
// Multiply two sizes and return true if overflow occurred;
// This is based off tensorflow/overflow.h but is simpler as we already
......@@ -681,6 +709,16 @@ TfLiteStatus Subgraph::AddNodeWithParameters(
&context_,
CheckTensorIndices("node outputs", outputs.data(), outputs.size()));
// For builtin ops, inputs and outputs must not overlap. Custom ops must do
// this check by themselves if they don't support overlapping tensors. This
// distinction is to allow custom ops to just forward a tensor, reusing it as
// both input and output.
if (builtin_data != nullptr) {
TF_LITE_ENSURE_OK(&context_, CheckInputAndOutputForOverlap(
inputs.data(), inputs.size(),
outputs.data(), outputs.size()));
}
int new_node_index = nodes_and_registration_.size();
if (node_index) *node_index = new_node_index;
nodes_and_registration_.resize(nodes_and_registration_.size() + 1);
......@@ -897,6 +935,19 @@ TfLiteStatus Subgraph::Invoke() {
tensor->data_is_stale) {
TF_LITE_ENSURE_STATUS(EnsureTensorDataIsReadable(tensor_index));
}
if (tensor->data.raw == nullptr && tensor->bytes > 0) {
if (registration.builtin_code == kTfLiteBuiltinReshape && i == 1) {
// In general, having a tensor here with no buffer will be an error.
// However, for the reshape operator, the second input tensor is only
// used for the shape, not for the data. Thus, null buffer is ok.
continue;
} else {
// In all other cases, we need to return an error as otherwise we will
// trigger a null pointer dereference (likely).
ReportError("Input tensor %d lacks data", tensor_index);
return kTfLiteError;
}
}
}
if (check_cancelled_func_ != nullptr &&
......
......@@ -415,6 +415,15 @@ class Subgraph {
TfLiteStatus CheckTensorIndices(const char* label, const int* indices,
int length);
// Check that the input indices and the output indices don't overlap.
// This is needed because same tensor must not be used both as input and
// output for an operator.
// NOTE: this changes consistent_ to be false if indices are out of bounds.
TfLiteStatus CheckInputAndOutputForOverlap(const int* input_indices,
int num_inputs,
const int* output_indices,
int num_outputs);
// Compute the number of bytes required to represent a tensor with dimensions
// specified by the array dims (of length dims_size). Returns the status code
// and bytes.
......
......@@ -67,6 +67,9 @@ inline bool ResolveAxis(const int num_dims, const int* axis,
// eg: For num_dims=3, [0, 1, 2] is the same as [-3, -2, -1] */
int current = axis[idx] < 0 ? (axis[idx] + num_dims) : axis[idx];
TFLITE_DCHECK(current >= 0 && current < num_dims);
if (current < 0 || current >= num_dims) {
return false;
}
bool is_dup = false;
for (int j = 0; j < *out_num_axis; ++j) {
if (out_axis[j] == current) {
......
......@@ -432,7 +432,7 @@ int MatchingArraySize(const ArrayType1& array1, int index1,
inline int MatchingDim(const RuntimeShape& shape1, int index1,
const RuntimeShape& shape2, int index2) {
TFLITE_DCHECK_EQ(shape1.Dims(index1), shape2.Dims(index2));
return shape1.Dims(index1);
return std::min(shape1.Dims(index1), shape2.Dims(index2));
}
template <typename... Args>
......
......@@ -30,31 +30,48 @@ inline int SizeOfDimension(const TfLiteTensor* t, int dim) {
}
inline const TfLiteTensor* GetInput(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context
->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];
const int tensor_index = flatbuffers::EndianScalar(node->inputs->data[index]);
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
// Note: You must check if result is not null:
// TfLiteTensor* my_tensor = GetVariableInput(context, node, kMyTensorIdx);
// TF_LITE_ENSURE(context, my_tensor != nullptr);
inline TfLiteTensor* GetVariableInput(TfLiteContext* context,
const TfLiteNode* node, int index) {
TfLiteTensor* tensor =
&context->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];
const int tensor_index = flatbuffers::EndianScalar(node->inputs->data[index]);
if (tensor_index < 0) {
return nullptr;
}
TfLiteTensor* tensor = &context->tensors[tensor_index];
return (tensor->is_variable) ? tensor : nullptr;
}
inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) {
return &context
->tensors[flatbuffers::EndianScalar(node->outputs->data[index])];
const int tensor_index = flatbuffers::EndianScalar(node->outputs->data[index]);
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[flatbuffers::EndianScalar(
node->temporaries->data[index])];
const int tensor_index = flatbuffers::EndianScalar(node->temporaries->data[index]);
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->intermediates->data[index]];
const int tensor_index = flatbuffers::EndianScalar(node->intermediates->data[index]);
if (tensor_index < 0) {
return nullptr;
}
return &context->tensors[tensor_index];
}
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
......@@ -77,13 +94,7 @@ inline int64_t NumElements(const TfLiteTensor* t) {
inline const TfLiteTensor* GetOptionalInputTensor(TfLiteContext* context,
const TfLiteNode* node,
int index) {
const bool use_tensor = index < node->inputs->size &&
node->inputs->data[index] != kTfLiteOptionalTensor;
if (use_tensor) {
return &context
->tensors[flatbuffers::EndianScalar(node->inputs->data[index])];
}
return nullptr;
return GetInput(context, node, index);
}
// Determines whether tensor is constant.
......
......@@ -32,11 +32,24 @@ TfLiteStatus ResizeOutputTensor(TfLiteContext* context,
const TfLiteTensor* data,
const TfLiteTensor* segment_ids,
TfLiteTensor* output) {
int max_index = -1;
// Segment ids should be of same cardinality as first input dimension and they
// should be increasing by at most 1, from 0 (e.g., [0, 0, 1, 2, 3] is valid)
const int segment_id_size = segment_ids->dims->data[0];
if (segment_id_size > 0) {
max_index = segment_ids->data.i32[segment_id_size - 1];
TF_LITE_ENSURE_EQ(context, segment_id_size, data->dims->data[0]);
int previous_segment_id = -1;
for (int i = 0; i < segment_id_size; i++) {
const int current_segment_id = GetTensorData<int32_t>(segment_ids)[i];
if (i == 0) {
TF_LITE_ENSURE_EQ(context, current_segment_id, 0);
} else {
int delta = current_segment_id - previous_segment_id;
TF_LITE_ENSURE(context, delta == 0 || delta == 1);
}
previous_segment_id = current_segment_id;
}
const int max_index = previous_segment_id;
const int data_rank = NumDimensions(data);
TfLiteIntArray* output_shape = TfLiteIntArrayCreate(NumDimensions(data));
output_shape->data[0] = max_index + 1;
......
......@@ -108,5 +108,37 @@ TEST(SegmentSumOpModelTest, Float32Test_ThreeDimensions) {
EXPECT_THAT(model.GetOutputShape(), ElementsAreArray({2, 2, 1}));
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotSorted) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 1});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotConsecutive) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 3, 5});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNegative) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {3}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {-1, 0, 1});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
TEST(SegmentSumOpModelTest, TestFailIfSegmentsAreNotTheRightCardinality) {
SegmentSumOpModel<int32_t> model({TensorType_INT32, {3, 2}},
{TensorType_INT32, {2}});
model.PopulateTensor<int32_t>(model.data(), {1, 2, 3, 4, 5, 6});
model.PopulateTensor<int32_t>(model.segment_ids(), {0, 1});
ASSERT_EQ(model.InvokeUnchecked(), kTfLiteError);
}
} // namespace
} // namespace tflite
......@@ -728,6 +728,11 @@ TfLiteStatus InterpreterBuilder::operator()(
return cleanup_and_error();
}
if (!buffers) {
error_reporter_->Report("No buffers in the model.\n");
return cleanup_and_error();
}
interpreter->reset(new Interpreter(error_reporter_));
(*interpreter)->SetNumThreads(num_threads);
if (subgraphs->Length() > 1) {
......@@ -745,9 +750,9 @@ TfLiteStatus InterpreterBuilder::operator()(
(*interpreter)->subgraph(subgraph_index);
auto operators = subgraph->operators();
auto tensors = subgraph->tensors();
if (!operators || !tensors || !buffers) {
if (!operators || !tensors) {
error_reporter_->Report(
"Did not get operators, tensors, or buffers in subgraph %d.\n",
"Did not get operators or tensors in subgraph %d.\n",
subgraph_index);
return cleanup_and_error();
}
......
......@@ -20,9 +20,11 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.dlpack import dlpack
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.platform import test
......@@ -95,6 +97,12 @@ class DLPackTest(parameterized.TestCase, test.TestCase):
self.assertRaisesRegex(Exception, ".* is not supported by dlpack",
UnsupportedComplex64)
def testMustPassTensorArgumentToDLPack(self):
with self.assertRaisesRegex(
errors.InvalidArgumentError,
"The argument to `to_dlpack` must be a TF tensor, not Python object"):
dlpack.to_dlpack([1])
if __name__ == "__main__":
ops.enable_eager_execution()
......
......@@ -4554,6 +4554,14 @@ class ControlFlowTest(test.TestCase, parameterized.TestCase):
result = control_flow_ops.merge([v_f, v_t])
self.evaluate(result)
def testSwitchEagerMode(self):
if not context.executing_eagerly():
return
input_data = [1, 2, 3, 4]
vf, vt = control_flow_ops.switch(input_data, False)
self.assertAllEqual(vf, input_data)
self.assertAllEqual(vt, [])
@test_util.run_deprecated_v1
def testQIntArgAndRet(self):
......
......@@ -18,16 +18,22 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from absl.testing import parameterized
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gen_string_ops
from tensorflow.python.platform import test
@test_util.run_all_in_graph_and_eager_modes
class RawOpsTest(test.TestCase):
@test_util.disable_tfrt
class RawOpsTest(test.TestCase, parameterized.TestCase):
def testSimple(self):
x = constant_op.constant(1)
......@@ -58,6 +64,29 @@ class RawOpsTest(test.TestCase):
gen_math_ops.Any(input=x, axis=0),
gen_math_ops.Any(input=x, axis=0, keep_dims=False))
@parameterized.parameters([[0, 8]], [[-1, 6]])
def testStringNGramsBadDataSplits(self, splits):
data = ["aa", "bb", "cc", "dd", "ee", "ff"]
with self.assertRaisesRegex(errors.InvalidArgumentError,
"Invalid split value"):
self.evaluate(
gen_string_ops.string_n_grams(
data=data,
data_splits=splits,
separator="",
ngram_widths=[2],
left_pad="",
right_pad="",
pad_width=0,
preserve_short_sequences=False))
def testGetSessionHandle(self):
if context.executing_eagerly():
with self.assertRaisesRegex(
errors.FailedPreconditionError,
"GetSessionHandle called on null session state"):
gen_data_flow_ops.GetSessionHandle(value=[1])
if __name__ == "__main__":
ops.enable_eager_execution()
......
......@@ -21,6 +21,7 @@ from __future__ import print_function
from absl.testing import parameterized
import numpy as np
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
......@@ -29,6 +30,7 @@ from tensorflow.python.framework import test_util
# Need array_grad to register gradient for Identity.
from tensorflow.python.ops import array_grad # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_sparse_ops
from tensorflow.python.ops import gradient_checker_v2 as gradient_checker
from tensorflow.python.ops import math_ops
# Need sparse_grad to register gradient for SparseToDense.
......@@ -181,5 +183,57 @@ class SparseOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
self.assertAllEqual(expected, result)
@test_util.run_all_in_graph_and_eager_modes
class RawOpsTest(test_util.TensorFlowTestCase, parameterized.TestCase):
def testSparseFillEmptyRowsGrad(self):
reverse_index_map = [2, 1]
grad_values = [0, 1, 2, 3]
d_values, d_default_value = self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
self.assertAllEqual([2, 1], d_values)
self.assertEqual(3, d_default_value)
def testSparseFillEmptyRowsGradNegativeIndexMapValue(self):
reverse_index_map = [2, -1]
grad_values = [0, 1, 2, 3]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r'Elements in reverse index must be in \[0, 4\)'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
def testSparseFillEmptyRowsGradLargeIndexMapValue(self):
reverse_index_map = [2, 10]
grad_values = [0, 1, 2, 3]
with self.assertRaisesRegex(
errors.InvalidArgumentError,
r'Elements in reverse index must be in \[0, 4\)'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
def testSparseFillEmptyRowsGradMatrix(self):
reverse_index_map = [0, 1]
grad_values = [[0, 1], [2, 3]]
# Note: Eager mode and graph mode throw different errors here. Graph mode
# will fail with a ValueError from the shape checking logic, while Eager
# will fail with an InvalidArgumentError from the kernel itself.
if context.executing_eagerly():
with self.assertRaisesRegex(errors.InvalidArgumentError,
r'grad_values must be a vector'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
else:
with self.assertRaisesRegex(ValueError,
r'Shape must be rank 1 but is rank 2'):
self.evaluate(
gen_sparse_ops.SparseFillEmptyRowsGrad(
reverse_index_map=reverse_index_map, grad_values=grad_values))
if __name__ == '__main__':
googletest.main()
......@@ -1051,9 +1051,16 @@ PYBIND11_MODULE(_pywrap_tfe, m) {
// DLPack functions
m.def("TFE_ToDlpackCapsule", [](py::handle& o) {
PyObject* eager_tensor_pyobject_ptr = o.ptr();
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
tensorflow::Safe_TF_StatusPtr status =
tensorflow::make_safe(TF_NewStatus());
if (!EagerTensor_CheckExact(eager_tensor_pyobject_ptr)) {
status->status = tensorflow::errors::InvalidArgument(
"The argument to `to_dlpack` must be a TF tensor, not Python object");
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
}
TFE_TensorHandle* thandle = EagerTensor_Handle(eager_tensor_pyobject_ptr);
void* dlm_ptr = tensorflow::TFE_HandleToDLPack(thandle, status.get());
tensorflow::MaybeRaiseRegisteredFromTFStatus(status.get());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册