提交 8913ba89 编写于 作者: T TensorFlower Gardener

Merge pull request #40776 from vnvo2409:gcs-string

PiperOrigin-RevId: 318167643
Change-Id: Id417d71f454eb1eadc3f41213c4887494ea810e7
...@@ -30,7 +30,6 @@ cc_library( ...@@ -30,7 +30,6 @@ cc_library(
"//tensorflow/c:tf_status", "//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface", "//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client", "@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
], ],
) )
...@@ -50,7 +49,6 @@ tf_cc_test( ...@@ -50,7 +49,6 @@ tf_cc_test(
"gcs_filesystem.cc", "gcs_filesystem.cc",
"gcs_filesystem_test.cc", "gcs_filesystem_test.cc",
], ],
local_defines = ["TF_GCS_FILESYSTEM_TEST"],
tags = [ tags = [
"manual", "manual",
"notap", "notap",
......
...@@ -38,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus( ...@@ -38,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); } static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket, void ParseGCSPath(const std::string& fname, bool object_empty_ok,
char** object, TF_Status* status) { std::string* bucket, std::string* object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2; size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") { if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT, TF_SetStatus(status, TF_INVALID_ARGUMENT,
...@@ -48,33 +48,19 @@ void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket, ...@@ -48,33 +48,19 @@ void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
} }
size_t bucket_end = fname.find("/", scheme_end + 1); size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == absl::string_view::npos) { if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT, TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name."); "GCS path doesn't contain a bucket name.");
return; return;
} }
absl::string_view bucket_view =
fname.substr(scheme_end + 1, bucket_end - scheme_end - 1); *bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*bucket = *object = fname.substr(bucket_end + 1);
static_cast<char*>(plugin_memory_allocate(bucket_view.length() + 1));
memcpy(*bucket, bucket_view.data(), bucket_view.length()); if (object->empty() && !object_empty_ok) {
(*bucket)[bucket_view.length()] = '\0'; TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
absl::string_view object_view = fname.substr(bucket_end + 1);
if (object_view.empty()) {
if (object_empty_ok) {
*object = nullptr;
return;
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
return;
}
} }
*object =
static_cast<char*>(plugin_memory_allocate(object_view.length() + 1));
// object_view.data() is a null-terminated string_view because fname is.
strcpy(*object, object_view.data());
} }
// SECTION 1. Implementation for `TF_RandomAccessFile` // SECTION 1. Implementation for `TF_RandomAccessFile`
...@@ -89,8 +75,8 @@ namespace tf_random_access_file { ...@@ -89,8 +75,8 @@ namespace tf_random_access_file {
// ---------------------------------------------------------------------------- // ----------------------------------------------------------------------------
namespace tf_writable_file { namespace tf_writable_file {
typedef struct GCSFile { typedef struct GCSFile {
const char* bucket; const std::string bucket;
const char* object; const std::string object;
gcs::Client* gcs_client; // not owned gcs::Client* gcs_client; // not owned
TempFile outfile; TempFile outfile;
bool sync_need; bool sync_need;
...@@ -98,8 +84,6 @@ typedef struct GCSFile { ...@@ -98,8 +84,6 @@ typedef struct GCSFile {
static void Cleanup(TF_WritableFile* file) { static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file); auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
plugin_memory_free(const_cast<char*>(gcs_file->object));
delete gcs_file; delete gcs_file;
} }
...@@ -141,15 +125,14 @@ void Cleanup(TF_Filesystem* filesystem) { ...@@ -141,15 +125,14 @@ void Cleanup(TF_Filesystem* filesystem) {
void NewWritableFile(const TF_Filesystem* filesystem, const char* path, void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
char* bucket; std::string bucket, object;
char* object;
ParseGCSPath(path, false, &bucket, &object, status); ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem); auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName(""); char* temp_file_name = TF_GetTempFileName("");
file->plugin_file = new tf_writable_file::GCSFile( file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client, {std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::out), true}); TempFile(temp_file_name, std::ios::binary | std::ios::out), true});
// We are responsible for freeing the pointer returned by TF_GetTempFileName // We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name); free(temp_file_name);
...@@ -158,8 +141,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path, ...@@ -158,8 +141,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
char* bucket; std::string bucket, object;
char* object;
ParseGCSPath(path, false, &bucket, &object, status); ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return; if (TF_GetCode(status) != TF_OK) return;
...@@ -175,7 +157,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, ...@@ -175,7 +157,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
// If this file does not exist on server, we will need to sync it. // If this file does not exist on server, we will need to sync it.
bool sync_need = (status_code == TF_NOT_FOUND); bool sync_need = (status_code == TF_NOT_FOUND);
file->plugin_file = new tf_writable_file::GCSFile( file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client, {std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need}); TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need});
free(temp_file_name); free(temp_file_name);
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
......
...@@ -15,13 +15,12 @@ ...@@ -15,13 +15,12 @@
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ #ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_ #define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h" #include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket, void ParseGCSPath(const std::string& fname, bool object_empty_ok,
char** object, TF_Status* status); std::string* bucket, std::string* object, TF_Status* status);
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status); void Init(TF_Filesystem* filesystem, TF_Status* status);
......
...@@ -43,9 +43,27 @@ class GCSFilesystemTest : public ::testing::Test { ...@@ -43,9 +43,27 @@ class GCSFilesystemTest : public ::testing::Test {
TF_Status* status_; TF_Status* status_;
}; };
// We have to add this test here because there must be at least one test. TEST_F(GCSFilesystemTest, ParseGCSPath) {
// This test will be removed in the future. std::string bucket, object;
TEST_F(GCSFilesystemTest, TestInit) { ASSERT_TF_OK(status_); } ParseGCSPath("gs://bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ASSERT_EQ(object, "path/to/object");
ParseGCSPath("gs://bucket/", true, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ParseGCSPath("bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
// bucket name must end with "/"
ParseGCSPath("gs://bucket", true, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
ParseGCSPath("gs://bucket/", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
}
} // namespace } // namespace
} // namespace tensorflow } // namespace tensorflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册