未验证 提交 13aef970 编写于 作者: A Adam Osewski 提交者: GitHub

operator checkpoints for new attributes. (#29832)

* Add operator checkpoints for new attributes.

* Fix adding subsequent checkpoint to quantize op.
上级 844d8e0c
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/dequantize_op.h" #include "paddle/fluid/operators/dequantize_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -44,3 +45,10 @@ void DeQuantOpMaker::Make() { ...@@ -44,3 +45,10 @@ void DeQuantOpMaker::Make() {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker); REGISTER_OPERATOR(dequantize, ops::DeQuantOp, ops::DeQuantOpMaker);
REGISTER_OP_VERSION(dequantize)
.AddCheckpoint(
R"ROC( Add a new attribute [Shift])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Shift", "Dequantize data to uint8 if provided non-zero value.",
0.0f));
...@@ -61,4 +61,9 @@ REGISTER_OP_VERSION(quantize) ...@@ -61,4 +61,9 @@ REGISTER_OP_VERSION(quantize)
R"ROC( Add a new attribute [bfloat16])ROC", R"ROC( Add a new attribute [bfloat16])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr( paddle::framework::compatible::OpVersionDesc().NewAttr(
"bfloat16", "If true, float32 input is converted to bfloat16", "bfloat16", "If true, float32 input is converted to bfloat16",
false)); false))
.AddCheckpoint(
R"ROC( Add a new attribute [Shift])ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"Shift", "Quantize data to uint8 if provided non-zero value.",
0.0f));
...@@ -13,6 +13,7 @@ ...@@ -13,6 +13,7 @@
* limitations under the License. */ * limitations under the License. */
#include "paddle/fluid/operators/requantize_op.h" #include "paddle/fluid/operators/requantize_op.h"
#include "paddle/fluid/framework/op_version_registry.h"
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
#include "paddle/fluid/platform/mkldnn_helper.h" #include "paddle/fluid/platform/mkldnn_helper.h"
#endif #endif
...@@ -46,3 +47,12 @@ void ReQuantOpMaker::Make() { ...@@ -46,3 +47,12 @@ void ReQuantOpMaker::Make() {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker); REGISTER_OPERATOR(requantize, ops::ReQuantOp, ops::ReQuantOpMaker);
REGISTER_OP_VERSION(requantize)
.AddCheckpoint(
R"ROC( Add new attributes [Shift_in, Shift_out])ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("Shift_in",
"Provide quantization shift value for input data", 1.0f)
.NewAttr("Shift_out",
"Provide quantization shift value for output data", 1.0f));
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册