未验证 提交 31cd9145 编写于 作者: Z zmx 提交者: GitHub

[heterps]bug fix for local training with --heter_worker_num (#37166)

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix. test=develop

* fix ut. test=develop

* fix ut. test=develop

* fix ut. test=develop
上级 895692e3
......@@ -59,12 +59,6 @@ void HeterPipelineTrainer::Initialize(const TrainerDesc& trainer_desc,
thread_num_ = trainer_desc.thread_num();
ParseDumpConfig(trainer_desc);
SetDebug(trainer_desc.debug());
// for (int i = 0; i < trainer_desc.downpour_param().stat_var_names_size();
// i++) {
// need_merge_var_names_.push_back(
// trainer_desc.downpour_param().stat_var_names(i));
//}
// get filelist from trainer_desc here
const std::vector<paddle::framework::DataFeed*> readers =
dataset->GetReaders();
VLOG(3) << "readers num: " << readers.size();
......
......@@ -12,8 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#if (defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)) && \
(defined PADDLE_WITH_PSCORE)
#if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE)
#include "gtest/gtest.h"
#include "paddle/fluid/framework/device_worker.h"
#include "paddle/fluid/framework/device_worker_factory.h"
......
......@@ -347,11 +347,6 @@ class HeterPipelineTrainer : public TrainerBase {
int thread_num_;
std::vector<std::thread> threads_;
std::vector<std::string> need_merge_var_names_;
#ifdef PADDLE_WITH_HETERPS
std::vector<platform::Place> places_;
#endif
int num_microbatches_;
platform::Place place_;
TrainerDesc trainer_desc_;
......
......@@ -12,8 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#if (defined PADDLE_WITH_CUDA || defined PADDLE_WITH_HIP) && \
(defined PADDLE_WITH_PSCORE)
#if (defined PADDLE_WITH_CUDA) && (defined PADDLE_WITH_PSCORE)
#include <stdlib.h>
#include <memory>
......
......@@ -199,17 +199,19 @@ see: http://www.paddlepaddle.org/documentation/docs/zh/1.6/user_guides/howto/tra
"--heter_workers",
type=str,
default="",
help="User defined heter workers ip:port")
help="User defined heter workers in each stage ip1:port1;ip2:port2")
ps_group.add_argument(
"--heter_devices",
type=str,
default="",
help="User defined heter devices")
help="User defined heter devices in each stage cpu;gpu;cpu")
ps_group.add_argument("--worker_num", type=int, help="number of workers")
ps_group.add_argument("--server_num", type=int, help="number of servers")
ps_group.add_argument(
"--heter_worker_num", type=int, help="number of heter_workers")
"--heter_worker_num",
type=str,
help="number of heter_workers in each stage 1;2;3")
ps_group.add_argument("--http_port", type=int, help="Gloo http Port")
# parameter elastic mode
......@@ -496,13 +498,15 @@ def launch():
- ``--workers``: User defined workers ip:port, e.g., ``--workers="192.168.0.16:6171,192.168.0.16:6172,192.168.0.17:6171,192.168.0.17:6172"``
- ``--heter_workers``: User defined heter workers ip:port, e.g., ``--heter_workers="192.168.0.16:6172,192.168.0.17:6172"``
- ``--heter_workers``: User defined heter workers ip1:port1;ip2:port2, e.g., ``--heter_workers="192.168.0.16:6172;192.168.0.17:6172"``
- ``--worker_num``: Number of workers (It recommend to set when in the emulated distributed environment using single node)
- ``--server_num``: Number of servers (It recommend to set when in the emulated distributed environment using single node)
- ``--heter_worker_num``: Number of heter_workers (It recommend to set when in the emulated distributed environment using single node)
- ``--heter_worker_num``: Number of heter_workers in each stage (It recommend to set when in the emulated distributed environment using single node)
- ``--heter_devices``: Type of heter_device in each stage
- ``--http_port``: Gloo http Port
......
......@@ -768,44 +768,44 @@ def get_custom_endpoints(origin_endpoints, offset=0):
return paddle_user_define_endpoints
def cloud_ps_heter_env_set(args):
environs = {}
paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "")
assert paddle_trainer_endpoints != None
paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "")
assert paddle_pserver_endpoints != None
# hard code for paddlecloud custom-framework
avilable_ports = os.getenv("TRAINER_PORTS", "").split(",")
assert len(
avilable_ports
) >= 2, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit"
# hard code for paddlecloud custom-framework
trainers_num = len(paddle_pserver_endpoints.split(","))
assert trainers_num != 0
environs["PADDLE_TRAINERS_NUM"] = trainers_num
environs["TRAINERS_NUM"] = trainers_num
# hard code for paddlecloud custom-framework
environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints
environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints
environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints(
paddle_pserver_endpoints, 1)
heter_worker_num = len(paddle_trainer_endpoints.split(","))
if (args.heter_worker_num != None) and (
heter_worker_num != args.heter_worker_num):
warnings.warn(
"Your fleetrun setting: heter_worker_num is {}, but we find {} device can be used, this setting has been changed.".
format(args.heter_worker_num, heter_worker_num))
args.heter_worker_num = heter_worker_num
for k, v in environs.items():
os.environ[k] = str(v)
logger.info("Set heter parameter server env: {}".format(
pretty_print_envs(environs)))
#def cloud_ps_heter_env_set(args):
# environs = {}
#
# paddle_trainer_endpoints = os.getenv("TRAINER_IP_PORT_LIST", "")
# assert paddle_trainer_endpoints != None
#
# paddle_pserver_endpoints = os.getenv("PSERVER_IP_PORT_LIST", "")
# assert paddle_pserver_endpoints != None
#
# # hard code for paddlecloud custom-framework
# avilable_ports = os.getenv("TRAINER_PORTS", "").split(",")
# assert len(
# avilable_ports
# ) >= 2, "set paddle_ports_num >= 2 in config.ini for paddlecloud job submit"
#
# # hard code for paddlecloud custom-framework
# trainers_num = len(paddle_pserver_endpoints.split(","))
# assert trainers_num != 0
# environs["PADDLE_TRAINERS_NUM"] = trainers_num
# environs["TRAINERS_NUM"] = trainers_num
#
# # hard code for paddlecloud custom-framework
# environs["PADDLE_HETER_TRAINER_IP_PORT_LIST"] = paddle_trainer_endpoints
# environs["PADDLE_PSERVERS_IP_PORT_LIST"] = paddle_pserver_endpoints
# environs["PADDLE_TRAINER_ENDPOINTS"] = get_custom_endpoints(
# paddle_pserver_endpoints, 1)
# heter_worker_num = len(paddle_trainer_endpoints.split(","))
# if (args.heter_worker_num != None) and (
# heter_worker_num != args.heter_worker_num):
# warnings.warn(
# "Your fleetrun setting: heter_worker_num is {}, but we find {} device can be used, this setting has been changed.".
# format(args.heter_worker_num, heter_worker_num))
# args.heter_worker_num = heter_worker_num
#
# for k, v in environs.items():
# os.environ[k] = str(v)
# logger.info("Set heter parameter server env: {}".format(
# pretty_print_envs(environs)))
def get_mapped_cluster(node_ips, node_ip, trainer_endpoints, device_mode,
......@@ -997,7 +997,7 @@ class ParameterServerLauncher(object):
self.stage_heter_map[1] = self.worker_endpoints
if args.heter_worker_num:
self.stage_heter_trainer_num = args.heter_worker_num.split(",")
self.stage_heter_trainer_num = args.heter_worker_num.split(";")
self.stage_heter_trainer_num = [
int(trainer_num)
for trainer_num in self.stage_heter_trainer_num
......
......@@ -48,6 +48,7 @@ function test_launch_ps_heter(){
--workers="127.0.0.1:${worker_port_01},127.0.0.1:${worker_port_11}" \
--heter_workers="127.0.0.1:${heter_worker_port_0},127.0.0.1:${heter_worker_port_1}" \
--heter_devices="gpu" \
--heter_worker_num="2" \
fleet_heter_ps_training.py 2> ut2.elog
if grep -q "server are killed" ut2.elog; then
echo "test heter trainer launch succeed"
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册