提交 807bf3bf 编写于 作者: H Henry Tan 提交者: TensorFlower Gardener

core/tpu/kernels/BUILD file proto target refactoring

PiperOrigin-RevId: 328222137
Change-Id: I1c4339867f6e887e3647f5f60c58a7cfd0885d3f
上级 e82addaf
......@@ -237,8 +237,8 @@ cc_library(
"tpu_compilation_cache_lookup.h",
],
deps = [
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_proto_cc",
"//tensorflow/core/lib/core:refcount",
"//tensorflow/core/platform:status",
"//tensorflow/core/profiler/lib:traceme",
......@@ -250,11 +250,11 @@ cc_library(
srcs = ["tpu_compilation_cache_local_lookup.cc"],
hdrs = ["tpu_compilation_cache_local_lookup.h"],
deps = [
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_external",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_lookup",
":tpu_compilation_cache_proto_cc",
"//tensorflow/core/platform:status",
],
)
......@@ -342,9 +342,9 @@ cc_library(
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_metrics"],
}) + [
":compiled_subgraph",
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_key",
":tpu_compilation_cache_proto_cc",
":tpu_compilation_metrics_hdrs",
":tpu_util",
":tpu_util_hdrs",
......@@ -373,10 +373,10 @@ cc_library(
],
deps = [
":compiled_subgraph",
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_entry",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_key",
":tpu_compilation_cache_proto_cc",
":tpu_compilation_metrics", # buildcleaner: keep
":tpu_compilation_metrics_hdrs",
":tpu_compile_c_api_hdrs",
......@@ -515,8 +515,8 @@ cc_library(
DEFAULT: [],
}),
deps = select({
WITH_TPU_SUPPORT: [":tpu_compilation_cache_response_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_response_proto_cc"],
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
}) + [
":tpu_compilation_cache_entry",
":tpu_compilation_cache_interface",
......@@ -531,7 +531,12 @@ cc_library(
cc_library(
name = "tpu_compilation_cache_rpc_support",
srcs = ["tpu_compilation_cache_rpc_support.cc"],
copts = select({
WITH_TPU_SUPPORT: ["-DLIBTFTPU"],
DEFAULT: [],
}),
deps = [
":tpu_compilation_cache_proto_cc",
":tpu_compilation_cache_rpc_support_hdrs",
],
)
......@@ -551,7 +556,7 @@ cc_library(
":tpu_compilation_cache_grpc",
":tpu_compilation_cache_interface",
":tpu_compilation_cache_lookup",
":tpu_compilation_cache_proto_cc",
":tpu_compilation_cache_common_proto_cc",
":tpu_compilation_cache_rpc_support_hdrs",
":tpu_program_group_interface",
"@com_google_absl//absl/strings",
......@@ -562,27 +567,23 @@ cc_library(
],
)
# TODO(henrytan): rename the proto file.
tf_proto_library_cc(
name = "tpu_compilation_cache_response_proto",
srcs = ["tpu_compilation_cache_response.proto"],
has_services = 1,
name = "tpu_compilation_cache_proto",
srcs = ["tpu_compilation_cache.proto"],
has_services = True,
cc_api_version = 2,
create_java_proto = False,
protodeps = [
":tpu_compilation_cache_proto",
":tpu_compilation_cache_common_proto",
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
],
)
tf_proto_library_cc(
name = "tpu_compilation_cache_proto",
srcs = ["tpu_compilation_cache.proto"],
name = "tpu_compilation_cache_common_proto",
srcs = ["tpu_compilation_cache_common.proto"],
cc_api_version = 2,
create_java_proto = False,
protodeps = [
"//tensorflow/compiler/tf2xla:host_compute_metadata_proto",
],
)
cc_library(
......@@ -594,10 +595,10 @@ cc_library(
DEFAULT: [],
}),
deps = select({
WITH_TPU_SUPPORT: [":tpu_compilation_cache_response_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_response_proto_cc"],
WITH_TPU_SUPPORT: [":tpu_compilation_cache_proto_cc"],
DEFAULT: ["//tensorflow/core/tpu/kernels:tpu_compilation_cache_proto_cc"],
}) + [
":tpu_compilation_cache_proto_cc",
":tpu_compilation_cache_common_proto_cc",
tf_grpc_cc_dependency(),
],
)
......
......@@ -16,23 +16,27 @@ syntax = "proto3";
package tensorflow.tpu;
// Target type for compilation cache fetch operation.
enum CompilationCacheFetchTarget {
INVALID = 0;
MAIN = 1;
SHARDING = 2;
UNSHARDING = 3;
}
import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
import "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.proto";
// Response for GetTpuProgram RPC.
message GetTpuProgramResponseExternal {
message Blob {
bytes data = 1;
}
message TpuCompilationUidAndIndex {
int64 uid = 1;
int32 proto_index = 2;
Blob proto = 1;
tf2xla.HostComputeMetadata host_compute_metadata = 2;
bool may_modify_variables = 3;
Blob compiler_metadata = 4;
// Whether the program is empty, which could be true for sharding/unsharding
// entries.
bool is_empty = 5;
}
message GetTpuProgramRequest {
oneof key_oneof {
string key = 1;
TpuCompilationUidAndIndex uid_and_index = 2;
}
CompilationCacheFetchTarget fetch_target = 3;
service TpuCompilationCacheServiceExternal {
// This method requests the cached proto that the TPU execute op has been
// instructed to execute.
rpc GetTpuProgram(GetTpuProgramRequest)
returns (GetTpuProgramResponseExternal) {}
}
......@@ -16,26 +16,23 @@ syntax = "proto3";
package tensorflow.tpu;
import "tensorflow/core/tpu/kernels/tpu_compilation_cache.proto";
import "tensorflow/compiler/tf2xla/host_compute_metadata.proto";
// Response for GetTpuProgram RPC.
message GetTpuProgramResponse {
message Blob {
bytes data = 1;
}
// Target type for compilation cache fetch operation.
enum CompilationCacheFetchTarget {
INVALID = 0;
MAIN = 1;
SHARDING = 2;
UNSHARDING = 3;
}
Blob proto = 1;
tf2xla.HostComputeMetadata host_compute_metadata = 2;
bool may_modify_variables = 3;
Blob compiler_metadata = 4;
// Whether the program is empty, which could be true for sharding/unsharding
// entries.
bool is_empty = 5;
message TpuCompilationUidAndIndex {
int64 uid = 1;
int32 proto_index = 2;
}
service TpuCompilationCacheService {
// This method requests the cached proto that the TPU execute op has been
// instructed to execute.
rpc GetTpuProgram(GetTpuProgramRequest) returns (GetTpuProgramResponse) {}
message GetTpuProgramRequest {
oneof key_oneof {
string key = 1;
TpuCompilationUidAndIndex uid_and_index = 2;
}
CompilationCacheFetchTarget fetch_target = 3;
}
......@@ -30,7 +30,7 @@ limitations under the License.
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
......
......@@ -30,7 +30,11 @@ namespace tensorflow {
namespace tpu {
static const char* grpcTpuCompilationCacheService_method_names[] = {
#if defined(LIBTFTPU)
"/tensorflow.tpu.TpuCompilationCacheServiceExternal/GetTpuProgram",
#else // LIBTFTPU
"/tensorflow.tpu.TpuCompilationCacheService/GetTpuProgram",
#endif // LIBTFTPU
};
std::unique_ptr<grpc::TpuCompilationCacheService::Stub>
......
......@@ -35,8 +35,12 @@ limitations under the License.
#include <functional>
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_response.pb.h"
#if defined(LIBTFTPU)
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#else
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h" // copybara"
#endif
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
namespace tensorflow {
namespace tpu {
......@@ -44,14 +48,22 @@ namespace grpc {
class TpuCompilationCacheService final {
public:
using RequestType = ::tensorflow::tpu::GetTpuProgramRequest;
#if defined(LIBTFTPU)
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponseExternal;
#else
using ResponseType = ::tensorflow::tpu::GetTpuProgramResponse;
#endif
// N.B. This must be synchronized with the method order in
// tpu_compilation_cache.proto.
enum class MethodId { kGetTpuProgram = 0 };
static constexpr char const* service_full_name() {
#if defined(LIBTFTPU)
return "tensorflow.tpu.TpuCompilationCacheServiceExternal";
#else
return "tensorflow.tpu.TpuCompilationCacheService";
#endif
}
class StubInterface {
public:
......
......@@ -31,7 +31,7 @@ limitations under the License.
#include "tensorflow/core/profiler/lib/traceme.h"
#include "tensorflow/core/protobuf/config.pb.h"
#include "tensorflow/core/tpu/kernels/compiled_subgraph.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_key.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_metrics.h"
......
......@@ -16,7 +16,7 @@ limitations under the License.
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILATION_CACHE_LOCAL_LOOKUP_H_
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_entry.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
......
......@@ -17,7 +17,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/refcount.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
namespace tensorflow {
......
......@@ -25,6 +25,12 @@ namespace tensorflow {
namespace tpu {
namespace {
#if defined(LIBTFTPU)
using ResponseType = GetTpuProgramResponseExternal;
#else
using ResponseType = GetTpuProgramResponse;
#endif
static constexpr absl::Duration kProtoTimeout = absl::Minutes(15);
static gpr_timespec TimeToGprTimespec(absl::Time time) {
if (time == absl::InfiniteFuture()) {
......@@ -147,7 +153,7 @@ Status TpuCompilationCacheRpcLookup::RemoteLookupLocked(
client_context.set_deadline(TimeToGprTimespec(::absl::Now() + kProtoTimeout));
client_context.set_compression_algorithm(GRPC_COMPRESS_GZIP);
tpu::GetTpuProgramResponse response;
ResponseType response;
Status s =
FromGrpcStatus(stub_->GetTpuProgram(&client_context, request, &response));
VLOG(1) << "Looked up key " << local_proto_key
......
......@@ -22,7 +22,7 @@ limitations under the License.
#include <vector>
#include "absl/synchronization/mutex.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_common.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_grpc.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_lookup.h"
......@@ -35,6 +35,8 @@ namespace tpu {
// Class for looking up and caching TPU program via RPC.
class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup {
public:
using StubType = tpu::grpc::TpuCompilationCacheService::Stub;
TpuCompilationCacheRpcLookup(const string& server_address,
int64 max_cache_size);
~TpuCompilationCacheRpcLookup() override = default;
......@@ -69,7 +71,7 @@ class TpuCompilationCacheRpcLookup : public TpuCompilationCacheLookup {
// evicted.
const int64 max_cache_size_;
std::unique_ptr<tpu::grpc::TpuCompilationCacheService::Stub> stub_;
std::unique_ptr<StubType> stub_;
// Protect concurrent access to member variables below.
mutable absl::Mutex mu_;
......
......@@ -14,17 +14,31 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache_rpc_support.h"
#if defined(LIBTFTPU)
#include "tensorflow/core/tpu/kernels/tpu_compilation_cache.pb.h"
#endif // LIBTFTPU
namespace tensorflow {
namespace tpu {
std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials() {
return ::grpc::InsecureChannelCredentials();
}
Status FillCacheEntryFromGetTpuProgramResponse(
absl::string_view local_proto_key, GetTpuProgramResponse* response,
#if defined(LIBTFTPU)
template <>
Status FillCacheEntryFromGetTpuProgramResponse<GetTpuProgramResponseExternal>(
absl::string_view local_proto_key, GetTpuProgramResponseExternal* response,
std::shared_ptr<CacheEntry>* cache_entry) {
// TODO(b/162904194): implement this method.
LOG(FATAL) << "Not implemented yet.";
}
void SendGetTpuProgramResponseHelper(
const TpuCompilationCacheEntry& cache_entry,
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn) {
// TODO(b/162904194): implement this method.
LOG(FATAL) << "Not implemented yet.";
}
#endif // LIBTFTPU
} // namespace tpu
} // namespace tensorflow
......@@ -17,6 +17,7 @@ limitations under the License.
#include <grpcpp/security/credentials.h>
#include <functional>
#include <memory>
#include <string>
......@@ -75,17 +76,20 @@ class CacheWrapper : public CompilationCacheEntryRef {
std::shared_ptr<CacheEntry> cache_entry_;
};
// Forward declaration.
class GetTpuProgramResponse;
// Creates gRPC channel credentials for the current runtime env.
std::shared_ptr<::grpc::ChannelCredentials> CreateChannelCredentials();
// Fills an uinitialized `CacheEntry` from `GetTpuProgramResponse` proto. The
// `cache_entry` will be instantiated by the function.
template <typename ResponseType>
Status FillCacheEntryFromGetTpuProgramResponse(
const absl::string_view local_proto_key, GetTpuProgramResponse* response,
const absl::string_view local_proto_key, ResponseType* response,
std::shared_ptr<CacheEntry>* cache_entry);
// A helper to send `TpuCompilationCacheEntry` payload through gRPC channel.
void SendGetTpuProgramResponseHelper(
const TpuCompilationCacheEntry& cache_entry,
std::function<void(::grpc::ByteBuffer*, ::grpc::Status)> call_fn);
} // namespace tpu
} // namespace tensorflow
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册