提交 63ff0b4b 编写于 作者: Y Yang Yu

Refine get_places

上级 ed0cf3d6
......@@ -53,12 +53,12 @@ VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {
return it->second.get();
}
VarDesc *BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
VarDesc &BlockDesc::FindRecursiveOrCreateVar(const std::string &name_bytes) {
VarDesc *res = FindVarRecursive(name_bytes);
if (res == nullptr) {
res = Var(name_bytes);
}
return res;
return *res;
}
bool BlockDesc::HasVarRecursive(const std::string &name) const {
......
......@@ -57,7 +57,7 @@ class BlockDesc {
VarDesc *FindVarRecursive(const std::string &name_bytes) const;
VarDesc *FindRecursiveOrCreateVar(const std::string &name_bytes);
VarDesc &FindRecursiveOrCreateVar(const std::string &name_bytes);
bool HasVarRecursive(const std::string &var_name) const;
......
......@@ -21,6 +21,7 @@ limitations under the License. */
#include "paddle/framework/lod_rank_table.h"
#include "paddle/framework/lod_tensor_array.h"
#include "paddle/framework/op_registry.h"
#include "paddle/platform/place.h"
DEFINE_bool(check_nan_inf, false,
"Checking whether operator produce NAN/INF or not. It will be "
......@@ -49,10 +50,13 @@ static void CreateTensor(Variable* var, proto::VarDesc::VarType var_type) {
var->GetMutable<LoDRankTable>();
} else if (var_type == proto::VarDesc::LOD_TENSOR_ARRAY) {
var->GetMutable<LoDTensorArray>();
} else if (var_type == proto::VarDesc::PLACE_LIST) {
var->GetMutable<platform::PlaceList>();
} else {
PADDLE_THROW(
"Variable type %d is not in "
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE]",
"[LoDTensor, SelectedRows, FEED_MINIBATCH, FETCH_LIST, LOD_RANK_TABLE,"
" PLACE_LIST]",
var_type);
}
}
......
......@@ -123,6 +123,7 @@ message VarDesc {
STEP_SCOPES = 5;
LOD_RANK_TABLE = 6;
LOD_TENSOR_ARRAY = 7;
PLACE_LIST = 8;
}
required string name = 1;
required VarType type = 2;
......
......@@ -383,7 +383,7 @@ void OpDesc::InferVarType(BlockDesc *block) const {
for (auto &out_pair : this->outputs_) {
for (auto &out_var_name : out_pair.second) {
block->FindRecursiveOrCreateVar(out_var_name)
->SetType(proto::VarDesc::LOD_TENSOR);
.SetType(proto::VarDesc::LOD_TENSOR);
}
}
}
......
......@@ -12,7 +12,9 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <thread>
#include "paddle/framework/op_registry.h"
#include "paddle/operators/detail/safe_ref.h"
#include "paddle/platform/place.h"
#ifdef PADDLE_WITH_CUDA
#include "paddle/platform/gpu_info.h"
......@@ -21,6 +23,14 @@ limitations under the License. */
namespace paddle {
namespace operators {
static size_t CUDADevCount() {
#ifdef PADDLE_WITH_CUDA
return platform::GetCUDADeviceCount();
#else
return 0UL;
#endif
}
class GetPlacesOp : public framework::OperatorBase {
public:
GetPlacesOp(const std::string &type, const framework::VariableNameMap &inputs,
......@@ -28,28 +38,34 @@ class GetPlacesOp : public framework::OperatorBase {
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
void Run(const framework::Scope &scope,
const platform::DeviceContext &dev_ctx) const override {
const platform::Place &place) const override {
std::string device_type = Attr<std::string>("device_type");
auto device_count = Attr<int>("device_count");
auto device_count = static_cast<size_t>(Attr<int>("device_count"));
if (device_count == 0) {
if (device_type == "CUDA") {
device_count = CUDADevCount();
} else if (device_type == "CPU") {
device_count = std::thread::hardware_concurrency();
}
}
PADDLE_ENFORCE_NE(device_count, 0, "Cannot indicate %s device count",
device_type);
auto out_var_name = Output("Out");
auto *out_var = scope.FindVar(out_var_name);
PADDLE_ENFORCE(out_var != nullptr, "Output variable %s cannot be found",
out_var_name);
auto &places = *(out_var->GetMutable<std::vector<platform::Place>>());
places.resize(device_count);
auto &places =
*(detail::Ref(scope.FindVar(out_var_name),
"Output variable %s cannot be found", out_var_name)
.GetMutable<platform::PlaceList>());
places.reserve(device_count);
if (device_type == "CUDA") {
#ifdef PADDLE_WITH_CUDA
PADDLE_ENFORCE_LT(device_count, platform::GetCUDADeviceCount());
for (int i = 0; i < device_count; i++) {
places.emplace_back(platform::GPUPlace(i));
PADDLE_ENFORCE_LE(device_count, CUDADevCount(),
"Only %d CUDA devices found, cannot set to %d",
CUDADevCount(), device_count);
for (size_t i = 0; i < device_count; ++i) {
places.emplace_back(platform::CUDAPlace(i));
}
#else
PADDLE_THROW("'GPUPlace' is not supported in CPU only device.");
#endif
} else if (device_type == "CPU") {
for (int i = 0; i < device_count; i++) {
for (size_t i = 0; i < device_count; ++i) {
places.emplace_back(platform::CPUPlace());
}
}
......@@ -61,18 +77,38 @@ class GetPlacesOpProtoMaker : public framework::OpProtoAndCheckerMaker {
GetPlacesOpProtoMaker(OpProto *proto, OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) {
AddOutput("Out", "vector of Place");
AddAttr<int>("device_count", "(int)device count").SetDefault(1);
AddAttr<int>("device_count", "device count").SetDefault(1);
AddAttr<std::string>("device_type",
"(string), deivce type can be \"CPU\" and \"CUDA\"")
R"(device type must be in ["CPU", "CUDA"])")
.InEnum({"CPU", "CUDA"});
AddComment(R"DOC(
Returns a list of places based on flags. The list will be used for parallel execution.
Returns a list of places based on flags. The list will be used for parallel
execution.
)DOC");
}
};
class GetPlacesInferVarType : public framework::VarTypeInference {
public:
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o_name : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o_name).SetType(
framework::proto::VarDesc::PLACE_LIST);
}
}
};
class GetPlacesInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
// Do nothing
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker);
REGISTER_OPERATOR(get_places, ops::GetPlacesOp, ops::GetPlacesOpProtoMaker,
ops::GetPlacesInferVarType, ops::GetPlacesInferShape);
......@@ -66,7 +66,7 @@ class LoDRankTableInferVarType : public framework::VarTypeInference {
void operator()(const framework::OpDesc &op_desc,
framework::BlockDesc *block) const override {
for (auto &o : op_desc.Output("Out")) {
block->FindRecursiveOrCreateVar(o)->SetType(
block->FindRecursiveOrCreateVar(o).SetType(
framework::proto::VarDesc::LOD_RANK_TABLE);
}
}
......
......@@ -122,17 +122,17 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
for (auto& name : op_desc.Input("X")) {
VLOG(10) << name << " "
<< block->FindRecursiveOrCreateVar(name)->GetType();
<< block->FindRecursiveOrCreateVar(name).GetType();
}
bool any_input_is_lod_tensor = std::any_of(
inputs.begin(), inputs.end(), [block](const std::string& name) {
return block->FindRecursiveOrCreateVar(name)->GetType() ==
return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR;
});
auto is_tensor_array = [block](const std::string& name) {
return detail::Ref(block->FindRecursiveOrCreateVar(name)).GetType() ==
return block->FindRecursiveOrCreateVar(name).GetType() ==
framework::proto::VarDesc::LOD_TENSOR_ARRAY;
};
......@@ -146,8 +146,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
std::ostringstream os;
for (auto& each : inputs) {
os << " " << each << " type is "
<< detail::Ref(block->FindRecursiveOrCreateVar(each)).GetType()
<< "\n";
<< block->FindRecursiveOrCreateVar(each).GetType() << "\n";
}
PADDLE_ENFORCE(all_inputs_are_tensor_array,
"Not all inputs are tensor array:\n%s", os.str());
......@@ -158,7 +157,7 @@ class SumOpVarTypeInference : public framework::VarTypeInference {
}
auto out_var_name = op_desc.Output("Out").front();
auto& out_var = detail::Ref(block->FindRecursiveOrCreateVar(out_var_name));
auto& out_var = block->FindRecursiveOrCreateVar(out_var_name);
out_var.SetType(var_type);
auto& in_var = detail::Ref(block->FindVarRecursive(inputs.front()));
out_var.SetDataType(in_var.GetDataType());
......
......@@ -106,8 +106,7 @@ class WriteToArrayInferVarType : public framework::VarTypeInference {
auto x_name = op_desc.Input("X")[0];
auto out_name = op_desc.Output("Out")[0];
VLOG(10) << "Set Variable " << out_name << " as LOD_TENSOR_ARRAY";
auto &out = detail::Ref(block->FindRecursiveOrCreateVar(out_name),
"Cannot found %s", out_name);
auto &out = block->FindRecursiveOrCreateVar(out_name);
out.SetType(framework::proto::VarDesc::LOD_TENSOR_ARRAY);
auto *x = block->FindVarRecursive(x_name);
if (x != nullptr) {
......
......@@ -52,6 +52,8 @@ struct IsCUDAPlace : public boost::static_visitor<bool> {
typedef boost::variant<CUDAPlace, CPUPlace> Place;
using PlaceList = std::vector<Place>;
void set_place(const Place &);
const Place &get_place();
......
......@@ -231,7 +231,8 @@ void BindVarDsec(py::module &m) {
.value("FETCH_LIST", proto::VarDesc::FETCH_LIST)
.value("STEP_SCOPES", proto::VarDesc::STEP_SCOPES)
.value("LOD_RANK_TABLE", proto::VarDesc::LOD_RANK_TABLE)
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY);
.value("LOD_TENSOR_ARRAY", proto::VarDesc::LOD_TENSOR_ARRAY)
.value("PLACE_LIST", proto::VarDesc::PLACE_LIST);
}
void BindOpDesc(py::module &m) {
......
......@@ -448,7 +448,7 @@ class Operator(object):
no_kernel_op_set = {
'feed', 'fetch', 'save', 'load', 'recurrent',
'rnn_memory_helper_grad', 'conditional_block', 'while', 'send',
'recv', 'get_places', 'parallel_do'
'recv', 'parallel_do'
}
if type not in no_kernel_op_set:
self.desc.infer_var_type(self.block.desc)
......
......@@ -3,14 +3,14 @@ All util layers.
"""
from ..layer_helper import LayerHelper
from ..framework import Variable
from ..framework import unique_name
__all__ = ['get_places']
def get_places(device_count, device_type="CPU"):
def get_places(device_count=0, device_type="CPU"):
helper = LayerHelper('get_places', **locals())
out_places = helper.create_tmp_variable(dtype=helper.input_dtype())
out_places = helper.create_variable(name=unique_name(helper.name + ".out"))
helper.append_op(
type='get_places',
outputs={"Out": [out_places]},
......
import paddle.v2.fluid as fluid
import decorators
import unittest
class TestGetPlaces(unittest.TestCase):
@decorators.prog_scope()
def test_get_places(self):
places = fluid.layers.get_places()
cpu = fluid.CPUPlace()
exe = fluid.Executor(cpu)
exe.run(fluid.default_main_program())
self.assertEqual(places.type, fluid.core.VarDesc.VarType.PLACE_LIST)
if __name__ == '__main__':
unittest.main()
......@@ -200,6 +200,7 @@ class TestBook(unittest.TestCase):
program = Program()
with program_guard(program):
x = layers.get_places(device_count=4)
self.assertIsNotNone(x)
print(str(program))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册