未验证 提交 ab8c3179 编写于 作者: H Huihuang Zheng 提交者: GitHub

Update Save/Load Interface to 2.0 (#55836)

Update Save/Load Interface to 2.0
上级 24c63733
...@@ -55,6 +55,7 @@ DEFINE_string(resnet50_model_dir, ...@@ -55,6 +55,7 @@ DEFINE_string(resnet50_model_dir,
DEFINE_int32(evaluate_knobs, DEFINE_int32(evaluate_knobs,
-1, -1,
"the options to control which schedule tests will be run."); "the options to control which schedule tests will be run.");
DECLARE_double(cinn_infer_model_version);
namespace cinn { namespace cinn {
namespace auto_schedule { namespace auto_schedule {
...@@ -353,6 +354,7 @@ TEST_F(PerformanceTester, Gather) { ...@@ -353,6 +354,7 @@ TEST_F(PerformanceTester, Gather) {
// paddle model test // paddle model test
TEST_F(PerformanceTester, ResNet50) { TEST_F(PerformanceTester, ResNet50) {
CHECK_NE(FLAGS_resnet50_model_dir, ""); CHECK_NE(FLAGS_resnet50_model_dir, "");
FLAGS_cinn_infer_model_version = 1.0;
std::unordered_map<std::string, std::vector<int64_t>> feeds = { std::unordered_map<std::string, std::vector<int64_t>> feeds = {
{"inputs", {batch_size, 3, 224, 224}}}; {"inputs", {batch_size, 3, 224, 224}}};
Evaluate(cinn::frontend::PaddleModelConvertor(common::DefaultNVGPUTarget()) Evaluate(cinn::frontend::PaddleModelConvertor(common::DefaultNVGPUTarget())
......
...@@ -43,7 +43,7 @@ class Interpreter final { ...@@ -43,7 +43,7 @@ class Interpreter final {
*/ */
void LoadPaddleModel(const std::string& model_dir, void LoadPaddleModel(const std::string& model_dir,
const Target& target, const Target& target,
bool params_combined = false, bool params_combined,
const std::string& model_name = ""); const std::string& model_name = "");
/** /**
......
...@@ -24,11 +24,11 @@ namespace cinn::frontend { ...@@ -24,11 +24,11 @@ namespace cinn::frontend {
TEST(Interpreter, basic) { TEST(Interpreter, basic) {
Interpreter executor({"A"}, {{1, 30}}); Interpreter executor({"A"}, {{1, 30}});
executor.LoadPaddleModel(FLAGS_model_dir, common::DefaultTarget()); executor.LoadPaddleModel(FLAGS_model_dir, common::DefaultTarget(), true);
executor.Run(); executor.Run();
// fc_0.tmp_2 is eliminated by OpFusion, so here // fc_0.tmp_2 is eliminated by OpFusion, so here
// change to get tenor of the out variable // change to get tenor of the out variable
executor.GetTensor("fc_0.tmp_2"); executor.GetTensor("save_infer_model/scale_0.tmp_0");
} }
} // namespace cinn::frontend } // namespace cinn::frontend
...@@ -245,12 +245,8 @@ void LoadModelPb(const std::string &model_dir, ...@@ -245,12 +245,8 @@ void LoadModelPb(const std::string &model_dir,
VLOG(3) << "param_file is: " << param_file; VLOG(3) << "param_file is: " << param_file;
// Load model // Load model
VLOG(4) << "Start load model program..."; VLOG(4) << "Start load model program...";
std::string prog_path = model_dir + "/__model__"; std::string prog_path = model_dir + model_file;
std::string param_file_temp = param_file; std::string param_file_temp = model_dir + param_file;
if (combined) {
// prog_path = model_file;
param_file_temp = model_dir + "/params";
}
framework_proto::ProgramDesc pb_proto_prog = framework_proto::ProgramDesc pb_proto_prog =
*LoadProgram(prog_path, model_from_memory); *LoadProgram(prog_path, model_from_memory);
pb::ProgramDesc pb_prog(&pb_proto_prog); pb::ProgramDesc pb_prog(&pb_proto_prog);
......
...@@ -24,7 +24,7 @@ namespace cinn::frontend::paddle { ...@@ -24,7 +24,7 @@ namespace cinn::frontend::paddle {
TEST(LoadModelPb, naive_model) { TEST(LoadModelPb, naive_model) {
hlir::framework::Scope scope; hlir::framework::Scope scope;
cpp::ProgramDesc program_desc; cpp::ProgramDesc program_desc;
LoadModelPb(FLAGS_model_dir, "__model__", "", &scope, &program_desc, false); LoadModelPb(FLAGS_model_dir, "/__model__", "", &scope, &program_desc, false);
ASSERT_EQ(program_desc.BlocksSize(), 1UL); ASSERT_EQ(program_desc.BlocksSize(), 1UL);
......
...@@ -27,6 +27,8 @@ ...@@ -27,6 +27,8 @@
#include "paddle/cinn/frontend/var_type_utils.h" #include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/cinn/hlir/op/use_ops.h" #include "paddle/cinn/hlir/op/use_ops.h"
DECLARE_double(cinn_infer_model_version);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
...@@ -123,14 +125,25 @@ Program PaddleModelConvertor::LoadModel( ...@@ -123,14 +125,25 @@ Program PaddleModelConvertor::LoadModel(
bool is_combined, bool is_combined,
const std::unordered_map<std::string, std::vector<int64_t>>& feed) { const std::unordered_map<std::string, std::vector<int64_t>>& feed) {
paddle::cpp::ProgramDesc program_desc; paddle::cpp::ProgramDesc program_desc;
paddle::LoadModelPb(model_dir, if (FLAGS_cinn_infer_model_version < 2.0) {
"__model__", paddle::LoadModelPb(model_dir,
"", "/__model__",
scope_.get(), "/params",
&program_desc, scope_.get(),
is_combined, &program_desc,
false, is_combined,
target_); false,
target_);
} else {
paddle::LoadModelPb(model_dir,
".pdmodel",
".pdiparams",
scope_.get(),
&program_desc,
is_combined,
false,
target_);
}
CHECK_EQ(program_desc.BlocksSize(), 1) CHECK_EQ(program_desc.BlocksSize(), 1)
<< "CINN can only support the model with a single block"; << "CINN can only support the model with a single block";
auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0); auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0);
......
...@@ -21,6 +21,8 @@ ...@@ -21,6 +21,8 @@
#include "paddle/cinn/frontend/paddle/pb/program_desc.h" #include "paddle/cinn/frontend/paddle/pb/program_desc.h"
#include "paddle/cinn/hlir/framework/node.h" #include "paddle/cinn/hlir/framework/node.h"
DECLARE_double(cinn_infer_model_version);
namespace cinn { namespace cinn {
namespace frontend { namespace frontend {
using utils::Join; using utils::Join;
...@@ -740,14 +742,25 @@ Variable PaddleModelToProgram::GetVar(const std::string& name) { ...@@ -740,14 +742,25 @@ Variable PaddleModelToProgram::GetVar(const std::string& name) {
std::unique_ptr<Program> PaddleModelToProgram::operator()( std::unique_ptr<Program> PaddleModelToProgram::operator()(
const std::string& model_dir, bool is_combined) { const std::string& model_dir, bool is_combined) {
paddle::cpp::ProgramDesc program_desc; paddle::cpp::ProgramDesc program_desc;
paddle::LoadModelPb(model_dir, if (FLAGS_cinn_infer_model_version < 2.0) {
"__model__", paddle::LoadModelPb(model_dir,
"", "/__model__",
scope_, "/params",
&program_desc, scope_,
is_combined, &program_desc,
false, is_combined,
target_); false,
target_);
} else {
paddle::LoadModelPb(model_dir,
".pdmodel",
".pdiparams",
scope_,
&program_desc,
is_combined,
false,
target_);
}
CHECK_EQ(program_desc.BlocksSize(), 1) CHECK_EQ(program_desc.BlocksSize(), 1)
<< "CINN can only support the model with a single block"; << "CINN can only support the model with a single block";
auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0); auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0);
......
...@@ -33,6 +33,7 @@ DEFINE_bool(cinn_cudnn_deterministic, ...@@ -33,6 +33,7 @@ DEFINE_bool(cinn_cudnn_deterministic,
#endif #endif
using ::GFLAGS_NAMESPACE::BoolFromEnv; using ::GFLAGS_NAMESPACE::BoolFromEnv;
using ::GFLAGS_NAMESPACE::DoubleFromEnv;
using ::GFLAGS_NAMESPACE::Int32FromEnv; using ::GFLAGS_NAMESPACE::Int32FromEnv;
using ::GFLAGS_NAMESPACE::Int64FromEnv; using ::GFLAGS_NAMESPACE::Int64FromEnv;
using ::GFLAGS_NAMESPACE::StringFromEnv; using ::GFLAGS_NAMESPACE::StringFromEnv;
...@@ -180,6 +181,11 @@ DEFINE_int32(cinn_error_message_level, ...@@ -180,6 +181,11 @@ DEFINE_int32(cinn_error_message_level,
"Specify the level of printing error message in the schedule." "Specify the level of printing error message in the schedule."
"0 means short, 1 means detailed."); "0 means short, 1 means detailed.");
DEFINE_double(cinn_infer_model_version,
DoubleFromEnv("FLAGS_cinn_infer_model_version", 2.0),
"Paddle has different model format in inference model. We use "
"a flag to load different versions.");
namespace cinn { namespace cinn {
namespace runtime { namespace runtime {
......
...@@ -91,7 +91,8 @@ if(WITH_CUDNN) ...@@ -91,7 +91,8 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet18.py FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_resnet18.py
"${CMAKE_BINARY_DIR}/third_party/ResNet18" "${WITH_GPU}" "${CMAKE_BINARY_DIR}/third_party/ResNet18" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
...@@ -100,7 +101,8 @@ if(WITH_CUDNN) ...@@ -100,7 +101,8 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv2.py FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv2.py
"${CMAKE_BINARY_DIR}/third_party/MobileNetV2" "${WITH_GPU}" "${CMAKE_BINARY_DIR}/third_party/MobileNetV2" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
...@@ -109,7 +111,8 @@ if(WITH_CUDNN) ...@@ -109,7 +111,8 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_efficientnet.py FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_efficientnet.py
"${CMAKE_BINARY_DIR}/third_party/EfficientNet" "${WITH_GPU}" "${CMAKE_BINARY_DIR}/third_party/EfficientNet" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
...@@ -118,7 +121,8 @@ if(WITH_CUDNN) ...@@ -118,7 +121,8 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv1.py FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv1.py
"${CMAKE_BINARY_DIR}/third_party/MobilenetV1" "${WITH_GPU}" "${CMAKE_BINARY_DIR}/third_party/MobilenetV1" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
...@@ -127,7 +131,8 @@ if(WITH_CUDNN) ...@@ -127,7 +131,8 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet50.py FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_resnet50.py
"${CMAKE_BINARY_DIR}/third_party/ResNet50" "${WITH_GPU}" "${CMAKE_BINARY_DIR}/third_party/ResNet50" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
...@@ -136,7 +141,8 @@ if(WITH_CUDNN) ...@@ -136,7 +141,8 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_squeezenet.py FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_squeezenet.py
"${CMAKE_BINARY_DIR}/third_party/SqueezeNet" "${WITH_GPU}" "${CMAKE_BINARY_DIR}/third_party/SqueezeNet" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
...@@ -145,8 +151,9 @@ if(WITH_CUDNN) ...@@ -145,8 +151,9 @@ if(WITH_CUDNN)
COMMAND COMMAND
${CMAKE_COMMAND} -E env ${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH} PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_paddle_model_convertor.py --path FLAGS_cinn_infer_model_version=1.0 python3
"${CMAKE_BINARY_DIR}/third_party/resnet_model" ${CMAKE_CURRENT_SOURCE_DIR}/test_paddle_model_convertor.py --path
"${CMAKE_BINARY_DIR}/third_party/resnet_model_1"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR}) WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endif() endif()
......
...@@ -14,7 +14,7 @@ ...@@ -14,7 +14,7 @@
import paddle import paddle
from paddle import fluid, static from paddle import static
size = 30 size = 30
paddle.enable_static() paddle.enable_static()
...@@ -37,5 +37,5 @@ loss = exe = static.Executor(cpu) ...@@ -37,5 +37,5 @@ loss = exe = static.Executor(cpu)
exe.run(static.default_startup_program()) exe.run(static.default_startup_program())
fluid.io.save_inference_model("./naive_mul_model", [a.name], [a1], exe) static.io.save_inference_model("./naive_mul_model", [a], [a1], exe)
print('res is : ', a1.name) print('res is : ', a1.name)
...@@ -17,7 +17,7 @@ A fake model with multiple FC layers to test CINN on a more complex model. ...@@ -17,7 +17,7 @@ A fake model with multiple FC layers to test CINN on a more complex model.
import paddle import paddle
from paddle import fluid, static from paddle import static
size = 64 size = 64
num_layers = 6 num_layers = 6
...@@ -54,5 +54,5 @@ loss = exe = static.Executor(cpu) ...@@ -54,5 +54,5 @@ loss = exe = static.Executor(cpu)
exe.run(static.default_startup_program()) exe.run(static.default_startup_program())
fluid.io.save_inference_model("./multi_fc_model", [a.name], [fc_out], exe) static.io.save_inference_model("./multi_fc_model", [a], [fc_out], exe)
print('res', fc_out.name) print('res', fc_out.name)
...@@ -47,7 +47,8 @@ exe = static.Executor(cpu) ...@@ -47,7 +47,8 @@ exe = static.Executor(cpu)
exe.run(static.default_startup_program()) exe.run(static.default_startup_program())
static.io.save_inference_model("./resnet_model", [resnet_input], [temp7], exe)
fluid.io.save_inference_model( fluid.io.save_inference_model(
"./resnet_model", [resnet_input.name], [temp7], exe "./resnet_model_1", [resnet_input.name], [temp7], exe
) )
print('res', temp7.name) print('res', temp7.name)
...@@ -39,7 +39,9 @@ class TestLoadResnetModel(unittest.TestCase): ...@@ -39,7 +39,9 @@ class TestLoadResnetModel(unittest.TestCase):
self.x_shape = [1, 160, 7, 7] self.x_shape = [1, 160, 7, 7]
def get_paddle_inference_result(self, data): def get_paddle_inference_result(self, data):
config = fluid.core.AnalysisConfig(self.model_dir) config = fluid.core.AnalysisConfig(
self.model_dir + ".pdmodel", self.model_dir + ".pdiparams"
)
config.disable_gpu() config.disable_gpu()
config.switch_ir_optim(False) config.switch_ir_optim(False)
self.paddle_predictor = fluid.core.create_paddle_predictor(config) self.paddle_predictor = fluid.core.create_paddle_predictor(config)
...@@ -51,11 +53,11 @@ class TestLoadResnetModel(unittest.TestCase): ...@@ -51,11 +53,11 @@ class TestLoadResnetModel(unittest.TestCase):
np.random.seed(0) np.random.seed(0)
x_data = np.random.random(self.x_shape).astype("float32") x_data = np.random.random(self.x_shape).astype("float32")
self.executor = Interpreter(["resnet_input"], [self.x_shape]) self.executor = Interpreter(["resnet_input"], [self.x_shape])
self.executor.load_paddle_model(self.model_dir, self.target, False) self.executor.load_paddle_model(self.model_dir, self.target, True)
a_t = self.executor.get_tensor("resnet_input") a_t = self.executor.get_tensor("resnet_input")
a_t.from_numpy(x_data, self.target) a_t.from_numpy(x_data, self.target)
out = self.executor.get_tensor("relu_0.tmp_0") out = self.executor.get_tensor("save_infer_model/scale_0.tmp_0")
out.from_numpy(np.zeros(out.shape(), dtype='float32'), self.target) out.from_numpy(np.zeros(out.shape(), dtype='float32'), self.target)
self.executor.run() self.executor.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册