From 807bf3bfb0cc06834838c89ee2d0f6c22b1ad6dd Mon Sep 17 00:00:00 2001 From: Henry Tan Date: Mon, 24 Aug 2020 15:32:53 -0700 Subject: [PATCH] core/tpu/kernels/BUILD file proto target refactoring PiperOrigin-RevId: 328222137 Change-Id: I1c4339867f6e887e3647f5f60c58a7cfd0885d3f --- tensorflow/core/tpu/kernels/BUILD | 41 ++++++++++--------- .../tpu/kernels/tpu_compilation_cache.proto | 36 ++++++++-------- ...oto => tpu_compilation_cache_common.proto} | 35 ++++++++-------- .../kernels/tpu_compilation_cache_external.h | 2 +- .../tpu/kernels/tpu_compilation_cache_grpc.cc | 4 ++ .../tpu/kernels/tpu_compilation_cache_grpc.h | 14 ++++++- .../kernels/tpu_compilation_cache_interface.h | 2 +- .../tpu_compilation_cache_local_lookup.h | 2 +- .../kernels/tpu_compilation_cache_lookup.h | 2 +- .../tpu_compilation_cache_rpc_lookup.cc | 8 +++- .../tpu_compilation_cache_rpc_lookup.h | 6 ++- .../tpu_compilation_cache_rpc_support.cc | 18 +++++++- .../tpu_compilation_cache_rpc_support.h | 12 ++++-- 13 files changed, 113 insertions(+), 69 deletions(-) rename tensorflow/core/tpu/kernels/{tpu_compilation_cache_response.proto => tpu_compilation_cache_common.proto} (50%) diff --git a/tensorflow/core/tpu/kernels/BUILD b/tensorflow/core/tpu/kernels/BUILD index be6a9d4d864..1372dcf7033 100644 --- a/tensorflow/core/tpu/kernels/BUILD +++ b/tensorflow/core/tpu/kernels/BUILD @@ -237,8 +237,8 @@ cc_library( "tpu_compilation_cache_lookup.h", ], deps = [ + ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_interface", - ":tpu_compilation_cache_proto_cc", "//tensorflow/core/lib/core:refcount", "//tensorflow/core/platform:status", "//tensorflow/core/profiler/lib:traceme", @@ -250,11 +250,11 @@ cc_library( srcs = ["tpu_compilation_cache_local_lookup.cc"], hdrs = ["tpu_compilation_cache_local_lookup.h"], deps = [ + ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_entry", ":tpu_compilation_cache_external", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_lookup", - ":tpu_compilation_cache_proto_cc", "//tensorflow/core/platform:status", ], ) @@ -342,9 +342,9 @@ cc_library( DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"], }) + [ ":compiled_subgraph", + ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_entry", ":tpu_compilation_cache_key", - ":tpu_compilation_cache_proto_cc", ":tpu_compilation_metrics_hdrs", ":tpu_util", ":tpu_util_hdrs", @@ -373,10 +373,10 @@ cc_library( ], deps = [ ":compiled_subgraph", + ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_entry", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_key", - ":tpu_compilation_cache_proto_cc", ":tpu_compilation_metrics", # buildcleaner: keep ":tpu_compilation_metrics_hdrs", ":tpu_compile_c_api_hdrs", @@ -515,8 +515,8 @@ cc_library( DEFAULT: [], }), deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_cache_response_proto_cc"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_response_proto_cc"], + WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], }) + [ ":tpu_compilation_cache_entry", ":tpu_compilation_cache_interface", @@ -531,7 +531,12 @@ cc_library( cc_library( name = "tpu_compilation_cache_rpc_support", srcs = ["tpu_compilation_cache_rpc_support.cc"], + copts = select({ + WITH_TPU_SUPPORT: ["-DLIBTFTPU"], + DEFAULT: [], + }), deps = [ + ":tpu_compilation_cache_proto_cc", ":tpu_compilation_cache_rpc_support_hdrs", ], ) @@ -551,7 +556,7 @@ cc_library( ":tpu_compilation_cache_grpc", ":tpu_compilation_cache_interface", ":tpu_compilation_cache_lookup", - ":tpu_compilation_cache_proto_cc", + ":tpu_compilation_cache_common_proto_cc", ":tpu_compilation_cache_rpc_support_hdrs", ":tpu_program_group_interface", "@com_google_absl//absl/strings", @@ -562,27 +567,23 @@ cc_library( ], ) -# TODO(henrytan): rename the proto file. tf_proto_library_cc( - name = "tpu_compilation_cache_response_proto", - srcs = ["tpu_compilation_cache_response.proto"], - has_services = 1, + name = "tpu_compilation_cache_proto", + srcs = ["tpu_compilation_cache.proto"], + has_services = True, cc_api_version = 2, create_java_proto = False, protodeps = [ - ":tpu_compilation_cache_proto", + ":tpu_compilation_cache_common_proto", "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", ], ) tf_proto_library_cc( - name = "tpu_compilation_cache_proto", - srcs = ["tpu_compilation_cache.proto"], + name = "tpu_compilation_cache_common_proto", + srcs = ["tpu_compilation_cache_common.proto"], cc_api_version = 2, create_java_proto = False, - protodeps = [ - "//tensorflow/compiler/tf2xla:host_compute_metadata_proto", - ], ) cc_library( @@ -594,10 +595,10 @@ cc_library( DEFAULT: [], }), deps = select({ - WITH_TPU_SUPPORT: [":tpu_compilation_cache_response_proto_cc"], - DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_response_proto_cc"], + WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"], + DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"], }) + [ - ":tpu_compilation_cache_proto_cc", + ":tpu_compilation_cache_common_proto_cc", tf_grpc_cc_dependency(), ], ) diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto index 89b92ae9157..f4529224109 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache.proto @@ -16,23 +16,27 @@ syntax = "proto3"; package tensorflow.tpu; -// Target type for compilation cache fetch operation. -enum CompilationCacheFetchTarget { - INVALID = 0; - MAIN = 1; - SHARDING = 2; - UNSHARDING = 3; -} +import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; +import "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto"; + +// Response for GetTpuProgram RPC. +message GetTpuProgramResponseExternal { + message Blob { + bytes data = 1; + } -message TpuCompilationUidAndIndex { - int64 uid = 1; - int32 proto_index = 2; + Blob proto = 1; + tf2xla.HostComputeMetadata host_compute_metadata = 2; + bool may_modify_variables = 3; + Blob compiler_metadata = 4; + // Whether the program is empty, which could be true for sharding/unsharding + // entries. + bool is_empty = 5; } -message GetTpuProgramRequest { - oneof key_oneof { - string key = 1; - TpuCompilationUidAndIndex uid_and_index = 2; - } - CompilationCacheFetchTarget fetch_target = 3; +service TpuCompilationCacheServiceExternal { + // This method requests the cached proto that the TPU execute op has been + // instructed to execute. + rpc GetTpuProgram(GetTpuProgramRequest) + returns (GetTpuProgramResponseExternal) {} } diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_response.proto b/tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto similarity index 50% rename from tensorflow/core/tpu/kernels/tpu_compilation_cache_response.proto rename to tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto index 2b3d404e308..89b92ae9157 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_response.proto +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto @@ -16,26 +16,23 @@ syntax = "proto3"; package tensorflow.tpu; -import "tensorflow/core/tpu/kernels/tpu_compilation_cache.proto"; -import "tensorflow/compiler/tf2xla/host_compute_metadata.proto"; - -// Response for GetTpuProgram RPC. -message GetTpuProgramResponse { - message Blob { - bytes data = 1; - } +// Target type for compilation cache fetch operation. +enum CompilationCacheFetchTarget { + INVALID = 0; + MAIN = 1; + SHARDING = 2; + UNSHARDING = 3; +} - Blob proto = 1; - tf2xla.HostComputeMetadata host_compute_metadata = 2; - bool may_modify_variables = 3; - Blob compiler_metadata = 4; - // Whether the program is empty, which could be true for sharding/unsharding - // entries. - bool is_empty = 5; +message TpuCompilationUidAndIndex { + int64 uid = 1; + int32 proto_index = 2; } -service TpuCompilationCacheService { - // This method requests the cached proto that the TPU execute op has been - // instructed to execute. - rpc GetTpuProgram(GetTpuProgramRequest) returns (GetTpuProgramResponse) {} +message GetTpuProgramRequest { + oneof key_oneof { + string key = 1; + TpuCompilationUidAndIndex uid_and_index = 2; + } + CompilationCacheFetchTarget fetch_target = 3; } diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h index 51b5ffbed0d..c3f95e7e09d 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_external.h @@ -30,7 +30,7 @@ limitations under the License. #include "tensorflow/core/platform/refcount.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc index a44518c0be6..207a60e7b48 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.cc @@ -30,7 +30,11 @@ namespace tensorflow { namespace tpu { static const char* grpcTpuCompilationCacheService_method_names[] = { +#if defined(LIBTFTPU) + "/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram", +#else // LIBTFTPU "/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram", +#endif // LIBTFTPU }; std::unique_ptr diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h index 39e37ad3722..324fc9e6f08 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h @@ -35,8 +35,12 @@ limitations under the License. #include -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_response.pb.h" +#if defined(LIBTFTPU) #include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#else +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara" +#endif +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" namespace tensorflow { namespace tpu { @@ -44,14 +48,22 @@ namespace grpc { class TpuCompilationCacheService final { public: using RequestType = ::tensorflow::tpu::GetTpuProgramRequest; +#if defined(LIBTFTPU) + using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal; +#else using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse; +#endif // N.B. This must be synchronized with the method order in // tpu_compilation_cache.proto. enum class MethodId { kGetTpuProgram = 0 }; static constexpr char const* service_full_name() { +#if defined(LIBTFTPU) + return "tensorflow.tpu.TpuCompilationCacheServiceExternal"; +#else return "tensorflow.tpu.TpuCompilationCacheService"; +#endif } class StubInterface { public: diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h index 7b206fb1cf4..e1e7cf2eddb 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h @@ -31,7 +31,7 @@ limitations under the License. #include "tensorflow/core/profiler/lib/traceme.h" #include "tensorflow/core/protobuf/config.pb.h" #include "tensorflow/core/tpu/kernels/compiled_subgraph.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h" diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h index 6f1fe9bdf87..96f92358241 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_local_lookup.h @@ -16,7 +16,7 @@ limitations under the License. #define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOCAL_LOOKUP_H_ #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h index ab476322a8a..fc819700204 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h @@ -17,7 +17,7 @@ limitations under the License. #include "tensorflow/core/lib/core/refcount.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" namespace tensorflow { diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc index 743229d91cf..e3560de0c44 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.cc @@ -25,6 +25,12 @@ namespace tensorflow { namespace tpu { namespace { +#if defined(LIBTFTPU) +using ResponseType = GetTpuProgramResponseExternal; +#else +using ResponseType = GetTpuProgramResponse; +#endif + static constexpr absl::Duration kProtoTimeout = absl::Minutes(15); static gpr_timespec TimeToGprTimespec(absl::Time time) { if (time == absl::InfiniteFuture()) { @@ -147,7 +153,7 @@ Status TpuCompilationCacheRpcLookup::RemoteLookupLocked( client_context.set_deadline(TimeToGprTimespec(::absl::Now() + kProtoTimeout)); client_context.set_compression_algorithm(GRPC_COMPRESS_GZIP); - tpu::GetTpuProgramResponse response; + ResponseType response; Status s = FromGrpcStatus(stub_->GetTpuProgram(&client_context, request, &response)); VLOG(1) << "Looked up key " << local_proto_key diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h index 4fbda6083ab..d5449a05371 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_lookup.h @@ -22,7 +22,7 @@ limitations under the License. #include #include "absl/synchronization/mutex.h" -#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h" #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h" @@ -35,6 +35,8 @@ namespace tpu { // Class for looking up and caching TPU program via RPC. class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup { public: + using StubType = tpu::grpc::TpuCompilationCacheService::Stub; + TpuCompilationCacheRpcLookup(const string& server_address, int64 max_cache_size); ~TpuCompilationCacheRpcLookup() override = default; @@ -69,7 +71,7 @@ class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup { // evicted. const int64 max_cache_size_; - std::unique_ptr stub_; + std::unique_ptr stub_; // Protect concurrent access to member variables below. mutable absl::Mutex mu_; diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc index 62df149c87a..b880b7ac1a2 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.cc @@ -14,17 +14,31 @@ limitations under the License. ==============================================================================*/ #include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h" +#if defined(LIBTFTPU) +#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" +#endif // LIBTFTPU + namespace tensorflow { namespace tpu { std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() { return ::grpc::InsecureChannelCredentials(); } -Status FillCacheEntryFromGetTpuProgramResponse( - absl::string_view local_proto_key, GetTpuProgramResponse* response, +#if defined(LIBTFTPU) +template <> +Status FillCacheEntryFromGetTpuProgramResponse( + absl::string_view local_proto_key, GetTpuProgramResponseExternal* response, std::shared_ptr* cache_entry) { // TODO(b/162904194): implement this method. LOG(FATAL) << "Not implemented yet."; } + +void SendGetTpuProgramResponseHelper( + const TpuCompilationCacheEntry& cache_entry, + std::function call_fn) { + // TODO(b/162904194): implement this method. + LOG(FATAL) << "Not implemented yet."; +} +#endif // LIBTFTPU } // namespace tpu } // namespace tensorflow diff --git a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h index 5d717df392b..6749138d710 100644 --- a/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h +++ b/tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include @@ -75,17 +76,20 @@ class CacheWrapper : public CompilationCacheEntryRef { std::shared_ptr cache_entry_; }; -// Forward declaration. -class GetTpuProgramResponse; - // Creates gRPC channel credentials for the current runtime env. std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials(); // Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The // `cache_entry` will be instantiated by the function. +template Status FillCacheEntryFromGetTpuProgramResponse( - const absl::string_view local_proto_key, GetTpuProgramResponse* response, + const absl::string_view local_proto_key, ResponseType* response, std::shared_ptr* cache_entry); + +// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel. +void SendGetTpuProgramResponseHelper( + const TpuCompilationCacheEntry& cache_entry, + std::function call_fn); } // namespace tpu } // namespace tensorflow -- GitLab