提交 7e8135cc 编写于 作者: V Vijay Vasudevan 提交者: TensorFlower Gardener

Expose fake quantization operators.

Add shape function for FakeQuant gradient ops.

Add a test for one of the more complicated ones.
Change: 138087659
上级 80ea8adf
......@@ -4486,6 +4486,7 @@ REGISTER_OP("FakeQuantWithMinMaxArgsGradient")
.Input("gradients: float")
.Input("inputs: float")
.Output("backprops: float")
.SetShapeFn(shape_inference::UnchangedShape)
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxArgs operation.
......@@ -4527,6 +4528,21 @@ REGISTER_OP("FakeQuantWithMinMaxVarsGradient")
.Output("backprops_wrt_input: float")
.Output("backprop_wrt_min: float")
.Output("backprop_wrt_max: float")
.SetShapeFn([](InferenceContext* c) {
// gradients and inputs are same size.
ShapeHandle inputs;
TF_RETURN_IF_ERROR(c->Merge(c->input(0), c->input(1), &inputs));
// min and max are scalars
ShapeHandle min_max;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 0, &min_max));
TF_RETURN_IF_ERROR(c->Merge(min_max, c->input(3), &min_max));
c->set_output(0, inputs);
c->set_output(1, min_max);
c->set_output(2, min_max);
return Status::OK();
})
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxVars operation.
......@@ -4580,6 +4596,24 @@ REGISTER_OP("FakeQuantWithMinMaxVarsPerChannelGradient")
.Output("backprops_wrt_input: float")
.Output("backprop_wrt_min: float")
.Output("backprop_wrt_max: float")
.SetShapeFn([](InferenceContext* c) {
ShapeHandle inputs;
TF_RETURN_IF_ERROR(c->WithRankAtLeast(c->input(0), 1, &inputs));
TF_RETURN_IF_ERROR(c->WithRankAtMost(inputs, 4, &inputs));
TF_RETURN_IF_ERROR(c->Merge(inputs, c->input(1), &inputs));
ShapeHandle last_dim = c->Vector(c->Dim(inputs, -1));
ShapeHandle min_max;
TF_RETURN_IF_ERROR(c->WithRank(c->input(2), 1, &min_max));
TF_RETURN_IF_ERROR(c->Merge(min_max, last_dim, &min_max));
TF_RETURN_IF_ERROR(c->Merge(c->input(3), min_max, &min_max));
c->set_output(0, inputs);
c->set_output(1, min_max);
c->set_output(2, min_max);
return Status::OK();
})
.Doc(R"doc(
Compute gradients for a FakeQuantWithMinMaxVarsPerChannel operation.
......
......@@ -1533,4 +1533,23 @@ TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannel) {
INFER_ERROR("must be equal", op, "[5];[4];[?]");
}
TEST(ArrayOpsTest, FakeQuantWithMinMaxVarsPerChannelGradient) {
ShapeInferenceTestOp op("FakeQuantWithMinMaxVarsPerChannelGradient");
INFER_OK(op, "?;?;?;?", "?;[?];[?]");
INFER_OK(op, "[3];[3];[3];[3]", "in0;in3;in3");
INFER_OK(op, "[1,3];[1,3];[3];[3]", "in0;in3;in3");
INFER_OK(op, "[1,2,3,4];[1,2,3,4];[4];[4]", "in0;in3;in3");
// Rank check vectors.
INFER_ERROR("be equal rank", op, "[1,?,3];[1,?,3];[3];[]");
INFER_ERROR("be rank 1", op, "[1,?,3];[1,?,3];[];[3]");
INFER_ERROR("be at least rank 1", op, "[];[];[1];[1]");
INFER_ERROR("be at most rank 4", op, "[1,2,3,4,5];[1,2,3,4,5];[1];[1]");
// Vectors must match each other, and match last dim of input.
INFER_ERROR("must be equal", op, "[1,3];[1,3];[2];[3]");
INFER_ERROR("must be equal", op, "[1,3];[1,3];[3];[2]");
}
} // end namespace tensorflow
......@@ -82,6 +82,15 @@ or join multiple tensors together.
@@quantized_concat
@@setdiff1d
## Fake quantization
Operations used to help train for better quantization accuracy.
@@fake_quant_with_min_max_args
@@fake_quant_with_min_max_args_gradient
@@fake_quant_with_min_max_vars
@@fake_quant_with_min_max_vars_gradient
@@fake_quant_with_min_max_vars_per_channel
@@fake_quant_with_min_max_vars_per_channel_gradient
"""
from __future__ import absolute_import
from __future__ import division
......@@ -2044,9 +2053,15 @@ def _FakeQuantWithMinMaxArgsGradient(op, grad):
ops.RegisterShape("FakeQuantWithMinMaxArgs")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("FakeQuantWithMinMaxArgsGradient")(
common_shapes.call_cpp_shape_fn)
ops.RegisterShape("FakeQuantWithMinMaxVars")(common_shapes.call_cpp_shape_fn)
ops.RegisterShape("FakeQuantWithMinMaxVarsGradient")(
common_shapes.call_cpp_shape_fn)
ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannel")(
common_shapes.call_cpp_shape_fn)
ops.RegisterShape("FakeQuantWithMinMaxVarsPerChannelGradient")(
common_shapes.call_cpp_shape_fn)
@ops.RegisterGradient("FakeQuantWithMinMaxVars")
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册