提交 d4ea993c 编写于 作者: A Alexandre Passos 提交者: TensorFlower Gardener

Removes unnecessary eager-mode call to convert_to_tensor in record_gradient.

PiperOrigin-RevId: 170944265
上级 add6d2d0
......@@ -524,7 +524,7 @@ _grad_fn_accepts_none_for_indices = {
}
def _record_gradient(op_name, inputs, attrs, results, ctx, name):
def _record_gradient(op_name, inputs, attrs, results, name):
"""Records gradients for a TensorFlow operation.
Args:
......@@ -534,7 +534,6 @@ def _record_gradient(op_name, inputs, attrs, results, ctx, name):
attrs: A tuple with alternating string attr names and attr values for this
operation.
results: The results of the operation (as a flat list).
ctx: The value of context.context().
name: Customized name for the operation.
Returns:
......@@ -572,7 +571,6 @@ def _record_gradient(op_name, inputs, attrs, results, ctx, name):
"output_grads", orig_outputs, "gradients", result)
return result
inputs = [ops.internal_convert_to_tensor(x, ctx=ctx) for x in inputs]
tape.record_operation(op_name, results, inputs, [], grad_fn)
if _tracing:
print("Computed op", (name if name else op_name), "inputs", inputs,
......
......@@ -84,7 +84,7 @@ def execute(op_name, num_outputs, inputs, attrs, ctx, name=None):
def record_gradient(unused_op_name, unused_inputs, unused_attrs, unused_results,
unused_ctx, unused_name):
unused_name):
"""Import backprop if you want gradients recorded."""
pass
......
......@@ -412,7 +412,7 @@ string GenEagerPythonOp::Code() {
" if not _result:\n"
" return _op\n");
}
strings::StrAppend(&result_, " _inputs_flat = ", inputs, "\n");
strings::StrAppend(&result_, " _inputs_flat = _op.inputs\n");
// Compute graph-mode attrs.
if (op_def_.attr_size() > 0) {
......@@ -511,7 +511,7 @@ string GenEagerPythonOp::Code() {
if (num_outs_ > 0) {
strings::StrAppend(&result_, " _execute.record_gradient(\n", " \"",
op_def_.name(),
"\", _inputs_flat, _attrs, _result, _ctx, name)\n");
"\", _inputs_flat, _attrs, _result, name)\n");
if (num_outs_ == 1 && !output_sizes[0].empty()) {
// Single list result.
} else if (num_outs_ == 1) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册