diff --git a/paddle/operators/get_places_op.cc b/paddle/operators/get_places_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..dd937488f4cb2b8ac53a6b0b179416ab1faedee5 --- /dev/null +++ b/paddle/operators/get_places_op.cc @@ -0,0 +1,69 @@ +/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/framework/op_registry.h" +#include "paddle/platform/place.h" + +namespace paddle { +namespace operators { + +class GetPlacesOp : public framework::OperatorBase { + public: + GetPlacesOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + void Run(const framework::Scope &scope, + const platform::DeviceContext &dev_ctx) const override { + auto use_gpu = Attr("use_gpu"); + auto trainer_count = Attr("trainer_count"); + + auto out_var_name = Output("Out"); + auto *out_var = scope.FindVar(out_var_name); + PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", + out_var_name); + + auto &places = *(out_var->GetMutable>()); + places.reserve(trainer_count); + if (use_gpu) { + for (int i = 0; i < trainer_count; i++) { + places.emplace_back(platform::GPUPlace(i)); + } + } else { + for (int i = 0; i < trainer_count; i++) { + places.emplace_back(platform::CPUPlace()); + } + } + } +}; + +class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + GetPlacesOpProtoMaker(framework::OpProto *proto, + framework::OpAttrChecker *op_checker) + : OpProtoAndCheckerMaker(proto, op_checker) { + AddOutput("Out", "vector of Place"); + AddAttr("trainer_count", "(int)trainer count").SetDefault(1); + AddAttr("use_gpu", "(bool)use gpu").SetDefault(false); + AddComment(R"DOC( +GetPlaces Operator. + +)DOC"); + } +}; +} // namespace operators +} // namespace paddle +namespace ops = paddle::operators; + +REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker); diff --git a/python/paddle/v2/fluid/framework.py b/python/paddle/v2/fluid/framework.py index bf0cd275b62ae2c4d7312592b8a730291c59a071..e9319cbe2a4f390f591de4d3b7426453a9ed5d1b 100644 --- a/python/paddle/v2/fluid/framework.py +++ b/python/paddle/v2/fluid/framework.py @@ -424,7 +424,7 @@ class Operator(object): self.desc.check_attrs() no_kernel_op_set = { 'feed', 'fetch', 'save', 'load', 'recurrent', - 'rnn_memory_helper_grad', 'conditional_block', 'while' + 'rnn_memory_helper_grad', 'conditional_block', 'while', 'get_places' } if type not in no_kernel_op_set: self.desc.infer_var_type(self.block.desc) diff --git a/python/paddle/v2/fluid/layers/__init__.py b/python/paddle/v2/fluid/layers/__init__.py index 249f570e13b7a1b50397fb971d1c6f77e0359a5e..cb8853062879de64a229943ce91cf30e0026a136 100644 --- a/python/paddle/v2/fluid/layers/__init__.py +++ b/python/paddle/v2/fluid/layers/__init__.py @@ -8,6 +8,8 @@ import tensor from tensor import * import control_flow from control_flow import * +import utils +from utils import * __all__ = [] __all__ += nn.__all__ @@ -15,3 +17,4 @@ __all__ += io.__all__ __all__ += tensor.__all__ __all__ += control_flow.__all__ __all__ += ops.__all__ +__all__ += utils.__all__ diff --git a/python/paddle/v2/fluid/layers/utils.py b/python/paddle/v2/fluid/layers/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..aa6857e088b721bfef78e4ae684622e862e6838f --- /dev/null +++ b/python/paddle/v2/fluid/layers/utils.py @@ -0,0 +1,22 @@ +""" +All util layers. +""" + +from ..layer_helper import LayerHelper +from ..framework import Variable + +__all__ = ['get_places'] + + +def get_places(use_gpu, trainer_count): + helper = LayerHelper('get_places', **locals()) + out_places = helper.create_tmp_variable(dtype=helper.input_dtype()) + helper.append_op( + type='get_places', + outputs={"Out": [out_places]}, + attrs={ + "use_gpu": use_gpu, + 'trainer_count': trainer_count, + }) + + return out_places