提交 22f80136 编写于 作者: A Artem Belevich 提交者: TensorFlower Gardener

Simplify gradient exclusions data to speed up compilation w/ clang on windows.

PiperOrigin-RevId: 294998085
Change-Id: Ie56b8f2cf4ed1e5fd8e2a641947b2d69f316e86a
上级 c8f74b62
......@@ -54,6 +54,7 @@ cc_library(
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:variant",
],
)
......
......@@ -63,10 +63,33 @@ limitations under the License.
_INCLUDES = """
#include "tensorflow/python/eager/pywrap_gradient_exclusions.h"
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
using tensorflow::string;
namespace {
// Keep static data in a format that's easy to init statically.
struct OpIndexInfo {
const char *op_name;
int num_indices;
std::array<int, 4> unused_indices;
};
// Helper function to initialize FlatMap<string,FlatSet> from OpIndexInfo.
template <typename T>
auto OpGradientInfoInit(const T &a) {
auto *m = new tensorflow::gtl::FlatMap<string, tensorflow::gtl::FlatSet<int>>;
for (const auto &item : a) {
m->emplace(string(item.op_name),
tensorflow::gtl::FlatSet<int>(
item.unused_indices.begin(),
item.unused_indices.begin() + item.num_indices));
}
return m;
}
} // namespace
"""
_EXCLUDED_OPS = [
......@@ -281,7 +304,6 @@ def get_entries(attr_name):
"""
assert attr_name in ["inputs", "outputs"]
entries = {}
spaces = " "
for op_type in ops._gradient_registry.list(): # pylint: disable=protected-access
if op_type in _EXCLUDED_OPS:
continue
......@@ -291,72 +313,57 @@ def get_entries(attr_name):
if gradient_fn is None:
# NotDifferentiable
if num_values != -1:
entries[op_type] = spaces + "{\"%s\", {true, {}}}," % op_type
entries[op_type] = "{\"%s\"}," % op_type
continue
used_tensors = _live_tensors(gradient_fn, attr_name=attr_name)
if used_tensors is _ALL:
continue
elif not used_tensors:
entries[op_type] = spaces + "{\"%s\", {true, {}}}," % op_type
entries[op_type] = "{\"%s\"}," % op_type
else:
all_tensors = set(range(num_values))
unused_tensors = all_tensors - used_tensors
if unused_tensors:
entries[op_type] = spaces + "{\"%s\", {false, {%s}}}," % (
op_type, ", ".join(str(i) for i in sorted(list(unused_tensors))))
unused_tensor_list = sorted(list(unused_tensors))
entries[op_type] = "{\"%s\", %d, {%s}}," % (
op_type, len(unused_tensor_list), ", ".join(
str(i) for i in unused_tensor_list))
return entries
def get_contents():
"""Returns contents for the generated file."""
contents = ""
contents += _GENERATED_FILE_HEADER + _INCLUDES
def get_function(name, entries):
"""Generates lookup function with given name and lookup table entries."""
contents = """
absl::optional<tensorflow::gtl::FlatSet<int>> {name}(
const tensorflow::string &op_name) {{
static std::array<OpIndexInfo, {count}> a = {{{{
""".format(
name=name, count=len(entries) + 1)
contents += " "
contents += "\n ".join(entries[op_type] for op_type in sorted(entries))
contents += "\n {\"VarHandleOp\"},"
contents += """
bool OpGradientDoesntRequireInputIndices(
const string& op_name,
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
static tensorflow::gtl::FlatMap<
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
new tensorflow::gtl::FlatMap<
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
"""
entries = get_entries("inputs")
contents += "\n".join(entries[op_type] for op_type in sorted(entries))
contents += "\n {\"VarHandleOp\", {true, {}}},\n"
contents += """ });
auto it = m->find(op_name);
if (it == m->end()) return false;
*output = &it->second;
return true;
}};
static const auto &m = *OpGradientInfoInit(a);
auto it = m.find(op_name);
if (it != m.end()) {
return it->second;
}
return absl::nullopt;
}
"""
contents += """
bool OpGradientDoesntRequireOutputIndices(
const string& op_name,
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output) {
static tensorflow::gtl::FlatMap<
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>* m =
new tensorflow::gtl::FlatMap<
string, std::pair<bool, tensorflow::gtl::FlatSet<int>>>({
"""
entries = get_entries("outputs")
contents += "\n".join(entries[op_type] for op_type in sorted(entries))
contents += "\n {\"VarHandleOp\", {true, {}}},\n"
contents += """ });
auto it = m->find(op_name);
return contents
if (it == m->end()) return false;
*output = &it->second;
return true;
}
"""
def get_contents():
"""Returns contents for the generated file."""
contents = ""
contents += _GENERATED_FILE_HEADER + _INCLUDES
contents += get_function("OpGradientUnusedInputIndices",
get_entries("inputs"))
contents += get_function("OpGradientUnusedOutputIndices",
get_entries("outputs"))
return contents
......
......@@ -15,15 +15,24 @@ limitations under the License.
#ifndef TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_
#define TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_
#include "absl/types/optional.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/flatset.h"
bool OpGradientDoesntRequireInputIndices(
const tensorflow::string& op_name,
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output);
// Lookup whether the Op with the given op_name has unused input indices.
// Returns absl::nullopt if all inputs are used, set of unused indices
// otherwise. Empty set indicates that all indices are unused. The latter is
// necessary because sometimes it may not be possible to enumerate all indices
// just using OpDef e.g. when there are `list(T)` or `N * T` type inputs.
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedInputIndices(
const tensorflow::string& op_name);
bool OpGradientDoesntRequireOutputIndices(
const tensorflow::string& op_name,
std::pair<bool, tensorflow::gtl::FlatSet<int>>** output);
// Lookup whether the Op with the given op_name has unused output indices.
// Returns absl::nullopt if all outputs are used, set of unused indices
// otherwise. Empty set indicates that all indices are unused. The latter is
// necessary because sometimes it may not be possible to enumerate all indices
// just using OpDef e.g. when there are `list(T)` or `N * T` type outputs.
absl::optional<tensorflow::gtl::FlatSet<int>> OpGradientUnusedOutputIndices(
const tensorflow::string& op_name);
#endif // TENSORFLOW_PYTHON_EAGER_PYWRAP_GRADIENT_EXCLUSIONS_H_
......@@ -2944,15 +2944,15 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* op_outputs;
bool op_outputs_tuple_created = false;
std::pair<bool, tensorflow::gtl::FlatSet<int>>* outputs_not_required;
if (OpGradientDoesntRequireOutputIndices(c_op_name, &outputs_not_required)) {
if (outputs_not_required->first) {
if (const auto unused_output_indices =
OpGradientUnusedOutputIndices(c_op_name)) {
if (unused_output_indices->empty()) {
op_outputs = Py_None;
} else {
op_outputs_tuple_created = true;
op_outputs = CopySequenceSettingIndicesToNull(
results, outputs_not_required->second);
op_outputs =
CopySequenceSettingIndicesToNull(results, *unused_output_indices);
}
} else {
op_outputs = results;
......@@ -2960,15 +2960,15 @@ PyObject* RecordGradient(PyObject* op_name, PyObject* inputs, PyObject* attrs,
PyObject* op_inputs;
bool op_inputs_tuple_created = false;
std::pair<bool, tensorflow::gtl::FlatSet<int>>* inputs_not_required;
if (OpGradientDoesntRequireInputIndices(c_op_name, &inputs_not_required)) {
if (inputs_not_required->first) {
if (const auto unused_input_indices =
OpGradientUnusedInputIndices(c_op_name)) {
if (unused_input_indices->empty()) {
op_inputs = Py_None;
} else {
op_inputs_tuple_created = true;
op_inputs =
CopySequenceSettingIndicesToNull(inputs, inputs_not_required->second);
CopySequenceSettingIndicesToNull(inputs, *unused_input_indices);
}
} else {
op_inputs = inputs;
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册