未验证 提交 9735d1b8 编写于 作者: H Hui Zhang 提交者: GitHub

[jit] c++ property deserialization & Variable support vector of int, float (#44727)

* c++ property deserialization

* fix for comment

* more error info

* fix exception info

* fix ci

* fix compile

* fix layer test ci
上级 9f1616a0
...@@ -38,6 +38,9 @@ if(WIN32) ...@@ -38,6 +38,9 @@ if(WIN32)
set(GTEST_MAIN_LIBRARIES set(GTEST_MAIN_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest_main.lib" "${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/gtest_main.lib"
CACHE FILEPATH "gtest main libraries." FORCE) CACHE FILEPATH "gtest main libraries." FORCE)
set(GMOCK_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgmock.lib"
CACHE FILEPATH "gmock libraries." FORCE)
string(REPLACE "/w " "" GTEST_CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") string(REPLACE "/w " "" GTEST_CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
string(REPLACE "/w " "" GTEST_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") string(REPLACE "/w " "" GTEST_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
string(REPLACE "/W0 " "" GTEST_CMAKE_C_FLAGS "${GTEST_CMAKE_C_FLAGS}") string(REPLACE "/W0 " "" GTEST_CMAKE_C_FLAGS "${GTEST_CMAKE_C_FLAGS}")
...@@ -49,6 +52,9 @@ else() ...@@ -49,6 +52,9 @@ else()
set(GTEST_MAIN_LIBRARIES set(GTEST_MAIN_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgtest_main.a" "${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgtest_main.a"
CACHE FILEPATH "gtest main libraries." FORCE) CACHE FILEPATH "gtest main libraries." FORCE)
set(GMOCK_LIBRARIES
"${GTEST_INSTALL_DIR}/${CMAKE_INSTALL_LIBDIR}/libgmock.a"
CACHE FILEPATH "gmock libraries." FORCE)
set(GTEST_CMAKE_C_FLAGS "${CMAKE_C_FLAGS}") set(GTEST_CMAKE_C_FLAGS "${CMAKE_C_FLAGS}")
set(GTEST_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}") set(GTEST_CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS}")
endif() endif()
...@@ -86,7 +92,8 @@ ExternalProject_Add( ...@@ -86,7 +92,8 @@ ExternalProject_Add(
-DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON -DCMAKE_POSITION_INDEPENDENT_CODE:BOOL=ON
-DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE} -DCMAKE_BUILD_TYPE:STRING=${THIRD_PARTY_BUILD_TYPE}
BUILD_BYPRODUCTS ${GTEST_LIBRARIES} BUILD_BYPRODUCTS ${GTEST_LIBRARIES}
BUILD_BYPRODUCTS ${GTEST_MAIN_LIBRARIES}) BUILD_BYPRODUCTS ${GTEST_MAIN_LIBRARIES}
BUILD_BYPRODUCTS ${GMOCK_LIBRARIES})
add_library(gtest STATIC IMPORTED GLOBAL) add_library(gtest STATIC IMPORTED GLOBAL)
set_property(TARGET gtest PROPERTY IMPORTED_LOCATION ${GTEST_LIBRARIES}) set_property(TARGET gtest PROPERTY IMPORTED_LOCATION ${GTEST_LIBRARIES})
...@@ -96,3 +103,7 @@ add_library(gtest_main STATIC IMPORTED GLOBAL) ...@@ -96,3 +103,7 @@ add_library(gtest_main STATIC IMPORTED GLOBAL)
set_property(TARGET gtest_main PROPERTY IMPORTED_LOCATION set_property(TARGET gtest_main PROPERTY IMPORTED_LOCATION
${GTEST_MAIN_LIBRARIES}) ${GTEST_MAIN_LIBRARIES})
add_dependencies(gtest_main extern_gtest) add_dependencies(gtest_main extern_gtest)
add_library(gmock STATIC IMPORTED GLOBAL)
set_property(TARGET gmock PROPERTY IMPORTED_LOCATION ${GMOCK_LIBRARIES})
add_dependencies(gmock extern_gtest)
...@@ -213,7 +213,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl< ...@@ -213,7 +213,9 @@ using VarTypeRegistry = detail::VarTypeRegistryImpl<
std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>, std::vector<std::unique_ptr<operators::CUDAGraphWithInOuts>>,
int, int,
float, float,
Vocab>; Vocab,
std::vector<int>,
std::vector<float>>;
template <typename T> template <typename T>
struct VarTypeTrait { struct VarTypeTrait {
static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type"); static_assert(VarTypeRegistry::IsRegistered<T>(), "Must be registered type");
......
...@@ -34,6 +34,15 @@ TEST(Variable, GetMutable) { ...@@ -34,6 +34,15 @@ TEST(Variable, GetMutable) {
return; return;
} }
EXPECT_TRUE(false); EXPECT_TRUE(false);
std::unique_ptr<Variable> v_ints(new Variable());
auto* v_t = v_ints->GetMutable<std::vector<int>>();
v_t->push_back(1);
v_t->push_back(2);
const auto& cv_t = v_ints->Get<std::vector<int>>();
EXPECT_EQ(cv_t[0], 1);
EXPECT_EQ(cv_t[1], 2);
} }
} // namespace framework } // namespace framework
......
proto_library(paddle_jit_property_proto SRCS property.proto)
cc_library(
jit_property
SRCS property.cc
DEPS paddle_jit_property_proto tensor)
cc_library( cc_library(
jit_serializer jit_serializer
SRCS serializer.cc SRCS serializer.cc
DEPS lod_tensor device_context) DEPS lod_tensor device_context jit_property)
cc_library( cc_library(
jit_function_utils jit_function_utils
...@@ -32,9 +39,10 @@ cc_library( ...@@ -32,9 +39,10 @@ cc_library(
if(WITH_TESTING AND NOT WIN32) if(WITH_TESTING AND NOT WIN32)
add_custom_target( add_custom_target(
jit_download_program jit_download_program
COMMAND wget -nc -q --no-check-certificate COMMAND
https://paddle-ci.gz.bcebos.com/dy2st/multi_program_load.tar.gz wget -nc -q
COMMAND tar zxf multi_program_load.tar.gz) https://paddle-ci.gz.bcebos.com/dy2st/multi_program_load_with_property.tar.gz
COMMAND tar zxf multi_program_load_with_property.tar.gz)
set(JIT_DEPS set(JIT_DEPS
phi phi
phi_api phi_api
...@@ -52,10 +60,3 @@ if(WITH_TESTING AND NOT WIN32) ...@@ -52,10 +60,3 @@ if(WITH_TESTING AND NOT WIN32)
DEPS ${JIT_DEPS}) DEPS ${JIT_DEPS})
add_dependencies(layer_test jit_download_program) add_dependencies(layer_test jit_download_program)
endif() endif()
proto_library(paddle_jit_property_proto SRCS property.proto)
cc_library(
jit_property
SRCS property.cc
DEPS paddle_jit_property_proto)
...@@ -19,11 +19,16 @@ ...@@ -19,11 +19,16 @@
#include "paddle/fluid/jit/base_function.h" #include "paddle/fluid/jit/base_function.h"
#include "paddle/fluid/jit/compilation_unit.h" #include "paddle/fluid/jit/compilation_unit.h"
#include "paddle/fluid/jit/function_schema.h" #include "paddle/fluid/jit/function_schema.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace jit { namespace jit {
Layer::Layer(const Name2VariableMap& params_dict, const phi::Place& place)
: params_dict_(params_dict) { Layer::Layer(const Name2VariableMap& params_dict,
const Name2VariableMap& attrs_dict,
const phi::Place& place)
: params_dict_(params_dict), attrs_dict_(attrs_dict) {
unit_.reset(new CompilationUnit()); unit_.reset(new CompilationUnit());
} }
...@@ -57,5 +62,25 @@ const Name2FunctionMap& Layer::FunctionMap() const { ...@@ -57,5 +62,25 @@ const Name2FunctionMap& Layer::FunctionMap() const {
return unit_->FunctionMap(); return unit_->FunctionMap();
} }
#define PD_SPECIALZE_ATTRIBUTE_TYPE(T) \
template <> \
T Layer::Attribute<T>(const std::string& name) const { \
if (attrs_dict_.find(name) == attrs_dict_.end()) { \
PADDLE_THROW(phi::errors::NotFound( \
"Attribute can not found %s, please check if it exists.")); \
return T(); \
} \
auto var = attrs_dict_.at(name); \
T ret = var->Get<T>(); \
return ret; \
}
PD_SPECIALZE_ATTRIBUTE_TYPE(int)
PD_SPECIALZE_ATTRIBUTE_TYPE(float)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::string)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<int>)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<float>)
PD_SPECIALZE_ATTRIBUTE_TYPE(std::vector<std::string>)
} // namespace jit } // namespace jit
} // namespace paddle } // namespace paddle
...@@ -21,7 +21,7 @@ ...@@ -21,7 +21,7 @@
#include "paddle/phi/api/include/tensor.h" #include "paddle/phi/api/include/tensor.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "base_function.h" #include "base_function.h" //NOLINT
namespace paddle { namespace paddle {
...@@ -42,11 +42,14 @@ using Name2FunctionMap = ...@@ -42,11 +42,14 @@ using Name2FunctionMap =
class Layer { class Layer {
public: public:
Layer(const Name2VariableMap& params_dict, const phi::Place& place); Layer(const Name2VariableMap& params_dict,
const Name2VariableMap& attrs_dict_,
const phi::Place& place);
std::shared_ptr<BaseFunction> Function(const std::string& name) const; std::shared_ptr<BaseFunction> Function(const std::string& name) const;
Variable Attribute(const std::string& name) const; template <typename T>
T Attribute(const std::string& name) const;
std::vector<Tensor> forward(const std::vector<Tensor>& inputs); std::vector<Tensor> forward(const std::vector<Tensor>& inputs);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <cmath> #include <cmath>
#include <string> #include <string>
#include <vector>
#include "gtest/gtest.h" #include "gtest/gtest.h"
...@@ -71,8 +72,31 @@ TEST(CpuLayerTest, Construct) { ...@@ -71,8 +72,31 @@ TEST(CpuLayerTest, Construct) {
auto place = phi::CPUPlace(); auto place = phi::CPUPlace();
std::string path = "./multi_program_load/export"; std::string path = "./multi_program_load/export";
auto layer = jit::Load(path, place); auto layer = jit::Load(path, place);
auto inputs = PrepareInputs(place);
float fbias = layer.Attribute<float>("fbias");
EXPECT_FLOAT_EQ(fbias, 1.4);
int ds = layer.Attribute<int>("down_sampling");
EXPECT_EQ(ds, 4);
std::string fstr = layer.Attribute<std::string>("fstr");
EXPECT_STREQ(fstr.c_str(), "save str property");
std::vector<int> ints = layer.Attribute<std::vector<int>>("ints");
EXPECT_EQ(ints[0], 10);
EXPECT_EQ(ints[1], 20);
std::vector<float> floats = layer.Attribute<std::vector<float>>("floats");
EXPECT_FLOAT_EQ(floats[0], 1.1);
EXPECT_FLOAT_EQ(floats[1], 2.2);
std::vector<std::string> strs =
layer.Attribute<std::vector<std::string>>("strs");
EXPECT_STREQ(strs[0].c_str(), "hello");
EXPECT_STREQ(strs[1].c_str(), "world");
// functions
auto inputs = PrepareInputs(place);
auto outs = layer.forward(inputs); auto outs = layer.forward(inputs);
auto out_data = outs[0].data<float>(); auto out_data = outs[0].data<float>();
EXPECT_NEAR(out_data[0], 0.02194316, 1e-6); EXPECT_NEAR(out_data[0], 0.02194316, 1e-6);
......
...@@ -12,16 +12,115 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,16 +12,115 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/jit/property.h" #include <fstream>
#include <streambuf>
#include <string>
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/fluid/framework/variable.h"
#include "paddle/fluid/jit/property.h"
#include "paddle/phi/core/enforce.h" #include "paddle/phi/core/enforce.h"
#include "paddle/phi/core/errors.h" #include "paddle/phi/core/errors.h"
namespace paddle { namespace paddle {
namespace jit { namespace jit {
using Variable = paddle::framework::Variable;
void Property::DeserializationFromString(const std::string &str) {
PADDLE_ENFORCE_EQ(
this->Proto()->ParsePartialFromString(str),
true,
phi::errors::InvalidArgument("Failed to parse pb from string"));
return;
}
std::string Property::SerializationToString() {
std::string retv;
PADDLE_ENFORCE_EQ(this->Proto()->SerializePartialToString(&retv),
true,
phi::errors::InvalidArgument(
"Failed to serialize input Desc to string."));
return retv;
}
void Property::Deserialization(const std::string &path) {
std::ifstream ifs(path, std::ios::binary | std::ios::in);
std::string str((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());
DeserializationFromString(str);
ifs.close();
return;
}
void Property::Serialization(const std::string &path) {
std::string str = SerializationToString();
std::ofstream ofs(path, std::ios::binary | std::ios::out);
ofs << str;
ofs.close();
return;
}
int Property::Size() const { return property_.entrys_size(); } int Property::Size() const { return property_.entrys_size(); }
std::vector<std::string> Property::Names() const {
std::vector<std::string> res;
for (int i = 0; i < Size(); i++) {
auto entry = property_.entrys(i);
if (entry.has_name()) {
res.push_back(entry.name());
} else {
LOG(WARNING) << "JIT::Property entry " << i
<< " not has name! Please check whether it is reasonable.";
}
}
return res;
}
std::unordered_map<std::string, std::shared_ptr<Variable>> Property::Values() {
std::unordered_map<std::string, std::shared_ptr<Variable>> res;
using ValueProto = proto::ValueProto;
for (int i = 0; i < Size(); i++) {
auto entry = property_.entrys(i);
if (entry.has_name()) {
auto &n = entry.name();
// remove Class Name suffix
auto key = n.substr(n.find_first_of(".") + 1);
std::shared_ptr<Variable> var(new Variable());
auto type = entry.type();
switch (type) {
case ValueProto::FLOAT:
*var->GetMutable<float>() = GetFloat(n);
break;
case ValueProto::INT:
*var->GetMutable<int>() = static_cast<int>(GetInt64(n));
break;
case ValueProto::STRING:
*var->GetMutable<std::string>() = GetString(n);
break;
case ValueProto::FLOATS:
*var->GetMutable<std::vector<float>>() = GetFloats(n);
break;
case ValueProto::INTS:
*var->GetMutable<std::vector<int>>() = GetInt64s(n);
break;
case ValueProto::STRINGS:
*var->GetMutable<std::vector<std::string>>() = GetStrings(n);
break;
default:
break;
}
res[key] = var;
VLOG(3) << "read property: " << n << " to " << key;
} else {
LOG(WARNING) << "JIT::Property entry " << i
<< " not has name! Please check whether it is reasonable.";
}
}
return res;
}
void Property::SetFloat(const float &f) { void Property::SetFloat(const float &f) {
auto type = proto::ValueProto::FLOAT; auto type = proto::ValueProto::FLOAT;
auto entry = property_.add_entrys(); auto entry = property_.add_entrys();
...@@ -42,7 +141,16 @@ void Property::SetFloat(const std::string &name, const float &f) { ...@@ -42,7 +141,16 @@ void Property::SetFloat(const std::string &name, const float &f) {
float Property::GetFloat(const std::string &name) const { float Property::GetFloat(const std::string &name) const {
for (int i = 0; i < Size(); i++) { for (int i = 0; i < Size(); i++) {
auto e = property_.entrys(i); auto e = property_.entrys(i);
if (e.has_name() && e.name() == name) { if (e.has_name() && e.name() == name) {
PADDLE_ENFORCE(
e.has_type() && e.type() == proto::ValueProto::FLOAT,
phi::errors::PreconditionNotMet("JIT::Property GetFloat: idx=%d type "
"is not float. Expect %d, but %d",
i,
proto::ValueProto::FLOAT,
e.type()));
return e.f(); return e.f();
} }
} }
...@@ -91,6 +199,26 @@ void Property::SetFloats(const std::string &name, const std::vector<float> &v) { ...@@ -91,6 +199,26 @@ void Property::SetFloats(const std::string &name, const std::vector<float> &v) {
<< " for name: " << name; << " for name: " << name;
} }
std::vector<float> Property::GetFloats(const std::string &name) {
for (int i = 0; i < Size(); i++) {
auto e = property_.entrys(i);
if (e.has_name() && e.name() == name) {
PADDLE_ENFORCE(
e.has_type() && e.type() == proto::ValueProto::FLOATS,
phi::errors::PreconditionNotMet(
"JIT::Property GetFloats: idx=%d type is not floats.", i));
auto items = e.floats();
return std::vector<float>(items.begin(), items.end());
}
}
PADDLE_THROW(phi::errors::NotFound(
"JIT::Property GetFloats: name: %s not found", name));
return std::vector<float>();
}
void Property::SetInt64(const int64_t &i) { void Property::SetInt64(const int64_t &i) {
auto type = proto::ValueProto::INT; auto type = proto::ValueProto::INT;
auto entry = property_.add_entrys(); auto entry = property_.add_entrys();
...@@ -108,6 +236,24 @@ void Property::SetInt64(const std::string &name, const int64_t &i) { ...@@ -108,6 +236,24 @@ void Property::SetInt64(const std::string &name, const int64_t &i) {
VLOG(3) << "Property: set_int " << i << " name: " << name; VLOG(3) << "Property: set_int " << i << " name: " << name;
} }
int64_t Property::GetInt64(const std::string &name) {
for (int i = 0; i < Size(); i++) {
auto e = property_.entrys(i);
if (e.has_name() && e.name() == name) {
PADDLE_ENFORCE(e.has_type() && e.type() == proto::ValueProto::INT,
phi::errors::PreconditionNotMet(
"JIT::Property GetInt64: idx=%d type is not int.", i));
return e.i();
}
}
PADDLE_THROW(phi::errors::NotFound(
"JIT::Property GetInt64: name: %s not found", name));
return 0;
}
void Property::SetInt64s(const std::vector<int64_t> &v) { void Property::SetInt64s(const std::vector<int64_t> &v) {
auto type = proto::ValueProto::INTS; auto type = proto::ValueProto::INTS;
auto entry = property_.add_entrys(); auto entry = property_.add_entrys();
...@@ -130,6 +276,31 @@ void Property::SetInt64s(const std::string &name, ...@@ -130,6 +276,31 @@ void Property::SetInt64s(const std::string &name,
VLOG(3) << "Property: set_ints " << v[0] << " name: " << name; VLOG(3) << "Property: set_ints " << v[0] << " name: " << name;
} }
std::vector<int> Property::GetInt64s(const std::string &name) {
for (int i = 0; i < Size(); i++) {
auto e = property_.entrys(i);
if (e.has_name() && e.name() == name) {
PADDLE_ENFORCE(
e.has_type() && e.type() == proto::ValueProto::INTS,
phi::errors::PreconditionNotMet(
"JIT::Property GetInt64s: idx=%d type is not ints.", i));
auto items = e.ints();
std::vector<int> res;
std::transform(items.begin(),
items.end(),
std::back_inserter(res),
[](const int64_t &v) { return static_cast<int>(v); });
return res;
}
}
PADDLE_THROW(phi::errors::NotFound(
"JIT::Property GetInt64s: name: %s not found", name));
return {};
}
void Property::SetString(const std::string &s) { void Property::SetString(const std::string &s) {
auto type = proto::ValueProto::STRING; auto type = proto::ValueProto::STRING;
auto entry = property_.add_entrys(); auto entry = property_.add_entrys();
...@@ -147,6 +318,24 @@ void Property::SetString(const std::string &name, const std::string &s) { ...@@ -147,6 +318,24 @@ void Property::SetString(const std::string &name, const std::string &s) {
VLOG(3) << "Property: set_string " << s << " name: " << name; VLOG(3) << "Property: set_string " << s << " name: " << name;
} }
std::string Property::GetString(const std::string &name) {
for (int i = 0; i < Size(); i++) {
auto e = property_.entrys(i);
if (e.has_name() && e.name() == name) {
PADDLE_ENFORCE(
e.has_type() && e.type() == proto::ValueProto::STRING,
phi::errors::PreconditionNotMet(
"JIT::Property GetString: idx=%d type is not string.", i));
return e.s();
}
}
PADDLE_THROW(phi::errors::NotFound(
"JIT::Property GetString: name: %s not found", name));
return {};
}
void Property::SetStrings(const std::vector<std::string> &v) { void Property::SetStrings(const std::vector<std::string> &v) {
auto type = proto::ValueProto::STRINGS; auto type = proto::ValueProto::STRINGS;
auto entry = property_.add_entrys(); auto entry = property_.add_entrys();
...@@ -169,5 +358,25 @@ void Property::SetStrings(const std::string &name, ...@@ -169,5 +358,25 @@ void Property::SetStrings(const std::string &name,
VLOG(3) << "Property: set_strings " << v[0] << " name: " << name; VLOG(3) << "Property: set_strings " << v[0] << " name: " << name;
} }
std::vector<std::string> Property::GetStrings(const std::string &name) {
for (int i = 0; i < Size(); i++) {
auto e = property_.entrys(i);
if (e.has_name() && e.name() == name) {
PADDLE_ENFORCE(
e.has_type() && e.type() == proto::ValueProto::STRINGS,
phi::errors::PreconditionNotMet(
"JIT::Property GetStrings: idx=%d type is not strings.", i));
auto items = e.strings();
return std::vector<std::string>(items.begin(), items.end());
}
}
PADDLE_THROW(phi::errors::NotFound(
"JIT::Property GetStrings: name: %s not found", name));
return {};
}
} // namespace jit } // namespace jit
} // namespace paddle } // namespace paddle
...@@ -17,13 +17,19 @@ ...@@ -17,13 +17,19 @@
#include <algorithm> #include <algorithm>
#include <atomic> #include <atomic>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/jit/property.pb.h" #include "paddle/fluid/jit/property.pb.h"
namespace paddle { namespace paddle {
namespace framework {
class Variable;
}
namespace jit { namespace jit {
using Variable = paddle::framework::Variable;
class Property { class Property {
public: public:
Property() {} Property() {}
...@@ -43,33 +49,54 @@ class Property { ...@@ -43,33 +49,54 @@ class Property {
const proto::PropertyVals *Proto() const { return &property_; } const proto::PropertyVals *Proto() const { return &property_; }
int Size() const; int Size() const;
std::vector<std::string> Names() const;
std::unordered_map<std::string, std::shared_ptr<Variable>> Values();
void SetFloat(const float &f); void SetFloat(const float &f);
void SetFloat(const std::string &name, const float &f); void SetFloat(const std::string &name, const float &f);
float GetFloat(const std::string &name) const;
float GetFloat(const int &idx) const;
void SetFloats(const std::vector<float> &v); void SetFloats(const std::vector<float> &v);
void SetFloats(const std::string &name, const std::vector<float> &v); void SetFloats(const std::string &name, const std::vector<float> &v);
float GetFloat(const std::string &name) const; std::vector<float> GetFloats(const std::string &name);
float GetFloat(const int &idx) const;
void SetInt64(const int64_t &i); void SetInt64(const int64_t &i);
void SetInt64(const std::string &name, const int64_t &i); void SetInt64(const std::string &name, const int64_t &i);
int64_t GetInt64(const std::string &name);
void SetInt64s(const std::vector<int64_t> &v); void SetInt64s(const std::vector<int64_t> &v);
void SetInt64s(const std::string &name, const std::vector<int64_t> &v); void SetInt64s(const std::string &name, const std::vector<int64_t> &v);
std::vector<int> GetInt64s(const std::string &name);
void SetString(const std::string &s); void SetString(const std::string &s);
void SetString(const std::string &name, const std::string &s); void SetString(const std::string &name, const std::string &s);
std::string GetString(const std::string &name);
void SetStrings(const std::vector<std::string> &v); void SetStrings(const std::vector<std::string> &v);
void SetStrings(const std::string &name, const std::vector<std::string> &v); void SetStrings(const std::string &name, const std::vector<std::string> &v);
std::vector<std::string> GetStrings(const std::string &name);
void Deserialization(const std::string &path);
void Serialization(const std::string &path);
// The Id() and OriginalId() are only used for auto parallel. // The Id() and OriginalId() are only used for auto parallel.
uint64_t Id() const { return id_; } uint64_t Id() const { return id_; }
uint64_t OriginalId() const { return original_id_; } uint64_t OriginalId() const { return original_id_; }
void SetOriginalId(uint64_t original_id) { original_id_ = original_id; } void SetOriginalId(uint64_t original_id) { original_id_ = original_id; }
private:
void DeserializationFromString(const std::string &str);
std::string SerializationToString();
private: private:
proto::PropertyVals property_; proto::PropertyVals property_;
......
...@@ -37,7 +37,6 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -37,7 +37,6 @@ Layer Deserializer::operator()(const std::string& path,
// set is ordered // set is ordered
std::set<std::string> param_names_set; std::set<std::string> param_names_set;
std::vector<std::shared_ptr<FunctionInfo>> infos; std::vector<std::shared_ptr<FunctionInfo>> infos;
Name2VariableMap params_dict;
for (auto& it : pdmodel_paths) { for (auto& it : pdmodel_paths) {
auto& func_name = it.first; auto& func_name = it.first;
auto program_desc = LoadProgram(it.second); auto program_desc = LoadProgram(it.second);
...@@ -56,19 +55,27 @@ Layer Deserializer::operator()(const std::string& path, ...@@ -56,19 +55,27 @@ Layer Deserializer::operator()(const std::string& path,
func_name, persist_var_names, program_desc)); func_name, persist_var_names, program_desc));
} }
Name2VariableMap params_dict;
Name2VariableMap attrs_dict;
ReadTensorData(path + PDPARAMS_SUFFIX, param_names_set, place, &params_dict); ReadTensorData(path + PDPARAMS_SUFFIX, param_names_set, place, &params_dict);
// ReadAttributeData();
Layer layer = Layer(params_dict, place); if (utils::FileExists(path + PROPERTY_SUFFIX)) {
ReadAttributeData(path + PROPERTY_SUFFIX, &attrs_dict);
VLOG(3) << "Read Property Success!";
}
Layer layer = Layer(params_dict, attrs_dict, place);
for (auto& info : infos) { for (auto& info : infos) {
if (FLAGS_jit_engine_type == "Executor") { if (FLAGS_jit_engine_type == "Executor") {
VLOG(3) << "Add function type: ExecutorFunction."; VLOG(3) << "Add function type: ExecutorFunction. name: "
<< info->FunctionName();
layer.SetFunction( layer.SetFunction(
info->FunctionName(), info->FunctionName(),
utils::MakeFunction<ExecutorFunction>(info, params_dict, place)); utils::MakeFunction<ExecutorFunction>(info, params_dict, place));
} else if (FLAGS_jit_engine_type == "PE") { } else if (FLAGS_jit_engine_type == "PE") {
VLOG(3) << "Add function type: PEFunction."; VLOG(3) << "Add function type: PEFunction. name: "
<< info->FunctionName();
layer.SetFunction( layer.SetFunction(
info->FunctionName(), info->FunctionName(),
utils::MakeFunction<PEFunction>(info, params_dict, place)); utils::MakeFunction<PEFunction>(info, params_dict, place));
...@@ -99,7 +106,13 @@ void Deserializer::ReadTensorData(const std::string& file_name, ...@@ -99,7 +106,13 @@ void Deserializer::ReadTensorData(const std::string& file_name,
} }
void Deserializer::ReadAttributeData(const std::string& file_path, void Deserializer::ReadAttributeData(const std::string& file_path,
Name2VariableMap* attrs_dict) const {} Name2VariableMap* attrs_dict) const {
VLOG(3) << "ReadPropertyData from: " << file_path;
Property p;
p.Deserialization(file_path);
*attrs_dict = static_cast<Name2VariableMap>(p.Values());
return;
}
framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) { framework::ProgramDesc Deserializer::LoadProgram(const std::string& file_name) {
VLOG(3) << "LoadProgram from: " << file_name; VLOG(3) << "LoadProgram from: " << file_name;
......
...@@ -26,6 +26,7 @@ class VarDesc; ...@@ -26,6 +26,7 @@ class VarDesc;
namespace jit { namespace jit {
static const char PDMODEL_SUFFIX[] = ".pdmodel"; static const char PDMODEL_SUFFIX[] = ".pdmodel";
static const char PDPARAMS_SUFFIX[] = ".pdiparams"; static const char PDPARAMS_SUFFIX[] = ".pdiparams";
static const char PROPERTY_SUFFIX[] = ".meta";
namespace utils { namespace utils {
bool IsPersistable(framework::VarDesc* desc_ptr); bool IsPersistable(framework::VarDesc* desc_ptr);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册