From 1f8936bbb6c6e1939d3f14bcc11a0e61dfea555e Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 1 Nov 2016 12:25:22 -0800 Subject: [PATCH] Basic slow versions of resource-based variable ops which have the right semantics. Change: 137863207 --- tensorflow/core/framework/tensor.h | 1 + .../core/kernels/resource_variable_ops.cc | 154 +++++++++++++++++- tensorflow/core/ops/resource_variable_ops.cc | 90 +++++++++- .../resource_variable_ops_test.py | 37 +++++ .../python/ops/resource_variable_ops.py | 3 + 5 files changed, 275 insertions(+), 10 deletions(-) diff --git a/tensorflow/core/framework/tensor.h b/tensorflow/core/framework/tensor.h index 47d74d4defc..43e44e7a96e 100644 --- a/tensorflow/core/framework/tensor.h +++ b/tensorflow/core/framework/tensor.h @@ -435,6 +435,7 @@ class Tensor { friend class VariableOp; // For access to set_shape friend class AutoReloadVariableOp; // For access to set_shape friend class TensorTestHelper; // For access to set_shape + template friend class CreateVariableOp; // Creates a tensor with the input datatype, shape and buf. diff --git a/tensorflow/core/kernels/resource_variable_ops.cc b/tensorflow/core/kernels/resource_variable_ops.cc index fbe66e83860..8809cba41d5 100644 --- a/tensorflow/core/kernels/resource_variable_ops.cc +++ b/tensorflow/core/kernels/resource_variable_ops.cc @@ -13,9 +13,12 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#define EIGEN_USE_THREADS + #include "tensorflow/core/framework/op_kernel.h" #include "tensorflow/core/framework/register_types.h" #include "tensorflow/core/framework/resource_mgr.h" +#include "tensorflow/core/framework/tensor_types.h" #include "tensorflow/core/kernels/variable_ops.h" #include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/platform/mutex.h" @@ -25,25 +28,160 @@ namespace tensorflow { REGISTER_RESOURCE_HANDLE_KERNEL(Var); +template class CreateVariableOp : public OpKernel { public: CreateVariableOp(OpKernelConstruction* c) : OpKernel(c) { OP_REQUIRES_OK(c, c->GetAttr("dtype", &dtype_)); + OP_REQUIRES(c, DataTypeToEnum::value == dtype_, + errors::InvalidArgument( + "Dtypes don't match; expected ", DataTypeString(dtype_), + " got ", DataTypeString(DataTypeToEnum::value))); } - void Compute(OpKernelContext* c) override { + void Compute(OpKernelContext* context) override { Var* var = new Var(dtype_); - var->Ref(); - core::ScopedUnref ur(var); - OP_REQUIRES_OK(c, CreateResource(c, HandleFromInput(c, 0), var)); - // TODO(apassos): this currently does not initialize the tensor, so it's - // pointless, other than checking construction in tests. Fix this. + AllocatorAttributes attr; + attr.set_gpu_compatible(true); + attr.set_nic_compatible(true); + PersistentTensor copy; + Tensor value = context->input(1); + + // TODO(apassos): allocating and copying is unnecessary if we are the last + // user of the value tensor. This should essentially always be the case, yet + // the refcount is usually 2 instead of 1. Figure out what needs to change + // in the code to make this not be the case, so we can safely take + // ownership. + Tensor* tmp_copy = nullptr; + OP_REQUIRES_OK(context, context->allocate_persistent( + dtype_, value.shape(), ©, &tmp_copy, attr)); + *var->tensor() = *tmp_copy; + var->tensor()->flat().device(context->eigen_device()) = + value.flat(); + OP_REQUIRES_OK(context, CreateResource( + context, HandleFromInput(context, 0), var)); } private: DataType dtype_; }; -REGISTER_KERNEL_BUILDER(Name("CreateVariableOp").Device(DEVICE_CPU), - CreateVariableOp); + +// TODO(apassos) register for the GPU as well. +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("CreateVariableOp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + CreateVariableOp); + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +template +class ReadVariableOp : public OpKernel { + public: + ReadVariableOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* ctx) { + Var* variable = nullptr; + OP_REQUIRES_OK(ctx, + LookupResource(ctx, HandleFromInput(ctx, 0), &variable)); + core::ScopedUnref s(variable); + // TODO(apassos): It's possible to do copy-on-write here instead of always + // copying by coordinating with the writing code. Do this. This will also + // obviate the need to hold a lock here. + mutex_lock ml(*variable->mu()); + Tensor* out = nullptr; + OP_REQUIRES_OK(ctx, + ctx->allocate_output(0, variable->tensor()->shape(), &out)); + out->flat().device(ctx->eigen_device()) = + variable->tensor()->flat(); + } +}; + +// TODO(apassos) register for the GPU as well. +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER( \ + Name("ReadVariableOp").Device(DEVICE_CPU).TypeConstraint("dtype"), \ + ReadVariableOp); + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +template +class AssignVariableOp : public OpKernel { + public: + AssignVariableOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* context) override { + Var* variable = nullptr; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &variable)); + core::ScopedUnref s(variable); + + // TODO(apassos): holding a lock and copying is unnecessary if we are the + // last user of the value tensor. This should essentially always be the + // case, yet the refcount is usually 2 instead of 1. Figure out what needs + // to change in the code to make this not be the case, so we can safely take + // ownership. + mutex_lock ml(*variable->mu()); + Tensor value = context->input(1); + variable->tensor()->flat().device(context->eigen_device()) = + value.flat(); + } +}; + +// TODO(apassos) register for the GPU as well. +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("AssignVariableOp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + AssignVariableOp); + +TF_CALL_ALL_TYPES(REGISTER_KERNELS); +TF_CALL_QUANTIZED_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS + +template +class AssignAddVariableOp : public OpKernel { + public: + AssignAddVariableOp(OpKernelConstruction* c) : OpKernel(c) {} + + void Compute(OpKernelContext* context) override { + Var* variable = nullptr; + OP_REQUIRES_OK(context, LookupResource(context, HandleFromInput(context, 0), + &variable)); + core::ScopedUnref s(variable); + + // TODO(apassos): holding a lock and copying is unnecessary if we are the + // last user of the value tensor. This should essentially always be the + // case, yet the refcount is usually 2 instead of 1. Figure out what needs + // to change in the code to make this not be the case, so we can safely take + // ownership. + mutex_lock ml(*variable->mu()); + Tensor value = context->input(1); + variable->tensor()->flat().device(context->eigen_device()) += + value.flat(); + + // TODO(apassos): this read can also be implemented efficiently so it is + // free if no one uses the resulting tensor. + Tensor* out = nullptr; + OP_REQUIRES_OK(context, context->allocate_output( + 0, variable->tensor()->shape(), &out)); + out->flat().device(context->eigen_device()) = + variable->tensor()->flat(); + } +}; + +// TODO(apassos) register for the GPU as well. +#define REGISTER_KERNELS(type) \ + REGISTER_KERNEL_BUILDER(Name("AssignAddVariableOp") \ + .Device(DEVICE_CPU) \ + .TypeConstraint("dtype"), \ + AssignAddVariableOp); + +TF_CALL_NUMBER_TYPES(REGISTER_KERNELS); +#undef REGISTER_KERNELS } // namespace tensorflow diff --git a/tensorflow/core/ops/resource_variable_ops.cc b/tensorflow/core/ops/resource_variable_ops.cc index 6211b07ac58..9e28291070c 100644 --- a/tensorflow/core/ops/resource_variable_ops.cc +++ b/tensorflow/core/ops/resource_variable_ops.cc @@ -49,10 +49,88 @@ dtype: the type of this variable. Must agree with the dtypes shape: The (possibly partially specified) shape of this variable. )"); +Status CreateAssignShapeFn(shape_inference::InferenceContext* c) { + DataType handle_dtype = c->input_handle_dtype(0); + DataType value_dtype; + c->GetAttr("dtype", &value_dtype); + if (handle_dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to initialize handle for variable with wrong dtype. " + "Expected ", + handle_dtype, " got ", value_dtype); + } + shape_inference::ShapeHandle s = c->input_handle_shape(0); + shape_inference::ShapeHandle value_shape = c->input(1); + shape_inference::ShapeHandle unused; + TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); + return Status::OK(); +} + REGISTER_OP("CreateVariableOp") .Input("resource: resource") .Input("value: dtype") .Attr("dtype: type") + .SetShapeFn(CreateAssignShapeFn) + .Doc(R"( +Creates a variable resource. + +resource: handle to the resource in which to store the variable. +value: the value to set the new tensor to use. +dtype: the dtype of the value. +)"); + +REGISTER_OP("ReadVariableOp") + .Input("resource: resource") + .Output("value: dtype") + .Attr("dtype: type") + .SetShapeFn([](shape_inference::InferenceContext* c) { + DataType handle_dtype = c->input_handle_dtype(0); + DataType value_dtype; + c->GetAttr("dtype", &value_dtype); + if (handle_dtype != value_dtype) { + return errors::InvalidArgument( + "Trying to read variable with wrong dtype. " + "Expected ", + handle_dtype, " got ", value_dtype); + } + c->set_output(0, c->input_handle_shape(0)); + return Status::OK(); + }) + .Doc(R"( +Reads the value of a variable. + +The tensor returned by this operation is immutable. + +The value returned by this operation is guaranteed to be influenced by all the +writes on which this operation depends directly or indirectly, and to not be +influenced by any of the writes which depend directly or indirectly on this +operation. + +resource: handle to the resource in which to store the variable. +dtype: the dtype of the value. +)"); + +REGISTER_OP("AssignVariableOp") + .Input("resource: resource") + .Input("value: dtype") + .Attr("dtype: type") + .SetShapeFn(CreateAssignShapeFn) + .Doc(R"( +Assigns a new value to a variable. + +Any ReadVariableOp with a control dependency on this op is guaranteed to return +this value or a subsequent newer value of the variable. + +resource: handle to the resource in which to store the variable. +value: the value to set the new tensor to use. +dtype: the dtype of the value. +)"); + +REGISTER_OP("AssignAddVariableOp") + .Input("resource: resource") + .Input("value: dtype") + .Output("new_value: dtype") + .Attr("dtype: type") .SetShapeFn([](shape_inference::InferenceContext* c) { DataType handle_dtype = c->input_handle_dtype(0); DataType value_dtype; @@ -67,13 +145,21 @@ REGISTER_OP("CreateVariableOp") shape_inference::ShapeHandle value_shape = c->input(1); shape_inference::ShapeHandle unused; TF_RETURN_IF_ERROR(c->Merge(s, value_shape, &unused)); + c->set_output(0, value_shape); return Status::OK(); }) .Doc(R"( -Creates a variable resource. +Adds a value to the current value of a variable. + +Any ReadVariableOp which depends directly or indirectly on this assign is +guaranteed to see the incremented value or a subsequent newer one. + +Outputs the incremented value, which can be used to totally order the +increments to this variable. resource: handle to the resource in which to store the variable. -value: the value to set the new tensor to use. +value: the value by which the variable will be incremented. +new_value: the new value of the variable. dtype: the dtype of the value. )"); diff --git a/tensorflow/python/kernel_tests/resource_variable_ops_test.py b/tensorflow/python/kernel_tests/resource_variable_ops_test.py index cb4375ce913..116939dc2d8 100644 --- a/tensorflow/python/kernel_tests/resource_variable_ops_test.py +++ b/tensorflow/python/kernel_tests/resource_variable_ops_test.py @@ -19,6 +19,7 @@ from __future__ import print_function from tensorflow.python.framework import constant_op from tensorflow.python.framework import dtypes +from tensorflow.python.framework import ops from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import resource_variable_ops @@ -46,6 +47,42 @@ class ResourceVariableOpsTest(test_util.TensorFlowTestCase): resource_variable_ops.create_variable_op( id_handle, constant_op.constant(0, dtype=dtypes.int32)).run() + def testCreateRead(self): + with self.test_session(): + handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) + resource_variable_ops.create_variable_op( + handle, constant_op.constant(1, dtype=dtypes.int32)).run() + value = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32).eval() + self.assertAllEqual(1, value) + + def testManyAssigns(self): + with self.test_session() as session: + handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) + create = resource_variable_ops.create_variable_op( + handle, constant_op.constant(1, dtype=dtypes.int32)) + with ops.control_dependencies([create]): + first_read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32) + with ops.control_dependencies([first_read]): + write = resource_variable_ops.assign_variable_op( + handle, constant_op.constant(2, dtype=dtypes.int32)) + with ops.control_dependencies([write]): + second_read = resource_variable_ops.read_variable_op( + handle, dtype=dtypes.int32) + f, s = session.run([first_read, second_read]) + self.assertEqual(f, 1) + self.assertEqual(s, 2) + + def testAssignAdd(self): + with self.test_session(): + handle = resource_variable_ops.var_handle_op(dtype=dtypes.int32, shape=[]) + resource_variable_ops.create_variable_op( + handle, constant_op.constant(1, dtype=dtypes.int32)).run() + assign_add = resource_variable_ops.assign_add_variable_op( + handle, constant_op.constant(1, dtype=dtypes.int32)) + self.assertEqual(assign_add.eval(), 2) + if __name__ == "__main__": test.main() diff --git a/tensorflow/python/ops/resource_variable_ops.py b/tensorflow/python/ops/resource_variable_ops.py index 7db9731e198..0057f86486b 100644 --- a/tensorflow/python/ops/resource_variable_ops.py +++ b/tensorflow/python/ops/resource_variable_ops.py @@ -28,3 +28,6 @@ from tensorflow.python.ops.gen_resource_variable_ops import * ops.RegisterShape("VarHandleOp")(common_shapes.call_cpp_shape_fn) ops.RegisterShape("CreateVariableOp")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("ReadVariableOp")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("AssignVariableOp")(common_shapes.call_cpp_shape_fn) +ops.RegisterShape("AssignAddVariableOp")(common_shapes.call_cpp_shape_fn) -- GitLab