未验证 提交 dd2a749a 编写于 作者: F Feiyu Chan 提交者: GitHub

1. modify set_value op, use Scalars to represent attr `values`, instead of a...

1. modify set_value op, use Scalars to represent attr `values`, instead of a bunch of attributs of various types; (#52408)

2. add program converter and set_value op as an example, which provides the functionality to convert `paddle::framework::ProgramDesc` between old and new formats(the differences are mainly some operators with incompatible updates in the definition);
3. program version and operator version map now are always saved when serializing `paddle::framework::ProgramDesc` to identify the version;
3. provide an option `legacy_format=false` in  serialization of `paddle::framework::ProgramDesc`, it decided whether to convert ProgramDesc back to a legacy format, which is compatible for paddle 2.4.2 or earlier versions to load and execute;
4. deserialization of `paddle::framework::ProgramDesc` is now automatically detecting whether the bytes it receives is in legacy format(contains any of the operators that has been incompatibly updated and have any attribute of type `Scalar`) and convert it to new format. But if you want a faithful deserialization without the automatic conversion, you can use protobuf's deserialization instead. Though it is not recommended, it can be used for the purpose of testing.
上级 b281b221
...@@ -532,7 +532,7 @@ cc_test( ...@@ -532,7 +532,7 @@ cc_test(
cc_library( cc_library(
proto_desc proto_desc
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc program_converter.cc
DEPS attribute DEPS attribute
ops_extra_info ops_extra_info
shape_inference shape_inference
...@@ -542,7 +542,9 @@ cc_library( ...@@ -542,7 +542,9 @@ cc_library(
version version
xxhash xxhash
dist_attr dist_attr
scalar) scalar
op_version_proto
op_version_registry)
cc_library( cc_library(
op_registry op_registry
......
...@@ -13,3 +13,96 @@ See the License for the specific language governing permissions and ...@@ -13,3 +13,96 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/framework/op_version_proto.h" #include "paddle/fluid/framework/op_version_proto.h"
namespace paddle {
namespace framework {
namespace compatible {
namespace pb {
const std::unordered_map<std::string, uint32_t>& GetLegacyOpVersions() {
static std::unordered_map<std::string, uint32_t> op_versions = {
{"not_equal", 1},
{"fake_channel_wise_dequantize_max_abs", 2},
{"yolo_box", 1},
{"data_norm", 1},
{"cumsum", 1},
{"fake_channel_wise_quantize_abs_max", 1},
{"greater_equal", 1},
{"fill_constant", 2},
{"conv_transpose", 1},
{"fusion_gru", 1},
{"flip", 1},
{"elementwise_sub", 1},
{"dequantize", 1},
{"grid_sampler", 1},
{"expand_as_v2", 1},
{"linspace", 1},
{"moving_average_abs_max_scale", 2},
{"p_norm", 1},
{"instance_norm", 1},
{"lookup_table_v2", 1},
{"seed", 1},
{"softmax_with_cross_entropy", 1},
{"rank_attention", 1},
{"cudnn_lstm", 1},
{"clip", 1},
{"requantize", 1},
{"for_pybind_test__", 4},
{"print", 1},
{"transfer_layout", 1},
{"arg_min", 1},
{"roll", 2},
{"roi_pool", 2},
{"conv2d_transpose", 2},
{"roi_align", 3},
{"softplus", 1},
{"momentum", 1},
{"trace", 1},
{"matmul", 1},
{"lookup_table", 1},
{"lstsq", 1},
{"conv3d_transpose", 1},
{"depthwise_conv2d_transpose", 1},
{"conv2d", 1},
{"lamb", 1},
{"send_and_recv", 1},
{"gaussian_random", 1},
{"unique_consecutive", 1},
{"conv3d", 1},
{"pixel_shuffle", 1},
{"collect_fpn_proposals", 1},
{"coalesce_tensor", 2},
{"arg_max", 1},
{"allclose", 2},
{"matrix_nms", 1},
{"less_than", 1},
{"affine_grid", 1},
{"hard_shrink", 1},
{"set_value", 3},
{"mish", 1},
{"quantize", 2},
{"distribute_fpn_proposals", 2},
{"adam", 4},
{"elementwise_pow", 1},
{"elementwise_mul", 1},
{"elementwise_mod", 1},
{"auc", 1},
{"elementwise_min", 1},
{"elementwise_max", 1},
{"gather", 1},
{"elementwise_div", 1},
{"elementwise_add", 1},
{"leaky_relu", 1},
{"generate_proposal_labels", 2},
{"elementwise_floordiv", 1},
{"less_equal", 1},
{"generate_proposals", 2},
{"depthwise_conv2d", 1},
{"greater_than", 1},
{"generate_proposals_v2", 1},
{"equal", 1}};
return op_versions;
}
} // namespace pb
} // namespace compatible
} // namespace framework
} // namespace paddle
...@@ -17,6 +17,7 @@ limitations under the License. */ ...@@ -17,6 +17,7 @@ limitations under the License. */
#include <stdint.h> #include <stdint.h>
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
...@@ -53,6 +54,10 @@ class OpVersionMap { ...@@ -53,6 +54,10 @@ class OpVersionMap {
proto::OpVersionMap* desc_; proto::OpVersionMap* desc_;
}; };
// get version id for operators with version id in paddle 2.4.2, this is used
// for converting ProgramDesc in 2.4 comtabible format
const std::unordered_map<std::string, uint32_t>& GetLegacyOpVersions();
} // namespace pb } // namespace pb
} // namespace compatible } // namespace compatible
} // namespace framework } // namespace framework
......
...@@ -264,6 +264,13 @@ inline void SaveOpVersions( ...@@ -264,6 +264,13 @@ inline void SaveOpVersions(
} }
} }
inline void SaveOpVersions(const std::unordered_map<std::string, uint32_t>& src,
pb::OpVersionMap* dst) {
for (const auto& pair : src) {
(*dst)[pair.first].SetVersionID(pair.second);
}
}
class OpVersionComparator { class OpVersionComparator {
public: public:
virtual bool operator()() = 0; virtual bool operator()() = 0;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/program_converter.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/version.h"
namespace paddle {
namespace framework {
using paddle::experimental::ExtractPlainVector;
using paddle::experimental::WrapAsScalars;
std::pair<bool, std::unordered_map<std::string, uint32_t>> DetectLegacyOps(
ProgramDesc* program) {
bool is_legacy_program = false;
std::unordered_map<std::string, uint32_t> legacy_op_versions;
std::unordered_map<std::string, uint32_t> current_op_versions;
std::unordered_map<std::string, uint32_t> program_op_versions;
// get *all kinds* of formats of op versions and op version map to a unified
// representation before comparison can be done in a neat way
if (!program->HasOpVersionMap()) {
is_legacy_program = true;
} else {
for (const auto& pair :
paddle::framework::compatible::get_op_version_map()) {
current_op_versions.insert(
std::make_pair(pair.first, pair.second.version_id()));
}
const auto* _op_version_map = program->OpVersionMap();
for (int i = 0; i < _op_version_map->pair_size(); ++i) {
auto pair =
std::make_pair(_op_version_map->pair(i).op_name(),
static_cast<uint32_t>(
_op_version_map->pair(i).op_version().version()));
program_op_versions.insert(pair);
}
for (const auto& pair : program_op_versions) {
uint32_t program_op_version = pair.second;
if (!current_op_versions.count(pair.first)) {
// this means program_op_versions is more upated than
// current_op_versions it is loading a program from future versions of
// paddle
continue;
}
uint32_t current_op_version = current_op_versions.at(pair.first);
if (program_op_version < current_op_version) {
is_legacy_program = true;
legacy_op_versions.insert(
std::make_pair(pair.first, program_op_version));
}
}
}
return std::make_pair(is_legacy_program, legacy_op_versions);
}
namespace no_scalar {
void ConvertSetValueOp(OpDesc* op) {
std::vector<paddle::experimental::Scalar> values = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, op->GetAttr("values", false));
op->RemoveAttr("values");
op->SetAttr("bool_values", std::vector<int>());
op->SetAttr("fp32_values", std::vector<float>());
op->SetAttr("int32_values", std::vector<int>());
op->SetAttr("int64_values", std::vector<int64_t>());
op->SetAttr("fp64_values", std::vector<double>());
op->SetAttr("fp16_values", std::vector<float>());
phi::DataType dtype = phi::DataType::FLOAT32;
if (values.size()) {
dtype = values.at(0).dtype();
}
switch (dtype) {
case phi::DataType::BOOL:
op->SetAttr("bool_values", ExtractPlainVector<int>(values));
break;
case phi::DataType::FLOAT32:
op->SetAttr("fp32_values", ExtractPlainVector<float>(values));
break;
case phi::DataType::INT32:
op->SetAttr("int32_values", ExtractPlainVector<int>(values));
break;
case phi::DataType::INT64:
op->SetAttr("int64_values", ExtractPlainVector<int64_t>(values));
break;
case phi::DataType::FLOAT64:
op->SetAttr("fp64_values", ExtractPlainVector<double>(values));
break;
case phi::DataType::FLOAT16:
op->SetAttr("fp16_values", ExtractPlainVector<float>(values));
break;
default:
PD_THROW("Invalid data type `", dtype, "`.");
}
}
void ConvertProgram(ProgramDesc* program) {
PADDLE_ENFORCE_NOT_NULL(
program,
paddle::platform::errors::InvalidArgument("program should not be null"));
VLOG(3) << "Setting Program Version and OpVersionMap to Legacy "
"settings(a.k.a 2.4.2)";
framework::compatible::pb::OpVersionMap op_version_map(
program->OpVersionMap());
program->SetVersion(paddle::framework::kLegacyProgramVersion);
paddle::framework::compatible::SaveOpVersions(
paddle::framework::compatible::pb::GetLegacyOpVersions(),
&op_version_map);
VLOG(3) << "Converting program from new(with scalar attributes) to old(no "
"scalar attributes)";
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(j);
const std::string op_type = op->Type();
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
}
}
}
} // namespace no_scalar
namespace scalar {
void ConvertSetValueOp(OpDesc* op) {
std::vector<paddle::experimental::Scalar> values;
if (op->HasAttr("bool_values")) {
std::vector<int> bool_values =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("bool_values", false));
if (bool_values.size()) {
values = WrapAsScalars(bool_values);
}
op->RemoveAttr("bool_values");
}
if (op->HasAttr("fp32_values")) {
std::vector<float> fp32_values =
PADDLE_GET_CONST(std::vector<float>, op->GetAttr("fp32_values", false));
if (fp32_values.size()) {
values = WrapAsScalars(fp32_values);
}
op->RemoveAttr("fp32_values");
}
if (op->HasAttr("int32_values")) {
std::vector<int> int32_values =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("int32_values", false));
if (int32_values.size()) {
values = WrapAsScalars(int32_values);
}
op->RemoveAttr("int32_values");
}
if (op->HasAttr("int64_values")) {
std::vector<int64_t> int64_values = PADDLE_GET_CONST(
std::vector<int64_t>, op->GetAttr("int64_values", false));
if (int64_values.size()) {
values = WrapAsScalars(int64_values);
}
op->RemoveAttr("int64_values");
}
if (op->HasAttr("fp64_values")) {
std::vector<double> fp64_values = PADDLE_GET_CONST(
std::vector<double>, op->GetAttr("fp64_values", false));
if (fp64_values.size()) {
values = WrapAsScalars(fp64_values);
}
op->RemoveAttr("fp64_values");
}
if (op->HasAttr("fp16_values")) {
std::vector<float> fp16_values =
PADDLE_GET_CONST(std::vector<float>, op->GetAttr("fp16_values", false));
if (fp16_values.size()) {
values = WrapAsScalars(fp16_values);
}
op->RemoveAttr("fp16_values");
}
op->SetAttr("values", values);
}
void ConvertProgram(ProgramDesc* program) {
PADDLE_ENFORCE_NOT_NULL(
program,
paddle::platform::errors::InvalidArgument("program should not be null"));
auto legacy_op_results = DetectLegacyOps(program);
bool is_legacy_program = legacy_op_results.first;
const std::unordered_map<std::string, uint32_t>& legacy_op_versions =
legacy_op_results.second;
if (!is_legacy_program) return;
VLOG(3) << "Updating Program Version and OpVersionMap";
program->SetVersion(paddle::framework::kCurProgramVersion);
framework::compatible::pb::OpVersionMap op_version_map(
program->OpVersionMap());
paddle::framework::compatible::SaveOpVersions(
framework::compatible::get_op_version_map(), &op_version_map);
VLOG(3) << "Converting program from old(no scalar attributes) to new(with "
"scalar attributes)";
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(j);
const std::string op_type = op->Type();
if (!legacy_op_versions.count(op_type)) {
continue;
}
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
}
}
}
} // namespace scalar
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace framework {
namespace no_scalar {
void ConvertProgram(ProgramDesc* program);
} // namespace no_scalar
namespace scalar {
void ConvertProgram(ProgramDesc* program);
} // namespace scalar
} // namespace framework
} // namespace paddle
...@@ -20,6 +20,9 @@ extern "C" { ...@@ -20,6 +20,9 @@ extern "C" {
#include <algorithm> #include <algorithm>
#include "paddle/fluid/framework/feed_fetch_type.h" #include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_converter.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
namespace paddle { namespace paddle {
...@@ -48,8 +51,12 @@ proto::OpVersionMap *ProgramDesc::OpVersionMap() { ...@@ -48,8 +51,12 @@ proto::OpVersionMap *ProgramDesc::OpVersionMap() {
return desc_.mutable_op_version_map(); return desc_.mutable_op_version_map();
} }
bool ProgramDesc::HasOpVersionMap() const { return desc_.has_op_version_map(); }
int64_t ProgramDesc::Version() const { return desc_.version().version(); } int64_t ProgramDesc::Version() const { return desc_.version().version(); }
bool ProgramDesc::HasVersion() const { return desc_.has_version(); }
void ProgramDesc::SetVersion(const int64_t version) { void ProgramDesc::SetVersion(const int64_t version) {
desc_.mutable_version()->set_version(version); desc_.mutable_version()->set_version(version);
} }
...@@ -142,6 +149,7 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) { ...@@ -142,6 +149,7 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"Failed to parse program_desc from binary string.")); "Failed to parse program_desc from binary string."));
InitFromProto(); InitFromProto();
scalar::ConvertProgram(this);
} }
void ProgramDesc::InitFromProto() { void ProgramDesc::InitFromProto() {
......
...@@ -61,8 +61,12 @@ class ProgramDesc { ...@@ -61,8 +61,12 @@ class ProgramDesc {
proto::OpVersionMap *OpVersionMap(); proto::OpVersionMap *OpVersionMap();
bool HasOpVersionMap() const;
int64_t Version() const; int64_t Version() const;
bool HasVersion() const;
void SetVersion(const int64_t version); void SetVersion(const int64_t version);
// The output variable of feed_op is referenced as feed_target. // The output variable of feed_op is referenced as feed_target.
......
...@@ -36,6 +36,13 @@ constexpr int64_t kCurProgramVersion = PADDLE_VERSION_INTEGER; ...@@ -36,6 +36,13 @@ constexpr int64_t kCurProgramVersion = PADDLE_VERSION_INTEGER;
constexpr int64_t kCurProgramVersion = 0; constexpr int64_t kCurProgramVersion = 0;
#endif #endif
// paddle in 2.4.2 and before does not support Scalar in op attributes
// and no op in program of 2.4.2 or earlier versions has attributes of type
// `Scalar` This version number is used for converting program to a legacy
// format to indentifiy their version id See Also
// paddle/fluid/framework/op_version_proto.cc
constexpr int64_t kLegacyProgramVersion = 2004002L;
// The program version that was generated by previous or current codes // The program version that was generated by previous or current codes
// and supported by current codes. // and supported by current codes.
constexpr int64_t kSupportedProgramVersion[] = {0}; constexpr int64_t kSupportedProgramVersion[] = {0};
......
...@@ -55,6 +55,43 @@ typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor( ...@@ -55,6 +55,43 @@ typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor(
framework::TensorFromVector(values, ctx.device_context(), out); framework::TensorFromVector(values, ctx.device_context(), out);
} }
template <typename T, typename Context>
typename std::enable_if<std::is_same<T, bool>::value>::type CopyVectorToTensor(
const Context& dev_ctx,
const std::vector<Scalar>& values,
phi::DenseTensor* out) {
// If attribute value dtype is vector<bool>, it will be converted to
// vector<int>. at the same time, we can not use vector<bool> to hold
// the value, because the c++ use bit value to replace byte value.
std::vector<int> assign_values;
assign_values.reserve(values.size());
for (const auto& val : values) {
assign_values.emplace_back(val.to<int>());
}
phi::TensorFromVector(assign_values, dev_ctx, out);
// use the array to replace to vector
bool* array_ptr = new T[assign_values.size()];
for (unsigned int i = 0; i < assign_values.size(); i++) {
array_ptr[i] = static_cast<T>(assign_values[i]);
}
phi::TensorFromArray(array_ptr, assign_values.size(), dev_ctx, out);
delete[] array_ptr;
}
template <typename T, typename Context>
typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor(
const Context& dev_ctx,
const std::vector<Scalar>& values,
phi::DenseTensor* out) {
std::vector<T> assign_values;
assign_values.reserve(values.size());
for (const auto& val : values) {
assign_values.emplace_back(val.to<T>());
}
phi::TensorFromVector(assign_values, dev_ctx, out);
}
template <typename T> template <typename T>
class AssignValueKernel : public framework::OpKernel<T> { class AssignValueKernel : public framework::OpKernel<T> {
public: public:
......
...@@ -110,7 +110,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -110,7 +110,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
framework::proto::VarType::INT64, framework::proto::VarType::INT64,
framework::proto::VarType::FP32, framework::proto::VarType::FP32,
framework::proto::VarType::FP64, framework::proto::VarType::FP64,
framework::proto::VarType::FP16}) framework::proto::VarType::FP16,
framework::proto::VarType::COMPLEX64,
framework::proto::VarType::COMPLEX128})
.SetDefault(framework::proto::VarType::FP32); .SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>( AddAttr<std::vector<int64_t>>(
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to."); "axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
...@@ -131,17 +133,7 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker { ...@@ -131,17 +133,7 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>("none_axes", "(list<int>) The axes to none.") AddAttr<std::vector<int64_t>>("none_axes", "(list<int>) The axes to none.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int>>("bool_values", "Store the bool values.") AddAttr<std::vector<paddle::experimental::Scalar>>("values", "values")
.SetDefault({});
AddAttr<std::vector<float>>("fp32_values", "Store the float32 values.")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "Store the int32 values.")
.SetDefault({});
AddAttr<std::vector<int64_t>>("int64_values", "Store the int64 values.")
.SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.")
.SetDefault({});
AddAttr<std::vector<float>>("fp16_values", "Store the float16 values.")
.SetDefault({}); .SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.") AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
...@@ -298,4 +290,15 @@ Upgrade set_value, add 1 attribute [decrease_axes]. ...@@ -298,4 +290,15 @@ Upgrade set_value, add 1 attribute [decrease_axes].
Upgrade set_value, add 1 attribute [none_axes]. Upgrade set_value, add 1 attribute [none_axes].
)ROC", )ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr( paddle::framework::compatible::OpVersionDesc().NewAttr(
"none_axes", "The axes with none index.", std::vector<int64_t>{})); "none_axes", "The axes with none index.", std::vector<int64_t>{}))
.AddCheckpoint(
R"ROC(Upgrade set_value to support generic Scalars as value and remove plain values, so as to support complex types.)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("values",
"values",
std::vector<paddle::experimental::Scalar>())
.DeleteAttr("bool_values", "remove plain attributes")
.DeleteAttr("fp32_values", "remove plain attributes")
.DeleteAttr("int32_values", "remove plain attributes")
.DeleteAttr("int64_values", "remove plain attributes")
.DeleteAttr("fp64_values", "remove plain attributes"));
...@@ -33,34 +33,6 @@ namespace operators { ...@@ -33,34 +33,6 @@ namespace operators {
using DDim = framework::DDim; using DDim = framework::DDim;
inline std::string GetValueName(framework::proto::VarType::Type data_type) {
std::string value_name;
switch (data_type) {
case framework::proto::VarType::INT32:
value_name = "int32_values";
break;
case framework::proto::VarType::INT64:
value_name = "int64_values";
break;
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
case framework::proto::VarType::FP64:
value_name = "fp64_values";
break;
case framework::proto::VarType::BOOL:
value_name = "bool_values";
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for SetValue operator, only "
"supports bool, int32, float32 and int64.",
data_type));
}
return value_name;
}
// check whether the tensor with dimension of second can assign to the // check whether the tensor with dimension of second can assign to the
// tensor with dimension of first // tensor with dimension of first
inline void CheckIsDimsMatch(const framework::DDim first, inline void CheckIsDimsMatch(const framework::DDim first,
......
...@@ -1197,11 +1197,21 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, ...@@ -1197,11 +1197,21 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) { if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp); value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
} }
} else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj_tmp);
}
} else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj_tmp);
}
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, " "When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, " "the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, " "float32, float64, complex64, complex128, int32 or int64, "
"please check the type of tensor.")); "please check the type of tensor."));
} }
...@@ -1217,29 +1227,38 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, ...@@ -1217,29 +1227,38 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
// convert the value to self data type // convert the value to self data type
if (py::isinstance<py::float_>(value_obj_tmp) || if (py::isinstance<py::float_>(value_obj_tmp) ||
py::isinstance<py::int_>(value_obj_tmp) || py::isinstance<py::int_>(value_obj_tmp) ||
py::isinstance<py::bool_>(value_obj_tmp)) { py::isinstance<py::bool_>(value_obj_tmp) ||
PyComplex_Check(value_obj)) {
if (self->tensor.dtype() == phi::DataType::FLOAT32) { if (self->tensor.dtype() == phi::DataType::FLOAT32) {
attrs["fp32_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<float>{value_obj_tmp.cast<float>()}; value_obj_tmp.cast<float>()};
} else if (self->tensor.dtype() == phi::DataType::FLOAT64) { } else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
attrs["fp64_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<double>{value_obj_tmp.cast<double>()}; value_obj_tmp.cast<double>()};
} else if (self->tensor.dtype() == phi::DataType::INT32) { } else if (self->tensor.dtype() == phi::DataType::INT32) {
attrs["int32_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<int32_t>{value_obj_tmp.cast<int32_t>()}; value_obj_tmp.cast<int32_t>()};
} else if (self->tensor.dtype() == phi::DataType::INT64) { } else if (self->tensor.dtype() == phi::DataType::INT64) {
attrs["int64_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<int64_t>{value_obj_tmp.cast<int64_t>()}; value_obj_tmp.cast<int64_t>()};
} else if (self->tensor.dtype() == phi::DataType::BOOL) { } else if (self->tensor.dtype() == phi::DataType::BOOL) {
attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()}; attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<bool>()};
} else if (self->tensor.dtype() == phi::DataType::FLOAT16) { } else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
attrs["fp16_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<float>{value_obj_tmp.cast<float>()}; value_obj_tmp.cast<float>()};
} else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<std::complex<float>>()};
} else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<std::complex<double>>()};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, " "When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, " "the data type of the paddle.Tensor must be bool, "
"float32, int32, int64 or float16, " "float32, float64, complex64, complex128, int32, int64 or "
"float16, "
"please check the type of tensor.")); "please check the type of tensor."));
} }
attrs["shape"] = std::vector<int64_t>{1}; attrs["shape"] = std::vector<int64_t>{1};
...@@ -1247,7 +1266,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self, ...@@ -1247,7 +1266,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Value type error. The assign value allows " "Value type error. The assign value allows "
"numpy.ndarray, integer, float or bool, " "numpy.ndarray, integer, float, complex or bool, "
"but received %s.", "but received %s.",
Py_TYPE(value_obj))); Py_TYPE(value_obj)));
} }
......
...@@ -902,11 +902,28 @@ void BindImperative(py::module *m_ptr) { ...@@ -902,11 +902,28 @@ void BindImperative(py::module *m_ptr) {
if (!py::isinstance<py::array_t<bool>>(value_obj)) { if (!py::isinstance<py::array_t<bool>>(value_obj)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj); value = pybind11::detail::CastNumpyArray<bool>(value_obj);
} }
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(
value_obj)) {
value =
pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(
value_obj)) {
value =
pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj);
}
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, " "When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, " "the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, " "float32, float64, complex64, complex128, int32 or "
"int64, "
"please check the type of tensor.")); "please check the type of tensor."));
} }
...@@ -921,35 +938,45 @@ void BindImperative(py::module *m_ptr) { ...@@ -921,35 +938,45 @@ void BindImperative(py::module *m_ptr) {
// convert the value to self data type // convert the value to self data type
if (py::isinstance<py::float_>(value_obj) || if (py::isinstance<py::float_>(value_obj) ||
py::isinstance<py::int_>(value_obj) || py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::bool_>(value_obj)) { py::isinstance<py::bool_>(value_obj) ||
PyComplex_Check(value_obj.ptr())) {
if (self->DataType() == framework::proto::VarType::FP32) { if (self->DataType() == framework::proto::VarType::FP32) {
attrs["fp32_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<float>{value_obj.cast<float>()}; value_obj.cast<float>()};
} else if (self->DataType() == } else if (self->DataType() ==
framework::proto::VarType::FP64) { framework::proto::VarType::FP64) {
attrs["fp64_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<double>{value_obj.cast<double>()}; value_obj.cast<double>()};
} else if (self->DataType() == } else if (self->DataType() ==
framework::proto::VarType::INT32) { framework::proto::VarType::INT32) {
attrs["int32_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<int32_t>{value_obj.cast<int32_t>()}; value_obj.cast<int32_t>()};
} else if (self->DataType() == } else if (self->DataType() ==
framework::proto::VarType::INT64) { framework::proto::VarType::INT64) {
attrs["int64_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<int64_t>{value_obj.cast<int64_t>()}; value_obj.cast<int64_t>()};
} else if (self->DataType() == } else if (self->DataType() ==
framework::proto::VarType::BOOL) { framework::proto::VarType::BOOL) {
attrs["bool_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<int>{value_obj.cast<bool>()}; value_obj.cast<bool>()};
} else if (self->DataType() == } else if (self->DataType() ==
framework::proto::VarType::FP16) { framework::proto::VarType::FP16) {
attrs["fp16_values"] = attrs["values"] = std::vector<paddle::experimental::Scalar>{
std::vector<float>{value_obj.cast<float>()}; value_obj.cast<float>()};
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<std::complex<float>>()};
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX128) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<std::complex<double>>()};
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, " "When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, " "the data type of the paddle.Tensor must be bool, "
"float32, int32, int64 or float16, " "float32, float64, complex64, complex128, int32, int64 "
"or float16, "
"please check the type of tensor.")); "please check the type of tensor."));
} }
attrs["shape"] = std::vector<int64_t>{1}; attrs["shape"] = std::vector<int64_t>{1};
......
...@@ -23,6 +23,9 @@ limitations under the License. */ ...@@ -23,6 +23,9 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph_helper.h" #include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_desc.h" #include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_converter.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h" #include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h" #include "paddle/fluid/framework/version.h"
...@@ -56,6 +59,33 @@ static pybind11::bytes SerializeMessage( ...@@ -56,6 +59,33 @@ static pybind11::bytes SerializeMessage(
return retv; return retv;
} }
pybind11::bytes SerializeProgramDesc(
pd::ProgramDesc &self, // NOLINT due to pybind11 convention.
bool legacy_format = false) {
// Check IsInitialized in Python
std::string retv;
pd::ProgramDesc copy = self;
framework::compatible::pb::OpVersionMap op_version_map(copy.OpVersionMap());
if (legacy_format) {
pd::no_scalar::ConvertProgram(&copy);
copy.SetVersion(pd::kLegacyProgramVersion);
paddle::framework::compatible::SaveOpVersions(
paddle::framework::compatible::pb::GetLegacyOpVersions(),
&op_version_map);
} else {
copy.SetVersion(pd::kCurProgramVersion);
paddle::framework::compatible::SaveOpVersions(
framework::compatible::get_op_version_map(), &op_version_map);
}
PADDLE_ENFORCE_EQ(copy.Proto()->SerializePartialToString(&retv),
true,
platform::errors::InvalidArgument(
"Failed to serialize input Desc to string."));
return retv;
}
template <typename T> template <typename T>
static void DeserializeMessage(T *self, const std::string &str) { static void DeserializeMessage(T *self, const std::string &str) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -88,7 +118,9 @@ void BindProgramDesc(pybind11::module *m) { ...@@ -88,7 +118,9 @@ void BindProgramDesc(pybind11::module *m) {
.def("flush", &pd::ProgramDesc::Flush) .def("flush", &pd::ProgramDesc::Flush)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames) .def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames) .def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>) .def("serialize_to_string",
&SerializeProgramDesc,
pybind11::arg("legacy_format") = false)
.def("need_update", &pd::ProgramDesc::NeedUpdate) .def("need_update", &pd::ProgramDesc::NeedUpdate)
.def("parse_from_string", .def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) { [](pd::ProgramDesc &program_desc, const std::string &data) {
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include <tuple> #include <tuple>
#include <type_traits>
#include <utility> #include <utility>
#include <vector> #include <vector>
...@@ -63,12 +64,19 @@ constexpr int NPY_UINT16_ = 4; ...@@ -63,12 +64,19 @@ constexpr int NPY_UINT16_ = 4;
constexpr int NPY_COMPLEX64 = 14; constexpr int NPY_COMPLEX64 = 14;
constexpr int NPY_COMPLEX128 = 15; constexpr int NPY_COMPLEX128 = 15;
template <typename T, typename S>
struct casting_complex_to_non_complex {
static const bool value = pybind11::detail::is_complex<S>::value &&
!pybind11::detail::is_complex<T>::value;
};
// cast numpy type form S to T, this may allocate new memory // cast numpy type form S to T, this may allocate new memory
template <class T, class S> template <
class T,
class S,
std::enable_if_t<!std::is_same<T, S>::value &&
!casting_complex_to_non_complex<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) { static py::array_t<T> CastNumpyType(py::array_t<S> array) {
if (std::is_same<T, S>::value) {
return array;
}
auto dim = array.ndim(); auto dim = array.ndim();
std::vector<py::ssize_t> result_shape(dim); std::vector<py::ssize_t> result_shape(dim);
for (auto i = 0; i < dim; i++) { for (auto i = 0; i < dim; i++) {
...@@ -80,6 +88,30 @@ static py::array_t<T> CastNumpyType(py::array_t<S> array) { ...@@ -80,6 +88,30 @@ static py::array_t<T> CastNumpyType(py::array_t<S> array) {
return py::vectorize([](S s) { return static_cast<T>(s); })(array); return py::vectorize([](S s) { return static_cast<T>(s); })(array);
} }
template <
class T,
class S,
std::enable_if_t<(!std::is_same<T, S>::value) &&
casting_complex_to_non_complex<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
auto dim = array.ndim();
std::vector<py::ssize_t> result_shape(dim);
for (auto i = 0; i < dim; i++) {
result_shape[i] = array.shape(i);
}
py::array_t<T> result(result_shape);
return py::vectorize([](S s) { return static_cast<T>(s.real()); })(array);
}
template <class T,
class S,
std::enable_if_t<std::is_same<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
return array;
}
template <class T> template <class T>
static py::array_t<T> CastNumpyArray(const py::object &array) { static py::array_t<T> CastNumpyArray(const py::object &array) {
if (py::isinstance<py::array_t<float>>(array)) { if (py::isinstance<py::array_t<float>>(array)) {
...@@ -92,10 +124,14 @@ static py::array_t<T> CastNumpyArray(const py::object &array) { ...@@ -92,10 +124,14 @@ static py::array_t<T> CastNumpyArray(const py::object &array) {
return CastNumpyType<T>(array.cast<py::array_t<int64_t>>()); return CastNumpyType<T>(array.cast<py::array_t<int64_t>>());
} else if (py::isinstance<py::array_t<bool>>(array)) { } else if (py::isinstance<py::array_t<bool>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<bool>>()); return CastNumpyType<T>(array.cast<py::array_t<bool>>());
} else if (py::isinstance<py::array_t<std::complex<float>>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<std::complex<float>>>());
} else if (py::isinstance<py::array_t<std::complex<double>>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<std::complex<double>>>());
} else { } else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument( PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Value type error. The assign numpy value allows integer, float, " "Value type error. The assign numpy value allows integer, float, "
"double and bool, " "double, complex64, complex128, and bool, "
"but received %s.", "but received %s.",
Py_TYPE(array.ptr())->tp_name)); Py_TYPE(array.ptr())->tp_name));
} }
......
...@@ -216,12 +216,14 @@ class DistributedSaver: ...@@ -216,12 +216,14 @@ class DistributedSaver:
dist_filename = filename + "_dist" + str(rank_id) dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename) dist_path = os.path.join(dirname, dist_filename)
legacy_format = kwargs.get("legacy_format", False)
paddle.static.save_inference_model( paddle.static.save_inference_model(
dist_path, dist_path,
dist_feed_vars, dist_feed_vars,
dist_fetch_vars, dist_fetch_vars,
exe, exe,
program=dist_main_prog, program=dist_main_prog,
legacy_format=legacy_format,
) )
def _save_rank_mapping(self, dirname): def _save_rank_mapping(self, dirname):
......
...@@ -718,6 +718,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -718,6 +718,7 @@ class ParameterServerRuntime(RuntimeBase):
target_vars, target_vars,
main_program=None, main_program=None,
export_for_deployment=True, export_for_deployment=True,
legacy_format=False,
): ):
""" """
Prune the given `main_program` to build a new program especially for inference, Prune the given `main_program` to build a new program especially for inference,
...@@ -743,6 +744,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -743,6 +744,7 @@ class ParameterServerRuntime(RuntimeBase):
None, None,
None, None,
export_for_deployment, export_for_deployment,
legacy_format=legacy_format,
) )
else: else:
paddle.fluid.io.save_inference_model( paddle.fluid.io.save_inference_model(
...@@ -755,6 +757,7 @@ class ParameterServerRuntime(RuntimeBase): ...@@ -755,6 +757,7 @@ class ParameterServerRuntime(RuntimeBase):
None, None,
export_for_deployment, export_for_deployment,
True, True,
legacy_format=legacy_format,
) )
model_basename = "__model__" model_basename = "__model__"
......
...@@ -5734,6 +5734,22 @@ class Program: ...@@ -5734,6 +5734,22 @@ class Program:
res_str = "" res_str = ""
for block in self.blocks: for block in self.blocks:
res_str += block.to_string(throw_on_error, with_details) res_str += block.to_string(throw_on_error, with_details)
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(bytes(protostr))
res_str += (
"version {\n "
+ textwrap.indent(
_debug_string_(proto.version, throw_on_error), " "
)
+ "}\n"
)
res_str += (
"op_version_map {\n "
+ textwrap.indent(
_debug_string_(proto.op_version_map, throw_on_error), " "
)
+ "}\n"
)
else: else:
protostr = self.desc.serialize_to_string() protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(bytes(protostr)) proto = framework_pb2.ProgramDesc.FromString(bytes(protostr))
......
...@@ -135,6 +135,7 @@ def save_inference_model( ...@@ -135,6 +135,7 @@ def save_inference_model(
export_for_deployment=True, export_for_deployment=True,
program_only=False, program_only=False,
clip_extra=True, clip_extra=True,
legacy_format=False,
): ):
""" """
Prune the given `main_program` to build a new program especially for inference, Prune the given `main_program` to build a new program especially for inference,
...@@ -176,6 +177,8 @@ def save_inference_model( ...@@ -176,6 +177,8 @@ def save_inference_model(
program_only(bool, optional): If True, It will save inference program only, and do not program_only(bool, optional): If True, It will save inference program only, and do not
save params of Program. save params of Program.
Default: False. Default: False.
legacy_format(bool, optional): Whether to save program in legacy format.
Default: False.
Returns: Returns:
list, The fetch variables' name list. list, The fetch variables' name list.
...@@ -314,8 +317,6 @@ def save_inference_model( ...@@ -314,8 +317,6 @@ def save_inference_model(
prepend_feed_ops(main_program, feeded_var_names) prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names) append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_basename, "wb") as f: with open(model_basename, "wb") as f:
f.write( f.write(
main_program._remove_training_info( main_program._remove_training_info(
......
...@@ -390,9 +390,6 @@ def create_fake_model(program_config): ...@@ -390,9 +390,6 @@ def create_fake_model(program_config):
op_desc.set_output('Out', ["fetch"]) op_desc.set_output('Out', ["fetch"])
op_desc._set_attr("col", index) op_desc._set_attr("col", index)
main_program_desc._set_version()
paddle.fluid.core.save_op_version_info(main_program_desc)
model = main_program_desc.serialize_to_string() model = main_program_desc.serialize_to_string()
util_program._sync_with_cpp() util_program._sync_with_cpp()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid.proto import framework_pb2
class TestSetValue(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def _test_for_new_program_format(self, program_bytes):
restored_prog_as_is = framework_pb2.ProgramDesc.FromString(
program_bytes
)
for block in restored_prog_as_is.blocks:
for op in block.ops:
if op.type in ("set_value", "set_value_grad"):
attr_names = [attr.name for attr in op.attrs]
self.assertTrue("values" in attr_names)
self.assertFalse("bool_values" in attr_names)
self.assertFalse("int32_values" in attr_names)
self.assertFalse("int64_values" in attr_names)
self.assertFalse("fp32_values" in attr_names)
self.assertFalse("fp64_values" in attr_names)
self.assertFalse("fp16_values" in attr_names)
def _test_for_legacy_program_format(self, program_bytes):
restored_prog_as_is = framework_pb2.ProgramDesc.FromString(
program_bytes
)
for block in restored_prog_as_is.blocks:
for op in block.ops:
if op.type in ("set_value", "set_value_grad"):
attr_names = [attr.name for attr in op.attrs]
self.assertFalse("values" in attr_names)
self.assertTrue("bool_values" in attr_names)
self.assertTrue("int32_values" in attr_names)
self.assertTrue("int64_values" in attr_names)
self.assertTrue("fp32_values" in attr_names)
self.assertTrue("fp64_values" in attr_names)
self.assertTrue("fp16_values" in attr_names)
def _test_equivalence(
self,
new_program_bytes,
legacy_program_bytes,
fetch_list,
expected_outputs,
):
normal_program = paddle.static.io.deserialize_program(new_program_bytes)
converted_back_program = paddle.static.io.deserialize_program(
legacy_program_bytes
)
exe = paddle.static.Executor(paddle.CPUPlace())
[out] = exe.run(normal_program, fetch_list=fetch_list)
np.testing.assert_allclose(out, expected_outputs[0])
[out] = exe.run(converted_back_program, fetch_list=fetch_list)
np.testing.assert_allclose(out, expected_outputs[0])
def test_int32(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.int32)
patch = np.array([41, 42]).astype(np.int32)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.int32)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_int64(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.int64)
patch = np.array(
[np.iinfo(np.int64).max, np.iinfo(np.int64).min]
).astype(np.int64)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.int64)
x_output = x_input.copy()
x_output[:1, :2] = patch
self.fetch_list = [x.name]
self.expected_outputs = [x_output]
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_float32(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.float32)
patch = np.array(
[np.finfo(np.float32).max, np.finfo(np.float32).min]
).astype(np.float32)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.float32)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_float64(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.float64)
patch = np.array(
[np.finfo(np.float64).max, np.finfo(np.float64).min]
).astype(np.float64)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.float64)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_float16(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.float16)
patch = np.array(
[np.finfo(np.float16).max, np.finfo(np.float16).min]
).astype(np.float16)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.float16)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_bool(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.bool)
patch = np.array([True, False])
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=bool)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_complex64(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.complex(
paddle.ones([3, 4], dtype=paddle.float32),
paddle.ones([3, 4], dtype=paddle.float32),
)
patch = np.array([42.1 + 42.1j, 42.2 + 42.2j]).astype(np.complex64)
x[:1, :2] = patch
x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex64)
x_output = x_input.copy()
x_output[:1, :2] = patch
with self.assertRaisesRegex(RuntimeError, "Invalid data type"):
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
def test_complex128(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.complex(
paddle.ones([3, 4], dtype=paddle.float64),
paddle.ones([3, 4], dtype=paddle.float64),
)
patch = np.array(
[
np.finfo(np.float64).max + 1j * np.finfo(np.float64).min,
np.finfo(np.float64).min + 1j * np.finfo(np.float64).max,
]
).astype(np.complex128)
x[:1, :2] = patch
x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex128)
x_output = x_input.copy()
x_output[:1, :2] = patch
with self.assertRaisesRegex(RuntimeError, "Invalid data type"):
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
...@@ -711,6 +711,91 @@ create_test_value_numpy_bool(TestSetValueItemSlice3) ...@@ -711,6 +711,91 @@ create_test_value_numpy_bool(TestSetValueItemSlice3)
create_test_value_numpy_bool(TestSetValueItemSlice4) create_test_value_numpy_bool(TestSetValueItemSlice4)
def create_test_value_complex64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 42.1 + 42.1j
def set_dtype(self):
self.dtype = "complex64"
cls_name = "{}_{}".format(parent.__name__, "ValueComplex64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_complex64(TestSetValueItemInt)
create_test_value_complex64(TestSetValueItemSlice)
create_test_value_complex64(TestSetValueItemSlice2)
create_test_value_complex64(TestSetValueItemSlice3)
create_test_value_complex64(TestSetValueItemSlice4)
def create_test_value_complex128(parent):
class TestValueInt(parent):
def set_value(self):
self.value = complex(
np.finfo(np.float64).max + 1j * np.finfo(np.float64).min
)
def set_dtype(self):
self.dtype = "complex128"
cls_name = "{}_{}".format(parent.__name__, "ValueComplex128")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_complex128(TestSetValueItemInt)
create_test_value_complex128(TestSetValueItemSlice)
create_test_value_complex128(TestSetValueItemSlice2)
create_test_value_complex128(TestSetValueItemSlice3)
create_test_value_complex128(TestSetValueItemSlice4)
def create_test_value_numpy_complex64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array(42.1 + 42.1j)
def set_dtype(self):
self.dtype = "complex64"
cls_name = "{}_{}".format(parent.__name__, "ValueNumpyComplex64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_complex64(TestSetValueItemInt)
create_test_value_numpy_complex64(TestSetValueItemSlice)
create_test_value_numpy_complex64(TestSetValueItemSlice2)
create_test_value_numpy_complex64(TestSetValueItemSlice3)
create_test_value_numpy_complex64(TestSetValueItemSlice4)
def create_test_value_numpy_complex128(parent):
class TestValueInt(parent):
def set_value(self):
v = complex(
np.finfo(np.float64).max + 1j * np.finfo(np.float64).min
)
self.value = np.array([v])
def set_dtype(self):
self.dtype = "complex128"
cls_name = "{}_{}".format(parent.__name__, "ValueNumpyComplex128")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_complex128(TestSetValueItemInt)
create_test_value_numpy_complex128(TestSetValueItemSlice)
create_test_value_numpy_complex128(TestSetValueItemSlice2)
create_test_value_numpy_complex128(TestSetValueItemSlice3)
create_test_value_numpy_complex128(TestSetValueItemSlice4)
# 2.3 value is a Paddle Tensor (int32, int64, float32, float64, bool) # 2.3 value is a Paddle Tensor (int32, int64, float32, float64, bool)
def create_test_value_tensor_int32(parent): def create_test_value_tensor_int32(parent):
class TestValueInt(parent): class TestValueInt(parent):
......
...@@ -782,38 +782,15 @@ def _setitem_impl_(var, item, value): ...@@ -782,38 +782,15 @@ def _setitem_impl_(var, item, value):
from .data_feeder import convert_dtype from .data_feeder import convert_dtype
# 2.1 value is an integer of float # 2.1 value is an integer, float or complex
if isinstance(value, (int, float)): if isinstance(value, (bool, int, float, complex)):
value = np.array([value]).astype(convert_dtype(dtype)) value = np.array([value]).astype(convert_dtype(dtype))
# 2.2 value is a np.ndarray # 2.2 value is a np.ndarray
if isinstance(value, np.ndarray): if isinstance(value, np.ndarray):
shape = list(value.shape) shape = list(value.shape)
if dtype == core.VarDesc.VarType.BOOL: values = value.ravel().tolist()
value_name = "bool_values" attrs["values"] = values
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP16:
value_name = "fp16_values"
values = [float(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but "
"received %s." % convert_dtype(dtype)
)
attrs[value_name] = values
attrs["shape"] = shape attrs["shape"] = shape
elif isinstance(value, (Variable, core.eager.Tensor)): elif isinstance(value, (Variable, core.eager.Tensor)):
......
...@@ -262,6 +262,7 @@ class Fleet(metaclass=abc.ABCMeta): ...@@ -262,6 +262,7 @@ class Fleet(metaclass=abc.ABCMeta):
target_vars, target_vars,
main_program=None, main_program=None,
export_for_deployment=True, export_for_deployment=True,
legacy_format=False,
): ):
pass pass
......
...@@ -82,6 +82,7 @@ class Collective(Fleet): ...@@ -82,6 +82,7 @@ class Collective(Fleet):
target_vars=None, target_vars=None,
main_program=None, main_program=None,
export_for_deployment=True, export_for_deployment=True,
legacy_format=False,
): ):
""" """
Prune the given `main_program` to build a new program especially for Prune the given `main_program` to build a new program especially for
...@@ -109,6 +110,7 @@ class Collective(Fleet): ...@@ -109,6 +110,7 @@ class Collective(Fleet):
None, None,
None, None,
export_for_deployment, export_for_deployment,
legacy_format=legacy_format,
) )
def save_persistables( def save_persistables(
......
...@@ -419,6 +419,7 @@ class FleetTranspiler(Fleet): ...@@ -419,6 +419,7 @@ class FleetTranspiler(Fleet):
target_vars, target_vars,
main_program=None, main_program=None,
export_for_deployment=True, export_for_deployment=True,
legacy_format=False,
): ):
""" """
Prune the given `main_program` to build a new program especially for inference, Prune the given `main_program` to build a new program especially for inference,
...@@ -453,6 +454,7 @@ class FleetTranspiler(Fleet): ...@@ -453,6 +454,7 @@ class FleetTranspiler(Fleet):
None, None,
None, None,
export_for_deployment, export_for_deployment,
legacy_format=legacy_format,
) )
else: else:
paddle.static.save_inference_model( paddle.static.save_inference_model(
...@@ -465,6 +467,7 @@ class FleetTranspiler(Fleet): ...@@ -465,6 +467,7 @@ class FleetTranspiler(Fleet):
None, None,
export_for_deployment, export_for_deployment,
True, True,
legacy_format=legacy_format,
) )
model_basename = "__model__" model_basename = "__model__"
......
...@@ -1752,7 +1752,9 @@ class TracedLayer: ...@@ -1752,7 +1752,9 @@ class TracedLayer:
saved inference model. If None, all output variables of the saved inference model. If None, all output variables of the
TracedLayer object would be the outputs of the saved inference TracedLayer object would be the outputs of the saved inference
model. Default None. model. Default None.
kwargs: Supported keys including 'clip_extra'.set to True if you want to clip extra information for every operator. kwargs: Supported keys including
- clip_extra(bool): whether to clip extra information for every operator. Defaults to True.
- legacy_format(bool): whether to save program in legacy format. Default to False.
Returns: Returns:
None None
...@@ -1854,6 +1856,7 @@ class TracedLayer: ...@@ -1854,6 +1856,7 @@ class TracedLayer:
model_filename = file_prefix + INFER_MODEL_SUFFIX model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX params_filename = file_prefix + INFER_PARAMS_SUFFIX
legacy_format = kwargs.get('legacy_format', False)
save_inference_model( save_inference_model(
dirname=dirname, dirname=dirname,
feeded_var_names=feeded_var_names, feeded_var_names=feeded_var_names,
...@@ -1863,4 +1866,5 @@ class TracedLayer: ...@@ -1863,4 +1866,5 @@ class TracedLayer:
model_filename=model_filename, model_filename=model_filename,
params_filename=params_filename, params_filename=params_filename,
clip_extra=clip_extra, clip_extra=clip_extra,
legacy_format=legacy_format,
) )
...@@ -75,7 +75,7 @@ def _check_args(caller, args, supported_args=None, deprecated_args=None): ...@@ -75,7 +75,7 @@ def _check_args(caller, args, supported_args=None, deprecated_args=None):
def _check_vars(name, var_list): def _check_vars(name, var_list):
if not isinstance(var_list, list): if not isinstance(var_list, list):
var_list = [var_list] var_list = [var_list]
if not var_list or not all([isinstance(var, Variable) for var in var_list]): if not all([isinstance(var, Variable) for var in var_list]):
raise ValueError( raise ValueError(
f"'{name}' should be a Variable or a list of Variable." f"'{name}' should be a Variable or a list of Variable."
) )
...@@ -252,6 +252,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs): ...@@ -252,6 +252,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
kwargs: Supported keys including ``program``. Attention please, kwargs is used for backward compatibility mainly. kwargs: Supported keys including ``program``. Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program. - program(Program): specify a program if you don't want to use default main program.
- legacy_format(bool): whether to save inference program in legacy format. Defaults to False.
Returns: Returns:
bytes: serialized program. bytes: serialized program.
...@@ -289,14 +290,15 @@ def serialize_program(feed_vars, fetch_vars, **kwargs): ...@@ -289,14 +290,15 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
program = _get_valid_program(kwargs.get('program', None)) program = _get_valid_program(kwargs.get('program', None))
program = normalize_program(program, feed_vars, fetch_vars) program = normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program) legacy_format = kwargs.get('legacy_format', False)
return _serialize_program(program, legacy_format=legacy_format)
def _serialize_program(program): def _serialize_program(program, legacy_format=False):
""" """
serialize given program to bytes. serialize given program to bytes.
""" """
return program.desc.serialize_to_string() return program.desc.serialize_to_string(legacy_format=legacy_format)
@static_only @static_only
...@@ -459,6 +461,8 @@ def save_inference_model( ...@@ -459,6 +461,8 @@ def save_inference_model(
- clip_extra(bool): the flag indicating whether to clip extra information for every operator. Default: True. - clip_extra(bool): the flag indicating whether to clip extra information for every operator. Default: True.
- legacy_format(bool): whether to save inference model in legacy format. Default: False.
Returns: Returns:
None None
...@@ -518,8 +522,10 @@ def save_inference_model( ...@@ -518,8 +522,10 @@ def save_inference_model(
clip_extra = kwargs.get('clip_extra', True) clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(program, feed_vars, fetch_vars) program = normalize_program(program, feed_vars, fetch_vars)
# serialize and save program # serialize and save program
legacy_format = kwargs.get('legacy_format', False)
program_bytes = _serialize_program( program_bytes = _serialize_program(
program._remove_training_info(clip_extra=clip_extra) program._remove_training_info(clip_extra=clip_extra),
legacy_format=legacy_format,
) )
save_to_file(model_path, program_bytes) save_to_file(model_path, program_bytes)
# serialize and save params # serialize and save params
...@@ -1371,8 +1377,6 @@ def save(program, model_path, protocol=4, **configs): ...@@ -1371,8 +1377,6 @@ def save(program, model_path, protocol=4, **configs):
main_program = program.clone() main_program = program.clone()
program.desc.flush() program.desc.flush()
main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(program.desc)
with open(model_path + ".pdmodel", "wb") as f: with open(model_path + ".pdmodel", "wb") as f:
f.write(program.desc.serialize_to_string()) f.write(program.desc.serialize_to_string())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册