未验证 提交 20e23e1b 编写于 作者: Y Yuang Liu 提交者: GitHub

[fleet_executor] Dist model run method Implementation (#39194)

上级 8bb509d5
...@@ -13,11 +13,13 @@ ...@@ -13,11 +13,13 @@
// limitations under the License. // limitations under the License.
#include <glog/logging.h> #include <glog/logging.h>
#include <chrono> // NOLINT
#include "paddle/fluid/distributed/fleet_executor/dist_model.h" #include "paddle/fluid/distributed/fleet_executor/dist_model.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor.h"
#include "paddle/fluid/distributed/fleet_executor/task_node.h" #include "paddle/fluid/distributed/fleet_executor/task_node.h"
#include "paddle/fluid/framework/block_desc.h" #include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/naive_executor.h" #include "paddle/fluid/framework/naive_executor.h"
#include "paddle/fluid/framework/op_proto_maker.h" #include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
...@@ -37,10 +39,110 @@ bool IsPersistable(const framework::VarDesc *var) { ...@@ -37,10 +39,110 @@ bool IsPersistable(const framework::VarDesc *var) {
} }
return false; return false;
} }
bool LoadDataFromDistModelTensor(const DistModelTensor &input_data,
framework::LoDTensor *input_tensor,
const platform::Place &place) {
VLOG(3) << "Loading data from DistModelTensor for " << input_data.name;
framework::DDim dims = framework::make_ddim(input_data.shape);
void *input_tensor_ptr;
if (input_data.dtype == DistModelDataType::INT64) {
input_tensor_ptr = input_tensor->mutable_data<int64_t>(dims, place);
} else if (input_data.dtype == DistModelDataType::FLOAT32) {
input_tensor_ptr = input_tensor->mutable_data<float>(dims, place);
} else if (input_data.dtype == DistModelDataType::INT32) {
input_tensor_ptr = input_tensor->mutable_data<int32_t>(dims, place);
} else {
// Q(fleet exe dev): for input/output, should we support fp16
LOG(ERROR) << "unsupported feed type " << input_data.dtype;
return false;
}
PADDLE_ENFORCE_NOT_NULL(
input_tensor_ptr,
paddle::platform::errors::Fatal(
"LoDTensor creation failed. DistModel loaded data failed."));
PADDLE_ENFORCE_NOT_NULL(input_data.data.data(),
paddle::platform::errors::InvalidArgument(
"DistModelTensor contains no data."));
if (platform::is_cpu_place(place)) {
VLOG(3) << "Loading data for CPU.";
std::memcpy(static_cast<void *>(input_tensor_ptr), input_data.data.data(),
input_data.data.length());
} else if (platform::is_gpu_place(place)) {
VLOG(3) << "Loading data for GPU.";
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto *dev_ctx =
dynamic_cast<const platform::CUDADeviceContext *>(pool.Get(place));
auto gpu_place = place;
memory::Copy(gpu_place, static_cast<void *>(input_tensor_ptr),
platform::CPUPlace(), input_data.data.data(),
input_data.data.length(), dev_ctx->stream());
#else
PADDLE_THROW(paddle::platform::errors::Fatal(
"Paddle wasn't compiled with CUDA, but place is GPU."));
#endif
} else {
PADDLE_THROW(paddle::platform::errors::InvalidArgument(
"DistModel only supports CPU and GPU."));
}
framework::LoD dst_lod;
for (auto &src_lod : input_data.lod) {
dst_lod.emplace_back(src_lod);
}
input_tensor->set_lod(dst_lod);
return true;
}
std::string DistModelDTypeToString(DistModelDataType dtype) {
switch (dtype) {
case DistModelDataType::FLOAT32:
return "float32";
case DistModelDataType::FLOAT16:
return "float16";
case DistModelDataType::INT64:
return "int64";
case DistModelDataType::INT32:
return "int32";
case DistModelDataType::INT8:
return "int8";
}
return "NOT SUPPORT DTYPE";
}
bool IsPPFirstStage(const DistModelConfig &config) {
return config.local_rank - config.mp_degree < 0;
}
bool IsPPLastStage(const DistModelConfig &config) {
return config.local_rank + config.mp_degree >= config.nranks;
}
class DistModelTimer {
public:
void tic() { tic_time = std::chrono::high_resolution_clock::now(); }
double toc() {
std::chrono::high_resolution_clock::time_point toc_time =
std::chrono::high_resolution_clock::now();
std::chrono::duration<double> time_elapse =
std::chrono::duration_cast<std::chrono::duration<double>>(toc_time -
tic_time);
double time_elapse_in_ms =
static_cast<double>(time_elapse.count()) * 1000.0;
return time_elapse_in_ms;
}
private:
std::chrono::high_resolution_clock::time_point tic_time;
};
} // namespace } // namespace
bool DistModel::Init() { bool DistModel::Init() {
/* TODO(fleet exe dev): implement this funct */ carrier_id_ = "inference";
bool init_method = (!config_.model_dir.empty() || config_.program_desc); bool init_method = (!config_.model_dir.empty() || config_.program_desc);
PADDLE_ENFORCE_EQ(init_method, true, PADDLE_ENFORCE_EQ(init_method, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -127,10 +229,9 @@ bool DistModel::CommInit() { ...@@ -127,10 +229,9 @@ bool DistModel::CommInit() {
InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints, InsertCommOp("mp_comm_id", mp_group_nranks, mp_group_rank, peer_endpoints,
comm_init_block, config_.mp_ring_id); comm_init_block, config_.mp_ring_id);
} }
if (config_.pp_degree) { if (config_.pp_degree > 1) {
// NOTE: the last pp stage doesn't need init pp comm
VLOG(3) << "Init comm group for pp."; VLOG(3) << "Init comm group for pp.";
if (config_.local_rank - config_.mp_degree >= 0) { if (!IsPPFirstStage(config_)) {
PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true, PADDLE_ENFORCE_EQ(config_.pp_upstream_ring_id >= 0, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"pp upstream ring id must be provided for " "pp upstream ring id must be provided for "
...@@ -143,7 +244,7 @@ bool DistModel::CommInit() { ...@@ -143,7 +244,7 @@ bool DistModel::CommInit() {
comm_init_block, config_.pp_upstream_ring_id); comm_init_block, config_.pp_upstream_ring_id);
} }
if (config_.local_rank + config_.mp_degree < config_.nranks) { if (!IsPPLastStage(config_)) {
PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true, PADDLE_ENFORCE_EQ(config_.pp_downstream_ring_id >= 0, true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"pp downstream ring id must be provided for " "pp downstream ring id must be provided for "
...@@ -326,7 +427,7 @@ bool DistModel::PrepareFleetExe() { ...@@ -326,7 +427,7 @@ bool DistModel::PrepareFleetExe() {
id_to_rank.insert({i, i}); id_to_rank.insert({i, i});
} }
fleet_exe.reset(new FleetExecutor(executor_desc_)); fleet_exe.reset(new FleetExecutor(executor_desc_));
fleet_exe->Init("inference", *(program_.get()), scope_.get(), place_, 1, fleet_exe->Init(carrier_id_, *(program_.get()), scope_.get(), place_, 1,
{task_node_.get()}, id_to_rank); {task_node_.get()}, id_to_rank);
return true; return true;
} }
...@@ -340,8 +441,27 @@ bool DistModel::PrepareFeedAndFetch() { ...@@ -340,8 +441,27 @@ bool DistModel::PrepareFeedAndFetch() {
feeds_.resize(idx + 1); feeds_.resize(idx + 1);
} }
feeds_[idx] = op; feeds_[idx] = op;
feed_names_[op->Output("Out")[0]] = idx; std::string var_name = op->Output("Out")[0];
idx_to_feeds_[idx] = op->Output("Out")[0]; feed_names_[var_name] = idx;
idx_to_feeds_[idx] = var_name;
framework::VarDesc *real_var = program_->Block(0).FindVar(var_name);
if (!real_var) {
LOG(ERROR)
<< "The output of feed ops [" << var_name
<< "] cannot be found in the program. Check the inference program.";
return false;
}
if (real_var->GetDataType() == framework::proto::VarType::FP32) {
feeds_to_dtype_.insert({var_name, DistModelDataType::FLOAT32});
} else if (real_var->GetDataType() == framework::proto::VarType::INT32) {
feeds_to_dtype_.insert({var_name, DistModelDataType::INT32});
} else if (real_var->GetDataType() == framework::proto::VarType::INT64) {
feeds_to_dtype_.insert({var_name, DistModelDataType::INT64});
} else {
LOG(ERROR) << "Don't support feed var dtype for: "
<< real_var->GetDataType();
return false;
}
} else if (op->Type() == "fetch") { } else if (op->Type() == "fetch") {
VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0]; VLOG(3) << "fetch op with fetch var: " << op->Input("X")[0];
int idx = BOOST_GET_CONST(int, op->GetAttr("col")); int idx = BOOST_GET_CONST(int, op->GetAttr("col"));
...@@ -349,15 +469,166 @@ bool DistModel::PrepareFeedAndFetch() { ...@@ -349,15 +469,166 @@ bool DistModel::PrepareFeedAndFetch() {
fetches_.resize(idx + 1); fetches_.resize(idx + 1);
} }
fetches_[idx] = op; fetches_[idx] = op;
id_to_fetches_[idx] = op->Input("X")[0]; idx_to_fetches_[idx] = op->Input("X")[0];
}
}
if (config_.pp_degree == 1) {
if (feeds_.size() == 0) {
LOG(ERROR) << "No feed ops in the inf program, please check the program.";
return false;
}
if (fetches_.size() == 0) {
LOG(ERROR) << "No fetch op in the inf program, please check the program.";
return false;
}
} else {
if (IsPPFirstStage(config_)) {
if (feeds_.size() == 0) {
LOG(ERROR) << "Feed ops are needed for the first pp stage.";
return false;
} else {
LOG(WARNING) << "No feed ops in non-first pp stage.";
} }
} else if (feeds_.size() > 0) {
LOG(WARNING) << "Feed op is found in the non-first stage of pp.";
}
if (IsPPLastStage(config_)) {
if (fetches_.size() == 0) {
LOG(ERROR) << "Fetch op is needed for the last pp stage.";
return false;
} else {
LOG(WARNING) << "No fetch op in non-last pp stage.";
}
} else if (fetches_.size() > 0) {
LOG(WARNING) << "Fetch op is found in the non-last stage of pp.";
}
}
return true;
}
bool DistModel::FeedData(const std::vector<DistModelTensor> &input_data,
framework::Scope *scope) {
VLOG(3) << "DistModel is feeding data.";
if (input_data.size() != feeds_.size()) {
LOG(ERROR) << "Should provide " << feeds_.size() << " feeds, but got "
<< input_data.size() << " data.";
return false;
}
feed_tensors_.resize(feeds_.size());
for (size_t i = 0; i < input_data.size(); ++i) {
// feed each data separately
framework::LoDTensor *input_tensor = &(feed_tensors_[i]);
if (!LoadDataFromDistModelTensor(input_data[i], input_tensor, place_)) {
LOG(ERROR) << "Fail to load data from tensor " << input_data[i].name;
return false;
}
std::string target_name = input_data[i].name;
if (feed_names_.find(target_name) == feed_names_.end()) {
LOG(ERROR) << "The input name [" << target_name
<< "] cannot be found in the program."
<< " DistModel loads data failed.";
return false;
}
if (input_data[i].dtype != feeds_to_dtype_[target_name]) {
LOG(ERROR) << "Feed var [" << target_name << "] expected dtype is: "
<< DistModelDTypeToString(feeds_to_dtype_[target_name])
<< ". But received dtype is: "
<< DistModelDTypeToString(input_data[i].dtype) << ".";
return false;
}
int feed_idx = feed_names_[target_name];
framework::SetFeedVariable(scope, *input_tensor, "feed", feed_idx);
}
return true;
}
bool DistModel::FetchResults(std::vector<DistModelTensor> *output_data,
framework::Scope *scope) {
VLOG(3) << "DistModel is fetch results.";
output_data->resize(fetches_.size());
for (size_t i = 0; i < fetches_.size(); ++i) {
int idx = BOOST_GET_CONST(int, fetches_[i]->GetAttr("col"));
VLOG(3) << "Fetching data for [" << idx_to_fetches_[idx] << "]";
PADDLE_ENFORCE_EQ(
static_cast<size_t>(idx), i,
platform::errors::InvalidArgument(
"Fetch op's col attr(%d) should be equal to the index(%d)", idx,
i));
framework::FetchType &fetch_var =
framework::GetFetchVariable(*scope, "fetch", idx);
auto &fetch = BOOST_GET(framework::LoDTensor, fetch_var);
auto type = fetch.type();
auto output = &(output_data->at(i));
output->name = idx_to_fetches_[idx];
bool rst = false;
if (type == framework::proto::VarType::FP32) {
rst = FetchResult<float>(fetch, output);
output->dtype = DistModelDataType::FLOAT32;
} else if (type == framework::proto::VarType::INT64) {
rst = FetchResult<int64_t>(fetch, output);
output->dtype = DistModelDataType::INT64;
} else if (type == framework::proto::VarType::INT32) {
rst = FetchResult<int32_t>(fetch, output);
output->dtype = DistModelDataType::INT32;
} else {
LOG(ERROR) << "DistModel meets unknown fetch data type. DistModel only "
"supports float32, int64 and int32 fetch type for now.";
}
if (!rst) {
LOG(ERROR) << "DistModel fails to fetch result " << idx_to_fetches_[idx];
return false;
}
}
return true;
}
template <typename T>
bool DistModel::FetchResult(const framework::LoDTensor &fetch,
DistModelTensor *output_data) {
auto shape = framework::vectorize(fetch.dims());
output_data->shape.assign(shape.begin(), shape.end());
const T *data = fetch.data<T>();
int64_t num_elems = fetch.numel();
output_data->data.Resize(num_elems * sizeof(T));
// The output of fetch op is always on the cpu, no need switch on place
memcpy(output_data->data.data(), data, num_elems * sizeof(T));
output_data->lod.clear();
for (auto &level : fetch.lod()) {
output_data->lod.emplace_back(level.begin(), level.end());
} }
return true; return true;
} }
void DistModel::Run(const std::vector<DistModelTensor> &input_data, bool DistModel::Run(const std::vector<DistModelTensor> &input_data,
std::vector<DistModelTensor> *output_data) { std::vector<DistModelTensor> *output_data) {
/* TODO(fleet exe dev): implement this funct */ // TODO(fleet exe dev): support pipeline inf mode
VLOG(3) << "DistModel run for once.";
DistModelTimer timer;
timer.tic();
if (!FeedData(input_data, scope_.get())) {
LOG(ERROR) << "DistModel failed at feeding data.";
return false;
}
double feed_elapse = timer.toc();
VLOG(3) << "Finish loading data, cost " << feed_elapse << "ms.";
fleet_exe->Run(carrier_id_);
double fleet_exe_elapse = timer.toc();
VLOG(3) << "Finish FleetExe running, cost " << fleet_exe_elapse - feed_elapse
<< "ms.";
if (!FetchResults(output_data, scope_.get())) {
LOG(ERROR) << "DistModel failed at fetching result.";
return false;
}
double fetch_elapse = timer.toc();
VLOG(3) << "Finish fetching data, cost " << fetch_elapse - fleet_exe_elapse
<< "ms.";
VLOG(3) << "DistModel finish inf, cost " << fetch_elapse << "ms";
return true;
} }
} // namespace distributed } // namespace distributed
......
...@@ -19,6 +19,7 @@ ...@@ -19,6 +19,7 @@
#include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h" #include "paddle/fluid/distributed/fleet_executor/dist_model_tensor_wrapper.h"
#include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h" #include "paddle/fluid/distributed/fleet_executor/fleet_executor_desc.pb.h"
#include "paddle/fluid/framework/lod_tensor.h"
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/macros.h" #include "paddle/fluid/platform/macros.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
...@@ -57,7 +58,7 @@ class DistModel { ...@@ -57,7 +58,7 @@ class DistModel {
public: public:
explicit DistModel(const DistModelConfig& config) : config_(config) {} explicit DistModel(const DistModelConfig& config) : config_(config) {}
bool Init(); bool Init();
void Run(const std::vector<DistModelTensor>& input_data, bool Run(const std::vector<DistModelTensor>& input_data,
std::vector<DistModelTensor>* output_data); std::vector<DistModelTensor>* output_data);
~DistModel() = default; ~DistModel() = default;
...@@ -75,12 +76,22 @@ class DistModel { ...@@ -75,12 +76,22 @@ class DistModel {
void InsertCommOp(std::string tmp_var_name, int nranks, int rank, void InsertCommOp(std::string tmp_var_name, int nranks, int rank,
const std::vector<std::string>& peer_endpoints, const std::vector<std::string>& peer_endpoints,
framework::BlockDesc* block, int ring_id); framework::BlockDesc* block, int ring_id);
bool FeedData(const std::vector<DistModelTensor>& input_data,
framework::Scope* scope);
bool FetchResults(std::vector<DistModelTensor>* output_data,
framework::Scope* scope);
template <typename T>
bool FetchResult(const framework::LoDTensor& fetch,
DistModelTensor* output_data);
std::string carrier_id_;
std::vector<framework::LoDTensor> feed_tensors_;
std::vector<framework::OpDesc*> feeds_; std::vector<framework::OpDesc*> feeds_;
std::map<std::string, int64_t> feed_names_; std::map<std::string, int64_t> feed_names_;
std::map<int64_t, std::string> idx_to_feeds_; std::map<int64_t, std::string> idx_to_feeds_;
std::map<std::string, DistModelDataType> feeds_to_dtype_;
std::vector<framework::OpDesc*> fetches_; std::vector<framework::OpDesc*> fetches_;
std::map<int64_t, std::string> id_to_fetches_; std::map<int64_t, std::string> idx_to_fetches_;
DistModelConfig config_; DistModelConfig config_;
FleetExecutorDesc executor_desc_; FleetExecutorDesc executor_desc_;
std::shared_ptr<FleetExecutor> fleet_exe; std::shared_ptr<FleetExecutor> fleet_exe;
......
...@@ -62,7 +62,7 @@ class DistModelDataBuf { ...@@ -62,7 +62,7 @@ class DistModelDataBuf {
void Free(); void Free();
void* data_{nullptr}; void* data_{nullptr};
size_t length_{0}; size_t length_{0};
bool memory_owned_{false}; bool memory_owned_{true};
}; };
struct DistModelTensor { struct DistModelTensor {
......
...@@ -162,7 +162,12 @@ void BindFleetExecutor(py::module* m) { ...@@ -162,7 +162,12 @@ void BindFleetExecutor(py::module* m) {
py::class_<DistModel>(*m, "DistModel") py::class_<DistModel>(*m, "DistModel")
.def(py::init<const DistModelConfig&>()) .def(py::init<const DistModelConfig&>())
.def("init", &DistModel::Init) .def("init", &DistModel::Init)
.def("run", &DistModel::Run, py::call_guard<py::gil_scoped_release>()); .def("run",
[](DistModel& self, const std::vector<DistModelTensor>& inputs) {
std::vector<DistModelTensor> outputs;
self.Run(inputs, &outputs);
return outputs;
});
py::class_<DistModelDataBuf>(*m, "DistModelDataBuf") py::class_<DistModelDataBuf>(*m, "DistModelDataBuf")
.def(py::init<size_t>()) .def(py::init<size_t>())
......
...@@ -156,7 +156,8 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32) ...@@ -156,7 +156,8 @@ if(((NOT WITH_ROCM) AND (NOT WITH_GPU)) OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_origin_scheduler)
LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper) LIST(REMOVE_ITEM TEST_OPS test_auto_parallel_mapper)
LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node) LIST(REMOVE_ITEM TEST_OPS test_fleet_executor_task_node)
LIST(REMOVE_ITEM TEST_OPS test_dist_model_tensor) LIST(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_run)
LIST(REMOVE_ITEM TEST_OPS test_fleet_exe_dist_model_tensor)
endif() endif()
# Temporally disable test_deprecated_decorator # Temporally disable test_deprecated_decorator
......
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle
import numpy as np
import os
from paddle.fluid import core
paddle.enable_static()
class TestDistModelRun(unittest.TestCase):
def test_dist_model_run(self):
# step 0: declare folder to save the model and params
folder = './dist_model_run_test/'
file = 'inf'
path_prefix = folder + file
# step 1: saving the inference model and params
x = paddle.static.data(name='x', shape=[28, 28], dtype='float32')
y = paddle.static.data(name='y', shape=[28, 1], dtype='int64')
predict = paddle.static.nn.fc(x, 10, activation='softmax')
loss = paddle.nn.functional.cross_entropy(predict, y)
avg_loss = paddle.tensor.stat.mean(loss)
exe = paddle.static.Executor(paddle.CUDAPlace(0))
exe.run(paddle.static.default_startup_program())
x_data = np.random.randn(28, 28).astype('float32')
y_data = np.random.randint(0, 9, size=[28, 1]).astype('int64')
exe.run(paddle.static.default_main_program(),
feed={'x': x_data,
'y': y_data},
fetch_list=[avg_loss])
paddle.static.save_inference_model(path_prefix, [x, y], [avg_loss], exe)
print('save model to', path_prefix)
# step 2: prepare fake data for the inference
x_tensor = np.random.randn(28, 28).astype('float32')
y_tensor = np.random.randint(0, 9, size=[28, 1]).astype('int64')
# step 3: init the dist model to inference with fake data
config = core.DistModelConfig()
config.model_dir = path_prefix
config.place = 'GPU'
dist = core.DistModel(config)
dist.init()
dist_x = core.DistModelTensor(x_tensor, 'x')
dist_y = core.DistModelTensor(y_tensor, 'y')
input_data = [dist_x, dist_y]
output_rst = dist.run(input_data)
dist_model_rst = output_rst[0].as_ndarray().ravel().tolist()
print("dist model rst:", dist_model_rst)
# step 4: use framework's api to inference with fake data
[inference_program, feed_target_names, fetch_targets] = (
paddle.static.load_inference_model(path_prefix, exe))
results = exe.run(inference_program,
feed={'x': x_tensor,
'y': y_tensor},
fetch_list=fetch_targets)
load_inference_model_rst = results[0]
print("load inference model api rst:", load_inference_model_rst)
# step 5: compare two results
self.assertTrue(np.allclose(dist_model_rst, load_inference_model_rst))
# step 6: clean up the env, delete the saved model and params
os.remove(path_prefix + '.pdiparams')
os.remove(path_prefix + '.pdmodel')
os.rmdir(folder)
print('cleaned up the env')
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved. # Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册