提交 ada1d9c7 编写于 作者: J Jian Li 提交者: TensorFlower Gardener

Add RISC Add Op register.

PiperOrigin-RevId: 339893949
Change-Id: I43c4d0c58a19efb52b66a2ccbf9f792f05e0c73a
上级 cb043911
......@@ -982,6 +982,7 @@ filegroup(
"random_ops_op_lib",
"remote_fused_graph_ops_op_lib",
"resource_variable_ops_op_lib",
"risc_ops_op_lib",
"rnn_ops_op_lib",
"rpc_ops_op_lib",
"scoped_allocator_ops_op_lib",
......
op {
graph_op_name: "RiscAdd"
visibility: HIDDEN
summary: "Returns x + y element-wise."
description: <<END
*NOTE*: `RiscAdd` does not supports broadcasting.
Given two input tensors, the `tf.risc_add` operation computes the sum for every element in the tensor.
Both input and output have a range `(-inf, inf)`.
END
}
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
tf_kernel_library(
name = "risc",
visibility = ["//visibility:public"],
deps = [
"//tensorflow/core/kernels/risc/experimental",
],
)
# TF-RISC
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
tf_kernel_library(
name = "risc_add_op",
srcs = ["risc_add_op.cc"],
deps = [
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
],
)
tf_kernel_library(
name = "experimental",
deps = [
":risc_add_op",
],
)
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
namespace risc {
namespace experimental {
template <typename T>
class RiscAddOp : public OpKernel {
public:
explicit RiscAddOp(OpKernelConstruction* ctx) : OpKernel(ctx) {}
void Compute(OpKernelContext* ctx) override {
// TODO(b/171294012): Implement RiscAdd op.
}
};
#define REGISTER_CPU(T) \
REGISTER_KERNEL_BUILDER( \
Name("RiscAdd").Device(DEVICE_CPU).TypeConstraint<T>("T"), \
RiscAddOp<T>);
REGISTER_CPU(bfloat16);
REGISTER_CPU(Eigen::half);
REGISTER_CPU(float);
REGISTER_CPU(double);
} // namespace experimental
} // namespace risc
} // namespace tensorflow
......@@ -83,6 +83,7 @@ tf_gen_op_libs(
"special_math_ops",
"stateful_random_ops",
"remote_fused_graph_ops",
"risc_ops",
"rnn_ops",
"rpc_ops",
"scoped_allocator_ops",
......@@ -279,6 +280,7 @@ cc_library(
":parsing_ops_op_lib",
":ragged_ops",
":random_ops_op_lib",
":risc_ops_op_lib",
":rnn_ops_op_lib",
":special_math_ops_op_lib",
":stateful_random_ops_op_lib",
......
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/framework/common_shape_fns.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/shape_inference.h"
namespace tensorflow {
REGISTER_OP("RiscAdd")
.Input("x: T")
.Input("y: T")
.Output("z: T")
.Attr("T: {bfloat16, half, float, double}")
.SetShapeFn(shape_inference::UnchangedShape)
.SetIsAggregate()
.SetIsCommutative();
} // namespace tensorflow
......@@ -3359,6 +3359,16 @@ tf_gen_op_wrapper_private_py(
visibility = ["//tensorflow/python/ops/ragged:__pkg__"],
)
tf_gen_op_wrapper_private_py(
name = "risc_ops_gen",
visibility = [
"//tensorflow/python/ops/risc:__pkg__",
],
deps = [
"//tensorflow/core:risc_ops_op_lib",
],
)
tf_gen_op_wrapper_private_py(name = "rnn_ops_gen")
tf_gen_op_wrapper_private_py(
......
package(
default_visibility = ["//tensorflow:internal"],
licenses = ["notice"], # Apache 2.0
)
py_library(
name = "risc_grad",
srcs = ["risc_grad.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:framework_ops",
],
)
py_library(
name = "risc_ops",
srcs = ["risc_ops.py"],
srcs_version = "PY2AND3",
deps = [
"//tensorflow/python:risc_ops_gen",
],
)
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RISC operation gradient."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.framework import ops
@ops.RegisterGradient("RiscAdd")
def _RiscAddGrad(_, grad):
# pylint: disable=unused-argument
# TODO(b/171294012): Implement gradient of RISC with RISC ops.
return None, None
# Copyright 2020 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""RISC Operations."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from tensorflow.python.ops import gen_risc_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.risc_ops_gen import *
# pylint: enable=wildcard-import
def risc_add(
input_lhs,
input_rhs,
name="RISC_ADD"):
return gen_risc_ops.risc_add(input_lhs, input_rhs, name=name)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册