提交 44bae42d 编写于 作者: Q qijun

follow comments

上级 31323f79
......@@ -14,6 +14,9 @@ limitations under the License. */
#include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/gpu_info.h"
#endif
namespace paddle {
namespace operators {
......@@ -26,7 +29,7 @@ class GetPlacesOp : public framework::OperatorBase {
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
auto use_gpu = Attr<bool>("use_gpu");
std::string device_type = Attr<std::string>("device_type");
auto trainer_count = Attr<int>("trainer_count");
auto out_var_name = Output("Out");
......@@ -36,11 +39,16 @@ class GetPlacesOp : public framework::OperatorBase {
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
places.resize(trainer_count);
if (use_gpu) {
if (device_type == "CUDA") {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_LT(trainer_count, GetCUDADeviceCount());
for (int i = 0; i < trainer_count; i++) {
places.emplace_back(platform::GPUPlace(i));
}
} else {
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
} else if (device_type == "CPU") {
for (int i = 0; i < trainer_count; i++) {
places.emplace_back(platform::CPUPlace());
}
......@@ -55,9 +63,11 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place");
AddAttr<int>("trainer_count", "(int)trainer count").SetDefault(1);
AddAttr<bool>("use_gpu", "(bool)use gpu").SetDefault(false);
AddAttr<std::string>("device_type",
"(string), deivce type can be \"CPU\" and \"CUDA\"")
.InEnum({"CPU", "CUDA"});
AddComment(R"DOC(
GetPlaces Operator.
Returns a list of places based on flags. The list will be used for parallel execution.
)DOC");
}
......
......@@ -8,14 +8,14 @@ from ..framework import Variable
__all__ = ['get_places']
def get_places(use_gpu, trainer_count):
def get_places(trainer_count, device_type="CPU"):
helper = LayerHelper('get_places', **locals())
out_places = helper.create_tmp_variable(dtype=helper.input_dtype())
helper.append_op(
type='get_places',
outputs={"Out": [out_places]},
attrs={
"use_gpu": use_gpu,
"device_type": device_type,
'trainer_count': trainer_count,
})
......
......@@ -173,7 +173,7 @@ class TestBook(unittest.TestCase):
def test_get_places(self):
program = Program()
with program_guard(program):
x = layers.get_places(use_gpu=True, trainer_count=4)
x = layers.get_places(trainer_count=4)
print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册