提交 e8de775f 编写于 作者: Q qijun

rename trainer_count to device_count

上级 aa320894
...@@ -30,7 +30,7 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -30,7 +30,7 @@ class GetPlacesOp : public framework::OperatorBase {
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::DeviceContext &dev_ctx) const override {
std::string device_type = Attr<std::string>("device_type"); std::string device_type = Attr<std::string>("device_type");
auto trainer_count = Attr<int>("trainer_count"); auto device_count = Attr<int>("device_count");
auto out_var_name = Output("Out"); auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name); auto *out_var = scope.FindVar(out_var_name);
...@@ -38,18 +38,18 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -38,18 +38,18 @@ class GetPlacesOp : public framework::OperatorBase {
out_var_name); out_var_name);
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>()); auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
places.resize(trainer_count); places.resize(device_count);
if (device_type == "CUDA") { if (device_type == "CUDA") {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_LT(trainer_count, platform::GetCUDADeviceCount()); PADDLE_ENFORCE_LT(device_count, platform::GetCUDADeviceCount());
for (int i = 0; i < trainer_count; i++) { for (int i = 0; i < device_count; i++) {
places.emplace_back(platform::GPUPlace(i)); places.emplace_back(platform::GPUPlace(i));
} }
#else #else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device."); PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif #endif
} else if (device_type == "CPU") { } else if (device_type == "CPU") {
for (int i = 0; i < trainer_count; i++) { for (int i = 0; i < device_count; i++) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
} }
} }
...@@ -62,7 +62,7 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -62,7 +62,7 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker) framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place"); AddOutput("Out", "vector of Place");
AddAttr<int>("trainer_count", "(int)trainer count").SetDefault(1); AddAttr<int>("device_count", "(int)device count").SetDefault(1);
AddAttr<std::string>("device_type", AddAttr<std::string>("device_type",
"(string), deivce type can be \"CPU\" and \"CUDA\"") "(string), deivce type can be \"CPU\" and \"CUDA\"")
.InEnum({"CPU", "CUDA"}); .InEnum({"CPU", "CUDA"});
......
...@@ -8,7 +8,7 @@ from ..framework import Variable ...@@ -8,7 +8,7 @@ from ..framework import Variable
__all__ = ['get_places'] __all__ = ['get_places']
def get_places(trainer_count, device_type="CPU"): def get_places(device_count, device_type="CPU"):
helper = LayerHelper('get_places', **locals()) helper = LayerHelper('get_places', **locals())
out_places = helper.create_tmp_variable(dtype=helper.input_dtype()) out_places = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op( helper.append_op(
...@@ -16,7 +16,7 @@ def get_places(trainer_count, device_type="CPU"): ...@@ -16,7 +16,7 @@ def get_places(trainer_count, device_type="CPU"):
outputs={"Out": [out_places]}, outputs={"Out": [out_places]},
attrs={ attrs={
"device_type": device_type, "device_type": device_type,
'trainer_count': trainer_count, 'device_count': device_count,
}) })
return out_places return out_places
...@@ -199,7 +199,7 @@ class TestBook(unittest.TestCase): ...@@ -199,7 +199,7 @@ class TestBook(unittest.TestCase):
def test_get_places(self): def test_get_places(self):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.get_places(trainer_count=4) x = layers.get_places(device_count=4)
print(str(program)) print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册