提交 acdb441e 编写于 作者: W wangguibao

pybind with async_executor

上级 24fa8eb6
...@@ -36,6 +36,7 @@ add_subdirectory(details) ...@@ -36,6 +36,7 @@ add_subdirectory(details)
endif (NOT WIN32) endif (NOT WIN32)
# ddim lib # ddim lib
proto_library(framework_proto SRCS framework.proto) proto_library(framework_proto SRCS framework.proto)
proto_library(async_executor_param SRCS async_executor_param.proto)
cc_library(ddim SRCS ddim.cc DEPS eigen3 boost) cc_library(ddim SRCS ddim.cc DEPS eigen3 boost)
cc_test(ddim_test SRCS ddim_test.cc DEPS ddim) cc_test(ddim_test SRCS ddim_test.cc DEPS ddim)
...@@ -177,7 +178,8 @@ endif() # NOT WIN32 ...@@ -177,7 +178,8 @@ endif() # NOT WIN32
cc_library(async_executor cc_library(async_executor
SRCS async_executor.cc data_feed.cc datafeed_creator.cc SRCS async_executor.cc data_feed.cc datafeed_creator.cc
DEPS op_registry device_context scope framework_proto glog DEPS op_registry device_context scope framework_proto glog
lod_rank_table feed_fetch_method graph_to_program_pass) lod_rank_table feed_fetch_method graph_to_program_pass
async_executor_param)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -16,6 +16,8 @@ limitations under the License. */ ...@@ -16,6 +16,8 @@ limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <string.h> #include <string.h>
#include <fcntl.h> #include <fcntl.h>
#include <sys/types.h>
#include <sys/stat.h>
#include <unistd.h> #include <unistd.h>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
...@@ -33,6 +35,7 @@ limitations under the License. */ ...@@ -33,6 +35,7 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/reader.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/pybind/pybind.h" #include "paddle/fluid/pybind/pybind.h"
namespace paddle { namespace paddle {
...@@ -543,7 +546,7 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) { ...@@ -543,7 +546,7 @@ void AsyncExecutor::PrepareThreads(const ProgramDesc& host_program) {
for (unsigned i = 0; i < thread_num_; ++i) { for (unsigned i = 0; i < thread_num_; ++i) {
// new a datafeed here // new a datafeed here
std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str()); std::shared_ptr<DataFeed> local_feed = CreateDataFeed(feed_name_.c_str());
local_feed->Init(data_feed_param_); local_feed->Init();
local_feed->SetBatchSize(batch_size_); local_feed->SetBatchSize(batch_size_);
workers_[i]->SetDataFeed(local_feed); workers_[i]->SetDataFeed(local_feed);
workers_[i]->BindingDataFeedMemory(); workers_[i]->BindingDataFeedMemory();
...@@ -564,7 +567,26 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) { ...@@ -564,7 +567,26 @@ void AsyncExecutor::RunAsyncExecutor(const ProgramDesc& host_program) {
} }
} }
} // end namespace framework void AsyncExecutor::LoadInitModel() {
auto place = paddle::platform::CPUPlace();
auto* executor = new paddle::framework::Executor(place);
std::string init_prog_file = model_path_ + "/" + init_prog_file_;
std::string init_model_file = model_path_ + "/" + init_model_file_;
struct stat stat_buf;
if (stat(init_prog_file.c_str(), &stat_buf) == 0 &&
S_ISREG(stat_buf.st_mode) &&
stat(init_model_file.c_str(), &stat_buf) == 0 &&
S_ISREG(stat_buf.st_mode)) {
paddle::inference::Load(executor,
GetRootScope(),
model_path_ + "/" + init_prog_file_,
model_path_ + "/" + init_model_file_);
}
}
} // einit_modelnd namespace framework
} // end namespace paddle } // end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */ /* vim: set expandtab ts=2 sw=2 sts=2 tw=100: */
...@@ -35,32 +35,33 @@ void CreateTensor(Variable* var, proto::VarType::Type var_type); ...@@ -35,32 +35,33 @@ void CreateTensor(Variable* var, proto::VarType::Type var_type);
class ExecutorThreadWorker { class ExecutorThreadWorker {
public: public:
ExecutorThreadWorker() {} ExecutorThreadWorker() {}
virtual ~ExecutorThreadWorker() {} ~ExecutorThreadWorker() {}
void CreateThreadScope(const framework::ProgramDesc& program); void CreateThreadScope(const framework::ProgramDesc& program);
void SetDataFeed(const DataFeed& datafeed); void SetDataFeed(const DataFeed& datafeed);
void SetThreadId(int tid); void SetThreadId(int tid);
void CreateThreadOperators(const framework::ProgramDesc& program); void CreateThreadOperators(const framework::ProgramDesc& program);
void SetRootScope(Scope* g_scope); void SetRootScope(Scope* g_scope);
void SetDevice(); void SetDevice();
virtual void AddFidSet(); void AddFidSet();
void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; } void SetCommBatch(int comm_batch) { comm_batch_ = comm_batch; }
void AddTrainFile(const std::string& filename); void AddTrainFile(const std::string& filename);
void SetMainProgram(const ProgramDesc& main_program_desc); void SetMainProgram(const ProgramDesc& main_program_desc);
void SetPlace(const paddle::platform::Place& place); void SetPlace(const paddle::platform::Place& place);
void SetMaxTrainingEpoch(const int max_epoch); void SetMaxTrainingEpoch(const int max_epoch);
void BindingDataFeedMemory(); void BindingDataFeedMemory();
void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; } void SetModelPrefix(const std::string& prefix) { model_prefix_ = prefix; }
void SetInspectVarName(const std::string& inspect_var_name); void SetInspectVarName(const std::string& inspect_var_name);
void SetModelParamNames(const std::vector<std::string>& param_names); void SetModelParamNames(const std::vector<std::string>& param_names);
void SetSparseCommData(const std::map<std::string, int>& param_names); void SetSparseCommData(const std::map<std::string, int>& param_names);
void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed); void SetDataFeed(const std::shared_ptr<DataFeed>& datafeed);
void Train(); void Train();
virtual const char* PickOneFile(); const char* PickOneFile();
void UpdateEpochNum(); void UpdateEpochNum();
virtual void SetDenseCommTensor( void SetDenseCommTensor(const std::vector<std::string>& param_names) {}
const std::vector<std::string>& param_names) {} void Initialize() {}
virtual void Initialize() {}
public: public:
static std::mutex s_locker_for_pick_file_; static std::mutex s_locker_for_pick_file_;
...@@ -122,13 +123,20 @@ class AsyncExecutor { ...@@ -122,13 +123,20 @@ class AsyncExecutor {
void SetFileList(const char* filelist); void SetFileList(const char* filelist);
void SetFileList(const std::vector<std::string> filelist); void SetFileList(const std::vector<std::string> filelist);
void SetDataFeedName(const char* feedname); void SetDataFeedName(const char* feedname);
void SetCommBatch(int comm_batch) {
comm_batch_ = comm_batch;
}
void SetDataFeedParam(const datafeed::DataFeedParameter& feed_param) { void SetModelPath(const std::string& model_path) {
data_feed_param_ = feed_param; model_path_ = model_path;
} }
void SetCommBatch(int comm_batch) { void SetInitProgFile(const std::string& init_prog_file) {
comm_batch_ = comm_batch; init_prog_file_ = init_prog_file;
}
void SetInitModelFile(const std::string& init_model_file) {
init_model_file_ = init_model_file;
} }
void SetModelPrefix(const std::string& model_prefix); void SetModelPrefix(const std::string& model_prefix);
...@@ -141,9 +149,10 @@ class AsyncExecutor { ...@@ -141,9 +149,10 @@ class AsyncExecutor {
framework::Scope* scope); framework::Scope* scope);
void RunAsyncExecutor(const ProgramDesc& host_program); void RunAsyncExecutor(const ProgramDesc& host_program);
void LoadInitModel();
public: public:
unsigned int thread_num_; unsigned int thread_num_;
datafeed::DataFeedParameter data_feed_param_;
int max_epoch_; int max_epoch_;
int batch_size_; int batch_size_;
int comm_batch_; int comm_batch_;
...@@ -156,6 +165,9 @@ class AsyncExecutor { ...@@ -156,6 +165,9 @@ class AsyncExecutor {
std::vector<std::string> sparse_comm_tensor_; std::vector<std::string> sparse_comm_tensor_;
std::map<std::string, int> sparse_comm_data_; std::map<std::string, int> sparse_comm_data_;
std::string model_prefix_; std::string model_prefix_;
std::string model_path_;
std::string init_prog_file_;
std::string init_model_file_;
std::string feed_name_; std::string feed_name_;
Scope* root_scope_; Scope* root_scope_;
platform::Place place_; platform::Place place_;
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package paddle;
message AsyncExecutorParameter {
optional BaseParameter base_param = 1;
repeated AucCalculatorParameter auc_calculator_parameter = 5;
repeated PnCalculatorParameter pn_calculator_parameter = 6;
}
message JobParameter {
optional string job_name = 1;
optional string startup_prog_file = 2;
optional string main_prog_file = 3;
repeated int32 push_sparse_table_id = 4;
repeated int32 push_dense_table_id = 5;
repeated int32 pull_sparse_table_id = 6;
repeated int32 pull_dense_table_id = 7;
optional int32 slot_dim = 8;
optional int32 fea_dim = 9;
optional bool use_cvm_feature = 10;
optional string inspect_var_name = 20;
optional string auc_calculator_name = 21;
optional string mean_var_name = 22;
optional string pn_calculator_name = 23;
repeated string debug_layer_name = 24;
}
message BaseParameter {
optional int32 thread_num = 1 [default = 1];
optional string datafeed_class = 2;
optional string startup_prog_file = 3;
optional string main_prog_file = 4;
optional string filelist = 5;
repeated string model_param_names = 6;
optional int32 max_epoch = 7;
optional string model_path = 8;
optional string model_prefix = 9;
optional int32 batch_size = 10;
optional string training_dir = 11; //local data path
optional string data_feed_param = 12;
optional string init_prog_file = 13;
optional string init_model_file = 14;
optional string inspect_var_name = 15;
repeated string input_variable_name = 16;
optional int32 download_thread_num = 17 [default = 12];
repeated JobParameter train_job = 19;
optional bool save_by_hour = 20 [default = true];
optional bool need_global_shuffle = 21 [default = true];
optional string converter_name = 22;
optional string data_converter = 23;
optional int32 checkpoint_per_pass = 24;
}
message AucCalculatorParameter {
optional string name = 1;
optional string output = 2;
}
message PnCalculatorParameter {
optional string name = 1;
optional string output_1 = 2;
optional string output_2 = 3;
}
...@@ -38,7 +38,7 @@ DEFINE_bool(is_text_feed, false, "is_text_feed"); ...@@ -38,7 +38,7 @@ DEFINE_bool(is_text_feed, false, "is_text_feed");
namespace paddle { namespace paddle {
namespace framework { namespace framework {
void TextClassDataFeed::Init(const datafeed::DataFeedParameter& feed_param) { void TextClassDataFeed::Init() {
// hard coding for a specific datafeed // hard coding for a specific datafeed
feed_vec_.resize(2); feed_vec_.resize(2);
// feed_vec_[0].reset(new LoDTensor); // feed_vec_[0].reset(new LoDTensor);
......
...@@ -31,7 +31,6 @@ limitations under the License. */ ...@@ -31,7 +31,6 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/scope.h" #include "paddle/fluid/framework/scope.h"
#include "proto/FeedDataParameter.pb.h"
namespace paddle { namespace paddle {
namespace framework { namespace framework {
...@@ -52,7 +51,7 @@ class DataFeed { ...@@ -52,7 +51,7 @@ class DataFeed {
public: public:
DataFeed() {} DataFeed() {}
virtual ~DataFeed() {} virtual ~DataFeed() {}
virtual void Init(const datafeed::DataFeedParameter& feed_param) = 0; virtual void Init() = 0;
/* /*
* This function will be used to check file format. * This function will be used to check file format.
* Considering that this function may be used alone, * Considering that this function may be used alone,
...@@ -101,7 +100,7 @@ class DataFeed { ...@@ -101,7 +100,7 @@ class DataFeed {
class TextClassDataFeed : public DataFeed { class TextClassDataFeed : public DataFeed {
public: public:
virtual ~TextClassDataFeed() {} virtual ~TextClassDataFeed() {}
virtual void Init(const datafeed::DataFeedParameter& feed_param); virtual void Init();
virtual bool ReadBatch(); virtual bool ReadBatch();
virtual void AddFeedVar(Variable* feed, const std::string& name); virtual void AddFeedVar(Variable* feed, const std::string& name);
virtual void BindScope(Scope* scope) {} virtual void BindScope(Scope* scope) {}
......
cc_library(inference_io SRCS io.cc)
# analysis and tensorrt must be added before creating static library, # analysis and tensorrt must be added before creating static library,
# otherwise, there would be undefined reference to them in static library. # otherwise, there would be undefined reference to them in static library.
add_subdirectory(analysis) add_subdirectory(analysis)
......
set(PYBIND_DEPS pybind python proto_desc memory executor async_executor prune feed_fetch_method pass_builder) set(PYBIND_DEPS pybind python proto_desc memory executor async_executor
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc) inference_io prune feed_fetch_method pass_builder)
set(PYBIND_SRCS pybind.cc exception.cc protobuf.cc const_value.cc async_executor_py.cc)
if(NOT WIN32) if(NOT WIN32)
list(APPEND PYBIND_DEPS parallel_executor profiler) list(APPEND PYBIND_DEPS parallel_executor profiler)
list(APPEND PYBIND_SRCS recordio.cc) list(APPEND PYBIND_SRCS recordio.cc)
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <fcntl.h>
// To avoid conflicting definition in gcc-4.8.2 headers and pyconfig.h (2.7.3)
#ifdef _POSIX_C_SOURCE
#undef _POSIX_C_SOURCE
#endif
#ifdef _XOPEN_SOURCE
#undef _XOPEN_SOURCE
#endif
#include "google/protobuf/text_format.h"
#include "google/protobuf/io/zero_copy_stream_impl.h"
#include "paddle/fluid/inference/io.h"
#include "paddle/fluid/platform/variant.h"
#include "paddle/fluid/platform/place.h"
#include "paddle/fluid/framework/async_executor_param.pb.h"
#include "paddle/fluid/framework/async_executor.h"
#include "paddle/fluid/pybind/async_executor_py.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindAsyncExecutor(py::module* m) {
py::class_<paddle::AsyncExecutorParameter>(*m, "AsyncExecutorParameter")
.def(py::init<>())
.def("parse",
[](paddle::AsyncExecutorParameter &self, const std::string &conf_file) {
int file_descriptor = open(conf_file.c_str(), O_RDONLY);
google::protobuf::io::FileInputStream file_input(file_descriptor);
google::protobuf::TextFormat::Parse(&file_input, &self);
close(file_descriptor);
}
);
py::class_<framework::AsyncExecutor>(*m, "AsyncExecutor")
.def(py::init<const platform::Place&>())
.def("init",
[](framework::AsyncExecutor &self,
paddle::AsyncExecutorParameter &parameter,
framework::Scope *scope) {
paddle::BaseParameter base_param = parameter.base_param();
// TODO Extract parameter list from python side, instead of
// providing them in confgurations manually
std::vector<std::string> param_names;
for (int i = 0; i < base_param.model_param_names_size(); ++i) {
param_names.push_back(base_param.model_param_names(i));
}
#ifdef FORK_V1
paddle::framework::InitDevices();
#else
paddle::framework::InitDevices(false);
#endif
self.InitRootScope(scope);
self.SetThreadNum(base_param.thread_num());
self.SetMaxTrainingEpoch(base_param.max_epoch());
self.SetFileList(base_param.filelist().c_str());
self.SetBatchSize(base_param.batch_size());
self.SetDataFeedName(base_param.datafeed_class().c_str());
self.SetInspectVarName(base_param.inspect_var_name());
self.SetParamNames(param_names);
self.SetModelPath(base_param.model_path());
self.SetModelPrefix(base_param.model_prefix());
self.SetInitProgFile(base_param.init_prog_file());
self.SetInitModelFile(base_param.init_model_file());
return;
}
)
.def("run_startup_program", &framework::AsyncExecutor::RunStartupProgram)
.def("load_init_model", &framework::AsyncExecutor::LoadInitModel)
.def("run", &framework::AsyncExecutor::RunAsyncExecutor);
} // end BindAsyncExecutor
} // end namespace framework
} // end namespace paddle
/* vim: set expandtab ts=2 sw=2 sts=2 tw=80: */
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#define PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
#include "pybind11/pybind11.h"
namespace py = pybind11;
namespace paddle {
namespace pybind {
void BindAsyncExecutor(py::module* m);
} // namespace pybind
} // namespace paddle
#endif // PADDLE_FLUID_PYBIND_ASYNC_EXECUTOR_PY_H_
...@@ -46,6 +46,7 @@ limitations under the License. */ ...@@ -46,6 +46,7 @@ limitations under the License. */
#include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/pybind.h" // NOLINT
#include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/recordio.h"
#include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/pybind/tensor_py.h"
#include "paddle/fluid/pybind/async_executor_py.h"
#include "paddle/fluid/string/to_string.h" #include "paddle/fluid/string/to_string.h"
...@@ -860,6 +861,7 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -860,6 +861,7 @@ All parameter, weight, gradient are variables in Paddle.
}); });
BindRecordIOWriter(&m); BindRecordIOWriter(&m);
BindAsyncExecutor(&m);
return m.ptr(); return m.ptr();
} }
} // namespace pybind } // namespace pybind
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
syntax = "proto2";
package datafeed;
message DataFeedParameter {
optional FeedDataParameter feed_data_param = 1;
optional JointOneHotParameter joint_onehot_data_param = 2;
optional ACDXParameter acdx_data_param = 3;
}
message FeedDataParameter {
repeated int32 slot_id = 1;
repeated int32 use_slot_id = 2;
repeated string use_slot_alias = 3;
repeated uint64 use_slot_mod = 4;
repeated int32 use_slot_type = 5;
optional int32 max_batch_num = 6 [default = 128];
optional int32 max_feasign_num = 7 [default = 1000];
}
message JointOneHotParameter {
optional int32 max_batch_num = 1 [default = 128];
optional int32 max_title_num = 2 [default = 400];
optional int32 max_term_num = 3 [default = 1024];
required float sampling_rate = 4;
repeated int32 slot_id = 5;
repeated int32 use_slot_id = 6;
repeated string use_slot_alias = 7;
repeated uint64 use_slot_mod = 8;
repeated int32 use_slot_type = 9;
}
message ACDXParameter {
optional int32 max_batch_num = 1 [default = 128];
optional int32 max_term_num = 3 [default = 512];
}
...@@ -19,6 +19,10 @@ from .framework import * ...@@ -19,6 +19,10 @@ from .framework import *
# import all class inside executor into fluid module # import all class inside executor into fluid module
from . import executor from . import executor
from .executor import * from .executor import *
from . import async_executor
from .async_executor import *
from . import trainer from . import trainer
from . import inferencer from . import inferencer
...@@ -52,7 +56,8 @@ Tensor = LoDTensor ...@@ -52,7 +56,8 @@ Tensor = LoDTensor
__all__ = framework.__all__ + executor.__all__ + \ __all__ = framework.__all__ + executor.__all__ + \
trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \ trainer.__all__ + inferencer.__all__ + transpiler.__all__ + \
parallel_executor.__all__ + lod_tensor.__all__ + [ parallel_executor.__all__ + lod_tensor.__all__ + \
async_executor.__all__ + [
'io', 'io',
'initializer', 'initializer',
'layers', 'layers',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import contextlib
import six
from .framework import Program, default_main_program, Variable
from . import core
from . import Executor
__all__ = ['AsyncExecutorParameter', 'AsyncExecutor']
g_scope = core.Scope()
class AsyncExecutorParameter(object):
"""
AsyncExecutor configure parameter
Args:
None
"""
def __init__(self):
self.parameter = core.AsyncExecutorParameter()
def parse(self, conf_file):
self.parameter.parse(conf_file)
class AsyncExecutor(object):
"""
An asynchronous Executor in Python
Args:
place(core.CPUPlace|core.CUDAPlace(n)): indicate the executor run on which device
Note: For debugging complicated network in parallel-GPUs, you can test it on the executor.
They has the exactly same arguments, and expected the same results.
"""
def __init__(self,
async_executor_parameter,
place,
scope):
if not isinstance(async_executor_parameter, AsyncExecutorParameter):
raise TypeError(
"AsyncExecutor requires AsyncExecutorParameter as its parameter. "
"But you passed in %s" %s (type(async_executor_parameter))
)
self.place = place
p = core.Place()
p.set_place(place)
self.executor = core.AsyncExecutor(p)
self.executor.init(async_executor_parameter.parameter, scope)
self._closed = False
self.parameter = async_executor_parameter.parameter
def close(self):
"""
Close this executor.
You can no long use this executor after calling this method.
For the distributed training, this method would free the resource on PServers related to
the current Trainer.
Example:
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> ...
>>> exe.close()
"""
if not self._closed:
self._closed = True
def run_startup_program(self,
program=None,
scope=None):
if program is None:
program = default_startup_program()
program_desc = program._get_desc()
if scope is None:
scope = g_scope
self.executor.run_startup_program(program_desc, scope)
def run(self, program=None, scope=None):
"""
Run program by this Executor. Feed data by feed map, fetch result by fetch_list.
Python executor takes a program, add feed operators and fetch operators to this program according
to feed map and fetch_list. Feed map provides input data for the program. fetch_list provides
the variables(or names) that user want to get after program run.
Note: the executor will run all
operators in the program but not only the operators dependent by the fetch_list
Args:
program(Program): the program that need to run, if not provied, then default_main_program will be used.
feed(dict): feed variable map, e.g. {"image": ImageData, "label": LableData}
fetch_list(list): a list of variable or variable names that user want to get, run will return them according to this list.
feed_var_name(str): the name for the input variable of feed Operator.
fetch_var_name(str): the name for the output variable of fetch Operator.
scope(Scope): the scope used to run this program, you can switch it to different scope. default is global_scope
return_numpy(bool): if convert the fetched tensor to numpy
use_program_cache(bool): set use_program_cache to true if program not changed compare to the last step.
Returns:
list(numpy.array): fetch result according to fetch_list.
Examples:
>>> data = layers.data(name='X', shape=[1], dtype='float32')
>>> hidden = layers.fc(input=data, size=10)
>>> layers.assign(hidden, out)
>>> loss = layers.mean(out)
>>> adam = fluid.optimizer.Adam()
>>> adam.minimize(loss)
>>> cpu = core.CPUPlace()
>>> exe = Executor(cpu)
>>> exe.run(default_startup_program())
>>> x = numpy.random.random(size=(10, 1)).astype('float32')
>>> outs = exe.run(
>>> feed={'X': x},
>>> fetch_list=[loss.name])
"""
if self._closed:
raise RuntimeError("Attempted to use a closed Executor")
if program is None:
program = default_main_program()
program_desc = program.desc
if not isinstance(program, Program):
raise TypeError(
"Executor requires Program as its Parameter. But you passed in %s"
% (type(program)))
if scope is None:
scope = g_scope
self.executor.run(program.desc)
def load_init_model(self):
return self.executor.load_init_model()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册