未验证 提交 ab8c3179 编写于 作者: H Huihuang Zheng 提交者: GitHub

Update Save/Load Interface to 2.0 (#55836)

Update Save/Load Interface to 2.0
上级 24c63733
......@@ -55,6 +55,7 @@ DEFINE_string(resnet50_model_dir,
DEFINE_int32(evaluate_knobs,
-1,
"the options to control which schedule tests will be run.");
DECLARE_double(cinn_infer_model_version);
namespace cinn {
namespace auto_schedule {
......@@ -353,6 +354,7 @@ TEST_F(PerformanceTester, Gather) {
// paddle model test
TEST_F(PerformanceTester, ResNet50) {
CHECK_NE(FLAGS_resnet50_model_dir, "");
FLAGS_cinn_infer_model_version = 1.0;
std::unordered_map<std::string, std::vector<int64_t>> feeds = {
{"inputs", {batch_size, 3, 224, 224}}};
Evaluate(cinn::frontend::PaddleModelConvertor(common::DefaultNVGPUTarget())
......
......@@ -43,7 +43,7 @@ class Interpreter final {
*/
void LoadPaddleModel(const std::string& model_dir,
const Target& target,
bool params_combined = false,
bool params_combined,
const std::string& model_name = "");
/**
......
......@@ -24,11 +24,11 @@ namespace cinn::frontend {
TEST(Interpreter, basic) {
Interpreter executor({"A"}, {{1, 30}});
executor.LoadPaddleModel(FLAGS_model_dir, common::DefaultTarget());
executor.LoadPaddleModel(FLAGS_model_dir, common::DefaultTarget(), true);
executor.Run();
// fc_0.tmp_2 is eliminated by OpFusion, so here
// change to get tenor of the out variable
executor.GetTensor("fc_0.tmp_2");
executor.GetTensor("save_infer_model/scale_0.tmp_0");
}
} // namespace cinn::frontend
......@@ -245,12 +245,8 @@ void LoadModelPb(const std::string &model_dir,
VLOG(3) << "param_file is: " << param_file;
// Load model
VLOG(4) << "Start load model program...";
std::string prog_path = model_dir + "/__model__";
std::string param_file_temp = param_file;
if (combined) {
// prog_path = model_file;
param_file_temp = model_dir + "/params";
}
std::string prog_path = model_dir + model_file;
std::string param_file_temp = model_dir + param_file;
framework_proto::ProgramDesc pb_proto_prog =
*LoadProgram(prog_path, model_from_memory);
pb::ProgramDesc pb_prog(&pb_proto_prog);
......
......@@ -24,7 +24,7 @@ namespace cinn::frontend::paddle {
TEST(LoadModelPb, naive_model) {
hlir::framework::Scope scope;
cpp::ProgramDesc program_desc;
LoadModelPb(FLAGS_model_dir, "__model__", "", &scope, &program_desc, false);
LoadModelPb(FLAGS_model_dir, "/__model__", "", &scope, &program_desc, false);
ASSERT_EQ(program_desc.BlocksSize(), 1UL);
......
......@@ -27,6 +27,8 @@
#include "paddle/cinn/frontend/var_type_utils.h"
#include "paddle/cinn/hlir/op/use_ops.h"
DECLARE_double(cinn_infer_model_version);
namespace cinn {
namespace frontend {
......@@ -123,14 +125,25 @@ Program PaddleModelConvertor::LoadModel(
bool is_combined,
const std::unordered_map<std::string, std::vector<int64_t>>& feed) {
paddle::cpp::ProgramDesc program_desc;
paddle::LoadModelPb(model_dir,
"__model__",
"",
scope_.get(),
&program_desc,
is_combined,
false,
target_);
if (FLAGS_cinn_infer_model_version < 2.0) {
paddle::LoadModelPb(model_dir,
"/__model__",
"/params",
scope_.get(),
&program_desc,
is_combined,
false,
target_);
} else {
paddle::LoadModelPb(model_dir,
".pdmodel",
".pdiparams",
scope_.get(),
&program_desc,
is_combined,
false,
target_);
}
CHECK_EQ(program_desc.BlocksSize(), 1)
<< "CINN can only support the model with a single block";
auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0);
......
......@@ -21,6 +21,8 @@
#include "paddle/cinn/frontend/paddle/pb/program_desc.h"
#include "paddle/cinn/hlir/framework/node.h"
DECLARE_double(cinn_infer_model_version);
namespace cinn {
namespace frontend {
using utils::Join;
......@@ -740,14 +742,25 @@ Variable PaddleModelToProgram::GetVar(const std::string& name) {
std::unique_ptr<Program> PaddleModelToProgram::operator()(
const std::string& model_dir, bool is_combined) {
paddle::cpp::ProgramDesc program_desc;
paddle::LoadModelPb(model_dir,
"__model__",
"",
scope_,
&program_desc,
is_combined,
false,
target_);
if (FLAGS_cinn_infer_model_version < 2.0) {
paddle::LoadModelPb(model_dir,
"/__model__",
"/params",
scope_,
&program_desc,
is_combined,
false,
target_);
} else {
paddle::LoadModelPb(model_dir,
".pdmodel",
".pdiparams",
scope_,
&program_desc,
is_combined,
false,
target_);
}
CHECK_EQ(program_desc.BlocksSize(), 1)
<< "CINN can only support the model with a single block";
auto* block_desc = program_desc.GetBlock<paddle::cpp::BlockDesc>(0);
......
......@@ -33,6 +33,7 @@ DEFINE_bool(cinn_cudnn_deterministic,
#endif
using ::GFLAGS_NAMESPACE::BoolFromEnv;
using ::GFLAGS_NAMESPACE::DoubleFromEnv;
using ::GFLAGS_NAMESPACE::Int32FromEnv;
using ::GFLAGS_NAMESPACE::Int64FromEnv;
using ::GFLAGS_NAMESPACE::StringFromEnv;
......@@ -180,6 +181,11 @@ DEFINE_int32(cinn_error_message_level,
"Specify the level of printing error message in the schedule."
"0 means short, 1 means detailed.");
DEFINE_double(cinn_infer_model_version,
DoubleFromEnv("FLAGS_cinn_infer_model_version", 2.0),
"Paddle has different model format in inference model. We use "
"a flag to load different versions.");
namespace cinn {
namespace runtime {
......
......@@ -91,7 +91,8 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet18.py
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_resnet18.py
"${CMAKE_BINARY_DIR}/third_party/ResNet18" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
......@@ -100,7 +101,8 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv2.py
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv2.py
"${CMAKE_BINARY_DIR}/third_party/MobileNetV2" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
......@@ -109,7 +111,8 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_efficientnet.py
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_efficientnet.py
"${CMAKE_BINARY_DIR}/third_party/EfficientNet" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
......@@ -118,7 +121,8 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv1.py
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_mobilenetv1.py
"${CMAKE_BINARY_DIR}/third_party/MobilenetV1" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
......@@ -127,7 +131,8 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_resnet50.py
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_resnet50.py
"${CMAKE_BINARY_DIR}/third_party/ResNet50" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
......@@ -136,7 +141,8 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_squeezenet.py
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_squeezenet.py
"${CMAKE_BINARY_DIR}/third_party/SqueezeNet" "${WITH_GPU}"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
......@@ -145,8 +151,9 @@ if(WITH_CUDNN)
COMMAND
${CMAKE_COMMAND} -E env
PYTHONPATH=${CMAKE_BINARY_DIR}:${CMAKE_BINARY_DIR}/python/cinn:$ENV{PYTHONPATH}
python3 ${CMAKE_CURRENT_SOURCE_DIR}/test_paddle_model_convertor.py --path
"${CMAKE_BINARY_DIR}/third_party/resnet_model"
FLAGS_cinn_infer_model_version=1.0 python3
${CMAKE_CURRENT_SOURCE_DIR}/test_paddle_model_convertor.py --path
"${CMAKE_BINARY_DIR}/third_party/resnet_model_1"
WORKING_DIRECTORY ${CMAKE_BINARY_DIR})
endif()
......
......@@ -14,7 +14,7 @@
import paddle
from paddle import fluid, static
from paddle import static
size = 30
paddle.enable_static()
......@@ -37,5 +37,5 @@ loss = exe = static.Executor(cpu)
exe.run(static.default_startup_program())
fluid.io.save_inference_model("./naive_mul_model", [a.name], [a1], exe)
static.io.save_inference_model("./naive_mul_model", [a], [a1], exe)
print('res is : ', a1.name)
......@@ -17,7 +17,7 @@ A fake model with multiple FC layers to test CINN on a more complex model.
import paddle
from paddle import fluid, static
from paddle import static
size = 64
num_layers = 6
......@@ -54,5 +54,5 @@ loss = exe = static.Executor(cpu)
exe.run(static.default_startup_program())
fluid.io.save_inference_model("./multi_fc_model", [a.name], [fc_out], exe)
static.io.save_inference_model("./multi_fc_model", [a], [fc_out], exe)
print('res', fc_out.name)
......@@ -47,7 +47,8 @@ exe = static.Executor(cpu)
exe.run(static.default_startup_program())
static.io.save_inference_model("./resnet_model", [resnet_input], [temp7], exe)
fluid.io.save_inference_model(
"./resnet_model", [resnet_input.name], [temp7], exe
"./resnet_model_1", [resnet_input.name], [temp7], exe
)
print('res', temp7.name)
......@@ -39,7 +39,9 @@ class TestLoadResnetModel(unittest.TestCase):
self.x_shape = [1, 160, 7, 7]
def get_paddle_inference_result(self, data):
config = fluid.core.AnalysisConfig(self.model_dir)
config = fluid.core.AnalysisConfig(
self.model_dir + ".pdmodel", self.model_dir + ".pdiparams"
)
config.disable_gpu()
config.switch_ir_optim(False)
self.paddle_predictor = fluid.core.create_paddle_predictor(config)
......@@ -51,11 +53,11 @@ class TestLoadResnetModel(unittest.TestCase):
np.random.seed(0)
x_data = np.random.random(self.x_shape).astype("float32")
self.executor = Interpreter(["resnet_input"], [self.x_shape])
self.executor.load_paddle_model(self.model_dir, self.target, False)
self.executor.load_paddle_model(self.model_dir, self.target, True)
a_t = self.executor.get_tensor("resnet_input")
a_t.from_numpy(x_data, self.target)
out = self.executor.get_tensor("relu_0.tmp_0")
out = self.executor.get_tensor("save_infer_model/scale_0.tmp_0")
out.from_numpy(np.zeros(out.shape(), dtype='float32'), self.target)
self.executor.run()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册