From 4059c9ca7fcff892ab198350e1decda71e516e01 Mon Sep 17 00:00:00 2001 From: Yang Yu Date: Wed, 10 Jan 2018 13:18:40 +0800 Subject: [PATCH] Polish GetPlacesOp --- paddle/operators/get_places_op.cc | 33 ++++++++++++++----------- python/paddle/v2/fluid/layers/device.py | 17 +++++++------ 2 files changed, 28 insertions(+), 22 deletions(-) diff --git a/paddle/operators/get_places_op.cc b/paddle/operators/get_places_op.cc index 291bbbcb3a7..24fafb23074 100644 --- a/paddle/operators/get_places_op.cc +++ b/paddle/operators/get_places_op.cc @@ -39,17 +39,19 @@ class GetPlacesOp : public framework::OperatorBase { : OperatorBase(type, inputs, outputs, attrs) {} void Run(const framework::Scope &scope, const platform::Place &place) const override { - std::string device_type = Attr("device_type"); + bool is_gpu; + if (Attr("device_type") == "AUTO") { + is_gpu = platform::is_gpu_place(place); + } else { + is_gpu = Attr("device_type") == "CUDA"; + } auto device_count = static_cast(Attr("device_count")); if (device_count == 0) { - if (device_type == "CUDA") { - device_count = CUDADevCount(); - } else if (device_type == "CPU") { - device_count = std::thread::hardware_concurrency(); - } + device_count = + is_gpu ? CUDADevCount() : std::thread::hardware_concurrency(); } PADDLE_ENFORCE_NE(device_count, 0, "Cannot indicate %s device count", - device_type); + is_gpu ? "GPU" : "CPU"); auto out_var_name = Output("Out"); auto &places = @@ -57,14 +59,14 @@ class GetPlacesOp : public framework::OperatorBase { "Output variable %s cannot be found", out_var_name) .GetMutable()); places.reserve(device_count); - if (device_type == "CUDA") { + if (is_gpu) { PADDLE_ENFORCE_LE(device_count, CUDADevCount(), "Only %d CUDA devices found, cannot set to %d", CUDADevCount(), device_count); for (size_t i = 0; i < device_count; ++i) { - places.emplace_back(platform::CUDAPlace(i)); + places.emplace_back(platform::CUDAPlace(static_cast(i))); } - } else if (device_type == "CPU") { + } else { for (size_t i = 0; i < device_count; ++i) { places.emplace_back(platform::CPUPlace()); } @@ -77,10 +79,10 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker { GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) : OpProtoAndCheckerMaker(proto, op_checker) { AddOutput("Out", "vector of Place"); - AddAttr("device_count", "device count").SetDefault(1); - AddAttr("device_type", - R"(device type must be in ["CPU", "CUDA"])") - .InEnum({"CPU", "CUDA"}); + AddAttr("device_count", "device count").SetDefault(0); + AddAttr("device_type", "device type") + .InEnum({"CUDA", "CPU", "AUTO"}) + .SetDefault("AUTO"); AddComment(R"DOC( Returns a list of places based on flags. The list will be used for parallel execution. @@ -111,4 +113,5 @@ class GetPlacesInferShape : public framework::InferShapeBase { namespace ops = paddle::operators; REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker, - ops::GetPlacesInferVarType, ops::GetPlacesInferShape); + ops::GetPlacesInferVarType, ops::GetPlacesInferShape, + paddle::framework::EmptyGradOpMaker); diff --git a/python/paddle/v2/fluid/layers/device.py b/python/paddle/v2/fluid/layers/device.py index c2355ed8020..775d40e5b5e 100644 --- a/python/paddle/v2/fluid/layers/device.py +++ b/python/paddle/v2/fluid/layers/device.py @@ -4,19 +4,22 @@ All util layers. from ..layer_helper import LayerHelper from ..framework import unique_name +from ..registry import autodoc __all__ = ['get_places'] -def get_places(device_count=0, device_type="CPU"): +@autodoc +def get_places(device_count=None, device_type=None): helper = LayerHelper('get_places', **locals()) out_places = helper.create_variable(name=unique_name(helper.name + ".out")) + attrs = dict() + if device_count is not None: + attrs['device_count'] = int(device_count) + if device_type is not None: + attrs['device_type'] = str(device_type) + helper.append_op( - type='get_places', - outputs={"Out": [out_places]}, - attrs={ - "device_type": device_type, - 'device_count': device_count, - }) + type='get_places', outputs={"Out": [out_places]}, attrs=attrs) return out_places -- GitLab