提交 e8de775f 编写于 作者: Q qijun

rename trainer_count to device_count

上级 aa320894
......@@ -30,7 +30,7 @@ class GetPlacesOp : public framework::OperatorBase {
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
std::string device_type = Attr<std::string>("device_type");
auto trainer_count = Attr<int>("trainer_count");
auto device_count = Attr<int>("device_count");
auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name);
......@@ -38,18 +38,18 @@ class GetPlacesOp : public framework::OperatorBase {
out_var_name);
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
places.resize(trainer_count);
places.resize(device_count);
if (device_type == "CUDA") {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_LT(trainer_count, platform::GetCUDADeviceCount());
for (int i = 0; i < trainer_count; i++) {
PADDLE_ENFORCE_LT(device_count, platform::GetCUDADeviceCount());
for (int i = 0; i < device_count; i++) {
places.emplace_back(platform::GPUPlace(i));
}
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
} else if (device_type == "CPU") {
for (int i = 0; i < trainer_count; i++) {
for (int i = 0; i < device_count; i++) {
places.emplace_back(platform::CPUPlace());
}
}
......@@ -62,7 +62,7 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place");
AddAttr<int>("trainer_count", "(int)trainer count").SetDefault(1);
AddAttr<int>("device_count", "(int)device count").SetDefault(1);
AddAttr<std::string>("device_type",
"(string), deivce type can be \"CPU\" and \"CUDA\"")
.InEnum({"CPU", "CUDA"});
......
......@@ -8,7 +8,7 @@ from ..framework import Variable
__all__ = ['get_places']
def get_places(trainer_count, device_type="CPU"):
def get_places(device_count, device_type="CPU"):
helper = LayerHelper('get_places', **locals())
out_places = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
......@@ -16,7 +16,7 @@ def get_places(trainer_count, device_type="CPU"):
outputs={"Out": [out_places]},
attrs={
"device_type": device_type,
'trainer_count': trainer_count,
'device_count': device_count,
})
return out_places
......@@ -199,7 +199,7 @@ class TestBook(unittest.TestCase):
def test_get_places(self):
program = Program()
with program_guard(program):
x = layers.get_places(trainer_count=4)
x = layers.get_places(device_count=4)
print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册