未验证 提交 f1b09ba3 编写于 作者: Y Yi Liu 提交者: GitHub

adapt test_collective_base.py for only two GPU cards available. (#21307)

* adapt test_collective_base.py for only two GPU cards available.
test=develop

* fix bug of issue #21259
test=develop
上级 ed2a1852
......@@ -104,16 +104,16 @@ def one_hot(input, depth, allow_out_of_range=False):
if in_dygraph_mode():
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
if not isinstance(depth, Variable):
# user attribute
inputs = {'X': input}
attrs = {'depth': depth}
attrs = {'depth': depth, 'allow_out_of_range': allow_out_of_range}
else:
depth.stop_gradient = True
inputs = {'X': input, 'depth_tensor': depth}
attrs = {}
attrs = {'allow_out_of_range': allow_out_of_range}
helper.append_op(
type="one_hot_v2",
inputs=inputs,
......
......@@ -163,7 +163,7 @@ class TestDistBase(unittest.TestCase):
w0_ep, w1_ep = worker_endpoints
#print("w0_ep:",w0_ep," w1_ep:",w1_ep)
env0 = {
"FLAGS_selected_gpus": "2",
"FLAGS_selected_gpus": "0",
"PADDLE_TRAINER_ID": "0",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
......@@ -171,7 +171,7 @@ class TestDistBase(unittest.TestCase):
}
env1 = {
"FLAGS_selected_gpus": "3",
"FLAGS_selected_gpus": "1",
"PADDLE_TRAINER_ID": "1",
"PADDLE_TRAINERS_NUM": "2",
"PADDLE_TRAINER_ENDPOINTS": self._ps_endpoints,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册