提交 e7d4fd62 编写于 作者: V Vijay Vasudevan 提交者: TensorFlower Gardener

Remove some uses of RE2 library that can be replaced with simpler

logic.
Change: 133908733
上级 454323fd
......@@ -26,7 +26,6 @@ limitations under the License.
#include "tensorflow/core/lib/core/notification.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
......@@ -46,14 +45,11 @@ class RemoteDeviceTest : public ::testing::Test {
(*options.config.mutable_device_count())["CPU"] = 2;
TF_CHECK_OK(test::TestCluster::MakeTestCluster(options, 1, &cluster_));
const string& hostport = cluster_->targets()[0];
string host;
int port;
CHECK(RE2::FullMatch(hostport, "(.+):(\\d+)", &host, &port));
GrpcChannelSpec spec;
spec.AddHostPortsJob("localhost", {hostport});
worker_cache_.reset(
NewGrpcWorkerCache(NewGrpcChannelCache(spec, NewHostPortGrpcChannel)));
remote_name_ = strings::StrCat("/job:", host, "/replica:0/task:0");
remote_name_ = "/job:localhost/replica:0/task:0";
wi_.reset(worker_cache_->CreateWorker(remote_name_));
}
......
......@@ -91,6 +91,7 @@ cc_library(
hdrs = ["grpc_channel.h"],
deps = [
":grpc_util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"@grpc//:grpc++_unsecure",
......
......@@ -24,14 +24,15 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/gtl/map_util.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/thread_annotations.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/device_name_utils.h"
namespace tensorflow {
......@@ -53,10 +54,11 @@ SharedGrpcChannelPtr NewHostPortGrpcChannel(const string& target) {
namespace {
Status ValidateHostPortPair(const string& host_port) {
const static RE2* kHostPortRE = new RE2("([^:/]+):(\\d+)");
string host;
int port;
if (!RE2::FullMatch(host_port, *kHostPortRE, &host, &port)) {
uint32 port;
std::vector<string> parts = str_util::Split(host_port, ':');
// Must be host:port, port must be a number, host must not contain a '/'.
if (parts.size() != 2 || !strings::safe_strtou32(parts[1], &port) ||
parts[0].find("/") != string::npos) {
return errors::InvalidArgument("Could not interpret \"", host_port,
"\" as a host-port pair.");
}
......@@ -204,23 +206,20 @@ class SparseGrpcChannelCache : public CachingGrpcChannelCache {
}
string TranslateTask(const string& target) override {
const static RE2* kTargetRE =
new RE2("^/job:([^/]+)/replica:([0-9]+)/task:([0-9]+)$");
RegexpStringPiece job;
int32 replica;
int32 task;
if (!RE2::FullMatch(target, *kTargetRE, &job, &replica, &task)) {
DeviceNameUtils::ParsedName parsed;
if (!DeviceNameUtils::ParseFullName(target, &parsed)) {
LOG(WARNING) << "Invalid target: " << target;
return "";
}
if (job != job_id_) {
if (!parsed.has_job || parsed.job != job_id_) {
return "";
}
if (replica != 0) {
if (!parsed.has_replica || parsed.replica != 0) {
LOG(WARNING) << "Replica ID must be 0 in target: " << target;
return "";
}
int32 task = parsed.has_task ? parsed.task : -1;
auto iter = host_ports_.find(task);
if (iter == host_ports_.end()) {
LOG(WARNING) << "Task " << task << " was not defined in sparse job "
......
......@@ -24,7 +24,6 @@ limitations under the License.
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/version.h"
......@@ -44,13 +43,17 @@ class GraphConstructorTest : public ::testing::Test {
CHECK(protobuf::TextFormat::ParseFromString(gdef_ascii, &gdef_));
}
void ExpectError(const string& gdef_ascii, const string& expected_error_re) {
void ExpectError(const string& gdef_ascii,
const std::vector<string>& expected_error_strs) {
Convert(gdef_ascii);
GraphConstructorOptions opts;
Status status = ConvertGraphDefToGraph(opts, gdef_, g_.get());
EXPECT_FALSE(status.ok());
EXPECT_TRUE(RE2::PartialMatch(status.error_message(), expected_error_re))
<< status;
for (const string& error : expected_error_strs) {
EXPECT_TRUE(status.error_message().find(error) != string::npos)
<< "Expected to find '" << error << "' in " << status;
}
}
void ExpectOK(const string& gdef_ascii) {
......@@ -129,14 +132,12 @@ REGISTER_OP("TestInt").Input("a: int32");
TEST_F(GraphConstructorTest, InvalidNodeName) {
auto expect_invalid_name = [this](const char* name) {
ExpectError(strings::StrCat("node { name: '", name, "' op: 'ABC' }"),
strings::StrCat("Node '", name,
"': Node name contains invalid characters"));
{"Node name contains invalid characters"});
};
expect_invalid_name("a:b");
expect_invalid_name("_abc"); // Can't start with '_'
// Name is a\b, but proto text format escapes slashes so we use a\\b here.
// This works for ExpectError too, since re2 also treats \\ as one slash.
expect_invalid_name(R"(a\\b)");
expect_invalid_name("/a");
expect_invalid_name("-a");
......@@ -153,7 +154,7 @@ TEST_F(GraphConstructorTest, InvalidSourceNodeName) {
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: 'W999' input: 'input' }",
"Unknown input node.*W999");
{"Unknown input node", "W999"});
}
TEST_F(GraphConstructorTest, InvalidSourceNodeIndex) {
......@@ -162,7 +163,7 @@ TEST_F(GraphConstructorTest, InvalidSourceNodeIndex) {
"node { name: 'input' op: 'TestInput' }"
"node { name: 't1' op: 'TestMul' input: [ 'W1:1', 'input:1' ] }",
"Connecting to invalid output 1 of source node W1");
{"Connecting to invalid output 1 of source node W1"});
}
TEST_F(GraphConstructorTest, GraphWithCycle) {
......@@ -171,7 +172,7 @@ TEST_F(GraphConstructorTest, GraphWithCycle) {
"node { name: 't1' op: 'TestMul' input: [ 'input:0', 't2' ] }"
"node { name: 't2' op: 'TestMul' input: [ 'input:1', 't1' ] }",
"cycle");
{"cycle"});
}
TEST_F(GraphConstructorTest, TypeMismatch) {
......@@ -179,8 +180,8 @@ TEST_F(GraphConstructorTest, TypeMismatch) {
"node { name: 'input' op: 'TestInput' }"
"node { name: 'int' op: 'TestInt' input: [ 'input' ] }",
"Input 0 of node int was passed float from input:0 incompatible with "
"expected int32.");
{"Input 0 of node int was passed float from input:0 incompatible with "
"expected int32."});
}
TEST_F(GraphConstructorTest, EmptyGraph) {
......@@ -197,20 +198,20 @@ TEST_F(GraphConstructorTest, VersionGraph) {
TEST_F(GraphConstructorTest, LowVersion) {
ExpectError(strings::StrCat("versions { producer: ", -1, " }"),
strings::StrCat(R"(^GraphDef producer version -1 below min )"
"producer ",
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
" supported by TensorFlow ", TF_VERSION_STRING,
R"(\. Please regenerate your graph\.$)"));
{strings::StrCat("GraphDef producer version -1 below min "
"producer ",
TF_GRAPH_DEF_VERSION_MIN_PRODUCER,
" supported by TensorFlow ", TF_VERSION_STRING,
". Please regenerate your graph.")});
}
TEST_F(GraphConstructorTest, HighVersion) {
const int version = TF_GRAPH_DEF_VERSION + 1;
ExpectError(strings::StrCat("versions { min_consumer: ", version, " }"),
strings::StrCat(R"(^GraphDef min consumer version )", version,
" above current version ", TF_GRAPH_DEF_VERSION,
" for TensorFlow ", TF_VERSION_STRING,
R"(\. Please upgrade TensorFlow\.$)"));
{strings::StrCat("GraphDef min consumer version ", version,
" above current version ", TF_GRAPH_DEF_VERSION,
" for TensorFlow ", TF_VERSION_STRING,
". Please upgrade TensorFlow.")});
}
TEST_F(GraphConstructorTest, BadVersion) {
......@@ -219,9 +220,9 @@ TEST_F(GraphConstructorTest, BadVersion) {
ExpectError(
strings::StrCat("versions { producer: ", version, " bad_consumers: ", bad,
" }"),
strings::StrCat(
R"(^GraphDef disallows consumer version )", bad,
R"(\. Please upgrade TensorFlow: this version is likely buggy\.$)"));
{strings::StrCat(
"GraphDef disallows consumer version ", bad,
". Please upgrade TensorFlow: this version is likely buggy.")});
}
TEST_F(GraphConstructorTest, SimpleModel) {
......@@ -260,7 +261,7 @@ TEST_F(GraphConstructorTest, Error_ControlEdgeBeforeRealInput) {
"node { name: 'input' op: 'TestInput' input: [ '^W1' ] }"
"node { name: 't1' op: 'TestMul' input: [ 'W1', 'input:1' ] }"
"node { name: 't2' op: 'TestMul' input: [ 'W1', '^t1', 'input:1' ] }",
"Node 't2': Control dependencies must come after regular dependencies");
{"Node 't2': Control dependencies must come after regular dependencies"});
}
TEST_F(GraphConstructorTest, CopyGraph) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册