提交 5e513928 编写于 作者: X xjqbest

fix runtime error

test=develop
上级 514d727a
...@@ -237,6 +237,7 @@ void FleetWrapper::PushDenseParamSync( ...@@ -237,6 +237,7 @@ void FleetWrapper::PushDenseParamSync(
std::vector<paddle::ps::Region> regions; std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) { for (auto& t : var_names) {
Variable* var = scope.FindVar(t); Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>(); LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place); float* g = tensor->mutable_data<float>(place);
paddle::ps::Region reg(g, tensor->numel()); paddle::ps::Region reg(g, tensor->numel());
......
...@@ -126,7 +126,7 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read, ...@@ -126,7 +126,7 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
} }
close_open_fds_internal(); close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) { if (execl("/bin/bash", "bash", "-c", real_cmd, NULL) < 0) {
return -1; return -1;
} }
exit(127); exit(127);
......
...@@ -712,7 +712,7 @@ class Executor(object): ...@@ -712,7 +712,7 @@ class Executor(object):
if dataset == None: if dataset == None:
raise RuntimeError("dataset is needed and should be initialized") raise RuntimeError("dataset is needed and should be initialized")
if self.place == paddle.fluid.CUDAPlace(): if not isinstance(self.place, core.CPUPlace):
raise RuntimeError("infer_from_dataset is verified on CPUPlace" raise RuntimeError("infer_from_dataset is verified on CPUPlace"
"We will open CUDAPlace in the future") "We will open CUDAPlace in the future")
......
...@@ -123,18 +123,23 @@ class Fleet(object): ...@@ -123,18 +123,23 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first") print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1) sys.exit(-1)
def init_worker(self, programs): def init_worker(self, programs, scopes=None):
""" """
init_worker(): will be called by user. When a user knows current process is_server(), he/she init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect should call init_worker() to initialize global information about worker and connect
worker with pserver. worker with pserver. You should run startup program before init_worker.
Args: Args:
programs(Program|list): a Program or a list of Programs programs(Program|list): a Program or a list of Programs
scopes(Scope|list): a Scope or a list of Scopes, default None.
""" """
if not isinstance(programs, list): if not isinstance(programs, list):
programs = [programs] programs = [programs]
if scopes is None:
scopes = [fluid.global_scope()] * len(programs)
if len(scopes) != len(programs):
print("You should make sure len(scopes) == len(programs) or set scopes None")
sys.exit(-1)
if self._opt_info: if self._opt_info:
if "fleet_desc" in self._opt_info: if "fleet_desc" in self._opt_info:
self._dist_desc_str = text_format.MessageToString( self._dist_desc_str = text_format.MessageToString(
...@@ -160,7 +165,7 @@ class Fleet(object): ...@@ -160,7 +165,7 @@ class Fleet(object):
self.role_maker_._barrier_worker() self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker(): if self.role_maker_._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table tables = self._dist_desc.trainer_param.dense_table
for prog in programs: for prog, scope in zip(programs, scopes):
prog_id = str(id(prog)) prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id] prog_conf = self._opt_info['program_configs'][prog_id]
prog_tables = {} prog_tables = {}
...@@ -174,8 +179,13 @@ class Fleet(object): ...@@ -174,8 +179,13 @@ class Fleet(object):
continue continue
var_name_list = [] var_name_list = []
for i in range(0, len(table.dense_variable_name)): for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i]) var_name = table.dense_variable_name[i]
self._fleet_ptr.init_model(prog.desc, if scope.find_var(var_name) is None:
print("var " + var_name + " not found in scope, "
"you should run startup program first")
sys.exit(-1)
var_name_list.append(var_name)
self._fleet_ptr.init_model(scope,
int(table.table_id), int(table.table_id),
var_name_list) var_name_list)
# barrier for init model done # barrier for init model done
......
...@@ -107,10 +107,12 @@ class TestDataset(unittest.TestCase): ...@@ -107,10 +107,12 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for i in range(2): for i in range(2):
try: #try:
exe.train_from_dataset(fluid.default_main_program(), dataset) exe.train_from_dataset(fluid.default_main_program(), dataset)
except: #except ImportError as e:
self.assertTrue(False) # pass
#except Exception as e:
# self.assertTrue(False)
os.remove("./test_in_memory_dataset_run_a.txt") os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt") os.remove("./test_in_memory_dataset_run_b.txt")
...@@ -149,10 +151,12 @@ class TestDataset(unittest.TestCase): ...@@ -149,10 +151,12 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace()) exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for i in range(2): for i in range(2):
try: #try:
exe.train_from_dataset(fluid.default_main_program(), dataset) exe.train_from_dataset(fluid.default_main_program(), dataset)
except: #except ImportError as e:
self.assertTrue(False) # pass
#except Exception as e:
# self.assertTrue(False)
os.remove("./test_queue_dataset_run_a.txt") os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt") os.remove("./test_queue_dataset_run_b.txt")
......
...@@ -23,7 +23,7 @@ class TrainerDesc(object): ...@@ -23,7 +23,7 @@ class TrainerDesc(object):
with open(proto_file, 'r') as f: with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc) text_format.Parse(f.read(), self.proto_desc)
''' '''
from proto import trainer_desc_pb2 from .proto import trainer_desc_pb2
self.proto_desc = trainer_desc_pb2.TrainerDesc() self.proto_desc = trainer_desc_pb2.TrainerDesc()
import multiprocessing as mp import multiprocessing as mp
# set default thread num == cpu count # set default thread num == cpu count
......
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
__all__ = ["TrainerFactory"] __all__ = ["TrainerFactory"]
...@@ -20,8 +23,6 @@ class TrainerFactory(object): ...@@ -20,8 +23,6 @@ class TrainerFactory(object):
pass pass
def _create_trainer(self, opt_info=None): def _create_trainer(self, opt_info=None):
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
trainer = None trainer = None
device_worker = None device_worker = None
if opt_info == None: if opt_info == None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册