未验证 提交 219fbd51 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #6732 from QiJune/get_places_op

add GetPlaces operator
...@@ -53,12 +53,12 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const { ...@@ -53,12 +53,12 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
return it->second.get(); return it->second.get();
} }
VarDesc *BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) { VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
VarDesc *res = FindVarRecursive(name_bytes); VarDesc *res = FindVarRecursive(name_bytes);
if (res == nullptr) { if (res == nullptr) {
res = Var(name_bytes); res = Var(name_bytes);
} }
return res; return *res;
} }
bool BlockDesc::HasVarRecursive(const std::string &name) const { bool BlockDesc::HasVarRecursive(const std::string &name) const {
......
...@@ -57,7 +57,7 @@ class BlockDesc { ...@@ -57,7 +57,7 @@ class BlockDesc {
VarDesc *FindVarRecursive(const std::string &name_bytes) const; VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDesc *FindRecursiveOrCreateVar(const std::string &name_bytes); VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
bool HasVarRecursive(const std::string &var_name) const; bool HasVarRecursive(const std::string &var_name) const;
......
...@@ -21,6 +21,7 @@ limitations under the License. */ ...@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h" #include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h" #include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h" #include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
DEFINE_bool(check_nan_inf, false, DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be " "Checking whether operator produce NAN/INF or not. It will be "
...@@ -49,10 +50,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) { ...@@ -49,10 +50,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDRankTable>(); var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) { } else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>(); var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else { } else {
PADDLE_THROW( PADDLE_THROW(
"Variable type %d is not in " "Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE]", "[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE,"
" PLACE_LIST]",
var_type); var_type);
} }
} }
......
...@@ -123,6 +123,7 @@ message VarDesc { ...@@ -123,6 +123,7 @@ message VarDesc {
STEP_SCOPES = 5; STEP_SCOPES = 5;
LOD_RANK_TABLE = 6; LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7; LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
} }
required string name = 1; required string name = 1;
required VarType type = 2; required VarType type = 2;
......
...@@ -384,7 +384,7 @@ void OpDesc::InferVarType(BlockDesc *block) const { ...@@ -384,7 +384,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) { for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) { for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name) block->FindRecursiveOrCreateVar(out_var_name)
->SetType(proto::VarDesc::LOD_TENSOR); .SetType(proto::VarDesc::LOD_TENSOR);
} }
} }
} }
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/gpu_info.h"
#endif
namespace paddle {
namespace operators {
static size_t CUDADevCount() {
#ifdef PADDLE_WITH_CUDA
return platform::GetCUDADeviceCount();
#else
return 0UL;
#endif
}
class GetPlacesOp : public framework::OperatorBase {
public:
GetPlacesOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::Place &place) const override {
std::string device_type = Attr<std::string>("device_type");
auto device_count = static_cast<size_t>(Attr<int>("device_count"));
if (device_count == 0) {
if (device_type == "CUDA") {
device_count = CUDADevCount();
} else if (device_type == "CPU") {
device_count = std::thread::hardware_concurrency();
}
}
PADDLE_ENFORCE_NE(device_count, 0, "Cannot indicate %s device count",
device_type);
auto out_var_name = Output("Out");
auto &places =
*(detail::Ref(scope.FindVar(out_var_name),
"Output variable %s cannot be found", out_var_name)
.GetMutable<platform::PlaceList>());
places.reserve(device_count);
if (device_type == "CUDA") {
PADDLE_ENFORCE_LE(device_count, CUDADevCount(),
"Only %d CUDA devices found, cannot set to %d",
CUDADevCount(), device_count);
for (size_t i = 0; i < device_count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
} else if (device_type == "CPU") {
for (size_t i = 0; i < device_count; ++i) {
places.emplace_back(platform::CPUPlace());
}
}
}
};
class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "device count").SetDefault(1);
AddAttr<std::string>("device_type",
R"(device type must be in ["CPU", "CUDA"])")
.InEnum({"CPU", "CUDA"});
AddComment(R"DOC(
Returns a list of places based on flags. The list will be used for parallel
execution.
)DOC");
}
};
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST);
}
}
};
class GetPlacesInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
// Do nothing
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker,
ops::GetPlacesInferVarType, ops::GetPlacesInferShape);
...@@ -66,7 +66,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference { ...@@ -66,7 +66,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc, void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override { framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) { for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o)->SetType( block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE); framework::proto::VarDesc::LOD_RANK_TABLE);
} }
} }
......
...@@ -122,17 +122,17 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -122,17 +122,17 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
for (auto& name : op_desc.Input("X")) { for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " " VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name)->GetType(); << block->FindRecursiveOrCreateVar(name).GetType();
} }
bool any_input_is_lod_tensor = std::any_of( bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) { inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR; framework::proto::VarDesc::LOD_TENSOR;
}); });
auto is_tensor_array = [block](const std::string& name) { auto is_tensor_array = [block](const std::string& name) {
return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() == return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY; framework::proto::VarDesc::LOD_TENSOR_ARRAY;
}; };
...@@ -146,8 +146,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -146,8 +146,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
std::ostringstream os; std::ostringstream os;
for (auto& each : inputs) { for (auto& each : inputs) {
os << " " << each << " type is " os << " " << each << " type is "
<< detail::Ref(block->FindRecursiveOrCreateVar(each)).GetType() << block->FindRecursiveOrCreateVar(each).GetType() << "\n";
<< "\n";
} }
PADDLE_ENFORCE(all_inputs_are_tensor_array, PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str()); "Not all inputs are tensor array:\n%s", os.str());
...@@ -158,7 +157,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference { ...@@ -158,7 +157,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
} }
auto out_var_name = op_desc.Output("Out").front(); auto out_var_name = op_desc.Output("Out").front();
auto& out_var = detail::Ref(block->FindRecursiveOrCreateVar(out_var_name)); auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
out_var.SetType(var_type); out_var.SetType(var_type);
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front())); auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType()); out_var.SetDataType(in_var.GetDataType());
......
...@@ -106,8 +106,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference { ...@@ -106,8 +106,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto x_name = op_desc.Input("X")[0]; auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0]; auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY"; VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name), auto &out = block->FindRecursiveOrCreateVar(out_name);
"Cannot found %s", out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY); out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name); auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) { if (x != nullptr) {
......
...@@ -52,6 +52,8 @@ struct IsCUDAPlace : public boost::static_visitor<bool> { ...@@ -52,6 +52,8 @@ struct IsCUDAPlace : public boost::static_visitor<bool> {
typedef boost::variant<CUDAPlace, CPUPlace> Place; typedef boost::variant<CUDAPlace, CPUPlace> Place;
using PlaceList = std::vector<Place>;
void set_place(const Place &); void set_place(const Place &);
const Place &get_place(); const Place &get_place();
......
...@@ -231,7 +231,8 @@ void BindVarDsec(py::module &m) { ...@@ -231,7 +231,8 @@ void BindVarDsec(py::module &m) {
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST) .value("FETCH_LIST", proto::VarDesc::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES) .value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE) .value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY); .value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
} }
void BindOpDesc(py::module &m) { void BindOpDesc(py::module &m) {
......
...@@ -8,6 +8,8 @@ import tensor ...@@ -8,6 +8,8 @@ import tensor
from tensor import * from tensor import *
import control_flow import control_flow
from control_flow import * from control_flow import *
import device
from device import *
__all__ = [] __all__ = []
__all__ += nn.__all__ __all__ += nn.__all__
...@@ -15,3 +17,4 @@ __all__ += io.__all__ ...@@ -15,3 +17,4 @@ __all__ += io.__all__
__all__ += tensor.__all__ __all__ += tensor.__all__
__all__ += control_flow.__all__ __all__ += control_flow.__all__
__all__ += ops.__all__ __all__ += ops.__all__
__all__ += device.__all__
"""
All util layers.
"""
from ..layer_helper import LayerHelper
from ..framework import unique_name
__all__ = ['get_places']
def get_places(device_count=0, device_type="CPU"):
helper = LayerHelper('get_places', **locals())
out_places = helper.create_variable(name=unique_name(helper.name + ".out"))
helper.append_op(
type='get_places',
outputs={"Out": [out_places]},
attrs={
"device_type": device_type,
'device_count': device_count,
})
return out_places
import paddle.v2.fluid as fluid
import decorators
import unittest
class TestGetPlaces(unittest.TestCase):
@decorators.prog_scope()
def test_get_places(self):
places = fluid.layers.get_places()
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_main_program())
self.assertEqual(places.type, fluid.core.VarDesc.VarType.PLACE_LIST)
if __name__ == '__main__':
unittest.main()
...@@ -196,6 +196,13 @@ class TestBook(unittest.TestCase): ...@@ -196,6 +196,13 @@ class TestBook(unittest.TestCase):
self.assertIsNotNone(layers.sequence_softmax(x=seq)) self.assertIsNotNone(layers.sequence_softmax(x=seq))
print(str(program)) print(str(program))
def test_get_places(self):
program = Program()
with program_guard(program):
x = layers.get_places(device_count=4)
self.assertIsNotNone(x)
print(str(program))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册