提交 f5c6a14b 编写于 作者: X xujiaqi01 提交者: dongdaxiang

fix runtime error

上级 a5b1a0e1
...@@ -170,7 +170,8 @@ void FleetWrapper::PullDenseVarsAsync( ...@@ -170,7 +170,8 @@ void FleetWrapper::PullDenseVarsAsync(
const std::vector<std::string>& var_names, const std::vector<std::string>& var_names,
std::vector<::std::future<int32_t>>* pull_dense_status) { std::vector<::std::future<int32_t>>* pull_dense_status) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions; auto& regions = _regions[tid];
regions.clear();
regions.resize(var_names.size()); regions.resize(var_names.size());
for (auto i = 0u; i < var_names.size(); ++i) { for (auto i = 0u; i < var_names.size(); ++i) {
Variable* var = scope.FindVar(var_names[i]); Variable* var = scope.FindVar(var_names[i]);
...@@ -189,7 +190,8 @@ void FleetWrapper::PullDenseVarsSync( ...@@ -189,7 +190,8 @@ void FleetWrapper::PullDenseVarsSync(
const Scope& scope, const uint64_t tid, const Scope& scope, const uint64_t tid,
const std::vector<std::string>& var_names) { const std::vector<std::string>& var_names) {
#ifdef PADDLE_WITH_PSLIB #ifdef PADDLE_WITH_PSLIB
std::vector<paddle::ps::Region> regions; auto& regions = _regions[tid];
regions.clear();
regions.reserve(var_names.size()); regions.reserve(var_names.size());
for (auto& t : var_names) { for (auto& t : var_names) {
Variable* var = scope.FindVar(t); Variable* var = scope.FindVar(t);
......
...@@ -146,6 +146,7 @@ class FleetWrapper { ...@@ -146,6 +146,7 @@ class FleetWrapper {
private: private:
static std::shared_ptr<FleetWrapper> s_instance_; static std::shared_ptr<FleetWrapper> s_instance_;
std::map<uint64_t, std::vector<paddle::ps::Region>> _regions;
protected: protected:
static bool is_initialized_; static bool is_initialized_;
......
...@@ -74,6 +74,7 @@ class MPIRoleMaker(RoleMakerBase): ...@@ -74,6 +74,7 @@ class MPIRoleMaker(RoleMakerBase):
""" """
def __init__(self): def __init__(self):
super(MPIRoleMaker, self).__init__()
from mpi4py import MPI from mpi4py import MPI
self.comm_ = MPI.COMM_WORLD self.comm_ = MPI.COMM_WORLD
self.MPI = MPI self.MPI = MPI
......
...@@ -141,9 +141,9 @@ class DistributedAdam(DistributedOptimizerImplBase): ...@@ -141,9 +141,9 @@ class DistributedAdam(DistributedOptimizerImplBase):
data_norm_params, data_norm_grads) data_norm_params, data_norm_grads)
#program_config.pull_dense_table_id.extend([dense_table_index]) #program_config.pull_dense_table_id.extend([dense_table_index])
#program_config.push_dense_table_id.extend([dense_table_index]) #program_config.push_dense_table_id.extend([dense_table_index])
program_config[program_id]["pull_dense"].extend( program_configs[program_id]["pull_dense"].extend(
[dense_table_index]) [dense_table_index])
program_config[program_id]["push_dense"].extend( program_configs[program_id]["push_dense"].extend(
[dense_table_index]) [dense_table_index])
dense_table_index += 1 dense_table_index += 1
#program_configs.append(program_config) #program_configs.append(program_config)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册