/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include // NOLINT // for call_once #include #include #include #include #include "paddle/fluid/framework/channel.h" #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/init.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/string/to_string.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/platform/cuda_profiler.h" #include "paddle/fluid/platform/gpu_info.h" #endif // disable auto conversion to list in Python PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); namespace paddle { namespace pybind { bool IsCompiledWithCUDA() { #ifndef PADDLE_WITH_CUDA return false; #else return true; #endif } PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); // using framework in this function. Since it is inside a function, it will // not cause namespace pollution. using namespace paddle::framework; // NOLINT BindException(&m); py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("get_dims", [](const Tensor &self) { return vectorize(self.dims()); }) .def("set_dims", [](Tensor &self, const std::vector &dim) { self.Resize(make_ddim(dim)); }) .def("set_layout", [](Tensor &self, const std::string &layout) { self.set_layout(StringToDataLayout(layout)); }) .def("alloc_float", [](Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) .def("alloc_float", [](Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("alloc_int", [](Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("alloc_int", [](Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) .def("alloc_int", [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) .def("alloc_float", [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) #ifdef PADDLE_WITH_CUDA .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) #endif .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("set_float_element", TensorSetElement) .def("get_float_element", TensorGetElement) .def("set_double_element", TensorSetElement) .def("get_double_element", TensorGetElement) .def("dtype", [](Tensor &self) { return ToDataType(self.type()); }); py::class_(m, "LoDTensor") .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def( "__init__", [](LoDTensor &instance, const std::vector> &lod) { LoD new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); new (&instance) LoDTensor(new_lod); }) .def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); }) .def("set_lod", [](LoDTensor &self, const std::vector> &lod) { LoD new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); self.set_lod(new_lod); }) .def("lod", [](LoDTensor &self) -> std::vector> { auto lod = self.lod(); std::vector> new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); return new_lod; }); py::class_(m, "SelectedRows") .def("__init__", [](SelectedRows &instance) { new (&instance) SelectedRows(); }) .def("__init__", [](SelectedRows &instance, const std::vector rows, const int64_t &height) { new (&instance) SelectedRows(rows, height); }) .def("get_tensor", [](SelectedRows &self) { return self.mutable_value(); }, py::return_value_policy::reference) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) .def("set_rows", [](SelectedRows &self, std::vector rows) { #ifndef PADDLE_WITH_CUDA self.set_rows(rows); #else Vector new_rows(rows); self.set_rows(new_rows); #endif }) .def("rows", [](SelectedRows &self) { #ifndef PADDLE_WITH_CUDA return self.rows(); #else auto rows = self.rows(); std::vector new_rows; new_rows.reserve(rows.size()); std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows)); return new_rows; #endif }); py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. )DOC") .def("is_int", [](const Variable &var) { return var.IsType(); }) .def("set_int", [](Variable &var, int val) -> void { *var.GetMutable() = val; }) .def("get_int", [](const Variable &var) -> int { return var.Get(); }) .def("is_float", [](const Variable &var) { return var.IsType(); }) .def("set_float", [](Variable &var, float val) -> void { *var.GetMutable() = val; }) .def("get_float", [](const Variable &var) -> float { return var.Get(); }) .def("get_tensor", [](Variable &self) -> LoDTensor * { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_lod_rank_table", [](Variable &self) { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_selected_rows", [](Variable &self) -> SelectedRows * { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_lod_tensor_array", [](Variable &self) { return self.GetMutable(); }, py::return_value_policy::reference) #ifdef PADDLE_WITH_CUDA .def("get_communicator", [](Variable &self) -> platform::Communicator * { return self.GetMutable(); }, py::return_value_policy::reference) #endif .def("get_reader", [](Variable &self) -> framework::ReaderHolder * { PADDLE_ENFORCE(self.IsType()); return self.GetMutable(); }, py::return_value_policy::reference); py::class_(m, "Reader", "") .def("reset", &framework::ReaderHolder::ReInit); py::class_(m, "Scope", "") .def("var", [](Scope &self, const std::string &name) -> Variable * { return self.Var(name); }, py::return_value_policy::reference) .def("find_var", &Scope::FindVar, py::return_value_policy::reference) .def(py::init<>()) .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, py::return_value_policy::reference) .def("drop_kids", &Scope::DropKids); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { std::vector ret_values; for (auto &iter : OpInfoMap::Instance().map()) { auto &info = iter.second; if (info.HasOpProtoAndChecker()) { std::string str; PADDLE_ENFORCE( info.Proto().SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); ret_values.emplace_back(str); } } return ret_values; }); m.def( "get_grad_op_desc", [](const OpDesc &op_desc, const std::unordered_set &no_grad_set, const std::vector &grad_sub_block) { std::unordered_map grad_to_var; std::vector> grad_op_descs = framework::OpInfoMap::Instance() .Get(op_desc.Type()) .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, grad_sub_block); std::vector grad_op_desc_ptrs(grad_op_descs.size()); std::transform(grad_op_descs.begin(), grad_op_descs.end(), grad_op_desc_ptrs.begin(), [](std::unique_ptr &p) { return p.release(); }); return std::make_pair(grad_op_desc_ptrs, grad_to_var); }); m.def("prune", [](const ProgramDesc &origin, const std::vector> &targets) { ProgramDesc prog_with_targets(origin); for (const auto &t : targets) { prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true); } proto::ProgramDesc pruned_desc; Prune(*prog_with_targets.Proto(), &pruned_desc); return new ProgramDesc(pruned_desc); }); m.def("inference_optimize", [](ProgramDesc &origin) { proto::ProgramDesc pruned_desc; InferenceOptimize(*(origin.Proto()), &pruned_desc); return new ProgramDesc(pruned_desc); }); m.def("empty_var_name", []() { return framework::kEmptyVarName; }); m.def("grad_var_suffix", []() { return framework::kGradVarSuffix; }); m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") .def("empty", []() { return kEmptyVarName; }) .def("temp", []() { return kTempVarName; }); // clang-format off py::class_(m, "DeviceContext") .def_static("create", [](paddle::platform::CPUPlace& place) -> paddle::platform::DeviceContext* { return new paddle::platform::CPUDeviceContext(); }) .def_static("create", [](paddle::platform::CUDAPlace& place) -> paddle::platform::DeviceContext* { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("CUDAPlace is not supported in CPU device."); #else return new paddle::platform::CUDADeviceContext(place); #endif }) .def_static("create", [](paddle::platform::CUDAPinnedPlace& place) -> paddle::platform::DeviceContext* { #ifndef PADDLE_WITH_CUDA PADDLE_THROW( "CUDAPinnedPlace is not supported in CPU device."); #else return new paddle::platform::CUDAPinnedDeviceContext(place); #endif });; // clang-format on #ifdef PADDLE_WITH_CUDA py::class_(m, "Communicator").def(py::init<>()); #endif py::class_(m, "CUDAPlace") .def(py::init()) .def("__str__", string::to_string); py::class_(m, "CPUPlace") .def(py::init<>()) .def("__str__", string::to_string); py::class_(m, "CUDAPinnedPlace") .def(py::init<>()) .def("__str__", string::to_string); py::class_(m, "Place") .def(py::init<>()) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; }) .def("set_place", [](platform::Place &self, const platform::CUDAPlace &gpu_place) { self = gpu_place; }) .def("set_place", [](platform::Place &self, const platform::CUDAPinnedPlace &cuda_pinned_place) { self = cuda_pinned_place; }); py::class_(m, "Operator") .def_static("create", [](py::bytes protobin) { proto::OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc"); PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CPUPlace &place) { self.Run(scope, place); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CUDAPlace &place) { self.Run(scope, place); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CUDAPinnedPlace &place) { self.Run(scope, place); }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", [](const OperatorBase &op) -> std::map> { return op.Outputs(); }) .def("output_vars", [](const OperatorBase &op) { return op.OutputVars(true); }) .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); }) .def("__str__", &OperatorBase::DebugString) .def("no_intermediate_outputs", [](const OperatorBase &op) { return op.OutputVars(false); }) .def("support_gpu", &OperatorBase::SupportGPU); py::class_(m, "Executor") .def(py::init()) .def("run", (void (Executor::*)(const ProgramDesc &, Scope *, int, bool, bool)) & Executor::Run); m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); m.def("init_devices", [](bool init_p2p) { framework::InitDevices(init_p2p); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); #ifdef PADDLE_WITH_CUDA m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { // Only GPUs with Compute Capability >= 53 support float16 return platform::GetCUDAComputeCapability(place.device) >= 53; }); #endif m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); BindProgramDesc(&m); BindBlockDesc(&m); BindVarDsec(&m); BindOpDesc(&m); BindConstValue(&m); py::class_(m, "LodRankTable") .def("items", [](framework::LoDRankTable &table) { std::vector> res; for (auto &item : table.items()) { res.push_back({item.index, item.length}); } return res; }); py::class_(m, "LoDTensorArray") .def("__getitem__", [](LoDTensorArray &self, size_t i) { return &self.at(i); }, py::return_value_policy::reference) .def("__len__", [](LoDTensorArray &self) { return self.size(); }) .def("__setitem__", [](LoDTensorArray &self, size_t i, const LoDTensor &t) { PADDLE_ENFORCE_LT(i, self.size()); self[i].ShareDataWith(t); self[i].set_lod(t.lod()); }) .def("append", [](LoDTensorArray &self, const LoDTensor &t) { self.emplace_back(); self.back().ShareDataWith(t); self.back().set_lod(t.lod()); }); m.def("IsInplace", [](std::string op) -> bool { return operators::IsInplace(op); }); m.def("op_support_gpu", OpSupportGPU); #ifdef PADDLE_WITH_CUDA m.def("get_cuda_device_count", platform::GetCUDADeviceCount); m.def("nvprof_init", platform::CudaProfilerInit); m.def("nvprof_start", platform::CudaProfilerStart); m.def("nvprof_stop", platform::CudaProfilerStop); #endif py::enum_(m, "ProfilerState", py::arithmetic()) .value("kDisabled", platform::ProfilerState::kDisabled) .value("kCPU", platform::ProfilerState::kCPU) .value("kCUDA", platform::ProfilerState::kCUDA) .value("kAll", platform::ProfilerState::kAll) .export_values(); py::enum_(m, "EventSortingKey", py::arithmetic()) .value("kDefault", platform::EventSortingKey::kDefault) .value("kCalls", platform::EventSortingKey::kCalls) .value("kTotal", platform::EventSortingKey::kTotal) .value("kMin", platform::EventSortingKey::kMin) .value("kMax", platform::EventSortingKey::kMax) .value("kAve", platform::EventSortingKey::kAve) .export_values(); m.def("enable_profiler", platform::EnableProfiler); m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); py::class_(pe, "ExecutionStrategy") .def(py::init()) .def_property( "num_threads", [](const ExecutionStrategy &self) { return self.num_threads_; }, [](ExecutionStrategy &self, size_t num_threads) { self.num_threads_ = num_threads; }) .def_property( "use_event", [](const ExecutionStrategy &self) { return self.use_event_; }, [](ExecutionStrategy &self, bool use_event) { self.use_event_ = use_event; }) .def_property( "allow_op_delay", [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](ExecutionStrategy &self, bool allow_op_delay) { self.allow_op_delay_ = allow_op_delay; }) .def_property( "num_iteration_per_drop_scope", [](const ExecutionStrategy &self) { return self.num_iteration_per_drop_scope_; }, [](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) { self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope; }); py::class_ build_strategy(pe, "BuildStrategy"); py::enum_(build_strategy, "ReduceStrategy") .value("Reduce", BuildStrategy::ReduceStrategy::kReduce) .value("AllReduce", BuildStrategy::ReduceStrategy::kAllReduce); py::enum_(build_strategy, "GradientScaleStrategy") .value("CoeffNumDevice", BuildStrategy::GradientScaleStrategy::kCoeffNumDevice) .value("One", BuildStrategy::GradientScaleStrategy::kOne) .value("Customized", BuildStrategy::GradientScaleStrategy::kCustomized); build_strategy.def(py::init()) .def_property( "reduce_strategy", [](const BuildStrategy &self) { return self.reduce_; }, [](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) { self.reduce_ = strategy; }) .def_property( "gradient_scale_strategy", [](const BuildStrategy &self) { return self.gradient_scale_; }, [](BuildStrategy &self, BuildStrategy::GradientScaleStrategy strategy) { self.gradient_scale_ = strategy; }); pe.def(py::init &, const std::unordered_set &, const std::unordered_set &, const ProgramDesc &, const std::string &, Scope *, std::vector &, const ExecutionStrategy &, const BuildStrategy &, size_t, size_t>()) .def("bcast_params", &ParallelExecutor::BCastParamsToGPUs) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope* // one by one and mark them as reference. .def("local_scopes", [](ParallelExecutor &self) -> std::vector * { return &self.GetLocalScopes(); }, py::return_value_policy::reference) .def("feed_tensors_into_local_scopes", &ParallelExecutor::FeedTensorsIntoLocalScopes) .def("feed_and_split_tensor_into_local_scopes", &ParallelExecutor::FeedAndSplitTensorIntoLocalScopes) .def("run", &ParallelExecutor::Run); BindRecordIOWriter(&m); return m.ptr(); } } // namespace pybind } // namespace paddle