diff --git a/paddle/fluid/operators/assign_value_op.cc b/paddle/fluid/operators/assign_value_op.cc index 2985fc28a086eeaf3512405937dd6bb1cb44edcc..c711fd802ffb8c3afae5473791444d9b022a69ba 100644 --- a/paddle/fluid/operators/assign_value_op.cc +++ b/paddle/fluid/operators/assign_value_op.cc @@ -36,7 +36,7 @@ class AssignValueOp : public framework::OperatorWithKernel { framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { return framework::OpKernelType( - framework::proto::DataType(ctx.Attr("dtype")), ctx.GetPlace()); + framework::proto::VarType::Type(ctx.Attr("dtype")), ctx.GetPlace()); } }; @@ -49,8 +49,8 @@ class AssignValueOpMaker : public framework::OpProtoAndCheckerMaker { "(vector) " "Shape of values."); AddAttr("dtype", "data type of values") - .InEnum({framework::proto::DataType::INT32, - framework::proto::DataType::FP32}); + .InEnum({framework::proto::VarType::Type::INT32, + framework::proto::VarType::Type::FP32}); AddAttr>("fp32_values", "store the float values") .SetDefault({}); AddAttr>("int32_values", "store the int values") diff --git a/paddle/fluid/operators/assign_value_op.h b/paddle/fluid/operators/assign_value_op.h index 90c9496a3c12419abaceb29ff09cf4aaf388eee1..6a62ca53d4782ce68dae249aeeb3e514d32e6d6d 100644 --- a/paddle/fluid/operators/assign_value_op.h +++ b/paddle/fluid/operators/assign_value_op.h @@ -30,10 +30,10 @@ class AssignValueKernel : public framework::OpKernel { int dtype = ctx.Attr("dtype"); const char* value_name = nullptr; switch (dtype) { - case framework::proto::DataType::INT32: + case framework::proto::VarType::Type::INT32: value_name = "int32_values"; break; - case framework::proto::DataType::FP32: + case framework::proto::VarType::Type::FP32: value_name = "fp32_values"; break; default: diff --git a/paddle/fluid/pybind/protobuf.cc b/paddle/fluid/pybind/protobuf.cc index 9f97cc5007ec00d0ad28ba755d24896017c9003f..1b8fc1698b94907519388d75f8d98972a0e8f2f5 100644 --- a/paddle/fluid/pybind/protobuf.cc +++ b/paddle/fluid/pybind/protobuf.cc @@ -195,14 +195,14 @@ void BindBlockDesc(py::module &m) { } void BindVarDsec(py::module &m) { - py::enum_(m, "DataType", "") - .value("BOOL", proto::DataType::BOOL) - .value("INT16", proto::DataType::INT16) - .value("INT32", proto::DataType::INT32) - .value("INT64", proto::DataType::INT64) - .value("FP16", proto::DataType::FP16) - .value("FP32", proto::DataType::FP32) - .value("FP64", proto::DataType::FP64); + py::enum_(m, "DataType", "") + .value("BOOL", proto::VarType::Type::BOOL) + .value("INT16", proto::VarType::Type::INT16) + .value("INT32", proto::VarType::Type::INT32) + .value("INT64", proto::VarType::Type::INT64) + .value("FP16", proto::VarType::Type::FP16) + .value("FP32", proto::VarType::Type::FP32) + .value("FP64", proto::VarType::Type::FP64); py::class_ var_desc(m, "VarDesc", ""); var_desc