未验证 提交 31cd9145 编写于 作者: Z zmx 提交者: GitHub

[heterps]bug fix for local training with --heter_worker_num (#37166)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop
上级 895692e3
...@@ -59,12 +59,6 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc, ...@@ -59,12 +59,6 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num(); thread_num_ = trainer_desc.thread_num();
ParseDumpConfig(trainer_desc); ParseDumpConfig(trainer_desc);
SetDebug(trainer_desc.debug()); SetDebug(trainer_desc.debug());
// for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
// i++) {
// need_merge_var_names_.push_back(
// trainer_desc.downpour_param().stat_var_names(i));
//}
// get filelist from trainer_desc here
const std::vector<paddle::framework::DataFeed*> readers = const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders(); dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size(); VLOG(3) << "readers num: " << readers.size();
......
...@@ -12,8 +12,7 @@ ...@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \ #if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE)
(defined PADDLE_WITH_PSCORE)
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include "paddle/fluid/framework/device_worker.h" #include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h" #include "paddle/fluid/framework/device_worker_factory.h"
......
...@@ -347,11 +347,6 @@ class HeterPipelineTrainer : public TrainerBase { ...@@ -347,11 +347,6 @@ class HeterPipelineTrainer : public TrainerBase {
int thread_num_; int thread_num_;
std::vector<std::thread> threads_; std::vector<std::thread> threads_;
std::vector<std::string> need_merge_var_names_;
#ifdef PADDLE_WITH_HETERPS
std::vector<platform::Place> places_;
#endif
int num_microbatches_; int num_microbatches_;
platform::Place place_; platform::Place place_;
TrainerDesc trainer_desc_; TrainerDesc trainer_desc_;
......
...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP) && \ #if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE)
(defined PADDLE_WITH_PSCORE)
#include <stdlib.h> #include <stdlib.h>
#include <memory> #include <memory>
......
...@@ -199,17 +199,19 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra ...@@ -199,17 +199,19 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"--heter_workers", "--heter_workers",
type=str, type=str,
default="", default="",
help="User defined heter workers ip:port") help="User defined heter workers in each stage ip1:port1;ip2:port2")
ps_group.add_argument( ps_group.add_argument(
"--heter_devices", "--heter_devices",
type=str, type=str,
default="", default="",
help="User defined heter devices") help="User defined heter devices in each stage cpu;gpu;cpu")
ps_group.add_argument("--worker_num", type=int, help="number of workers") ps_group.add_argument("--worker_num", type=int, help="number of workers")
ps_group.add_argument("--server_num", type=int, help="number of servers") ps_group.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument( ps_group.add_argument(
"--heter_worker_num", type=int, help="number of heter_workers") "--heter_worker_num",
type=str,
help="number of heter_workers in each stage 1;2;3")
ps_group.add_argument("--http_port", type=int, help="Gloo http Port") ps_group.add_argument("--http_port", type=int, help="Gloo http Port")
# parameter elastic mode # parameter elastic mode
...@@ -496,13 +498,15 @@ def launch(): ...@@ -496,13 +498,15 @@ def launch():
- ``--workers``: User defined workers ip:port, e.g., ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"`` - ``--workers``: User defined workers ip:port, e.g., ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"``
- ``--heter_workers``: User defined heter workers ip:port, e.g., ``--heter_workers="192.168.0.16:6172,192.168.0.17:6172"`` - ``--heter_workers``: User defined heter workers ip1:port1;ip2:port2, e.g., ``--heter_workers="192.168.0.16:6172;192.168.0.17:6172"``
- ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node) - ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node)
- ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node) - ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node)
- ``--heter_worker_num``: Number of heter_workers (It recommend to set when in the emulated distributed environment using single node) - ``--heter_worker_num``: Number of heter_workers in each stage (It recommend to set when in the emulated distributed environment using single node)
- ``--heter_devices``: Type of heter_device in each stage
- ``--http_port``: Gloo http Port - ``--http_port``: Gloo http Port
......
...@@ -768,44 +768,44 @@ def get_custom_endpoints(origin_endpoints, offset=0): ...@@ -768,44 +768,44 @@ def get_custom_endpoints(origin_endpoints, offset=0):
return paddle_user_define_endpoints return paddle_user_define_endpoints
def cloud_ps_heter_env_set(args): #def cloud_ps_heter_env_set(args):
environs = {} # environs = {}
#
paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "") # paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "")
assert paddle_trainer_endpoints != None # assert paddle_trainer_endpoints != None
#
paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "") # paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "")
assert paddle_pserver_endpoints != None # assert paddle_pserver_endpoints != None
#
# hard code for paddlecloud custom-framework # # hard code for paddlecloud custom-framework
avilable_ports = os.getenv("TRAINER_PORTS", "").split(",") # avilable_ports = os.getenv("TRAINER_PORTS", "").split(",")
assert len( # assert len(
avilable_ports # avilable_ports
) >= 2, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit" # ) >= 2, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit"
#
# hard code for paddlecloud custom-framework # # hard code for paddlecloud custom-framework
trainers_num = len(paddle_pserver_endpoints.split(",")) # trainers_num = len(paddle_pserver_endpoints.split(","))
assert trainers_num != 0 # assert trainers_num != 0
environs["PADDLE_TRAINERS_NUM"] = trainers_num # environs["PADDLE_TRAINERS_NUM"] = trainers_num
environs["TRAINERS_NUM"] = trainers_num # environs["TRAINERS_NUM"] = trainers_num
#
# hard code for paddlecloud custom-framework # # hard code for paddlecloud custom-framework
environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints # environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints # environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints( # environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints(
paddle_pserver_endpoints, 1) # paddle_pserver_endpoints, 1)
heter_worker_num = len(paddle_trainer_endpoints.split(",")) # heter_worker_num = len(paddle_trainer_endpoints.split(","))
if (args.heter_worker_num != None) and ( # if (args.heter_worker_num != None) and (
heter_worker_num != args.heter_worker_num): # heter_worker_num != args.heter_worker_num):
warnings.warn( # warnings.warn(
"Your fleetrun setting: heter_worker_num is {}, but we find {} device can be used, this setting has been changed.". # "Your fleetrun setting: heter_worker_num is {}, but we find {} device can be used, this setting has been changed.".
format(args.heter_worker_num, heter_worker_num)) # format(args.heter_worker_num, heter_worker_num))
args.heter_worker_num = heter_worker_num # args.heter_worker_num = heter_worker_num
#
for k, v in environs.items(): # for k, v in environs.items():
os.environ[k] = str(v) # os.environ[k] = str(v)
logger.info("Set heter parameter server env: {}".format( # logger.info("Set heter parameter server env: {}".format(
pretty_print_envs(environs))) # pretty_print_envs(environs)))
def get_mapped_cluster(node_ips, node_ip, trainer_endpoints, device_mode, def get_mapped_cluster(node_ips, node_ip, trainer_endpoints, device_mode,
...@@ -997,7 +997,7 @@ class ParameterServerLauncher(object): ...@@ -997,7 +997,7 @@ class ParameterServerLauncher(object):
self.stage_heter_map[1] = self.worker_endpoints self.stage_heter_map[1] = self.worker_endpoints
if args.heter_worker_num: if args.heter_worker_num:
self.stage_heter_trainer_num = args.heter_worker_num.split(",") self.stage_heter_trainer_num = args.heter_worker_num.split(";")
self.stage_heter_trainer_num = [ self.stage_heter_trainer_num = [
int(trainer_num) int(trainer_num)
for trainer_num in self.stage_heter_trainer_num for trainer_num in self.stage_heter_trainer_num
......
...@@ -48,6 +48,7 @@ function test_launch_ps_heter(){ ...@@ -48,6 +48,7 @@ function test_launch_ps_heter(){
--workers="127.0.0.1:${worker_port_01},127.0.0.1:${worker_port_11}" \ --workers="127.0.0.1:${worker_port_01},127.0.0.1:${worker_port_11}" \
--heter_workers="127.0.0.1:${heter_worker_port_0},127.0.0.1:${heter_worker_port_1}" \ --heter_workers="127.0.0.1:${heter_worker_port_0},127.0.0.1:${heter_worker_port_1}" \
--heter_devices="gpu" \ --heter_devices="gpu" \
--heter_worker_num="2" \
fleet_heter_ps_training.py 2> ut2.elog fleet_heter_ps_training.py 2> ut2.elog
if grep -q "server are killed" ut2.elog; then if grep -q "server are killed" ut2.elog; then
echo "test heter trainer launch succeed" echo "test heter trainer launch succeed"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册