diff --git a/tensorflow/compiler/tf2xla/tf2xla.cc b/tensorflow/compiler/tf2xla/tf2xla.cc index 9ced6e682fc4022a5780bfe387e45ae164d3d518..bcdfd1c6a8ec5d0e842385d41dce1433b9a8bafa 100644 --- a/tensorflow/compiler/tf2xla/tf2xla.cc +++ b/tensorflow/compiler/tf2xla/tf2xla.cc @@ -35,6 +35,7 @@ limitations under the License. #include "tensorflow/core/framework/function.h" #include "tensorflow/core/framework/graph.pb.h" #include "tensorflow/core/framework/graph_def_util.h" +#include "tensorflow/core/framework/node_def.pb.h" #include "tensorflow/core/framework/op.h" #include "tensorflow/core/framework/tensor_shape.h" #include "tensorflow/core/framework/versions.pb.h" @@ -42,6 +43,7 @@ limitations under the License. #include "tensorflow/core/graph/graph.h" #include "tensorflow/core/graph/node_builder.h" #include "tensorflow/core/lib/core/errors.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/types.h" #include "tensorflow/core/util/dump_graph.h" @@ -128,19 +130,31 @@ Status ConvertGraphToXla(std::unique_ptr graph, return Status::OK(); } -void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { - for (auto& node : *graph_def->mutable_node()) { +Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) { + auto update_var_handle_op_node = [](NodeDef& node) -> Status { if (node.op() == "VarHandleOp") { node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); + const auto& it = node.attr().find("allowed_devices"); + if (it != node.attr().end()) { + if (!it->second.list().s().empty()) { + // TODO(b/149512838): Support non-empty allowed devices. + return errors::InvalidArgument( + "VarHandleOp with non-empty allowed devices is not supported."); + } + node.mutable_attr()->erase("allowed_devices"); + } } + return Status::OK(); + }; + for (auto& node : *graph_def->mutable_node()) { + TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } for (auto& fn : *graph_def->mutable_library()->mutable_function()) { for (auto& node : *fn.mutable_node_def()) { - if (node.op() == "VarHandleOp") { - node.set_op(tfcompile::kXlaAotOnlyVarHandleOp); - } + TF_RETURN_IF_ERROR(update_var_handle_op_node(node)); } } + return Status::OK(); } } // namespace @@ -149,7 +163,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config, xla::Client* client, xla::XlaComputation* computation) { std::unique_ptr graph; - ConvertVarHandlesToAotVarHandles(&graph_def); + TF_RETURN_IF_ERROR(ConvertVarHandlesToAotVarHandles(&graph_def)); TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph)); TF_RETURN_IF_ERROR( ConvertGraphToXla(std::move(graph), config, client, computation)); diff --git a/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt b/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt index 0a4caa06bdb2f9f92ed43f4d4658c7101e622885..39606a071845a0093dcdc3be7d587a65e38eae7d 100644 --- a/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt +++ b/tensorflow/core/api_def/base_api/api_def_VarHandleOp.pbtxt @@ -23,6 +23,13 @@ END name: "shape" description: < resource_dtypes_and_shapes; - TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes( - &resource_dtypes_and_shapes)); - if (!resource_dtypes_and_shapes.empty()) { + TensorHandle::ResourceHandleInfo resource_handle_info; + TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info)); + std::vector* resource_dtypes_and_shapes = + &resource_handle_info.dtypes_and_shapes; + if (!resource_dtypes_and_shapes->empty()) { const DtypeAndPartialTensorShape& dtype_and_shape = - resource_dtypes_and_shapes.at(0); + resource_dtypes_and_shapes->at(0); input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape; // Add _Arg index, dtype and shape to "cache_key". @@ -629,8 +630,13 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op, TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype)); const AttrValue* shape; TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape)); - retvals[0]->SetResourceHandleDtypeAndShape( - {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}); + TensorHandle::ResourceHandleInfo resource_handle_info = { + {DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}}; + // "allowed_devices" is set only when the output represents a + // per-replica/partitioned resource variable. + TryGetNodeAttr(attr_slice, "allowed_devices", + &resource_handle_info.allowed_devices); + retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info)); } return Status::OK(); } @@ -856,6 +862,19 @@ Status MaybeUpdateOpDevice(EagerOperation* op) { // is a resource we must pin it to prevent different device selection. // TODO(iga): null device can mean "unspecified" or "CPU". Clean this up. if (resource_device != op_device || op->Device() == nullptr) { + std::vector allowed_devices; + TF_RETURN_IF_ERROR( + tensor_handle->GetResourceAllowedDevices(&allowed_devices)); + if (!allowed_devices.empty()) { + // TODO(b/145922293): Support allowed_devices specified in wildcard + // patterns. + std::vector device_names; + if (std::find(allowed_devices.begin(), allowed_devices.end(), + op->GetDeviceName()) != allowed_devices.end()) { + TF_RETURN_IF_ERROR(ctx.FindDeviceFromName( + op->GetDeviceName().c_str(), &resource_device)); + } + } DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ") << "device of operation " << op->Name() << " to " << resource_device->name() << " because input #" << i diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.cc b/tensorflow/core/common_runtime/eager/tensor_handle.cc index 9b626d18b39198eb1164fca678dbeb2cf275553e..fe93b647a7c3c2f5d394a3640c70901052857246 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.cc +++ b/tensorflow/core/common_runtime/eager/tensor_handle.cc @@ -62,13 +62,13 @@ const int32 kInvalidOutputNum = -1; #endif } // namespace -void TensorHandle::SetResourceHandleDtypeAndShape( - std::vector dtypes_and_shapes) { - handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes); +void TensorHandle::SetResourceHandleInfo( + ResourceHandleInfo resource_handle_info) { + resource_handle_info_ = std::move(resource_handle_info); } -Status TensorHandle::GetResourceHandleDtypesAndShapes( - std::vector* result) { +Status TensorHandle::GetResourceHandleInfoImpl( + std::function set_resource_info) { if (dtype != DT_RESOURCE) { return errors::InvalidArgument( "TensorHandle::GetResourceDtypeAndShape should be called on tensor " @@ -77,7 +77,7 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes( } if (IsRemote()) { - *result = handle_dtypes_and_shapes_; + set_resource_info(); return Status::OK(); } @@ -88,10 +88,32 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes( TF_RETURN_IF_ERROR( WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes")); - *result = handle_dtypes_and_shapes_; + set_resource_info(); return Status::OK(); } +Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) { + auto get_resource_info = [result, this]() { + *result = resource_handle_info_; + }; + return GetResourceHandleInfoImpl(get_resource_info); +} + +Status TensorHandle::GetResourceHandleDtypesAndShapes( + std::vector* result) { + auto get_resource_info = [result, this]() { + *result = resource_handle_info_.dtypes_and_shapes; + }; + return GetResourceHandleInfoImpl(get_resource_info); +} + +Status TensorHandle::GetResourceAllowedDevices(std::vector* result) { + auto get_resource_info = [result, this]() { + *result = resource_handle_info_.allowed_devices; + }; + return GetResourceHandleInfoImpl(get_resource_info); +} + Status TensorHandle::CreateLocalHandle(const class Tensor& t, TensorHandle** h) { // TODO(b/136608821): Move away from nullptr @@ -165,7 +187,8 @@ TensorHandle::TensorHandle(std::unique_ptr t, is_async_(false), implicit_mirroring_(false), is_ready_(true), - handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()), + resource_handle_info_({resource_handle.dtypes_and_shapes(), + resource_handle.allowed_devices()}), tensor_handle_data_(std::move(t)) { DVLOG(3) << "Creating Local TensorHandle: " << this << " device: " << VariantDeviceDebugString(device_); @@ -681,7 +704,8 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) { if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) { auto& resource_handle = tensor.flat()(0); - handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes(); + resource_handle_info_ = {resource_handle.dtypes_and_shapes(), + resource_handle.allowed_devices()}; } tensor_handle_data_ = absl::make_unique(tensor); if (is_async_) { diff --git a/tensorflow/core/common_runtime/eager/tensor_handle.h b/tensorflow/core/common_runtime/eager/tensor_handle.h index dd6171d1ee01b139e6fb7721228fd82892f4f538..04b0091aa99beae542e8c04dd4dee40ce9858d2e 100644 --- a/tensorflow/core/common_runtime/eager/tensor_handle.h +++ b/tensorflow/core/common_runtime/eager/tensor_handle.h @@ -227,13 +227,19 @@ class TensorHandle : public core::RefCounted { string DebugString() const; - void SetResourceHandleDtypeAndShape( - std::vector dtypes_and_shapes); + struct ResourceHandleInfo { + std::vector dtypes_and_shapes; + std::vector allowed_devices; + }; + + void SetResourceHandleInfo(ResourceHandleInfo resource_handle_info); // If this TensorHandle is 1) a local tensor, and 2) a resource handle, - // return data types and shapes of the underlying resource. + // return data types, shapes and allowed devices of the underlying resource. + Status GetResourceHandleInfo(ResourceHandleInfo* result); Status GetResourceHandleDtypesAndShapes( std::vector* result); + Status GetResourceAllowedDevices(std::vector* result); private: // The TensorHandleData can either represent a local or remote tensor handle. @@ -247,6 +253,8 @@ class TensorHandle : public core::RefCounted { // done and the handle is "ready". Status WaitReady(const char* caller) const; + Status GetResourceHandleInfoImpl(std::function set_resource_info); + // TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0 // This was expedient, but perhaps worth revisiting ('device_' should always // be a valid pointer?) @@ -309,9 +317,9 @@ class TensorHandle : public core::RefCounted { bool is_ready_ GUARDED_BY(mu_); // If this TensorHandle 1) is a local tensor, and 2) is a resource handle or - // refers to a remote resource handle, we store data types and shapes for - // the underlying resource. - std::vector handle_dtypes_and_shapes_; + // refers to a remote resource handle, we store data types, shapes and allowed + // devices for the underlying resource. + ResourceHandleInfo resource_handle_info_; // Does not need synchronization because it can be accessed only after // WaitReady() has returned. At that point, tensor_handle_data_ is immutable. diff --git a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc index aefe86c654d03a3614b335616c3cddc2474cd60b..e68e5d46f3f681d00748d7248ea60e2cac7c6a9d 100644 --- a/tensorflow/core/distributed_runtime/eager/remote_mgr.cc +++ b/tensorflow/core/distributed_runtime/eager/remote_mgr.cc @@ -173,22 +173,24 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in, remote_handle_data->ReleaseRemoteTensorHandle(); TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle( std::move(remote_handle_data), in.dtype(), device, parent_, out)); - std::vector dtypes_and_shapes; + TensorHandle::ResourceHandleInfo resource_handle_info; + std::vector* dtypes_and_shapes = + &resource_handle_info.dtypes_and_shapes; if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in), - &dtypes_and_shapes) + dtypes_and_shapes) .ok()) { for (const auto& dtype_and_shape_proto : in.resource_dtypes_and_shapes()) { - dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{ + dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{ dtype_and_shape_proto.dtype(), TensorShape(dtype_and_shape_proto.shape())}); } mutex_lock l(mirrored_resource_shape_mu_); mirrored_resource_shape_map_.emplace( RemoteTensorHandleInternal(in.op_id(), in.output_num()), - dtypes_and_shapes); + *dtypes_and_shapes); } - (*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes)); + (*out)->SetResourceHandleInfo(std::move(resource_handle_info)); } return Status::OK(); diff --git a/tensorflow/core/framework/resource_mgr.h b/tensorflow/core/framework/resource_mgr.h index a22ddc1f0ea03834f3642ff5f18923161d6ebd57..9fe3ba02e3f5000ff085693d2e4edfdecf9a5d10 100644 --- a/tensorflow/core/framework/resource_mgr.h +++ b/tensorflow/core/framework/resource_mgr.h @@ -314,11 +314,13 @@ ResourceHandle MakeResourceHandle( template ResourceHandle MakeResourceHandle( OpKernelConstruction* ctx, const string& container, const string& name, - const std::vector& dtypes_and_shapes = {}) { - return MakeResourceHandle( - container.empty() ? ctx->resource_manager()->default_container() - : container, - name, *ctx->device(), MakeTypeIndex(), dtypes_and_shapes); + const std::vector& dtypes_and_shapes = {}, + const std::vector& allowed_devices = {}) { + return MakeResourceHandle(container.empty() + ? ctx->resource_manager()->default_container() + : container, + name, *ctx->device(), MakeTypeIndex(), + dtypes_and_shapes, allowed_devices); } Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index, diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 80ca00388ffa4cbd62fb2cbe436d389bb5c9e7a0..351d9b580a85380f1b7e493bb693fec6260f3250 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -229,6 +229,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype)); PartialTensorShape shape; OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape)); + OP_REQUIRES_OK(context, + context->GetAttr("allowed_devices", &allowed_devices_)); is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME; @@ -239,7 +241,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) { &resource_, attr)); resource_.scalar()() = MakeResourceHandle( context, container_, name_, - std::vector{dtype_and_shape_}); + std::vector{dtype_and_shape_}, + allowed_devices_); } } @@ -252,7 +255,8 @@ void VarHandleOp::Compute(OpKernelContext* ctx) { ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr)); handle.scalar()() = MakeResourceHandle( ctx, container_, name_, - std::vector{dtype_and_shape_}); + std::vector{dtype_and_shape_}, + allowed_devices_); ctx->set_output(0, handle); } else { ctx->set_output(0, resource_); diff --git a/tensorflow/core/kernels/resource_variable_ops.h b/tensorflow/core/kernels/resource_variable_ops.h index 1bb70b537c16d9ec8d206a41aa248cbf5841659a..5935fa91d212d27c2963039f5838a4b14e21bf2e 100644 --- a/tensorflow/core/kernels/resource_variable_ops.h +++ b/tensorflow/core/kernels/resource_variable_ops.h @@ -36,6 +36,10 @@ class VarHandleOp : public OpKernel { Tensor resource_; DtypeAndPartialTensorShape dtype_and_shape_; + + // A set of devices containing the resource variable. Set when the output + // ResourceHandle represents a per-replica/partitioned resource variable. + std::vector allowed_devices_; }; class ReadVariableOp : public OpKernel { diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 696a69eff809cd1b2bd29d73b05c7d55a4e463c6..77ab5f604c8f947b5328eb2c6fd69bf483d0c0f5 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -80,6 +80,7 @@ REGISTER_OP("VarHandleOp") .Attr("shared_name: string = ''") .Attr("dtype: type") .Attr("shape: shape") + .Attr("allowed_devices: list(string) = []") .Output("resource: resource") .SetIsStateful() .SetShapeFn([](InferenceContext* c) { diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index f20e54d18a556064b11878539bec98ca71e5a568..cbd8f6a2ebea0ae38d78c6894bffe4fde50d1911 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -30,6 +30,7 @@ from tensorflow.core.framework import tensor_pb2 from tensorflow.python.eager import backprop from tensorflow.python.eager import context from tensorflow.python.eager import def_function +from tensorflow.python.framework import config from tensorflow.python.framework import constant_op from tensorflow.python.framework import cpp_shape_inference_pb2 from tensorflow.python.framework import dtypes @@ -1488,5 +1489,40 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase, self.assertAllEqual(expected, result) +class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase): + + def setUp(self): + super(PerReplicaResourceHandleTest, self).setUp() + cpus = config.list_physical_devices("CPU") + # Set 2 virtual CPUs + config.set_logical_device_configuration(cpus[0], [ + context.LogicalDeviceConfiguration(), + context.LogicalDeviceConfiguration(), + ]) + + def testAllowedDevices(self): + device0 = "/job:localhost/replica:0/task:0/device:CPU:0" + device1 = "/job:localhost/replica:0/task:0/device:CPU:1" + value0 = 1 + value1 = 2 + with context.eager_mode(): + handle = resource_variable_ops.var_handle_op( + dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1]) + with ops.device(device0): + assign0 = resource_variable_ops.assign_variable_op(handle, value0) + with ops.device(device1): + assign1 = resource_variable_ops.assign_variable_op(handle, value1) + with ops.control_dependencies([assign0, assign1]): + with ops.device(device0): + read0 = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32) + with ops.device(device1): + read1 = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32) + + self.assertAllEqual(value0, read0) + self.assertAllEqual(value1, read1) + + if __name__ == "__main__": test.main() diff --git a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt index 9b68caecf9a21a55d99a216387ddce19c9425f22..07c3fa8ef08dfbbe76871d0b861b62ddef2250d1 100644 --- a/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v1/tensorflow.raw_ops.pbtxt @@ -4926,7 +4926,7 @@ tf_module { } member_method { name: "VarHandleOp" - argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " + argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], " } member_method { name: "VarIsInitializedOp" diff --git a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt index 9b68caecf9a21a55d99a216387ddce19c9425f22..07c3fa8ef08dfbbe76871d0b861b62ddef2250d1 100644 --- a/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt +++ b/tensorflow/tools/api/golden/v2/tensorflow.raw_ops.pbtxt @@ -4926,7 +4926,7 @@ tf_module { } member_method { name: "VarHandleOp" - argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], " + argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], " } member_method { name: "VarIsInitializedOp"