diff --git a/tensorflow/compiler/jit/xla_device_context.cc b/tensorflow/compiler/jit/xla_device_context.cc index 6e6532731e64bd42ee56aa719748988f321e0f17..1f3afe8822d441a5ce37617fe18d7767e9bc72e4 100644 --- a/tensorflow/compiler/jit/xla_device_context.cc +++ b/tensorflow/compiler/jit/xla_device_context.cc @@ -79,6 +79,13 @@ XlaDeviceContext::XlaDeviceContext( } } +void XlaDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("XLA->XLA same-device copies not implemented.")); +} + void XlaDeviceContext::CopyCPUTensorToDevice(const Tensor* cpu_tensor, Device* device, Tensor* device_tensor, diff --git a/tensorflow/compiler/jit/xla_device_context.h b/tensorflow/compiler/jit/xla_device_context.h index 1e18df197a2dd65590c5181b4dae4481dca36641..e45db989fac720df6c3458c93a6b8dbb0919f930 100644 --- a/tensorflow/compiler/jit/xla_device_context.h +++ b/tensorflow/compiler/jit/xla_device_context.h @@ -62,6 +62,9 @@ class XlaDeviceContext : public DeviceContext { void CopyDeviceTensorToCPU(const Tensor* device_tensor, absl::string_view tensor_name, Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; xla::LocalClient* client() const { return client_; } se::Stream* stream() const { return stream_.get(); } diff --git a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc index 4bc88ffc8c3950176ae05f32c774f2f2971a4e34..0ef39fb3d78044a8611b315afbdeb4975a3af15f 100644 --- a/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc +++ b/tensorflow/core/common_runtime/gpu/gpu_util_platform_specific.cc @@ -37,6 +37,14 @@ void GPUDeviceContext::CopyDeviceTensorToCPU(const Tensor* device_tensor, GPUUtil::CopyGPUTensorToCPU(device, this, device_tensor, cpu_tensor, done); } +void GPUDeviceContext::CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, + Tensor* output_tensor, + StatusCallback done) const { + GPUUtil::CopyGPUTensorToSameGPU(device, this, input_tensor, output_tensor, + done); +} + Status GPUDeviceContext::ThenExecute(Device* device, se::Stream* stream, std::function func) { const DeviceBase::GpuDeviceInfo* gpu_info = diff --git a/tensorflow/core/common_runtime/gpu_device_context.h b/tensorflow/core/common_runtime/gpu_device_context.h index 3603808152748009f29d1d01f0eeee0dd8b6ab0e..f5135267241db94a0afdd9845b09dbfdda242ecc 100644 --- a/tensorflow/core/common_runtime/gpu_device_context.h +++ b/tensorflow/core/common_runtime/gpu_device_context.h @@ -57,6 +57,10 @@ class GPUDeviceContext : public DeviceContext { Device* device, Tensor* cpu_tensor, StatusCallback done) override; + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override; + void MaintainLifetimeOnStream(const Tensor* t, se::Stream* stream) const override {} diff --git a/tensorflow/core/framework/device_base.h b/tensorflow/core/framework/device_base.h index 446c31b17f2904da3143438304d6407bd65c450c..321947aca8e06008c3291fa43befa389b53f998c 100644 --- a/tensorflow/core/framework/device_base.h +++ b/tensorflow/core/framework/device_base.h @@ -82,6 +82,13 @@ class DeviceContext : public core::RefCounted { done(errors::Internal("Unrecognized device type in CPU-to-device Copy")); } + // Copies a tensor in this device. + virtual void CopyTensorInSameDevice(const Tensor* input_tensor, + Device* device, Tensor* output_tensor, + StatusCallback done) const { + done(errors::Unimplemented("Copy in same device not implemented.")); + } + // "device_tensor" is a tensor on a non-CPU device. Copies // device_tensor into "cpu_tensor". "cpu_tensor" must be allocated // to be of the same size as "device_tensor". diff --git a/tensorflow/core/framework/rendezvous_test.cc b/tensorflow/core/framework/rendezvous_test.cc index de148f0bd3474421c1361cf7ae4aa681107aa883..7a777f064c7b517de9f9c1c14648e5ff32ca4b5e 100644 --- a/tensorflow/core/framework/rendezvous_test.cc +++ b/tensorflow/core/framework/rendezvous_test.cc @@ -278,6 +278,12 @@ class DummyDeviceContext : public DeviceContext { ~DummyDeviceContext() override {} int stream_id() const { return stream_id_; } + void CopyTensorInSameDevice(const Tensor* input_tensor, Device* device, + Tensor* output_tensor, + StatusCallback done) const override { + done(Status::OK()); + } + private: const int stream_id_; }; diff --git a/tensorflow/core/framework/resource_var.h b/tensorflow/core/framework/resource_var.h index ff7b3e78a711a717d44e1e2ca307d6fef05243d9..f5de5dba8854adcfd5b94447da3ba42566a26bd8 100644 --- a/tensorflow/core/framework/resource_var.h +++ b/tensorflow/core/framework/resource_var.h @@ -20,14 +20,46 @@ limitations under the License. namespace tensorflow { -// Resource stored by variables in the resource manager -// (new, resource-style version). +// Resource stored by variables in the resource manager (new, resource-style +// version). +// +// These variables have a mixed access mode: they can operate on copy-on-write +// mode (the default) or copy-on-read mode (used only for sparse access). +// +// When copy-on-write mode is enabled reading the value of the variable involves +// grabbing its mutex in shared mode and aliasing the internal tensor as the +// output of the read operation, increasing its reference count. Writing, +// conversely, works by, under an exclusive lock, detecting whether there are +// outstanding aliases of the tensor, using the reference count, copying the +// tensor if they exist, and writing to either the original or a copy with no +// outstanding aliases. Sparse operations are not supported in copy-on-write +// mode. +// +// When a variable is accessed sparsely it switches to copy-on-read mode. To +// switch we need to grab an exclusive lock and might (if there are aliases) +// need to copy the entire tensor. Once copy-on-read mode is enabled, no tensor +// is allowed to alias the variable's internal tensor. This means dense reads +// must return a copy of the variable, done while holding a shared lock. Dense +// writes do not need to check whether aliases exist, and can always write +// directly to the buffer without making a copy, while holding an exclusive +// lock. Sparse reads and sparse writes, on the other hand, can be done under a +// shared or exclusive mutex (the damage from writes under a shared mutex is +// limited since no other buffer is allowed to alias the variable's +// buffer). Using an exclusive mutex disallows concurrent writes and concurrent +// sparse reads, providing some extra safety at the expense of performance, +// while shared mutex allow for "hogwild" behavior. Doing sparse writes under a +// shared mutex prevents them from overlapping with dense writes, which is +// necessary as dense writes can change the shape the of the tensor. +// +// Transitioning a variable from copy-on-read mode to copy-on-write mode is +// currently not supported. To upgrade a variable from copy-on-write to +// copy-on-read use `EnsureSparseVariableAccess()`, and then grab the variable's +// mutex as desired. To access the variable in dense mode grab the mutex either +// directly or via `MaybeLockVariableInputMutexesInOrder` on all variables being +// modified and then call `PrepareToUpdateVariable` on them in any order. class Var : public ResourceBase { public: explicit Var(DataType dtype) : tensor_(dtype) {} - // Not copyable or movable. - Var(const Var&) = delete; - Var& operator=(const Var&) = delete; // When locking multiple variables, the locks must be acquired in order of // increasing mu() address. @@ -48,11 +80,19 @@ class Var : public ResourceBase { bool is_initialized = false; // GUARDED_BY(mu_) but annotalysis doesn't like // it. + // Also fake-guarded by mu_. Should be set to True whenever any sparse + // operation uses the variable. Once this is true no tensor is allowed to + // alias the memory of the variable, and we always copy the variable on + // reads. This allows sparse operations to happen with only a shared lock if + // so desired. + std::atomic copy_on_read_mode{false}; + private: mutex mu_; Tensor tensor_; ~Var() override {} + TF_DISALLOW_COPY_AND_ASSIGN(Var); }; } // end namespace tensorflow diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 6e03cf9f6f47c89289ffaec507f56d8c734e52a9..009dd0846d2639eb9cf1ef47f8f12c10994dcb3b 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -45,6 +45,7 @@ class TensorBuffer; class TensorCApi; class TensorDescription; class TensorProto; +class Var; namespace batch_util { Status CopyElementToSlice(Tensor element, Tensor* parent, int64 index); @@ -581,11 +582,16 @@ class Tensor { friend class XlaTensor; // For access to RefCountIsOne(). friend class XlaTensorBuffer; // For access to the private constructor taking // the buffer + friend class Var; template friend class AssignVariableOp; // For access to RefCountIsOne(). template friend Status PrepareToUpdateVariable( - OpKernelContext* ctx, Tensor* tensor); // For access to RefCountIsOne(). + OpKernelContext* ctx, Tensor* tensor, + bool copy_on_read_mode); // For access to RefCountIsOne(). + template + friend Status EnsureSparseVariableAccess( + OpKernelContext* ctx, Var* var); // For access to RefCountIsOne(). friend Status batch_util::CopyElementToSlice( Tensor element, Tensor* parent, int64 index); // For access to RefCountIsOne(). diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 0e5d8d765a6bfde3a0e187c0b386174d3b20a098..e8b1dd270ff3053a906a1d4f29b632cd4635775f 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -2196,6 +2196,7 @@ tf_kernel_library( ":state", ":training_op_helpers", ":variable_ops", + "//tensorflow/core:core_cpu_lib", "//tensorflow/core:framework", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index 170b08b4b7f6c8a6842dd12ad7389900b2d83b86..4167b6005194409d780b3698fda688728a50b3cc 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -55,6 +55,7 @@ limitations under the License. #include #include "absl/strings/str_join.h" +#include "tensorflow/core/common_runtime/device.h" #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" @@ -84,6 +85,47 @@ ReadVariableOp::ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); } +namespace { +Status CopyVariable(int output_idx, OpKernelContext* ctx, const Tensor* t) { + Tensor* output; + Notification n; + Status status; + AllocatorAttributes attr; + if (t->dtype() == DT_VARIANT) { + attr.set_on_host(true); + } + TF_RETURN_IF_ERROR( + ctx->allocate_output(output_idx, t->shape(), &output, attr)); + if (t->dtype() == DT_VARIANT) { + output->flat() = t->flat(); + } else if (ctx->op_device_context() != nullptr) { + // TODO(apassos): remove the down_cast by just returning Device* from + // OpKernelContext + Device* device = static_cast(ctx->device()); + ctx->op_device_context()->CopyTensorInSameDevice( + t, device, output, [&n, &status](const Status& s) { + status = s; + n.Notify(); + }); + n.WaitForNotification(); + return status; + } else { + switch (t->dtype()) { +#define HANDLER(type) \ + case DataTypeToEnum::value: \ + output->flat() = t->flat(); \ + break; + TF_CALL_ALL_TYPES(HANDLER); +#undef HANDLER + default: + return errors::Internal("Unsupported dtype", t->dtype()); + } + } + return Status::OK(); +} + +} // namespace + void ReadVariableOp::Compute(OpKernelContext* ctx) { Var* variable = nullptr; const ResourceHandle& handle = HandleFromInput(ctx, 0); @@ -100,12 +142,16 @@ void ReadVariableOp::Compute(OpKernelContext* ctx) { // holding a shared lock to guarantee ordering of reads and // writes. tf_shared_lock ml(*variable->mu()); - const Tensor& t = *variable->tensor(); - OP_REQUIRES(ctx, dtype_ == t.dtype(), + const Tensor* t = variable->tensor(); + OP_REQUIRES(ctx, dtype_ == t->dtype(), errors::InvalidArgument( "Trying to read variable with wrong dtype. Expected ", - DataTypeString(dtype_), " got ", DataTypeString(t.dtype()))); - ctx->set_output(0, t); + DataTypeString(dtype_), " got ", DataTypeString(t->dtype()))); + if (variable->copy_on_read_mode.load()) { + OP_REQUIRES_OK(ctx, CopyVariable(0, ctx, t)); + } else { + ctx->set_output(0, *t); + } } ReadVariablesOp::ReadVariablesOp(OpKernelConstruction* c) : OpKernel(c) { @@ -146,14 +192,18 @@ void ReadVariablesOp::Compute(OpKernelContext* ctx) { // holding a shared lock to guarantee ordering of reads and // writes. tf_shared_lock ml(*variables[i]->mu()); - const Tensor& t = *variables[i]->tensor(); - OP_REQUIRES(ctx, dtypes_[i] == t.dtype(), + OP_REQUIRES(ctx, dtypes_[i] == variables[i]->tensor()->dtype(), errors::InvalidArgument( "Trying to read variable ", handles[i]->name(), " from Container: ", handles[i]->container(), " with wrong dtype. Expected ", DataTypeString(dtypes_[i]), - " got ", DataTypeString(t.dtype()))); - ctx->set_output(i, t); + " got ", DataTypeString(variables[i]->tensor()->dtype()))); + if (variables[i]->copy_on_read_mode.load()) { + OP_REQUIRES_OK(ctx, CopyVariable(i, ctx, variables[i]->tensor())); + } else { + const Tensor& t = *variables[i]->tensor(); + ctx->set_output(i, t); + } } } @@ -308,8 +358,23 @@ class AssignVariableOp : public OpKernel { "Trying to assign variable with wrong dtype. Expected ", DataTypeString(variable->tensor()->dtype()), " got ", DataTypeString(dtype_))); + if (variable->copy_on_read_mode.load()) { + PersistentTensor unused; + Tensor* tmp; + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + OP_REQUIRES_OK(context, + context->allocate_persistent(value.dtype(), value.shape(), + &unused, &tmp, attr)); + functor::DenseUpdate copy_functor; + copy_functor(context->eigen_device(), tmp->flat(), + value.flat()); + *variable->tensor() = *tmp; + } else { + *variable->tensor() = value; + } variable->is_initialized = true; - *variable->tensor() = value; } private: @@ -442,8 +507,9 @@ class AssignUpdateVariableOp : public OpKernel { " using a Tensor with shape ", value.shape().DebugString(), ", shapes must be equal.")); - OP_REQUIRES_OK(context, - PrepareToUpdateVariable(context, var_tensor)); + OP_REQUIRES_OK( + context, PrepareToUpdateVariable( + context, var_tensor, variable->copy_on_read_mode.load())); functor::DenseUpdate update_functor; update_functor(context->eigen_device(), var_tensor->flat(), value.flat()); @@ -524,6 +590,7 @@ class ResourceGatherOp : public OpKernel { Var* v = nullptr; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); core::ScopedUnref su(v); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); // NOTE: We hold the lock for the whole gather operation instead // of increasing the reference count of v->tensor() to avoid a // situation where a write to the same variable will see a @@ -639,9 +706,9 @@ class ResourceScatterUpdateOp : public OpKernel { Var* v = nullptr; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); core::ScopedUnref unref_v(v); - mutex_lock ml(*v->mu()); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); + tf_shared_lock ml(*v->mu()); Tensor* params = v->tensor(); - OP_REQUIRES_OK(c, PrepareToUpdateVariable(c, params)); const Tensor& indices = c->input(1); const Tensor& updates = c->input(2); diff --git a/tensorflow/core/kernels/scatter_nd_op.cc b/tensorflow/core/kernels/scatter_nd_op.cc index 63bb793fdcb7eb20daeee1708cb4ba78274cb9f7..b466e572495ae709d0fb05d58d964ee358077558 100644 --- a/tensorflow/core/kernels/scatter_nd_op.cc +++ b/tensorflow/core/kernels/scatter_nd_op.cc @@ -231,6 +231,7 @@ class ScatterNdUpdateOp : public OpKernel { Var* v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); core::ScopedUnref scoped_unref(v); + OP_REQUIRES_OK(c, EnsureSparseVariableAccess(c, v)); mutex_lock m(*v->mu()); DoCompute(c); } else if (use_exclusive_lock_) { @@ -258,7 +259,6 @@ class ScatterNdUpdateOp : public OpKernel { Var* v; OP_REQUIRES_OK(c, LookupResource(c, HandleFromInput(c, 0), &v)); Tensor* t = v->tensor(); - OP_REQUIRES_OK(c, PrepareToUpdateVariable(c, t)); params = *t; params_shape = params.shape(); } else if (IsRefType(c->input_dtype(0))) { diff --git a/tensorflow/core/kernels/strided_slice_op.cc b/tensorflow/core/kernels/strided_slice_op.cc index 70a7ddbd0643e88655e1c0e1ad197316078267de..6db68f937def6fb4827b7fc85bff873b651a0002 100644 --- a/tensorflow/core/kernels/strided_slice_op.cc +++ b/tensorflow/core/kernels/strided_slice_op.cc @@ -307,9 +307,9 @@ class StridedSliceAssignOp : public OpKernel { OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), &v)); core::ScopedUnref scoped_unref(v); - mutex_lock ml(*v->mu()); OP_REQUIRES_OK(context, - PrepareToUpdateVariable(context, v->tensor())); + EnsureSparseVariableAccess(context, v)); + mutex_lock ml(*v->mu()); old_lhs = v->tensor(); OP_REQUIRES(context, old_lhs->dtype() == DataTypeToEnum::value, errors::InvalidArgument( diff --git a/tensorflow/core/kernels/training_op_helpers.cc b/tensorflow/core/kernels/training_op_helpers.cc index 4262a5404b6ac233d0fe7a8453e3e875eb9caf1f..20c08cf8fbb6b911c8b89b719237ac4677151e3c 100644 --- a/tensorflow/core/kernels/training_op_helpers.cc +++ b/tensorflow/core/kernels/training_op_helpers.cc @@ -19,70 +19,6 @@ limitations under the License. namespace tensorflow { -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, - Var** maybe_resource) { - *maybe_resource = nullptr; - if (ctx->input_dtype(input) == DT_RESOURCE) { - if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { - return (*maybe_resource)->mu(); - } else { - ctx->CtxFailureWithWarning( - errors::Internal("Invalid variable reference.")); - return nullptr; - } - } - return ctx->input_ref_mutex(input); -} - -// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes -// in address order to mitigate deadlock. Returns a structure that, when -// deleted, will release the acquired mutexes. Safe to pass duplicates - will -// only lock each distinct mutex once. If do_lock is false, returns -// immediately. Note that this silently doesn't lock mutexes for invalid -// variable references; in all usages this is followed by GetInputTensor which -// will signal a failure. -VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( - OpKernelContext* ctx, bool do_lock, const std::vector& input_ids) { - bool any_resource = false; - for (auto i : input_ids) { - if (ctx->input_dtype(i) == DT_RESOURCE) { - any_resource = true; - break; - } - } - if (!do_lock && !any_resource) { - return VariableInputLockHolder({}, {}); - } - std::vector vars; - std::vector mutexes; - std::vector acquire_order; - for (auto input : input_ids) { - Var* var; - mutex* mutex = GetTrainingVariableMutex(ctx, input, &var); - if (var) vars.push_back(var); - // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). - if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { - acquire_order.push_back(mutexes.size()); - mutexes.push_back(mutex); - } - } - std::sort(acquire_order.begin(), acquire_order.end(), - [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); - - std::unique_ptr> locks = - MakeUnique>(); - locks->reserve(acquire_order.size()); - - for (auto input : acquire_order) { - Var* var; - mutex* mu = GetTrainingVariableMutex(ctx, input, &var); - core::ScopedUnref scoped_unref(var); - if (mu != nullptr) { - locks->emplace_back(*mu); - } - } - return VariableInputLockHolder(std::move(vars), std::move(locks)); -} void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, int output) { diff --git a/tensorflow/core/kernels/training_op_helpers.h b/tensorflow/core/kernels/training_op_helpers.h index 9f173a80f74612beaa4da265658eafb5b9e92360..e96cd023fc25d7fb18632d2ced2a35fdf4e70973 100644 --- a/tensorflow/core/kernels/training_op_helpers.h +++ b/tensorflow/core/kernels/training_op_helpers.h @@ -17,30 +17,72 @@ limitations under the License. #define TENSORFLOW_CORE_KERNELS_TRAINING_OP_HELPERS_H_ #include "tensorflow/core/framework/op_kernel.h" +#include "tensorflow/core/framework/tensor.h" #include "tensorflow/core/framework/variant_op_registry.h" #include "tensorflow/core/kernels/dense_update_functor.h" #include "tensorflow/core/kernels/variable_ops.h" namespace tensorflow { -// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. -// -// If `input` corresponds to a `DT_RESOURCE`-type variable input, -// `*maybe_resource` will be updated to contain the underlying resource, and the -// caller will be responsible for calling `Unref()` on that resource. -mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, - Var** maybe_resource); +// Must be called before performing a sparse operation on a variable. Ensures +// that no concurrent dense operations can happen while holding the variable's +// lock. +template +Status EnsureSparseVariableAccess(OpKernelContext* ctx, Var* var) { + if (var->copy_on_read_mode.load()) { + return Status::OK(); + } + mutex_lock ml(*var->mu()); + // Once copy-on-read mode is True the refcount is guaranteed to be 1. This can + // also happen if there are no concurrent reads of the variable and + // copy-on-read mode is false. + if (var->tensor()->RefCountIsOne()) { + var->copy_on_read_mode.store(true); + return Status::OK(); + } + PersistentTensor unused; + Tensor* tmp; + if (std::is_same::value) { + AllocatorAttributes attr; + attr.set_on_host(true); + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr)); + + const auto elements_in = var->tensor()->flat(); + auto elements_out = tmp->flat(); + for (int64 i = 0; i < elements_in.size(); ++i) { + elements_out(i) = elements_in(i); + } + } else { + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + TF_RETURN_IF_ERROR(ctx->allocate_persistent( + var->tensor()->dtype(), var->tensor()->shape(), &unused, &tmp, attr)); + functor::DenseUpdate copy_functor; + copy_functor(ctx->eigen_device(), tmp->flat(), + const_cast(var->tensor())->flat()); + } + *var->tensor() = *tmp; + var->copy_on_read_mode.store(true); + return Status::OK(); +} // Utility structure that releases a sequence of borrowed mutexes when it is // deleted. struct VariableInputLockHolder { public: - VariableInputLockHolder(std::vector vars, - std::unique_ptr> locks) - : vars_(std::move(vars)), locks_(std::move(locks)) {} + VariableInputLockHolder( + std::vector vars, std::unique_ptr> locks, + std::unique_ptr> shared_locks) + : vars_(std::move(vars)), + locks_(std::move(locks)), + shared_locks_(std::move(shared_locks)) {} VariableInputLockHolder(VariableInputLockHolder&& other) - : vars_(std::move(other.vars_)), locks_(std::move(other.locks_)) {} + : vars_(std::move(other.vars_)), + locks_(std::move(other.locks_)), + shared_locks_(std::move(other.shared_locks_)) {} ~VariableInputLockHolder() { // Release the locks before unreffing the Vars, because each lock @@ -56,10 +98,95 @@ struct VariableInputLockHolder { // NOTE: Use a `std::unique_ptr` instead of moving in a vector directly, // because a `std::vector` is not movable on all platforms. std::unique_ptr> locks_; + std::unique_ptr> shared_locks_; }; +// Returns a borrowed pointer to the mutex for the variable `input` in `ctx`. +// +// If `input` corresponds to a `DT_RESOURCE`-type variable input, +// `*maybe_resource` will be updated to contain the underlying resource, and the +// caller will be responsible for calling `Unref()` on that resource. +template +mutex* GetTrainingVariableMutex(OpKernelContext* ctx, int input, bool sparse, + Var** maybe_resource) { + *maybe_resource = nullptr; + if (ctx->input_dtype(input) == DT_RESOURCE) { + if (LookupResource(ctx, HandleFromInput(ctx, input), maybe_resource).ok()) { + if (sparse) { + EnsureSparseVariableAccess(ctx, *maybe_resource); + } + return (*maybe_resource)->mu(); + } else { + ctx->CtxFailureWithWarning( + errors::Internal("Invalid variable reference.")); + return nullptr; + } + } + return ctx->input_ref_mutex(input); +} + +// MaybeLockVariableInputMutexesInOrder is a helper function to acquire mutexes +// in address order to mitigate deadlock. Returns a structure that, when +// deleted, will release the acquired mutexes. Safe to pass duplicates - will +// only lock each distinct mutex once. If sparse is true will ensure the +// variable gets switched to copy-on-read mode before trying to acquire the +// locks. If do_lock is false, returns immediately for reference variables. For +// resource variables in copy-on-read-mode it will grab a shared lock if do_lock +// is false, exclusive lock otherwise. Note that this silently doesn't lock +// mutexes for invalid variable references; in all usages this is followed by +// GetInputTensor which will signal a failure. +template VariableInputLockHolder MaybeLockVariableInputMutexesInOrder( - OpKernelContext* ctx, bool do_lock, const std::vector& input_ids); + OpKernelContext* ctx, bool do_lock, bool sparse, + const std::vector& input_ids) { + bool any_resource = false; + for (auto i : input_ids) { + if (ctx->input_dtype(i) == DT_RESOURCE) { + any_resource = true; + break; + } + } + if (!do_lock && !any_resource) { + return VariableInputLockHolder({}, {}, {}); + } + std::vector vars; + std::vector mutexes; + std::vector acquire_order; + for (auto input : input_ids) { + Var* var; + mutex* mutex = + GetTrainingVariableMutex(ctx, input, sparse, &var); + if (var) vars.push_back(var); + // Only lock each mutex once if duplicates exist (n^2 but n is 2 or 3). + if (std::find(mutexes.begin(), mutexes.end(), mutex) == mutexes.end()) { + acquire_order.push_back(mutexes.size()); + mutexes.push_back(mutex); + } + } + std::sort(acquire_order.begin(), acquire_order.end(), + [&mutexes](int a, int b) { return mutexes[a] < mutexes[b]; }); + + std::unique_ptr> locks = + absl::make_unique>(); + std::unique_ptr> shared_locks = + absl::make_unique>(); + locks->reserve(acquire_order.size()); + + for (auto input : acquire_order) { + Var* var; + mutex* mu = GetTrainingVariableMutex(ctx, input, sparse, &var); + core::ScopedUnref scoped_unref(var); + if (mu != nullptr) { + if (do_lock) { + locks->emplace_back(*mu); + } else { + shared_locks->emplace_back(*mu); + } + } + } + return VariableInputLockHolder(std::move(vars), std::move(locks), + std::move(shared_locks)); +} void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, int output); @@ -68,8 +195,9 @@ void MaybeForwardRefInputToRefOutput(OpKernelContext* ctx, int input, // reference count of 1 before you update it. // REQUIRES: If you pass in variable->tensor(), *variable->mu() must be held. template -Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) { - if (!tensor->RefCountIsOne()) { +Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor, + bool copy_on_read_mode) { + if (copy_on_read_mode || !tensor->RefCountIsOne()) { // Tensor's buffer is in use by some read, so we need to copy before // updating. PersistentTensor unused; @@ -100,12 +228,14 @@ Status PrepareToUpdateVariable(OpKernelContext* ctx, Tensor* tensor) { return Status::OK(); } -// This gives you `*out`, a tensor you can update, corresponding to a -// variable passed as input index `input`. This handles the -// differences between reference and resource variables. For resource -// variables, we ensure `*out` has a reference count of 1 (using -// PrepareToUpdateVariable() to copy if necessary) unless -// sparse && !lock_held, in which case it never copies. +// This gives you `*out`, a tensor you can update, corresponding to a variable +// passed as input index `input`. This handles the differences between +// reference and resource variables. For reference variables we can just grab +// the tensor, grabbing the lock if lock_held is False. +// +// For resource variables we, if sparse is true, ensure it's in copy-on-read +// mode, and then, regardless of the value of sparse, ensure its refcount is 1 +// (by potentially copying its contents). In this case lock_held is ignored. template Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, bool lock_held, bool sparse, Tensor* out) { @@ -113,7 +243,13 @@ Status GetInputTensorFromVariable(OpKernelContext* ctx, int input, Var* var; TF_RETURN_IF_ERROR(LookupResource(ctx, HandleFromInput(ctx, input), &var)); core::ScopedUnref unref_var(var); - TF_RETURN_IF_ERROR(PrepareToUpdateVariable(ctx, var->tensor())); + if (sparse) { + TF_RETURN_IF_ERROR(EnsureSparseVariableAccess(ctx, var)); + *out = *var->tensor(); + return Status::OK(); + } + TF_RETURN_IF_ERROR(PrepareToUpdateVariable( + ctx, var->tensor(), var->copy_on_read_mode.load())); *out = *var->tensor(); return Status::OK(); } diff --git a/tensorflow/core/kernels/training_ops.cc b/tensorflow/core/kernels/training_ops.cc index 6504ad1b09c089cafec8c2b0ce0f2971aa506b52..b2239ab5c39fea33fc70b6aaf170d456cd1ba3fe 100644 --- a/tensorflow/core/kernels/training_ops.cc +++ b/tensorflow/core/kernels/training_ops.cc @@ -465,11 +465,12 @@ class ApplyGradientDescentOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -506,11 +507,12 @@ class ApplyGradientDescentOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -600,7 +602,8 @@ class ApplyAdadeltaOp : public OpKernel { void Compute(OpKernelContext* ctx) override { Var* resource; - mutex* mu = GetTrainingVariableMutex(ctx, 0, &resource); + const bool sparse = false; + mutex* mu = GetTrainingVariableMutex(ctx, 0, sparse, &resource); core::ScopedUnref scoped_unref(resource); if (use_exclusive_lock_ && mu != nullptr) { mutex_lock l1(*mu); @@ -624,14 +627,16 @@ class ApplyAdadeltaOp : public OpKernel { void DoValidate(OpKernelContext* ctx) { Tensor var; + const bool sparse = false; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); Tensor accum_update; - OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &accum_update)); + OP_REQUIRES_OK( + ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, + sparse, &accum_update)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -678,14 +683,16 @@ class ApplyAdadeltaOp : public OpKernel { void DoCompute(OpKernelContext* ctx) { const Device& device = ctx->template eigen_device(); Tensor var; + const bool sparse = false; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); Tensor accum_update; - OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &accum_update)); + OP_REQUIRES_OK( + ctx, GetInputTensorFromVariable(ctx, 2, use_exclusive_lock_, + sparse, &accum_update)); const Tensor& lr = ctx->input(3); const Tensor& rho = ctx->input(4); @@ -751,7 +758,8 @@ class SparseApplyAdadeltaOp : public OpKernel { void Compute(OpKernelContext* ctx) override { Var* var; - mutex* mu = GetTrainingVariableMutex(ctx, 0, &var); + const bool sparse = true; + mutex* mu = GetTrainingVariableMutex(ctx, 0, sparse, &var); core::ScopedUnref scoped_unref(var); // mu_accum is actually the same mutex as mu_var since currently we use a // global mutex. @@ -767,14 +775,16 @@ class SparseApplyAdadeltaOp : public OpKernel { void DoCompute(OpKernelContext* ctx) { Tensor var; + const bool sparse = true; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum_grad; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &accum_grad)); + ctx, 1, use_exclusive_lock_, sparse, &accum_grad)); Tensor accum_update; - OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, true, &accum_update)); + OP_REQUIRES_OK(ctx, + GetInputTensorFromVariable( + ctx, 2, use_exclusive_lock_, sparse, &accum_update)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -907,11 +917,12 @@ class ApplyProximalGradientDescentOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -976,11 +987,12 @@ class SparseApplyProximalGradientDescentOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); OP_REQUIRES(ctx, TensorShapeUtils::IsVectorOrHigher(var.shape()), errors::InvalidArgument("var must be at least 1 dimensional")); @@ -1121,14 +1133,15 @@ class ApplyAdagradOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1214,14 +1227,15 @@ class ApplyProximalAdagradOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1316,14 +1330,15 @@ class SparseApplyAdagradOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1456,14 +1471,15 @@ class SparseApplyProximalAdagradOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1628,19 +1644,20 @@ class ApplyAdagradDAOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor gradient_accum; OP_REQUIRES_OK( ctx, GetInputTensorFromVariable(ctx, 1, use_exclusive_lock_, - false, &gradient_accum)); + sparse, &gradient_accum)); Tensor gradient_squared_accum; OP_REQUIRES_OK( ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &gradient_squared_accum)); + ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1729,19 +1746,20 @@ class SparseApplyAdagradDAOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor gradient_accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &gradient_accum)); + ctx, 1, use_exclusive_lock_, sparse, &gradient_accum)); Tensor gradient_squared_accum; OP_REQUIRES_OK( ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, true, &gradient_squared_accum)); + ctx, 2, use_exclusive_lock_, sparse, &gradient_squared_accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -1927,18 +1945,19 @@ class ApplyFtrlOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); Tensor linear; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &linear)); + ctx, 2, use_exclusive_lock_, sparse, &linear)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2079,17 +2098,18 @@ class SparseApplyFtrlOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); Tensor linear; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, true, &linear)); + ctx, 2, use_exclusive_lock_, sparse, &linear)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2353,15 +2373,16 @@ class ApplyMomentumOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2454,15 +2475,16 @@ class SparseApplyMomentumOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2572,15 +2594,16 @@ class ApplyKerasMomentumOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2671,15 +2694,16 @@ class SparseApplyKerasMomentumOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor accum; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &accum)); + ctx, 1, use_exclusive_lock_, sparse, &accum)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2783,18 +2807,19 @@ class ApplyAdamOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &m)); + ctx, 1, use_exclusive_lock_, sparse, &m)); Tensor v; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &v)); + ctx, 2, use_exclusive_lock_, sparse, &v)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -2873,18 +2898,19 @@ class ApplyAdamOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &m)); + ctx, 1, use_exclusive_lock_, sparse, &m)); Tensor v; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &v)); + ctx, 2, use_exclusive_lock_, sparse, &v)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -3043,21 +3069,22 @@ class ApplyAdamWithAmsgradOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &m)); + ctx, 1, use_exclusive_lock_, sparse, &m)); Tensor v; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &v)); + ctx, 2, use_exclusive_lock_, sparse, &v)); Tensor vhat; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 3, use_exclusive_lock_, false, &vhat)); + ctx, 3, use_exclusive_lock_, sparse, &vhat)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -3184,18 +3211,19 @@ class ApplyAdaMaxOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &m)); + ctx, 1, use_exclusive_lock_, sparse, &m)); Tensor v; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &v)); + ctx, 2, use_exclusive_lock_, sparse, &v)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -3312,18 +3340,19 @@ class ApplyRMSPropOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor ms; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &ms)); + ctx, 1, use_exclusive_lock_, sparse, &ms)); Tensor mom; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &mom)); + ctx, 2, use_exclusive_lock_, sparse, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -3394,21 +3423,22 @@ class ApplyCenteredRMSPropOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2, 3}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor mg; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &mg)); + ctx, 1, use_exclusive_lock_, sparse, &mg)); Tensor ms; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, false, &ms)); + ctx, 2, use_exclusive_lock_, sparse, &ms)); Tensor mom; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 3, use_exclusive_lock_, false, &mom)); + ctx, 3, use_exclusive_lock_, sparse, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -3553,18 +3583,19 @@ class SparseApplyRMSPropOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor ms; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &ms)); + ctx, 1, use_exclusive_lock_, sparse, &ms)); Tensor mom; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, true, &mom)); + ctx, 2, use_exclusive_lock_, sparse, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -3682,21 +3713,22 @@ class SparseApplyCenteredRMSPropOp : public OpKernel { } void Compute(OpKernelContext* ctx) override NO_THREAD_SAFETY_ANALYSIS { - auto locks = MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, - {0, 1, 2, 3}); + const bool sparse = true; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1, 2, 3}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, true, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor mg; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, true, &mg)); + ctx, 1, use_exclusive_lock_, sparse, &mg)); Tensor ms; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 2, use_exclusive_lock_, true, &ms)); + ctx, 2, use_exclusive_lock_, sparse, &ms)); Tensor mom; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 3, use_exclusive_lock_, true, &mom)); + ctx, 3, use_exclusive_lock_, sparse, &mom)); OP_REQUIRES( ctx, var.IsInitialized(), @@ -3852,15 +3884,16 @@ class ApplyAddSignOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &m)); + ctx, 1, use_exclusive_lock_, sparse, &m)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( @@ -3958,15 +3991,16 @@ class ApplyPowerSignOp : public OpKernel { } void Compute(OpKernelContext* ctx) override { - auto locks = - MaybeLockVariableInputMutexesInOrder(ctx, use_exclusive_lock_, {0, 1}); + const bool sparse = false; + auto locks = MaybeLockVariableInputMutexesInOrder( + ctx, use_exclusive_lock_, sparse, {0, 1}); Tensor var; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 0, use_exclusive_lock_, false, &var)); + ctx, 0, use_exclusive_lock_, sparse, &var)); Tensor m; OP_REQUIRES_OK(ctx, GetInputTensorFromVariable( - ctx, 1, use_exclusive_lock_, false, &m)); + ctx, 1, use_exclusive_lock_, sparse, &m)); OP_REQUIRES( ctx, var.IsInitialized(), errors::FailedPrecondition( diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index 433957fd1d38890c0952c443097e4955e1eb99cb..1dabcbb5c3f55029fbc590f94308c69fa4d5e326 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -36,6 +36,7 @@ from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import custom_gradient from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import init_ops +from tensorflow.python.ops import list_ops from tensorflow.python.ops import math_ops from tensorflow.python.ops import resource_variable_ops from tensorflow.python.ops import state_ops @@ -953,6 +954,19 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): state_ops.scatter_sub(v, [1], [3]) self.assertAllEqual([1.0, -1.0], v.numpy()) + def testScatterUpdateVariant(self): + with context.eager_mode(): + v = resource_variable_ops.ResourceVariable([ + list_ops.empty_tensor_list( + element_dtype=dtypes.float32, element_shape=[]) + ]) + v.scatter_update( + ops.IndexedSlices( + list_ops.tensor_list_from_tensor([1., 2.], element_shape=[]), 0)) + self.assertAllEqual( + list_ops.tensor_list_get_item(v[0], 0, element_dtype=dtypes.float32), + 1.) + def testScatterNdAddStateOps(self): with context.eager_mode(): v = resource_variable_ops.ResourceVariable(