提交 63ff0b4b 编写于 作者: Y Yang Yu

Refine get_places

上级 ed0cf3d6
...@@ -53,12 +53,12 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { ...@@ -53,12 +53,12 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
VarDesc *BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) { VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
VarDesc *res = FindVarRecursive(name_bytes); VarDesc *res = FindVarRecursive(name_bytes);
if (res == nullptr) { if (res == nullptr) {
res = Var(name_bytes); res = Var(name_bytes);
} }
return res; return *res;
} }
bool BlockDesc::HasVarRecursive(const std::string &name) const { bool BlockDesc::HasVarRecursive(const std::string &name) const {
......
...@@ -57,7 +57,7 @@ class BlockDesc { ...@@ -57,7 +57,7 @@ class BlockDesc {
VarDesc *FindVarRecursive(const std::string &name_bytes) const; VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDesc *FindRecursiveOrCreateVar(const std::string &name_bytes); VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
bool HasVarRecursive(const std::string &var_name) const; bool HasVarRecursive(const std::string &var_name) const;
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
DEFINE_bool(check_nan_inf, false, DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be " "Checking whether operator produce NAN/INF or not. It will be "
...@@ -49,10 +50,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -49,10 +50,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE]", "[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE,"
" PLACE_LIST]",
var_type); var_type);
} }
} }
......
...@@ -123,6 +123,7 @@ message VarDesc { ...@@ -123,6 +123,7 @@ message VarDesc {
STEP_SCOPES = 5; STEP_SCOPES = 5;
LOD_RANK_TABLE = 6; LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7; LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
} }
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;
......
...@@ -383,7 +383,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -383,7 +383,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
->SetType(proto::VarDesc::LOD_TENSOR); .SetType(proto::VarDesc::LOD_TENSOR);
} }
} }
} }
......
...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <thread>
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/place.h" #include "paddle/platform/place.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/platform/gpu_info.h" #include "paddle/platform/gpu_info.h"
...@@ -21,6 +23,14 @@ limitations under the License. */ ...@@ -21,6 +23,14 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
static size_t CUDADevCount() {
#ifdef PADDLE_WITH_CUDA
return platform::GetCUDADeviceCount();
#else
return 0UL;
#endif
}
class GetPlacesOp : public framework::OperatorBase { class GetPlacesOp : public framework::OperatorBase {
public: public:
GetPlacesOp(const std::string &type, const framework::VariableNameMap &inputs, GetPlacesOp(const std::string &type, const framework::VariableNameMap &inputs,
...@@ -28,28 +38,34 @@ class GetPlacesOp : public framework::OperatorBase { ...@@ -28,28 +38,34 @@ class GetPlacesOp : public framework::OperatorBase {
const framework::AttributeMap &attrs) const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {} : OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope, void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override { const platform::Place &place) const override {
std::string device_type = Attr<std::string>("device_type"); std::string device_type = Attr<std::string>("device_type");
auto device_count = Attr<int>("device_count"); auto device_count = static_cast<size_t>(Attr<int>("device_count"));
if (device_count == 0) {
if (device_type == "CUDA") {
device_count = CUDADevCount();
} else if (device_type == "CPU") {
device_count = std::thread::hardware_concurrency();
}
}
PADDLE_ENFORCE_NE(device_count, 0, "Cannot indicate %s device count",
device_type);
auto out_var_name = Output("Out"); auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name); auto &places =
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found", *(detail::Ref(scope.FindVar(out_var_name),
out_var_name); "Output variable %s cannot be found", out_var_name)
.GetMutable<platform::PlaceList>());
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>()); places.reserve(device_count);
places.resize(device_count);
if (device_type == "CUDA") { if (device_type == "CUDA") {
#ifdef PADDLE_WITH_CUDA PADDLE_ENFORCE_LE(device_count, CUDADevCount(),
PADDLE_ENFORCE_LT(device_count, platform::GetCUDADeviceCount()); "Only %d CUDA devices found, cannot set to %d",
for (int i = 0; i < device_count; i++) { CUDADevCount(), device_count);
places.emplace_back(platform::GPUPlace(i)); for (size_t i = 0; i < device_count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
} }
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
} else if (device_type == "CPU") { } else if (device_type == "CPU") {
for (int i = 0; i < device_count; i++) { for (size_t i = 0; i < device_count; ++i) {
places.emplace_back(platform::CPUPlace()); places.emplace_back(platform::CPUPlace());
} }
} }
...@@ -61,18 +77,38 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -61,18 +77,38 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker) GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place"); AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "(int)device count").SetDefault(1); AddAttr<int>("device_count", "device count").SetDefault(1);
AddAttr<std::string>("device_type", AddAttr<std::string>("device_type",
"(string), deivce type can be \"CPU\" and \"CUDA\"") R"(device type must be in ["CPU", "CUDA"])")
.InEnum({"CPU", "CUDA"}); .InEnum({"CPU", "CUDA"});
AddComment(R"DOC( AddComment(R"DOC(
Returns a list of places based on flags. The list will be used for parallel execution. Returns a list of places based on flags. The list will be used for parallel
execution.
)DOC"); )DOC");
} }
}; };
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST);
}
}
};
class GetPlacesInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
// Do nothing
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker); REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker,
ops::GetPlacesInferVarType, ops::GetPlacesInferShape);
...@@ -66,7 +66,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { ...@@ -66,7 +66,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o)->SetType( block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE); framework::proto::VarDesc::LOD_RANK_TABLE);
} }
} }
......
...@@ -122,17 +122,17 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -122,17 +122,17 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name)->GetType(); << block->FindRecursiveOrCreateVar(name).GetType();
} }
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR; framework::proto::VarDesc::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY; framework::proto::VarDesc::LOD_TENSOR_ARRAY;
}; };
...@@ -146,8 +146,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -146,8 +146,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
std::ostringstream os; std::ostringstream os;
for (auto& each : inputs) { for (auto& each : inputs) {
os << " " << each << " type is " os << " " << each << " type is "
<< detail::Ref(block->FindRecursiveOrCreateVar(each)).GetType() << block->FindRecursiveOrCreateVar(each).GetType() << "\n";
<< "\n";
} }
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
...@@ -158,7 +157,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -158,7 +157,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
auto& out_var = detail::Ref(block->FindRecursiveOrCreateVar(out_var_name)); auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
out_var.SetType(var_type); out_var.SetType(var_type);
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front())); auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType()); out_var.SetDataType(in_var.GetDataType());
......
...@@ -106,8 +106,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -106,8 +106,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto x_name = op_desc.Input("X")[0]; auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name), auto &out = block->FindRecursiveOrCreateVar(out_name);
"Cannot found %s", out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name); auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) { if (x != nullptr) {
......
...@@ -52,6 +52,8 @@ struct IsCUDAPlace : public boost::static_visitor<bool> { ...@@ -52,6 +52,8 @@ struct IsCUDAPlace : public boost::static_visitor<bool> {
typedef boost::variant<CUDAPlace, CPUPlace> Place; typedef boost::variant<CUDAPlace, CPUPlace> Place;
using PlaceList = std::vector<Place>;
void set_place(const Place &); void set_place(const Place &);
const Place &get_place(); const Place &get_place();
......
...@@ -231,7 +231,8 @@ void BindVarDsec(py::module &m) { ...@@ -231,7 +231,8 @@ void BindVarDsec(py::module &m) {
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST) .value("FETCH_LIST", proto::VarDesc::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY); .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -448,7 +448,7 @@ class Operator(object): ...@@ -448,7 +448,7 @@ class Operator(object):
no_kernel_op_set = { no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent', 'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send', 'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'get_places', 'parallel_do' 'recv', 'parallel_do'
} }
if type not in no_kernel_op_set: if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc) self.desc.infer_var_type(self.block.desc)
......
...@@ -3,14 +3,14 @@ All util layers. ...@@ -3,14 +3,14 @@ All util layers.
""" """
from ..layer_helper import LayerHelper from ..layer_helper import LayerHelper
from ..framework import Variable from ..framework import unique_name
__all__ = ['get_places'] __all__ = ['get_places']
def get_places(device_count, device_type="CPU"): def get_places(device_count=0, device_type="CPU"):
helper = LayerHelper('get_places', **locals()) helper = LayerHelper('get_places', **locals())
out_places = helper.create_tmp_variable(dtype=helper.input_dtype()) out_places = helper.create_variable(name=unique_name(helper.name + ".out"))
helper.append_op( helper.append_op(
type='get_places', type='get_places',
outputs={"Out": [out_places]}, outputs={"Out": [out_places]},
......
import paddle.v2.fluid as fluid
import decorators
import unittest
class TestGetPlaces(unittest.TestCase):
@decorators.prog_scope()
def test_get_places(self):
places = fluid.layers.get_places()
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_main_program())
self.assertEqual(places.type, fluid.core.VarDesc.VarType.PLACE_LIST)
if __name__ == '__main__':
unittest.main()
...@@ -200,6 +200,7 @@ class TestBook(unittest.TestCase): ...@@ -200,6 +200,7 @@ class TestBook(unittest.TestCase):
program = Program() program = Program()
with program_guard(program): with program_guard(program):
x = layers.get_places(device_count=4) x = layers.get_places(device_count=4)
self.assertIsNotNone(x)
print(str(program)) print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册