提交 5e513928 编写于 作者: X xjqbest

fix runtime error

test=develop
上级 514d727a
......@@ -237,6 +237,7 @@ void FleetWrapper::PushDenseParamSync(
std::vector<paddle::ps::Region> regions;
for (auto& t : var_names) {
Variable* var = scope.FindVar(t);
CHECK(var != nullptr) << "var[" << t << "] not found";
LoDTensor* tensor = var->GetMutable<LoDTensor>();
float* g = tensor->mutable_data<float>(place);
paddle::ps::Region reg(g, tensor->numel());
......
......@@ -126,7 +126,7 @@ static int shell_popen_fork_internal(const char* real_cmd, bool do_read,
}
close_open_fds_internal();
if (execl("/bin/sh", "sh", "-c", real_cmd, NULL) < 0) {
if (execl("/bin/bash", "bash", "-c", real_cmd, NULL) < 0) {
return -1;
}
exit(127);
......
......@@ -712,7 +712,7 @@ class Executor(object):
if dataset == None:
raise RuntimeError("dataset is needed and should be initialized")
if self.place == paddle.fluid.CUDAPlace():
if not isinstance(self.place, core.CPUPlace):
raise RuntimeError("infer_from_dataset is verified on CPUPlace"
"We will open CUDAPlace in the future")
......
......@@ -123,18 +123,23 @@ class Fleet(object):
print("You should run DistributedOptimizer.minimize() first")
sys.exit(-1)
def init_worker(self, programs):
def init_worker(self, programs, scopes=None):
"""
init_worker(): will be called by user. When a user knows current process is_server(), he/she
should call init_worker() to initialize global information about worker and connect
worker with pserver.
worker with pserver. You should run startup program before init_worker.
Args:
programs(Program|list): a Program or a list of Programs
scopes(Scope|list): a Scope or a list of Scopes, default None.
"""
if not isinstance(programs, list):
programs = [programs]
if scopes is None:
scopes = [fluid.global_scope()] * len(programs)
if len(scopes) != len(programs):
print("You should make sure len(scopes) == len(programs) or set scopes None")
sys.exit(-1)
if self._opt_info:
if "fleet_desc" in self._opt_info:
self._dist_desc_str = text_format.MessageToString(
......@@ -160,7 +165,7 @@ class Fleet(object):
self.role_maker_._barrier_worker()
if self.role_maker_._is_first_worker():
tables = self._dist_desc.trainer_param.dense_table
for prog in programs:
for prog, scope in zip(programs, scopes):
prog_id = str(id(prog))
prog_conf = self._opt_info['program_configs'][prog_id]
prog_tables = {}
......@@ -174,8 +179,13 @@ class Fleet(object):
continue
var_name_list = []
for i in range(0, len(table.dense_variable_name)):
var_name_list.append(table.dense_variable_name[i])
self._fleet_ptr.init_model(prog.desc,
var_name = table.dense_variable_name[i]
if scope.find_var(var_name) is None:
print("var " + var_name + " not found in scope, "
"you should run startup program first")
sys.exit(-1)
var_name_list.append(var_name)
self._fleet_ptr.init_model(scope,
int(table.table_id),
var_name_list)
# barrier for init model done
......
......@@ -107,10 +107,12 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
#try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
self.assertTrue(False)
#except ImportError as e:
# pass
#except Exception as e:
# self.assertTrue(False)
os.remove("./test_in_memory_dataset_run_a.txt")
os.remove("./test_in_memory_dataset_run_b.txt")
......@@ -149,10 +151,12 @@ class TestDataset(unittest.TestCase):
exe = fluid.Executor(fluid.CPUPlace())
exe.run(fluid.default_startup_program())
for i in range(2):
try:
#try:
exe.train_from_dataset(fluid.default_main_program(), dataset)
except:
self.assertTrue(False)
#except ImportError as e:
# pass
#except Exception as e:
# self.assertTrue(False)
os.remove("./test_queue_dataset_run_a.txt")
os.remove("./test_queue_dataset_run_b.txt")
......
......@@ -23,7 +23,7 @@ class TrainerDesc(object):
with open(proto_file, 'r') as f:
text_format.Parse(f.read(), self.proto_desc)
'''
from proto import trainer_desc_pb2
from .proto import trainer_desc_pb2
self.proto_desc = trainer_desc_pb2.TrainerDesc()
import multiprocessing as mp
# set default thread num == cpu count
......
......@@ -12,6 +12,9 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
__all__ = ["TrainerFactory"]
......@@ -20,8 +23,6 @@ class TrainerFactory(object):
pass
def _create_trainer(self, opt_info=None):
from .trainer_desc import MultiTrainer, DistMultiTrainer
from .device_worker import Hogwild, DownpourSGD
trainer = None
device_worker = None
if opt_info == None:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册