From 13aef97043b732a59b0481486952895c713f54bf Mon Sep 17 00:00:00 2001 From: Adam Osewski Date: Thu, 31 Dec 2020 04:25:38 +0100 Subject: [PATCH] operator checkpoints for new attributes. (#29832) * Add operator checkpoints for new attributes. * Fix adding subsequent checkpoint to quantize op. --- paddle/fluid/operators/dequantize_op.cc | 8 ++++++++ paddle/fluid/operators/quantize_op.cc | 7 ++++++- paddle/fluid/operators/requantize_op.cc | 10 ++++++++++ 3 files changed, 24 insertions(+), 1 deletion(-) diff --git a/paddle/fluid/operators/dequantize_op.cc b/paddle/fluid/operators/dequantize_op.cc index 8c2aeb1f8e..876bd1199a 100644 --- a/paddle/fluid/operators/dequantize_op.cc +++ b/paddle/fluid/operators/dequantize_op.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/dequantize_op.h" +#include "paddle/fluid/framework/op_version_registry.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -44,3 +45,10 @@ void DeQuantOpMaker::Make() { namespace ops = paddle::operators; REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker); + +REGISTER_OP_VERSION(dequantize) + .AddCheckpoint( + R"ROC( Add a new attribute [Shift])ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "Shift", "Dequantize data to uint8 if provided non-zero value.", + 0.0f)); diff --git a/paddle/fluid/operators/quantize_op.cc b/paddle/fluid/operators/quantize_op.cc index f21243de83..951951253c 100644 --- a/paddle/fluid/operators/quantize_op.cc +++ b/paddle/fluid/operators/quantize_op.cc @@ -61,4 +61,9 @@ REGISTER_OP_VERSION(quantize) R"ROC( Add a new attribute [bfloat16])ROC", paddle::framework::compatible::OpVersionDesc().NewAttr( "bfloat16", "If true, float32 input is converted to bfloat16", - false)); + false)) + .AddCheckpoint( + R"ROC( Add a new attribute [Shift])ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "Shift", "Quantize data to uint8 if provided non-zero value.", + 0.0f)); diff --git a/paddle/fluid/operators/requantize_op.cc b/paddle/fluid/operators/requantize_op.cc index ea3058c5ae..2d87ae91fb 100644 --- a/paddle/fluid/operators/requantize_op.cc +++ b/paddle/fluid/operators/requantize_op.cc @@ -13,6 +13,7 @@ * limitations under the License. */ #include "paddle/fluid/operators/requantize_op.h" +#include "paddle/fluid/framework/op_version_registry.h" #ifdef PADDLE_WITH_MKLDNN #include "paddle/fluid/platform/mkldnn_helper.h" #endif @@ -46,3 +47,12 @@ void ReQuantOpMaker::Make() { namespace ops = paddle::operators; REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker); + +REGISTER_OP_VERSION(requantize) + .AddCheckpoint( + R"ROC( Add new attributes [Shift_in, Shift_out])ROC", + paddle::framework::compatible::OpVersionDesc() + .NewAttr("Shift_in", + "Provide quantization shift value for input data", 1.0f) + .NewAttr("Shift_out", + "Provide quantization shift value for output data", 1.0f)); -- GitLab