提交 09adaff5 编写于 作者: P Peter Hawkins 提交者: TensorFlower Gardener

[TF:XLA] Implement ResourceApplyAdagrad. Split XLA implementation of training...

[TF:XLA] Implement ResourceApplyAdagrad. Split XLA implementation of training ops into their own file.
Change: 150125044
上级 ac8d8465
......@@ -38,6 +38,20 @@ cc_library(
deps = ["//tensorflow/core:framework_lite"],
)
tf_xla_py_test(
name = "adagrad_test",
size = "small",
srcs = ["adagrad_test.py"],
deps = [
":xla_test",
"//tensorflow/python:array_ops",
"//tensorflow/python:framework_for_generated_wrappers",
"//tensorflow/python:math_ops",
"//tensorflow/python:platform_test",
"//tensorflow/python:training",
],
)
tf_xla_py_test(
name = "binary_ops_test",
size = "small",
......
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Functional tests for aggregate operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import numpy as np
from tensorflow.compiler.tests.xla_test import XLATestCase
from tensorflow.python.framework import constant_op
from tensorflow.python.ops import resource_variable_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import test
from tensorflow.python.training import adagrad
class AdagradOptimizerTest(XLATestCase):
def testBasic(self):
for dtype in self.float_types:
with self.test_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
ada_opt = adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1)
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 3 steps of adagrad
for _ in range(3):
ada_update.run()
# Validate updated params
self.assertAllCloseAccordingToType(
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testTensorLearningRate(self):
for dtype in self.float_types:
with self.test_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
ada_opt = adagrad.AdagradOptimizer(
constant_op.constant(3.0), initial_accumulator_value=0.1)
ada_update = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
variables.global_variables_initializer().run()
# Fetch params to validate initial values
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Run 3 steps of adagrad
for _ in range(3):
ada_update.run()
# Validate updated params
self.assertAllCloseAccordingToType(
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
def testSharing(self):
for dtype in self.float_types:
with self.test_session(), self.test_scope():
var0 = resource_variable_ops.ResourceVariable([1.0, 2.0], dtype=dtype)
var1 = resource_variable_ops.ResourceVariable([3.0, 4.0], dtype=dtype)
grads0 = constant_op.constant([0.1, 0.1], dtype=dtype)
grads1 = constant_op.constant([0.01, 0.01], dtype=dtype)
ada_opt = adagrad.AdagradOptimizer(3.0)
# Apply the optimizer twice. Both applications will use
# the same accums.
ada_update1 = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
ada_update2 = ada_opt.apply_gradients(
zip([grads0, grads1], [var0, var1]))
self.assertEqual(["accumulator"], ada_opt.get_slot_names())
slot0 = ada_opt.get_slot(var0, "accumulator")
self.assertEquals(slot0.get_shape(), var0.get_shape())
slot1 = ada_opt.get_slot(var1, "accumulator")
self.assertEquals(slot1.get_shape(), var1.get_shape())
variables.global_variables_initializer().run()
# Fetch params to validate initial values.
self.assertAllClose([1.0, 2.0], var0.eval())
self.assertAllClose([3.0, 4.0], var1.eval())
# Mix the first and the second adagrad for 3 steps.
ada_update1.run()
ada_update2.run()
ada_update1.run()
# Validate updated params (the same as with only 1 Adagrad).
self.assertAllCloseAccordingToType(
np.array([-1.6026098728179932, -0.6026098728179932]), var0.eval())
self.assertAllCloseAccordingToType(
np.array([2.715679168701172, 3.715679168701172]), var1.eval())
if __name__ == "__main__":
test.main()
......@@ -52,6 +52,7 @@ tf_kernel_library(
"split_op.cc",
"strided_slice_op.cc",
"tile_ops.cc",
"training_ops.cc",
"transpose_op.cc",
"unary_ops.cc",
"unpack_op.cc",
......
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/tf2xla/kernels/cwise_ops.h"
#include "tensorflow/compiler/tf2xla/xla_helpers.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/compiler/xla/client/computation_builder.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/core/framework/kernel_def_builder.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/no_op.h"
namespace tensorflow {
namespace {
class ResourceApplyGradientDescent : public XlaOpKernel {
public:
explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
xla::ComputationBuilder* b = ctx->builder();
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
}
};
REGISTER_XLA_OP("ResourceApplyGradientDescent", ResourceApplyGradientDescent);
class ResourceApplyMomentum : public XlaOpKernel {
public:
explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
OP_REQUIRES(
ctx, type == var_type && type == accum_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyMomentum must match: ",
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
DataTypeString(accum_type)));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
"var and accum do not have the same shape",
var_shape.DebugString(), " ", accum_shape.DebugString()));
TensorShape lr_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
TensorShape grad_shape = ctx->InputShape(3);
OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
errors::InvalidArgument(
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
TensorShape momentum_shape = ctx->InputShape(4);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle momentum = ctx->Input(4);
accum = b->Add(b->Mul(accum, momentum), grad);
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
var = b->Sub(
var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr)));
} else {
var = b->Sub(var, b->Mul(accum, lr));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
private:
bool use_nesterov_;
};
REGISTER_XLA_OP("ResourceApplyMomentum", ResourceApplyMomentum);
class ResourceApplyAdagrad : public XlaOpKernel {
public:
explicit ResourceApplyAdagrad(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
OP_REQUIRES(
ctx, type == var_type && type == accum_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyAdagrad must match: ",
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
DataTypeString(accum_type)));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
"var and accum do not have the same shape",
var_shape.DebugString(), " ", accum_shape.DebugString()));
TensorShape lr_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
TensorShape grad_shape = ctx->InputShape(3);
OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
errors::InvalidArgument(
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
accum = b->Add(accum, b->Pow(grad, XlaHelpers::FloatLiteral(b, type, 2.0)));
var = b->Sub(
var, b->Mul(b->Mul(grad, lr),
b->Pow(accum, XlaHelpers::FloatLiteral(b, type, -0.5))));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
};
REGISTER_XLA_OP("ResourceApplyAdagrad", ResourceApplyAdagrad);
} // namespace
} // namespace tensorflow
......@@ -82,90 +82,5 @@ class AssignSubVariableOp : public XlaOpKernel {
};
REGISTER_XLA_OP("AssignSubVariableOp", AssignSubVariableOp);
class ResourceApplyGradientDescent : public XlaOpKernel {
public:
explicit ResourceApplyGradientDescent(OpKernelConstruction* ctx)
: XlaOpKernel(ctx) {}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationDataHandle handle;
xla::ComputationBuilder* b = ctx->builder();
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &handle));
handle = b->Sub(handle, b->Mul(ctx->Input(1), ctx->Input(2)));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, ctx->input_type(1), handle));
}
};
REGISTER_XLA_OP("ResourceApplyGradientDescent", ResourceApplyGradientDescent);
class ResourceApplyMomentum : public XlaOpKernel {
public:
explicit ResourceApplyMomentum(OpKernelConstruction* ctx) : XlaOpKernel(ctx) {
OP_REQUIRES_OK(ctx, ctx->GetAttr("use_nesterov", &use_nesterov_));
}
void Compile(XlaOpKernelContext* ctx) override {
xla::ComputationBuilder* b = ctx->builder();
DataType type = ctx->input_type(2);
DataType var_type, accum_type;
TensorShape var_shape, accum_shape;
OP_REQUIRES_OK(ctx, ctx->GetVariableTypeAndShape(0, &var_type, &var_shape));
OP_REQUIRES_OK(ctx,
ctx->GetVariableTypeAndShape(1, &accum_type, &accum_shape));
OP_REQUIRES(
ctx, type == var_type && type == accum_type,
errors::InvalidArgument(
"Types of variable arguments to ResourceApplyMomentum must match: ",
DataTypeString(type), " vs. ", DataTypeString(var_type), " and ",
DataTypeString(accum_type)));
OP_REQUIRES(ctx, var_shape.IsSameSize(accum_shape),
errors::InvalidArgument(
"var and accum do not have the same shape",
var_shape.DebugString(), " ", accum_shape.DebugString()));
TensorShape lr_shape = ctx->InputShape(2);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(lr_shape),
errors::InvalidArgument("lr is not a scalar: ",
lr_shape.DebugString()));
TensorShape grad_shape = ctx->InputShape(3);
OP_REQUIRES(ctx, var_shape.IsSameSize(grad_shape),
errors::InvalidArgument(
"var and grad do not have the same shape",
var_shape.DebugString(), " ", grad_shape.DebugString()));
TensorShape momentum_shape = ctx->InputShape(4);
OP_REQUIRES(ctx, TensorShapeUtils::IsScalar(momentum_shape),
errors::InvalidArgument("momentum is not a scalar: ",
momentum_shape.DebugString()));
xla::ComputationDataHandle var, accum;
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(0, &var));
OP_REQUIRES_OK(ctx, ctx->ReadVariableInput(1, &accum));
xla::ComputationDataHandle lr = ctx->Input(2);
xla::ComputationDataHandle grad = ctx->Input(3);
xla::ComputationDataHandle momentum = ctx->Input(4);
accum = b->Add(b->Mul(accum, momentum), grad);
if (use_nesterov_) {
// See https://github.com/tensorflow/tensorflow/pull/2798 for an
// explanation of the reparameterization used here.
var = b->Sub(
var, b->Add(b->Mul(grad, lr), b->Mul(b->Mul(accum, momentum), lr)));
} else {
var = b->Sub(var, b->Mul(accum, lr));
}
OP_REQUIRES_OK(ctx, ctx->AssignVariable(0, type, var));
OP_REQUIRES_OK(ctx, ctx->AssignVariable(1, type, accum));
}
private:
bool use_nesterov_;
};
REGISTER_XLA_OP("ResourceApplyMomentum", ResourceApplyMomentum);
} // namespace
} // namespace tensorflow
......@@ -224,6 +224,9 @@ REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
Name("Relu6Grad").TypeConstraint("T", kCpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_CPU_XLA_JIT,
Name("Reshape").TypeConstraint("T", kCpuAllTypes));
REGISTER_XLA_KERNEL(
DEVICE_CPU_XLA_JIT,
Name("ResourceApplyAdagrad").TypeConstraint("T", kCpuFloatTypes));
REGISTER_XLA_KERNEL(
DEVICE_CPU_XLA_JIT,
Name("ResourceApplyGradientDescent").TypeConstraint("T", kCpuFloatTypes));
......@@ -515,6 +518,9 @@ REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("Relu6Grad").TypeConstraint("T", kGpuFloatTypes));
REGISTER_XLA_KERNEL(DEVICE_GPU_XLA_JIT,
Name("Reshape").TypeConstraint("T", kGpuAllTypes));
REGISTER_XLA_KERNEL(
DEVICE_GPU_XLA_JIT,
Name("ResourceApplyAdagrad").TypeConstraint("T", kGpuFloatTypes));
REGISTER_XLA_KERNEL(
DEVICE_GPU_XLA_JIT,
Name("ResourceApplyGradientDescent").TypeConstraint("T", kGpuFloatTypes));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册