提交 8913ba89 编写于 作者: T TensorFlower Gardener

Merge pull request #40776 from vnvo2409:gcs-string

PiperOrigin-RevId: 318167643
Change-Id: Id417d71f454eb1eadc3f41213c4887494ea810e7
......@@ -30,7 +30,6 @@ cc_library(
"//tensorflow/c:tf_status",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)
......@@ -50,7 +49,6 @@ tf_cc_test(
"gcs_filesystem.cc",
"gcs_filesystem_test.cc",
],
local_defines = ["TF_GCS_FILESYSTEM_TEST"],
tags = [
"manual",
"notap",
......
......@@ -38,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status) {
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
......@@ -48,33 +48,19 @@ void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
}
size_t bucket_end = fname.find("/", scheme_end + 1);
if (bucket_end == absl::string_view::npos) {
if (bucket_end == std::string::npos) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain a bucket name.");
return;
}
absl::string_view bucket_view =
fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*bucket =
static_cast<char*>(plugin_memory_allocate(bucket_view.length() + 1));
memcpy(*bucket, bucket_view.data(), bucket_view.length());
(*bucket)[bucket_view.length()] = '\0';
absl::string_view object_view = fname.substr(bucket_end + 1);
if (object_view.empty()) {
if (object_empty_ok) {
*object = nullptr;
return;
} else {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
return;
}
*bucket = fname.substr(scheme_end + 1, bucket_end - scheme_end - 1);
*object = fname.substr(bucket_end + 1);
if (object->empty() && !object_empty_ok) {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
"GCS path doesn't contain an object name.");
}
*object =
static_cast<char*>(plugin_memory_allocate(object_view.length() + 1));
// object_view.data() is a null-terminated string_view because fname is.
strcpy(*object, object_view.data());
}
// SECTION 1. Implementation for `TF_RandomAccessFile`
......@@ -89,8 +75,8 @@ namespace tf_random_access_file {
// ----------------------------------------------------------------------------
namespace tf_writable_file {
typedef struct GCSFile {
const char* bucket;
const char* object;
const std::string bucket;
const std::string object;
gcs::Client* gcs_client; // not owned
TempFile outfile;
bool sync_need;
......@@ -98,8 +84,6 @@ typedef struct GCSFile {
static void Cleanup(TF_WritableFile* file) {
auto gcs_file = static_cast<GCSFile*>(file->plugin_file);
plugin_memory_free(const_cast<char*>(gcs_file->bucket));
plugin_memory_free(const_cast<char*>(gcs_file->object));
delete gcs_file;
}
......@@ -141,15 +125,14 @@ void Cleanup(TF_Filesystem* filesystem) {
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
auto gcs_client = static_cast<gcs::Client*>(filesystem->plugin_filesystem);
char* temp_file_name = TF_GetTempFileName("");
file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client,
{std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::out), true});
// We are responsible for freeing the pointer returned by TF_GetTempFileName
free(temp_file_name);
......@@ -158,8 +141,7 @@ void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
std::string bucket, object;
ParseGCSPath(path, false, &bucket, &object, status);
if (TF_GetCode(status) != TF_OK) return;
......@@ -175,7 +157,7 @@ void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
// If this file does not exist on server, we will need to sync it.
bool sync_need = (status_code == TF_NOT_FOUND);
file->plugin_file = new tf_writable_file::GCSFile(
{bucket, object, gcs_client,
{std::move(bucket), std::move(object), gcs_client,
TempFile(temp_file_name, std::ios::binary | std::ios::app), sync_need});
free(temp_file_name);
TF_SetStatus(status, TF_OK, "");
......
......@@ -15,13 +15,12 @@
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status);
void ParseGCSPath(const std::string& fname, bool object_empty_ok,
std::string* bucket, std::string* object, TF_Status* status);
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
......
......@@ -43,9 +43,27 @@ class GCSFilesystemTest : public ::testing::Test {
TF_Status* status_;
};
// We have to add this test here because there must be at least one test.
// This test will be removed in the future.
TEST_F(GCSFilesystemTest, TestInit) { ASSERT_TF_OK(status_); }
TEST_F(GCSFilesystemTest, ParseGCSPath) {
std::string bucket, object;
ParseGCSPath("gs://bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ASSERT_EQ(object, "path/to/object");
ParseGCSPath("gs://bucket/", true, &bucket, &object, status_);
ASSERT_TF_OK(status_);
ASSERT_EQ(bucket, "bucket");
ParseGCSPath("bucket/path/to/object", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
// bucket name must end with "/"
ParseGCSPath("gs://bucket", true, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
ParseGCSPath("gs://bucket/", false, &bucket, &object, status_);
ASSERT_EQ(TF_GetCode(status_), TF_INVALID_ARGUMENT);
}
} // namespace
} // namespace tensorflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册