提交 ed0e46c2 编写于 作者: Y Yujing Zhang 提交者: TensorFlower Gardener

Add attribute allowed_devices to VarHandleOp.

Support per-replica ResourceHandle in eager op-by-op mode.

PiperOrigin-RevId: 295059587
Change-Id: I69e0ada1418092bfb319d6a5530f69ccd78777b1
上级 362914be
......@@ -35,6 +35,7 @@ limitations under the License.
#include "tensorflow/core/framework/function.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/graph_def_util.h"
#include "tensorflow/core/framework/node_def.pb.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/versions.pb.h"
......@@ -42,6 +43,7 @@ limitations under the License.
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/util/dump_graph.h"
......@@ -128,19 +130,31 @@ Status ConvertGraphToXla(std::unique_ptr<Graph> graph,
return Status::OK();
}
void ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
for (auto& node : *graph_def->mutable_node()) {
Status ConvertVarHandlesToAotVarHandles(GraphDef* graph_def) {
auto update_var_handle_op_node = [](NodeDef& node) -> Status {
if (node.op() == "VarHandleOp") {
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
const auto& it = node.attr().find("allowed_devices");
if (it != node.attr().end()) {
if (!it->second.list().s().empty()) {
// TODO(b/149512838): Support non-empty allowed devices.
return errors::InvalidArgument(
"VarHandleOp with non-empty allowed devices is not supported.");
}
node.mutable_attr()->erase("allowed_devices");
}
}
return Status::OK();
};
for (auto& node : *graph_def->mutable_node()) {
TF_RETURN_IF_ERROR(update_var_handle_op_node(node));
}
for (auto& fn : *graph_def->mutable_library()->mutable_function()) {
for (auto& node : *fn.mutable_node_def()) {
if (node.op() == "VarHandleOp") {
node.set_op(tfcompile::kXlaAotOnlyVarHandleOp);
}
TF_RETURN_IF_ERROR(update_var_handle_op_node(node));
}
}
return Status::OK();
}
} // namespace
......@@ -149,7 +163,7 @@ Status ConvertGraphDefToXla(GraphDef graph_def, const tf2xla::Config& config,
xla::Client* client,
xla::XlaComputation* computation) {
std::unique_ptr<Graph> graph;
ConvertVarHandlesToAotVarHandles(&graph_def);
TF_RETURN_IF_ERROR(ConvertVarHandlesToAotVarHandles(&graph_def));
TF_RETURN_IF_ERROR(InitGraph(graph_def, config, &graph));
TF_RETURN_IF_ERROR(
ConvertGraphToXla(std::move(graph), config, client, computation));
......
......@@ -23,6 +23,13 @@ END
name: "shape"
description: <<END
The (possibly partially specified) shape of this variable.
END
}
attr {
name: "allowed_devices"
description: <<END
The allowed devices containing the resource variable. Set when the output
ResourceHandle represents a per-replica/partitioned resource variable.
END
}
summary: "Creates a handle to a Variable resource."
......
......@@ -420,12 +420,13 @@ Status EagerLocalExecute(EagerOperation* op, TensorHandle** retvals,
// looking it up in ResourceMgr, which is slow). So we just get
// resource_dtypes_and_shapes for all DT_RESOURCE inputs. If
// resource_dtypes_and_shapes is not empty, take the first element.
std::vector<DtypeAndPartialTensorShape> resource_dtypes_and_shapes;
TF_RETURN_IF_ERROR(input->GetResourceHandleDtypesAndShapes(
&resource_dtypes_and_shapes));
if (!resource_dtypes_and_shapes.empty()) {
TensorHandle::ResourceHandleInfo resource_handle_info;
TF_RETURN_IF_ERROR(input->GetResourceHandleInfo(&resource_handle_info));
std::vector<DtypeAndPartialTensorShape>* resource_dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
if (!resource_dtypes_and_shapes->empty()) {
const DtypeAndPartialTensorShape& dtype_and_shape =
resource_dtypes_and_shapes.at(0);
resource_dtypes_and_shapes->at(0);
input_resource_variable_dtypes_and_shapes[i] = dtype_and_shape;
// Add _Arg index, dtype and shape to "cache_key".
......@@ -629,8 +630,13 @@ Status StoreResourceDtypesAndShapes(const eager::Operation& remote_op,
TF_RETURN_IF_ERROR(attr_slice.Find("dtype", &dtype));
const AttrValue* shape;
TF_RETURN_IF_ERROR(attr_slice.Find("shape", &shape));
retvals[0]->SetResourceHandleDtypeAndShape(
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}});
TensorHandle::ResourceHandleInfo resource_handle_info = {
{DtypeAndPartialTensorShape{dtype->type(), shape->shape()}}, {}};
// "allowed_devices" is set only when the output represents a
// per-replica/partitioned resource variable.
TryGetNodeAttr(attr_slice, "allowed_devices",
&resource_handle_info.allowed_devices);
retvals[0]->SetResourceHandleInfo(std::move(resource_handle_info));
}
return Status::OK();
}
......@@ -856,6 +862,19 @@ Status MaybeUpdateOpDevice(EagerOperation* op) {
// is a resource we must pin it to prevent different device selection.
// TODO(iga): null device can mean "unspecified" or "CPU". Clean this up.
if (resource_device != op_device || op->Device() == nullptr) {
std::vector<string> allowed_devices;
TF_RETURN_IF_ERROR(
tensor_handle->GetResourceAllowedDevices(&allowed_devices));
if (!allowed_devices.empty()) {
// TODO(b/145922293): Support allowed_devices specified in wildcard
// patterns.
std::vector<string> device_names;
if (std::find(allowed_devices.begin(), allowed_devices.end(),
op->GetDeviceName()) != allowed_devices.end()) {
TF_RETURN_IF_ERROR(ctx.FindDeviceFromName(
op->GetDeviceName().c_str(), &resource_device));
}
}
DVLOG(1) << (resource_device != op_device ? "Changing " : "Setting ")
<< "device of operation " << op->Name() << " to "
<< resource_device->name() << " because input #" << i
......
......@@ -62,13 +62,13 @@ const int32 kInvalidOutputNum = -1;
#endif
} // namespace
void TensorHandle::SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes) {
handle_dtypes_and_shapes_ = std::move(dtypes_and_shapes);
void TensorHandle::SetResourceHandleInfo(
ResourceHandleInfo resource_handle_info) {
resource_handle_info_ = std::move(resource_handle_info);
}
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
Status TensorHandle::GetResourceHandleInfoImpl(
std::function<void()> set_resource_info) {
if (dtype != DT_RESOURCE) {
return errors::InvalidArgument(
"TensorHandle::GetResourceDtypeAndShape should be called on tensor "
......@@ -77,7 +77,7 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
}
if (IsRemote()) {
*result = handle_dtypes_and_shapes_;
set_resource_info();
return Status::OK();
}
......@@ -88,10 +88,32 @@ Status TensorHandle::GetResourceHandleDtypesAndShapes(
TF_RETURN_IF_ERROR(
WaitReady("TensorHandle::GetResourceHandleDtypesAndShapes"));
*result = handle_dtypes_and_shapes_;
set_resource_info();
return Status::OK();
}
Status TensorHandle::GetResourceHandleInfo(ResourceHandleInfo* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.dtypes_and_shapes;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::GetResourceAllowedDevices(std::vector<string>* result) {
auto get_resource_info = [result, this]() {
*result = resource_handle_info_.allowed_devices;
};
return GetResourceHandleInfoImpl(get_resource_info);
}
Status TensorHandle::CreateLocalHandle(const class Tensor& t,
TensorHandle** h) {
// TODO(b/136608821): Move away from nullptr
......@@ -165,7 +187,8 @@ TensorHandle::TensorHandle(std::unique_ptr<LocalTensorHandleData> t,
is_async_(false),
implicit_mirroring_(false),
is_ready_(true),
handle_dtypes_and_shapes_(resource_handle.dtypes_and_shapes()),
resource_handle_info_({resource_handle.dtypes_and_shapes(),
resource_handle.allowed_devices()}),
tensor_handle_data_(std::move(t)) {
DVLOG(3) << "Creating Local TensorHandle: " << this
<< " device: " << VariantDeviceDebugString(device_);
......@@ -681,7 +704,8 @@ Status TensorHandle::SetTensor(tensorflow::Tensor&& tensor, const Device* d) {
if (tensor.dtype() == DT_RESOURCE && tensor.NumElements() > 0) {
auto& resource_handle = tensor.flat<class ResourceHandle>()(0);
handle_dtypes_and_shapes_ = resource_handle.dtypes_and_shapes();
resource_handle_info_ = {resource_handle.dtypes_and_shapes(),
resource_handle.allowed_devices()};
}
tensor_handle_data_ = absl::make_unique<LocalTensorHandleData>(tensor);
if (is_async_) {
......
......@@ -227,13 +227,19 @@ class TensorHandle : public core::RefCounted {
string DebugString() const;
void SetResourceHandleDtypeAndShape(
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes);
struct ResourceHandleInfo {
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
std::vector<string> allowed_devices;
};
void SetResourceHandleInfo(ResourceHandleInfo resource_handle_info);
// If this TensorHandle is 1) a local tensor, and 2) a resource handle,
// return data types and shapes of the underlying resource.
// return data types, shapes and allowed devices of the underlying resource.
Status GetResourceHandleInfo(ResourceHandleInfo* result);
Status GetResourceHandleDtypesAndShapes(
std::vector<DtypeAndPartialTensorShape>* result);
Status GetResourceAllowedDevices(std::vector<string>* result);
private:
// The TensorHandleData can either represent a local or remote tensor handle.
......@@ -247,6 +253,8 @@ class TensorHandle : public core::RefCounted {
// done and the handle is "ready".
Status WaitReady(const char* caller) const;
Status GetResourceHandleInfoImpl(std::function<void()> set_resource_info);
// TODO(b/136608821): device_ == nullptr (Device*) iff Host CPU:0
// This was expedient, but perhaps worth revisiting ('device_' should always
// be a valid pointer?)
......@@ -309,9 +317,9 @@ class TensorHandle : public core::RefCounted {
bool is_ready_ GUARDED_BY(mu_);
// If this TensorHandle 1) is a local tensor, and 2) is a resource handle or
// refers to a remote resource handle, we store data types and shapes for
// the underlying resource.
std::vector<DtypeAndPartialTensorShape> handle_dtypes_and_shapes_;
// refers to a remote resource handle, we store data types, shapes and allowed
// devices for the underlying resource.
ResourceHandleInfo resource_handle_info_;
// Does not need synchronization because it can be accessed only after
// WaitReady() has returned. At that point, tensor_handle_data_ is immutable.
......
......@@ -173,22 +173,24 @@ Status RemoteMgr::DeserializeRemoteTensorHandle(const RemoteTensorHandle& in,
remote_handle_data->ReleaseRemoteTensorHandle();
TF_RETURN_IF_ERROR(TensorHandle::CreateUnshapedRemoteHandle(
std::move(remote_handle_data), in.dtype(), device, parent_, out));
std::vector<DtypeAndPartialTensorShape> dtypes_and_shapes;
TensorHandle::ResourceHandleInfo resource_handle_info;
std::vector<DtypeAndPartialTensorShape>* dtypes_and_shapes =
&resource_handle_info.dtypes_and_shapes;
if (!GetMirroredResourceShape(RemoteTensorHandleInternal(in),
&dtypes_and_shapes)
dtypes_and_shapes)
.ok()) {
for (const auto& dtype_and_shape_proto :
in.resource_dtypes_and_shapes()) {
dtypes_and_shapes.push_back(DtypeAndPartialTensorShape{
dtypes_and_shapes->push_back(DtypeAndPartialTensorShape{
dtype_and_shape_proto.dtype(),
TensorShape(dtype_and_shape_proto.shape())});
}
mutex_lock l(mirrored_resource_shape_mu_);
mirrored_resource_shape_map_.emplace(
RemoteTensorHandleInternal(in.op_id(), in.output_num()),
dtypes_and_shapes);
*dtypes_and_shapes);
}
(*out)->SetResourceHandleDtypeAndShape(std::move(dtypes_and_shapes));
(*out)->SetResourceHandleInfo(std::move(resource_handle_info));
}
return Status::OK();
......
......@@ -314,11 +314,13 @@ ResourceHandle MakeResourceHandle(
template <typename T>
ResourceHandle MakeResourceHandle(
OpKernelConstruction* ctx, const string& container, const string& name,
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {}) {
return MakeResourceHandle(
container.empty() ? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
const std::vector<DtypeAndPartialTensorShape>& dtypes_and_shapes = {},
const std::vector<string>& allowed_devices = {}) {
return MakeResourceHandle(container.empty()
? ctx->resource_manager()->default_container()
: container,
name, *ctx->device(), MakeTypeIndex<T>(),
dtypes_and_shapes, allowed_devices);
}
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
......
......@@ -229,6 +229,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("dtype", &dtype_and_shape_.dtype));
PartialTensorShape shape;
OP_REQUIRES_OK(context, context->GetAttr("shape", &dtype_and_shape_.shape));
OP_REQUIRES_OK(context,
context->GetAttr("allowed_devices", &allowed_devices_));
is_anonymous_ = name_ == ResourceHandle::ANONYMOUS_NAME;
......@@ -239,7 +241,8 @@ VarHandleOp::VarHandleOp(OpKernelConstruction* context) : OpKernel(context) {
&resource_, attr));
resource_.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
context, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
allowed_devices_);
}
}
......@@ -252,7 +255,8 @@ void VarHandleOp::Compute(OpKernelContext* ctx) {
ctx, ctx->allocate_temp(DT_RESOURCE, TensorShape({}), &handle, attr));
handle.scalar<ResourceHandle>()() = MakeResourceHandle<Var>(
ctx, container_, name_,
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_});
std::vector<DtypeAndPartialTensorShape>{dtype_and_shape_},
allowed_devices_);
ctx->set_output(0, handle);
} else {
ctx->set_output(0, resource_);
......
......@@ -36,6 +36,10 @@ class VarHandleOp : public OpKernel {
Tensor resource_;
DtypeAndPartialTensorShape dtype_and_shape_;
// A set of devices containing the resource variable. Set when the output
// ResourceHandle represents a per-replica/partitioned resource variable.
std::vector<string> allowed_devices_;
};
class ReadVariableOp : public OpKernel {
......
......@@ -80,6 +80,7 @@ REGISTER_OP("VarHandleOp")
.Attr("shared_name: string = ''")
.Attr("dtype: type")
.Attr("shape: shape")
.Attr("allowed_devices: list(string) = []")
.Output("resource: resource")
.SetIsStateful()
.SetShapeFn([](InferenceContext* c) {
......
......@@ -30,6 +30,7 @@ from tensorflow.core.framework import tensor_pb2
from tensorflow.python.eager import backprop
from tensorflow.python.eager import context
from tensorflow.python.eager import def_function
from tensorflow.python.framework import config
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import cpp_shape_inference_pb2
from tensorflow.python.framework import dtypes
......@@ -1488,5 +1489,40 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase,
self.assertAllEqual(expected, result)
class PerReplicaResourceHandleTest(test_util.TensorFlowTestCase):
def setUp(self):
super(PerReplicaResourceHandleTest, self).setUp()
cpus = config.list_physical_devices("CPU")
# Set 2 virtual CPUs
config.set_logical_device_configuration(cpus[0], [
context.LogicalDeviceConfiguration(),
context.LogicalDeviceConfiguration(),
])
def testAllowedDevices(self):
device0 = "/job:localhost/replica:0/task:0/device:CPU:0"
device1 = "/job:localhost/replica:0/task:0/device:CPU:1"
value0 = 1
value1 = 2
with context.eager_mode():
handle = resource_variable_ops.var_handle_op(
dtype=dtypes.int32, shape=[], allowed_devices=[device0, device1])
with ops.device(device0):
assign0 = resource_variable_ops.assign_variable_op(handle, value0)
with ops.device(device1):
assign1 = resource_variable_ops.assign_variable_op(handle, value1)
with ops.control_dependencies([assign0, assign1]):
with ops.device(device0):
read0 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
with ops.device(device1):
read1 = resource_variable_ops.read_variable_op(
handle, dtype=dtypes.int32)
self.assertAllEqual(value0, read0)
self.assertAllEqual(value1, read1)
if __name__ == "__main__":
test.main()
......@@ -4926,7 +4926,7 @@ tf_module {
}
member_method {
name: "VarHandleOp"
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], "
}
member_method {
name: "VarIsInitializedOp"
......
......@@ -4926,7 +4926,7 @@ tf_module {
}
member_method {
name: "VarHandleOp"
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'None\'], "
argspec: "args=[\'dtype\', \'shape\', \'container\', \'shared_name\', \'allowed_devices\', \'name\'], varargs=None, keywords=None, defaults=[\'\', \'\', \'[]\', \'None\'], "
}
member_method {
name: "VarIsInitializedOp"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册