未验证 提交 dd2a749a 编写于 作者: F Feiyu Chan 提交者: GitHub

1. modify set_value op, use Scalars to represent attr `values`, instead of a...

1. modify set_value op, use Scalars to represent attr `values`, instead of a bunch of attributs of various types; (#52408)

2. add program converter and set_value op as an example, which provides the functionality to convert `paddle::framework::ProgramDesc` between old and new formats(the differences are mainly some operators with incompatible updates in the definition);
3. program version and operator version map now are always saved when serializing `paddle::framework::ProgramDesc` to identify the version;
3. provide an option `legacy_format=false` in  serialization of `paddle::framework::ProgramDesc`, it decided whether to convert ProgramDesc back to a legacy format, which is compatible for paddle 2.4.2 or earlier versions to load and execute;
4. deserialization of `paddle::framework::ProgramDesc` is now automatically detecting whether the bytes it receives is in legacy format(contains any of the operators that has been incompatibly updated and have any attribute of type `Scalar`) and convert it to new format. But if you want a faithful deserialization without the automatic conversion, you can use protobuf's deserialization instead. Though it is not recommended, it can be used for the purpose of testing.
上级 b281b221
......@@ -532,7 +532,7 @@ cc_test(
cc_library(
proto_desc
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc
SRCS var_desc.cc op_desc.cc block_desc.cc program_desc.cc program_converter.cc
DEPS attribute
ops_extra_info
shape_inference
......@@ -542,7 +542,9 @@ cc_library(
version
xxhash
dist_attr
scalar)
scalar
op_version_proto
op_version_registry)
cc_library(
op_registry
......
......@@ -13,3 +13,96 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_version_proto.h"
namespace paddle {
namespace framework {
namespace compatible {
namespace pb {
const std::unordered_map<std::string, uint32_t>& GetLegacyOpVersions() {
static std::unordered_map<std::string, uint32_t> op_versions = {
{"not_equal", 1},
{"fake_channel_wise_dequantize_max_abs", 2},
{"yolo_box", 1},
{"data_norm", 1},
{"cumsum", 1},
{"fake_channel_wise_quantize_abs_max", 1},
{"greater_equal", 1},
{"fill_constant", 2},
{"conv_transpose", 1},
{"fusion_gru", 1},
{"flip", 1},
{"elementwise_sub", 1},
{"dequantize", 1},
{"grid_sampler", 1},
{"expand_as_v2", 1},
{"linspace", 1},
{"moving_average_abs_max_scale", 2},
{"p_norm", 1},
{"instance_norm", 1},
{"lookup_table_v2", 1},
{"seed", 1},
{"softmax_with_cross_entropy", 1},
{"rank_attention", 1},
{"cudnn_lstm", 1},
{"clip", 1},
{"requantize", 1},
{"for_pybind_test__", 4},
{"print", 1},
{"transfer_layout", 1},
{"arg_min", 1},
{"roll", 2},
{"roi_pool", 2},
{"conv2d_transpose", 2},
{"roi_align", 3},
{"softplus", 1},
{"momentum", 1},
{"trace", 1},
{"matmul", 1},
{"lookup_table", 1},
{"lstsq", 1},
{"conv3d_transpose", 1},
{"depthwise_conv2d_transpose", 1},
{"conv2d", 1},
{"lamb", 1},
{"send_and_recv", 1},
{"gaussian_random", 1},
{"unique_consecutive", 1},
{"conv3d", 1},
{"pixel_shuffle", 1},
{"collect_fpn_proposals", 1},
{"coalesce_tensor", 2},
{"arg_max", 1},
{"allclose", 2},
{"matrix_nms", 1},
{"less_than", 1},
{"affine_grid", 1},
{"hard_shrink", 1},
{"set_value", 3},
{"mish", 1},
{"quantize", 2},
{"distribute_fpn_proposals", 2},
{"adam", 4},
{"elementwise_pow", 1},
{"elementwise_mul", 1},
{"elementwise_mod", 1},
{"auc", 1},
{"elementwise_min", 1},
{"elementwise_max", 1},
{"gather", 1},
{"elementwise_div", 1},
{"elementwise_add", 1},
{"leaky_relu", 1},
{"generate_proposal_labels", 2},
{"elementwise_floordiv", 1},
{"less_equal", 1},
{"generate_proposals", 2},
{"depthwise_conv2d", 1},
{"greater_than", 1},
{"generate_proposals_v2", 1},
{"equal", 1}};
return op_versions;
}
} // namespace pb
} // namespace compatible
} // namespace framework
} // namespace paddle
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <stdint.h>
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/framework.pb.h"
......@@ -53,6 +54,10 @@ class OpVersionMap {
proto::OpVersionMap* desc_;
};
// get version id for operators with version id in paddle 2.4.2, this is used
// for converting ProgramDesc in 2.4 comtabible format
const std::unordered_map<std::string, uint32_t>& GetLegacyOpVersions();
} // namespace pb
} // namespace compatible
} // namespace framework
......
......@@ -264,6 +264,13 @@ inline void SaveOpVersions(
}
}
inline void SaveOpVersions(const std::unordered_map<std::string, uint32_t>& src,
pb::OpVersionMap* dst) {
for (const auto& pair : src) {
(*dst)[pair.first].SetVersionID(pair.second);
}
}
class OpVersionComparator {
public:
virtual bool operator()() = 0;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/program_converter.h"
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/version.h"
namespace paddle {
namespace framework {
using paddle::experimental::ExtractPlainVector;
using paddle::experimental::WrapAsScalars;
std::pair<bool, std::unordered_map<std::string, uint32_t>> DetectLegacyOps(
ProgramDesc* program) {
bool is_legacy_program = false;
std::unordered_map<std::string, uint32_t> legacy_op_versions;
std::unordered_map<std::string, uint32_t> current_op_versions;
std::unordered_map<std::string, uint32_t> program_op_versions;
// get *all kinds* of formats of op versions and op version map to a unified
// representation before comparison can be done in a neat way
if (!program->HasOpVersionMap()) {
is_legacy_program = true;
} else {
for (const auto& pair :
paddle::framework::compatible::get_op_version_map()) {
current_op_versions.insert(
std::make_pair(pair.first, pair.second.version_id()));
}
const auto* _op_version_map = program->OpVersionMap();
for (int i = 0; i < _op_version_map->pair_size(); ++i) {
auto pair =
std::make_pair(_op_version_map->pair(i).op_name(),
static_cast<uint32_t>(
_op_version_map->pair(i).op_version().version()));
program_op_versions.insert(pair);
}
for (const auto& pair : program_op_versions) {
uint32_t program_op_version = pair.second;
if (!current_op_versions.count(pair.first)) {
// this means program_op_versions is more upated than
// current_op_versions it is loading a program from future versions of
// paddle
continue;
}
uint32_t current_op_version = current_op_versions.at(pair.first);
if (program_op_version < current_op_version) {
is_legacy_program = true;
legacy_op_versions.insert(
std::make_pair(pair.first, program_op_version));
}
}
}
return std::make_pair(is_legacy_program, legacy_op_versions);
}
namespace no_scalar {
void ConvertSetValueOp(OpDesc* op) {
std::vector<paddle::experimental::Scalar> values = PADDLE_GET_CONST(
std::vector<paddle::experimental::Scalar>, op->GetAttr("values", false));
op->RemoveAttr("values");
op->SetAttr("bool_values", std::vector<int>());
op->SetAttr("fp32_values", std::vector<float>());
op->SetAttr("int32_values", std::vector<int>());
op->SetAttr("int64_values", std::vector<int64_t>());
op->SetAttr("fp64_values", std::vector<double>());
op->SetAttr("fp16_values", std::vector<float>());
phi::DataType dtype = phi::DataType::FLOAT32;
if (values.size()) {
dtype = values.at(0).dtype();
}
switch (dtype) {
case phi::DataType::BOOL:
op->SetAttr("bool_values", ExtractPlainVector<int>(values));
break;
case phi::DataType::FLOAT32:
op->SetAttr("fp32_values", ExtractPlainVector<float>(values));
break;
case phi::DataType::INT32:
op->SetAttr("int32_values", ExtractPlainVector<int>(values));
break;
case phi::DataType::INT64:
op->SetAttr("int64_values", ExtractPlainVector<int64_t>(values));
break;
case phi::DataType::FLOAT64:
op->SetAttr("fp64_values", ExtractPlainVector<double>(values));
break;
case phi::DataType::FLOAT16:
op->SetAttr("fp16_values", ExtractPlainVector<float>(values));
break;
default:
PD_THROW("Invalid data type `", dtype, "`.");
}
}
void ConvertProgram(ProgramDesc* program) {
PADDLE_ENFORCE_NOT_NULL(
program,
paddle::platform::errors::InvalidArgument("program should not be null"));
VLOG(3) << "Setting Program Version and OpVersionMap to Legacy "
"settings(a.k.a 2.4.2)";
framework::compatible::pb::OpVersionMap op_version_map(
program->OpVersionMap());
program->SetVersion(paddle::framework::kLegacyProgramVersion);
paddle::framework::compatible::SaveOpVersions(
paddle::framework::compatible::pb::GetLegacyOpVersions(),
&op_version_map);
VLOG(3) << "Converting program from new(with scalar attributes) to old(no "
"scalar attributes)";
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(j);
const std::string op_type = op->Type();
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
}
}
}
} // namespace no_scalar
namespace scalar {
void ConvertSetValueOp(OpDesc* op) {
std::vector<paddle::experimental::Scalar> values;
if (op->HasAttr("bool_values")) {
std::vector<int> bool_values =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("bool_values", false));
if (bool_values.size()) {
values = WrapAsScalars(bool_values);
}
op->RemoveAttr("bool_values");
}
if (op->HasAttr("fp32_values")) {
std::vector<float> fp32_values =
PADDLE_GET_CONST(std::vector<float>, op->GetAttr("fp32_values", false));
if (fp32_values.size()) {
values = WrapAsScalars(fp32_values);
}
op->RemoveAttr("fp32_values");
}
if (op->HasAttr("int32_values")) {
std::vector<int> int32_values =
PADDLE_GET_CONST(std::vector<int>, op->GetAttr("int32_values", false));
if (int32_values.size()) {
values = WrapAsScalars(int32_values);
}
op->RemoveAttr("int32_values");
}
if (op->HasAttr("int64_values")) {
std::vector<int64_t> int64_values = PADDLE_GET_CONST(
std::vector<int64_t>, op->GetAttr("int64_values", false));
if (int64_values.size()) {
values = WrapAsScalars(int64_values);
}
op->RemoveAttr("int64_values");
}
if (op->HasAttr("fp64_values")) {
std::vector<double> fp64_values = PADDLE_GET_CONST(
std::vector<double>, op->GetAttr("fp64_values", false));
if (fp64_values.size()) {
values = WrapAsScalars(fp64_values);
}
op->RemoveAttr("fp64_values");
}
if (op->HasAttr("fp16_values")) {
std::vector<float> fp16_values =
PADDLE_GET_CONST(std::vector<float>, op->GetAttr("fp16_values", false));
if (fp16_values.size()) {
values = WrapAsScalars(fp16_values);
}
op->RemoveAttr("fp16_values");
}
op->SetAttr("values", values);
}
void ConvertProgram(ProgramDesc* program) {
PADDLE_ENFORCE_NOT_NULL(
program,
paddle::platform::errors::InvalidArgument("program should not be null"));
auto legacy_op_results = DetectLegacyOps(program);
bool is_legacy_program = legacy_op_results.first;
const std::unordered_map<std::string, uint32_t>& legacy_op_versions =
legacy_op_results.second;
if (!is_legacy_program) return;
VLOG(3) << "Updating Program Version and OpVersionMap";
program->SetVersion(paddle::framework::kCurProgramVersion);
framework::compatible::pb::OpVersionMap op_version_map(
program->OpVersionMap());
paddle::framework::compatible::SaveOpVersions(
framework::compatible::get_op_version_map(), &op_version_map);
VLOG(3) << "Converting program from old(no scalar attributes) to new(with "
"scalar attributes)";
const size_t num_blocks = program->Size();
for (size_t i = 0; i < num_blocks; i++) {
BlockDesc* block = program->MutableBlock(i);
const size_t num_ops = block->OpSize();
for (size_t j = 0; j < num_ops; j++) {
OpDesc* op = block->Op(j);
const std::string op_type = op->Type();
if (!legacy_op_versions.count(op_type)) {
continue;
}
if (op_type == "set_value" || op_type == "set_value_grad") {
ConvertSetValueOp(op);
}
}
}
}
} // namespace scalar
} // namespace framework
} // namespace paddle
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/fluid/framework/op_desc.h"
namespace paddle {
namespace framework {
namespace no_scalar {
void ConvertProgram(ProgramDesc* program);
} // namespace no_scalar
namespace scalar {
void ConvertProgram(ProgramDesc* program);
} // namespace scalar
} // namespace framework
} // namespace paddle
......@@ -20,6 +20,9 @@ extern "C" {
#include <algorithm>
#include "paddle/fluid/framework/feed_fetch_type.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_converter.h"
#include "paddle/fluid/framework/version.h"
namespace paddle {
......@@ -48,8 +51,12 @@ proto::OpVersionMap *ProgramDesc::OpVersionMap() {
return desc_.mutable_op_version_map();
}
bool ProgramDesc::HasOpVersionMap() const { return desc_.has_op_version_map(); }
int64_t ProgramDesc::Version() const { return desc_.version().version(); }
bool ProgramDesc::HasVersion() const { return desc_.has_version(); }
void ProgramDesc::SetVersion(const int64_t version) {
desc_.mutable_version()->set_version(version);
}
......@@ -142,6 +149,7 @@ ProgramDesc::ProgramDesc(const std::string &binary_str) {
platform::errors::InvalidArgument(
"Failed to parse program_desc from binary string."));
InitFromProto();
scalar::ConvertProgram(this);
}
void ProgramDesc::InitFromProto() {
......
......@@ -61,8 +61,12 @@ class ProgramDesc {
proto::OpVersionMap *OpVersionMap();
bool HasOpVersionMap() const;
int64_t Version() const;
bool HasVersion() const;
void SetVersion(const int64_t version);
// The output variable of feed_op is referenced as feed_target.
......
......@@ -36,6 +36,13 @@ constexpr int64_t kCurProgramVersion = PADDLE_VERSION_INTEGER;
constexpr int64_t kCurProgramVersion = 0;
#endif
// paddle in 2.4.2 and before does not support Scalar in op attributes
// and no op in program of 2.4.2 or earlier versions has attributes of type
// `Scalar` This version number is used for converting program to a legacy
// format to indentifiy their version id See Also
// paddle/fluid/framework/op_version_proto.cc
constexpr int64_t kLegacyProgramVersion = 2004002L;
// The program version that was generated by previous or current codes
// and supported by current codes.
constexpr int64_t kSupportedProgramVersion[] = {0};
......
......@@ -55,6 +55,43 @@ typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor(
framework::TensorFromVector(values, ctx.device_context(), out);
}
template <typename T, typename Context>
typename std::enable_if<std::is_same<T, bool>::value>::type CopyVectorToTensor(
const Context& dev_ctx,
const std::vector<Scalar>& values,
phi::DenseTensor* out) {
// If attribute value dtype is vector<bool>, it will be converted to
// vector<int>. at the same time, we can not use vector<bool> to hold
// the value, because the c++ use bit value to replace byte value.
std::vector<int> assign_values;
assign_values.reserve(values.size());
for (const auto& val : values) {
assign_values.emplace_back(val.to<int>());
}
phi::TensorFromVector(assign_values, dev_ctx, out);
// use the array to replace to vector
bool* array_ptr = new T[assign_values.size()];
for (unsigned int i = 0; i < assign_values.size(); i++) {
array_ptr[i] = static_cast<T>(assign_values[i]);
}
phi::TensorFromArray(array_ptr, assign_values.size(), dev_ctx, out);
delete[] array_ptr;
}
template <typename T, typename Context>
typename std::enable_if<!std::is_same<T, bool>::value>::type CopyVectorToTensor(
const Context& dev_ctx,
const std::vector<Scalar>& values,
phi::DenseTensor* out) {
std::vector<T> assign_values;
assign_values.reserve(values.size());
for (const auto& val : values) {
assign_values.emplace_back(val.to<T>());
}
phi::TensorFromVector(assign_values, dev_ctx, out);
}
template <typename T>
class AssignValueKernel : public framework::OpKernel<T> {
public:
......
......@@ -110,7 +110,9 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
framework::proto::VarType::INT64,
framework::proto::VarType::FP32,
framework::proto::VarType::FP64,
framework::proto::VarType::FP16})
framework::proto::VarType::FP16,
framework::proto::VarType::COMPLEX64,
framework::proto::VarType::COMPLEX128})
.SetDefault(framework::proto::VarType::FP32);
AddAttr<std::vector<int64_t>>(
"axes", "(list<int64_t>) Axes that `starts` and `ends` apply to.");
......@@ -131,17 +133,7 @@ class SetValueMaker : public framework::OpProtoAndCheckerMaker {
AddAttr<std::vector<int64_t>>("none_axes", "(list<int>) The axes to none.")
.SetDefault({});
AddAttr<std::vector<int>>("bool_values", "Store the bool values.")
.SetDefault({});
AddAttr<std::vector<float>>("fp32_values", "Store the float32 values.")
.SetDefault({});
AddAttr<std::vector<int>>("int32_values", "Store the int32 values.")
.SetDefault({});
AddAttr<std::vector<int64_t>>("int64_values", "Store the int64 values.")
.SetDefault({});
AddAttr<std::vector<double>>("fp64_values", "Store the float64 values.")
.SetDefault({});
AddAttr<std::vector<float>>("fp16_values", "Store the float16 values.")
AddAttr<std::vector<paddle::experimental::Scalar>>("values", "values")
.SetDefault({});
AddAttr<std::vector<int64_t>>("shape", "(vector<int64_t>) Shape of values.")
......@@ -298,4 +290,15 @@ Upgrade set_value, add 1 attribute [decrease_axes].
Upgrade set_value, add 1 attribute [none_axes].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"none_axes", "The axes with none index.", std::vector<int64_t>{}));
"none_axes", "The axes with none index.", std::vector<int64_t>{}))
.AddCheckpoint(
R"ROC(Upgrade set_value to support generic Scalars as value and remove plain values, so as to support complex types.)ROC",
paddle::framework::compatible::OpVersionDesc()
.NewAttr("values",
"values",
std::vector<paddle::experimental::Scalar>())
.DeleteAttr("bool_values", "remove plain attributes")
.DeleteAttr("fp32_values", "remove plain attributes")
.DeleteAttr("int32_values", "remove plain attributes")
.DeleteAttr("int64_values", "remove plain attributes")
.DeleteAttr("fp64_values", "remove plain attributes"));
......@@ -33,34 +33,6 @@ namespace operators {
using DDim = framework::DDim;
inline std::string GetValueName(framework::proto::VarType::Type data_type) {
std::string value_name;
switch (data_type) {
case framework::proto::VarType::INT32:
value_name = "int32_values";
break;
case framework::proto::VarType::INT64:
value_name = "int64_values";
break;
case framework::proto::VarType::FP32:
value_name = "fp32_values";
break;
case framework::proto::VarType::FP64:
value_name = "fp64_values";
break;
case framework::proto::VarType::BOOL:
value_name = "bool_values";
break;
default:
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported data type(code %d) for SetValue operator, only "
"supports bool, int32, float32 and int64.",
data_type));
}
return value_name;
}
// check whether the tensor with dimension of second can assign to the
// tensor with dimension of first
inline void CheckIsDimsMatch(const framework::DDim first,
......
......@@ -1197,11 +1197,21 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
if (!py::isinstance<py::array_t<bool>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj_tmp);
}
} else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj_tmp);
}
} else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(value_obj_tmp)) {
value = pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj_tmp);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, "
"float32, float64, complex64, complex128, int32 or int64, "
"please check the type of tensor."));
}
......@@ -1217,29 +1227,38 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
// convert the value to self data type
if (py::isinstance<py::float_>(value_obj_tmp) ||
py::isinstance<py::int_>(value_obj_tmp) ||
py::isinstance<py::bool_>(value_obj_tmp)) {
py::isinstance<py::bool_>(value_obj_tmp) ||
PyComplex_Check(value_obj)) {
if (self->tensor.dtype() == phi::DataType::FLOAT32) {
attrs["fp32_values"] =
std::vector<float>{value_obj_tmp.cast<float>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<float>()};
} else if (self->tensor.dtype() == phi::DataType::FLOAT64) {
attrs["fp64_values"] =
std::vector<double>{value_obj_tmp.cast<double>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<double>()};
} else if (self->tensor.dtype() == phi::DataType::INT32) {
attrs["int32_values"] =
std::vector<int32_t>{value_obj_tmp.cast<int32_t>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<int32_t>()};
} else if (self->tensor.dtype() == phi::DataType::INT64) {
attrs["int64_values"] =
std::vector<int64_t>{value_obj_tmp.cast<int64_t>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<int64_t>()};
} else if (self->tensor.dtype() == phi::DataType::BOOL) {
attrs["bool_values"] = std::vector<int>{value_obj_tmp.cast<bool>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<bool>()};
} else if (self->tensor.dtype() == phi::DataType::FLOAT16) {
attrs["fp16_values"] =
std::vector<float>{value_obj_tmp.cast<float>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<float>()};
} else if (self->tensor.dtype() == phi::DataType::COMPLEX64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<std::complex<float>>()};
} else if (self->tensor.dtype() == phi::DataType::COMPLEX128) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj_tmp.cast<std::complex<double>>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32, int64 or float16, "
"float32, float64, complex64, complex128, int32, int64 or "
"float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
......@@ -1247,7 +1266,7 @@ static PyObject* tensor_method__setitem_eager_tensor(TensorObject* self,
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Value type error. The assign value allows "
"numpy.ndarray, integer, float or bool, "
"numpy.ndarray, integer, float, complex or bool, "
"but received %s.",
Py_TYPE(value_obj)));
}
......
......@@ -902,11 +902,28 @@ void BindImperative(py::module *m_ptr) {
if (!py::isinstance<py::array_t<bool>>(value_obj)) {
value = pybind11::detail::CastNumpyArray<bool>(value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX64) {
if (!py::isinstance<py::array_t<std::complex<float>>>(
value_obj)) {
value =
pybind11::detail::CastNumpyArray<std::complex<float>>(
value_obj);
}
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX128) {
if (!py::isinstance<py::array_t<std::complex<double>>>(
value_obj)) {
value =
pybind11::detail::CastNumpyArray<std::complex<double>>(
value_obj);
}
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a numpy.np value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32 or int64, "
"float32, float64, complex64, complex128, int32 or "
"int64, "
"please check the type of tensor."));
}
......@@ -921,35 +938,45 @@ void BindImperative(py::module *m_ptr) {
// convert the value to self data type
if (py::isinstance<py::float_>(value_obj) ||
py::isinstance<py::int_>(value_obj) ||
py::isinstance<py::bool_>(value_obj)) {
py::isinstance<py::bool_>(value_obj) ||
PyComplex_Check(value_obj.ptr())) {
if (self->DataType() == framework::proto::VarType::FP32) {
attrs["fp32_values"] =
std::vector<float>{value_obj.cast<float>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<float>()};
} else if (self->DataType() ==
framework::proto::VarType::FP64) {
attrs["fp64_values"] =
std::vector<double>{value_obj.cast<double>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<double>()};
} else if (self->DataType() ==
framework::proto::VarType::INT32) {
attrs["int32_values"] =
std::vector<int32_t>{value_obj.cast<int32_t>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<int32_t>()};
} else if (self->DataType() ==
framework::proto::VarType::INT64) {
attrs["int64_values"] =
std::vector<int64_t>{value_obj.cast<int64_t>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<int64_t>()};
} else if (self->DataType() ==
framework::proto::VarType::BOOL) {
attrs["bool_values"] =
std::vector<int>{value_obj.cast<bool>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<bool>()};
} else if (self->DataType() ==
framework::proto::VarType::FP16) {
attrs["fp16_values"] =
std::vector<float>{value_obj.cast<float>()};
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<float>()};
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX64) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<std::complex<float>>()};
} else if (self->DataType() ==
framework::proto::VarType::COMPLEX128) {
attrs["values"] = std::vector<paddle::experimental::Scalar>{
value_obj.cast<std::complex<double>>()};
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"When assign a value to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, "
"float32, int32, int64 or float16, "
"float32, float64, complex64, complex128, int32, int64 "
"or float16, "
"please check the type of tensor."));
}
attrs["shape"] = std::vector<int64_t>{1};
......
......@@ -23,6 +23,9 @@ limitations under the License. */
#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/ir/graph_helper.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_version_proto.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/framework/program_converter.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/var_desc.h"
#include "paddle/fluid/framework/version.h"
......@@ -56,6 +59,33 @@ static pybind11::bytes SerializeMessage(
return retv;
}
pybind11::bytes SerializeProgramDesc(
pd::ProgramDesc &self, // NOLINT due to pybind11 convention.
bool legacy_format = false) {
// Check IsInitialized in Python
std::string retv;
pd::ProgramDesc copy = self;
framework::compatible::pb::OpVersionMap op_version_map(copy.OpVersionMap());
if (legacy_format) {
pd::no_scalar::ConvertProgram(&copy);
copy.SetVersion(pd::kLegacyProgramVersion);
paddle::framework::compatible::SaveOpVersions(
paddle::framework::compatible::pb::GetLegacyOpVersions(),
&op_version_map);
} else {
copy.SetVersion(pd::kCurProgramVersion);
paddle::framework::compatible::SaveOpVersions(
framework::compatible::get_op_version_map(), &op_version_map);
}
PADDLE_ENFORCE_EQ(copy.Proto()->SerializePartialToString(&retv),
true,
platform::errors::InvalidArgument(
"Failed to serialize input Desc to string."));
return retv;
}
template <typename T>
static void DeserializeMessage(T *self, const std::string &str) {
PADDLE_ENFORCE_EQ(
......@@ -88,7 +118,9 @@ void BindProgramDesc(pybind11::module *m) {
.def("flush", &pd::ProgramDesc::Flush)
.def("get_feed_target_names", &pd::ProgramDesc::GetFeedTargetNames)
.def("get_fetch_target_names", &pd::ProgramDesc::GetFetchTargetNames)
.def("serialize_to_string", SerializeMessage<pd::ProgramDesc>)
.def("serialize_to_string",
&SerializeProgramDesc,
pybind11::arg("legacy_format") = false)
.def("need_update", &pd::ProgramDesc::NeedUpdate)
.def("parse_from_string",
[](pd::ProgramDesc &program_desc, const std::string &data) {
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include <memory>
#include <string>
#include <tuple>
#include <type_traits>
#include <utility>
#include <vector>
......@@ -63,12 +64,19 @@ constexpr int NPY_UINT16_ = 4;
constexpr int NPY_COMPLEX64 = 14;
constexpr int NPY_COMPLEX128 = 15;
template <typename T, typename S>
struct casting_complex_to_non_complex {
static const bool value = pybind11::detail::is_complex<S>::value &&
!pybind11::detail::is_complex<T>::value;
};
// cast numpy type form S to T, this may allocate new memory
template <class T, class S>
template <
class T,
class S,
std::enable_if_t<!std::is_same<T, S>::value &&
!casting_complex_to_non_complex<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
if (std::is_same<T, S>::value) {
return array;
}
auto dim = array.ndim();
std::vector<py::ssize_t> result_shape(dim);
for (auto i = 0; i < dim; i++) {
......@@ -80,6 +88,30 @@ static py::array_t<T> CastNumpyType(py::array_t<S> array) {
return py::vectorize([](S s) { return static_cast<T>(s); })(array);
}
template <
class T,
class S,
std::enable_if_t<(!std::is_same<T, S>::value) &&
casting_complex_to_non_complex<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
auto dim = array.ndim();
std::vector<py::ssize_t> result_shape(dim);
for (auto i = 0; i < dim; i++) {
result_shape[i] = array.shape(i);
}
py::array_t<T> result(result_shape);
return py::vectorize([](S s) { return static_cast<T>(s.real()); })(array);
}
template <class T,
class S,
std::enable_if_t<std::is_same<T, S>::value> * = nullptr>
static py::array_t<T> CastNumpyType(py::array_t<S> array) {
return array;
}
template <class T>
static py::array_t<T> CastNumpyArray(const py::object &array) {
if (py::isinstance<py::array_t<float>>(array)) {
......@@ -92,10 +124,14 @@ static py::array_t<T> CastNumpyArray(const py::object &array) {
return CastNumpyType<T>(array.cast<py::array_t<int64_t>>());
} else if (py::isinstance<py::array_t<bool>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<bool>>());
} else if (py::isinstance<py::array_t<std::complex<float>>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<std::complex<float>>>());
} else if (py::isinstance<py::array_t<std::complex<double>>>(array)) {
return CastNumpyType<T>(array.cast<py::array_t<std::complex<double>>>());
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"Value type error. The assign numpy value allows integer, float, "
"double and bool, "
"double, complex64, complex128, and bool, "
"but received %s.",
Py_TYPE(array.ptr())->tp_name));
}
......
......@@ -216,12 +216,14 @@ class DistributedSaver:
dist_filename = filename + "_dist" + str(rank_id)
dist_path = os.path.join(dirname, dist_filename)
legacy_format = kwargs.get("legacy_format", False)
paddle.static.save_inference_model(
dist_path,
dist_feed_vars,
dist_fetch_vars,
exe,
program=dist_main_prog,
legacy_format=legacy_format,
)
def _save_rank_mapping(self, dirname):
......
......@@ -718,6 +718,7 @@ class ParameterServerRuntime(RuntimeBase):
target_vars,
main_program=None,
export_for_deployment=True,
legacy_format=False,
):
"""
Prune the given `main_program` to build a new program especially for inference,
......@@ -743,6 +744,7 @@ class ParameterServerRuntime(RuntimeBase):
None,
None,
export_for_deployment,
legacy_format=legacy_format,
)
else:
paddle.fluid.io.save_inference_model(
......@@ -755,6 +757,7 @@ class ParameterServerRuntime(RuntimeBase):
None,
export_for_deployment,
True,
legacy_format=legacy_format,
)
model_basename = "__model__"
......
......@@ -5734,6 +5734,22 @@ class Program:
res_str = ""
for block in self.blocks:
res_str += block.to_string(throw_on_error, with_details)
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(bytes(protostr))
res_str += (
"version {\n "
+ textwrap.indent(
_debug_string_(proto.version, throw_on_error), " "
)
+ "}\n"
)
res_str += (
"op_version_map {\n "
+ textwrap.indent(
_debug_string_(proto.op_version_map, throw_on_error), " "
)
+ "}\n"
)
else:
protostr = self.desc.serialize_to_string()
proto = framework_pb2.ProgramDesc.FromString(bytes(protostr))
......
......@@ -135,6 +135,7 @@ def save_inference_model(
export_for_deployment=True,
program_only=False,
clip_extra=True,
legacy_format=False,
):
"""
Prune the given `main_program` to build a new program especially for inference,
......@@ -176,6 +177,8 @@ def save_inference_model(
program_only(bool, optional): If True, It will save inference program only, and do not
save params of Program.
Default: False.
legacy_format(bool, optional): Whether to save program in legacy format.
Default: False.
Returns:
list, The fetch variables' name list.
......@@ -314,8 +317,6 @@ def save_inference_model(
prepend_feed_ops(main_program, feeded_var_names)
append_fetch_ops(main_program, fetch_var_names)
main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(main_program.desc)
with open(model_basename, "wb") as f:
f.write(
main_program._remove_training_info(
......
......@@ -390,9 +390,6 @@ def create_fake_model(program_config):
op_desc.set_output('Out', ["fetch"])
op_desc._set_attr("col", index)
main_program_desc._set_version()
paddle.fluid.core.save_op_version_info(main_program_desc)
model = main_program_desc.serialize_to_string()
util_program._sync_with_cpp()
......
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
import paddle
from paddle.fluid.proto import framework_pb2
class TestSetValue(unittest.TestCase):
def setUp(self):
paddle.enable_static()
def _test_for_new_program_format(self, program_bytes):
restored_prog_as_is = framework_pb2.ProgramDesc.FromString(
program_bytes
)
for block in restored_prog_as_is.blocks:
for op in block.ops:
if op.type in ("set_value", "set_value_grad"):
attr_names = [attr.name for attr in op.attrs]
self.assertTrue("values" in attr_names)
self.assertFalse("bool_values" in attr_names)
self.assertFalse("int32_values" in attr_names)
self.assertFalse("int64_values" in attr_names)
self.assertFalse("fp32_values" in attr_names)
self.assertFalse("fp64_values" in attr_names)
self.assertFalse("fp16_values" in attr_names)
def _test_for_legacy_program_format(self, program_bytes):
restored_prog_as_is = framework_pb2.ProgramDesc.FromString(
program_bytes
)
for block in restored_prog_as_is.blocks:
for op in block.ops:
if op.type in ("set_value", "set_value_grad"):
attr_names = [attr.name for attr in op.attrs]
self.assertFalse("values" in attr_names)
self.assertTrue("bool_values" in attr_names)
self.assertTrue("int32_values" in attr_names)
self.assertTrue("int64_values" in attr_names)
self.assertTrue("fp32_values" in attr_names)
self.assertTrue("fp64_values" in attr_names)
self.assertTrue("fp16_values" in attr_names)
def _test_equivalence(
self,
new_program_bytes,
legacy_program_bytes,
fetch_list,
expected_outputs,
):
normal_program = paddle.static.io.deserialize_program(new_program_bytes)
converted_back_program = paddle.static.io.deserialize_program(
legacy_program_bytes
)
exe = paddle.static.Executor(paddle.CPUPlace())
[out] = exe.run(normal_program, fetch_list=fetch_list)
np.testing.assert_allclose(out, expected_outputs[0])
[out] = exe.run(converted_back_program, fetch_list=fetch_list)
np.testing.assert_allclose(out, expected_outputs[0])
def test_int32(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.int32)
patch = np.array([41, 42]).astype(np.int32)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.int32)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_int64(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.int64)
patch = np.array(
[np.iinfo(np.int64).max, np.iinfo(np.int64).min]
).astype(np.int64)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.int64)
x_output = x_input.copy()
x_output[:1, :2] = patch
self.fetch_list = [x.name]
self.expected_outputs = [x_output]
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_float32(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.float32)
patch = np.array(
[np.finfo(np.float32).max, np.finfo(np.float32).min]
).astype(np.float32)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.float32)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_float64(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.float64)
patch = np.array(
[np.finfo(np.float64).max, np.finfo(np.float64).min]
).astype(np.float64)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.float64)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_float16(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.float16)
patch = np.array(
[np.finfo(np.float16).max, np.finfo(np.float16).min]
).astype(np.float16)
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=np.float16)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_bool(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.ones([3, 4], dtype=paddle.bool)
patch = np.array([True, False])
x[:1, :2] = patch
x_input = np.ones([3, 4], dtype=bool)
x_output = x_input.copy()
x_output[:1, :2] = patch
normal_program_bytes = mp._get_desc().serialize_to_string()
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
self.assertNotEqual(normal_program_bytes, legacy_program_bytes)
self._test_for_new_program_format(normal_program_bytes)
self._test_for_legacy_program_format(legacy_program_bytes)
self._test_equivalence(
normal_program_bytes,
legacy_program_bytes,
fetch_list=[x.name],
expected_outputs=[x_output],
)
def test_complex64(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.complex(
paddle.ones([3, 4], dtype=paddle.float32),
paddle.ones([3, 4], dtype=paddle.float32),
)
patch = np.array([42.1 + 42.1j, 42.2 + 42.2j]).astype(np.complex64)
x[:1, :2] = patch
x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex64)
x_output = x_input.copy()
x_output[:1, :2] = patch
with self.assertRaisesRegex(RuntimeError, "Invalid data type"):
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
def test_complex128(self):
mp = paddle.static.Program()
sp = paddle.static.Program()
with paddle.static.program_guard(mp, sp):
x = paddle.complex(
paddle.ones([3, 4], dtype=paddle.float64),
paddle.ones([3, 4], dtype=paddle.float64),
)
patch = np.array(
[
np.finfo(np.float64).max + 1j * np.finfo(np.float64).min,
np.finfo(np.float64).min + 1j * np.finfo(np.float64).max,
]
).astype(np.complex128)
x[:1, :2] = patch
x_input = (np.ones([3, 4]) + 1j * np.ones([3, 4])).astype(np.complex128)
x_output = x_input.copy()
x_output[:1, :2] = patch
with self.assertRaisesRegex(RuntimeError, "Invalid data type"):
legacy_program_bytes = mp._get_desc().serialize_to_string(
legacy_format=True
)
......@@ -711,6 +711,91 @@ create_test_value_numpy_bool(TestSetValueItemSlice3)
create_test_value_numpy_bool(TestSetValueItemSlice4)
def create_test_value_complex64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = 42.1 + 42.1j
def set_dtype(self):
self.dtype = "complex64"
cls_name = "{}_{}".format(parent.__name__, "ValueComplex64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_complex64(TestSetValueItemInt)
create_test_value_complex64(TestSetValueItemSlice)
create_test_value_complex64(TestSetValueItemSlice2)
create_test_value_complex64(TestSetValueItemSlice3)
create_test_value_complex64(TestSetValueItemSlice4)
def create_test_value_complex128(parent):
class TestValueInt(parent):
def set_value(self):
self.value = complex(
np.finfo(np.float64).max + 1j * np.finfo(np.float64).min
)
def set_dtype(self):
self.dtype = "complex128"
cls_name = "{}_{}".format(parent.__name__, "ValueComplex128")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_complex128(TestSetValueItemInt)
create_test_value_complex128(TestSetValueItemSlice)
create_test_value_complex128(TestSetValueItemSlice2)
create_test_value_complex128(TestSetValueItemSlice3)
create_test_value_complex128(TestSetValueItemSlice4)
def create_test_value_numpy_complex64(parent):
class TestValueInt(parent):
def set_value(self):
self.value = np.array(42.1 + 42.1j)
def set_dtype(self):
self.dtype = "complex64"
cls_name = "{}_{}".format(parent.__name__, "ValueNumpyComplex64")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_complex64(TestSetValueItemInt)
create_test_value_numpy_complex64(TestSetValueItemSlice)
create_test_value_numpy_complex64(TestSetValueItemSlice2)
create_test_value_numpy_complex64(TestSetValueItemSlice3)
create_test_value_numpy_complex64(TestSetValueItemSlice4)
def create_test_value_numpy_complex128(parent):
class TestValueInt(parent):
def set_value(self):
v = complex(
np.finfo(np.float64).max + 1j * np.finfo(np.float64).min
)
self.value = np.array([v])
def set_dtype(self):
self.dtype = "complex128"
cls_name = "{}_{}".format(parent.__name__, "ValueNumpyComplex128")
TestValueInt.__name__ = cls_name
globals()[cls_name] = TestValueInt
create_test_value_numpy_complex128(TestSetValueItemInt)
create_test_value_numpy_complex128(TestSetValueItemSlice)
create_test_value_numpy_complex128(TestSetValueItemSlice2)
create_test_value_numpy_complex128(TestSetValueItemSlice3)
create_test_value_numpy_complex128(TestSetValueItemSlice4)
# 2.3 value is a Paddle Tensor (int32, int64, float32, float64, bool)
def create_test_value_tensor_int32(parent):
class TestValueInt(parent):
......
......@@ -782,38 +782,15 @@ def _setitem_impl_(var, item, value):
from .data_feeder import convert_dtype
# 2.1 value is an integer of float
if isinstance(value, (int, float)):
# 2.1 value is an integer, float or complex
if isinstance(value, (bool, int, float, complex)):
value = np.array([value]).astype(convert_dtype(dtype))
# 2.2 value is a np.ndarray
if isinstance(value, np.ndarray):
shape = list(value.shape)
if dtype == core.VarDesc.VarType.BOOL:
value_name = "bool_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP32:
value_name = "fp32_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP64:
value_name = "fp64_values"
values = [float(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT32:
value_name = "int32_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.INT64:
value_name = "int64_values"
values = [int(v) for v in value.flat]
elif dtype == core.VarDesc.VarType.FP16:
value_name = "fp16_values"
values = [float(v) for v in value.flat]
else:
raise TypeError(
"When assign a numpy.ndarray, integer or float to a paddle.Tensor, "
"the data type of the paddle.Tensor must be bool, float32, int32, int64 or float16, but "
"received %s." % convert_dtype(dtype)
)
attrs[value_name] = values
values = value.ravel().tolist()
attrs["values"] = values
attrs["shape"] = shape
elif isinstance(value, (Variable, core.eager.Tensor)):
......
......@@ -262,6 +262,7 @@ class Fleet(metaclass=abc.ABCMeta):
target_vars,
main_program=None,
export_for_deployment=True,
legacy_format=False,
):
pass
......
......@@ -82,6 +82,7 @@ class Collective(Fleet):
target_vars=None,
main_program=None,
export_for_deployment=True,
legacy_format=False,
):
"""
Prune the given `main_program` to build a new program especially for
......@@ -109,6 +110,7 @@ class Collective(Fleet):
None,
None,
export_for_deployment,
legacy_format=legacy_format,
)
def save_persistables(
......
......@@ -419,6 +419,7 @@ class FleetTranspiler(Fleet):
target_vars,
main_program=None,
export_for_deployment=True,
legacy_format=False,
):
"""
Prune the given `main_program` to build a new program especially for inference,
......@@ -453,6 +454,7 @@ class FleetTranspiler(Fleet):
None,
None,
export_for_deployment,
legacy_format=legacy_format,
)
else:
paddle.static.save_inference_model(
......@@ -465,6 +467,7 @@ class FleetTranspiler(Fleet):
None,
export_for_deployment,
True,
legacy_format=legacy_format,
)
model_basename = "__model__"
......
......@@ -1752,7 +1752,9 @@ class TracedLayer:
saved inference model. If None, all output variables of the
TracedLayer object would be the outputs of the saved inference
model. Default None.
kwargs: Supported keys including 'clip_extra'.set to True if you want to clip extra information for every operator.
kwargs: Supported keys including
- clip_extra(bool): whether to clip extra information for every operator. Defaults to True.
- legacy_format(bool): whether to save program in legacy format. Default to False.
Returns:
None
......@@ -1854,6 +1856,7 @@ class TracedLayer:
model_filename = file_prefix + INFER_MODEL_SUFFIX
params_filename = file_prefix + INFER_PARAMS_SUFFIX
legacy_format = kwargs.get('legacy_format', False)
save_inference_model(
dirname=dirname,
feeded_var_names=feeded_var_names,
......@@ -1863,4 +1866,5 @@ class TracedLayer:
model_filename=model_filename,
params_filename=params_filename,
clip_extra=clip_extra,
legacy_format=legacy_format,
)
......@@ -75,7 +75,7 @@ def _check_args(caller, args, supported_args=None, deprecated_args=None):
def _check_vars(name, var_list):
if not isinstance(var_list, list):
var_list = [var_list]
if not var_list or not all([isinstance(var, Variable) for var in var_list]):
if not all([isinstance(var, Variable) for var in var_list]):
raise ValueError(
f"'{name}' should be a Variable or a list of Variable."
)
......@@ -252,6 +252,7 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
kwargs: Supported keys including ``program``. Attention please, kwargs is used for backward compatibility mainly.
- program(Program): specify a program if you don't want to use default main program.
- legacy_format(bool): whether to save inference program in legacy format. Defaults to False.
Returns:
bytes: serialized program.
......@@ -289,14 +290,15 @@ def serialize_program(feed_vars, fetch_vars, **kwargs):
program = _get_valid_program(kwargs.get('program', None))
program = normalize_program(program, feed_vars, fetch_vars)
return _serialize_program(program)
legacy_format = kwargs.get('legacy_format', False)
return _serialize_program(program, legacy_format=legacy_format)
def _serialize_program(program):
def _serialize_program(program, legacy_format=False):
"""
serialize given program to bytes.
"""
return program.desc.serialize_to_string()
return program.desc.serialize_to_string(legacy_format=legacy_format)
@static_only
......@@ -459,6 +461,8 @@ def save_inference_model(
- clip_extra(bool): the flag indicating whether to clip extra information for every operator. Default: True.
- legacy_format(bool): whether to save inference model in legacy format. Default: False.
Returns:
None
......@@ -518,8 +522,10 @@ def save_inference_model(
clip_extra = kwargs.get('clip_extra', True)
program = normalize_program(program, feed_vars, fetch_vars)
# serialize and save program
legacy_format = kwargs.get('legacy_format', False)
program_bytes = _serialize_program(
program._remove_training_info(clip_extra=clip_extra)
program._remove_training_info(clip_extra=clip_extra),
legacy_format=legacy_format,
)
save_to_file(model_path, program_bytes)
# serialize and save params
......@@ -1371,8 +1377,6 @@ def save(program, model_path, protocol=4, **configs):
main_program = program.clone()
program.desc.flush()
main_program.desc._set_version()
paddle.fluid.core.save_op_version_info(program.desc)
with open(model_path + ".pdmodel", "wb") as f:
f.write(program.desc.serialize_to_string())
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册