/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ #include #include #include #include #include // NOLINT // for call_once #include #include #include #include #include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor_array.h" #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/parallel_executor.h" #include "paddle/fluid/framework/prune.h" #include "paddle/fluid/framework/reader.h" #include "paddle/fluid/framework/selected_rows.h" #include "paddle/fluid/framework/version.h" #include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/reader/lod_tensor_blocking_queue.h" #include "paddle/fluid/platform/enforce.h" #include "paddle/fluid/platform/init.h" #include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/pybind/const_value.h" #include "paddle/fluid/pybind/exception.h" #include "paddle/fluid/pybind/protobuf.h" #include "paddle/fluid/pybind/pybind.h" // NOLINT #include "paddle/fluid/pybind/recordio.h" #include "paddle/fluid/pybind/tensor_py.h" #include "paddle/fluid/string/to_string.h" #ifdef PADDLE_WITH_CUDA #include "paddle/fluid/operators/nccl/nccl_gpu_common.h" #include "paddle/fluid/platform/cuda_profiler.h" #include "paddle/fluid/platform/gpu_info.h" #endif #include "pybind11/stl.h" // disable auto conversion to list in Python PYBIND11_MAKE_OPAQUE(paddle::framework::LoDTensorArray); namespace paddle { namespace pybind { bool IsCompiledWithCUDA() { #ifndef PADDLE_WITH_CUDA return false; #else return true; #endif } bool IsCompiledWithDIST() { #ifdef PADDLE_WITH_DISTRIBUTE return true; #else return false; #endif } PYBIND11_PLUGIN(core) { py::module m("core", "C++ core of PaddlePaddle"); // using framework in this function. Since it is inside a function, it will // not cause namespace pollution. using namespace paddle::framework; // NOLINT BindException(&m); py::class_(m, "Tensor", py::buffer_protocol()) .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("_get_dims", [](const Tensor &self) { return vectorize(self.dims()); }) .def("_set_dims", [](Tensor &self, const std::vector &dim) { self.Resize(make_ddim(dim)); }) .def("_set_layout", [](Tensor &self, const std::string &layout) { self.set_layout(StringToDataLayout(layout)); }) .def("_alloc_float", [](Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) .def("_alloc_float", [](Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("_alloc_int", [](Tensor &self, paddle::platform::CPUPlace &place) { self.mutable_data(place); }) .def("_alloc_int", [](Tensor &self, paddle::platform::CUDAPlace &place) { self.mutable_data(place); }) .def("_alloc_int", [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) .def("_alloc_float", [](Tensor &self, paddle::platform::CUDAPinnedPlace &place) { self.mutable_data(place); }) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) .def("set", PyCPUTensorSetFromArray) #ifdef PADDLE_WITH_CUDA .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDATensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) .def("set", PyCUDAPinnedTensorSetFromArray) #endif .def("shape", [](Tensor &self) { return vectorize(self.dims()); }) .def("_set_float_element", TensorSetElement) .def("_get_float_element", TensorGetElement) .def("_set_double_element", TensorSetElement) .def("_get_double_element", TensorGetElement) .def("_dtype", [](Tensor &self) { return ToDataType(self.type()); }); py::class_(m, "LoDTensor", R"DOC( LoDTensor is a Tensor with optional LoD information. np.array(lod_tensor) can convert LoDTensor to numpy array. lod_tensor.lod() can retrieve the LoD information. LoD is short for Level of Details and is usually used for varied sequence length. You can skip the following comment if you don't need optional LoD. For example: A LoDTensor X can look like the example below. It contains 2 sequences. The first has length 2 and the second has length 3, as described by x.lod. The first tensor dimension 6=2+3 is calculated from LoD if it's available. It means the total number of sequence element. In X, each element has 2 columns, hence [6, 2]. x.lod = [[2, 3]] x.data = [[1, 2], [3, 4], [5, 6], [7, 8], [9, 10], [11, 12]] x.shape = [6, 2] LoD can have multiple levels (for example, a paragraph can have multiple sentences and a sentence can have multiple words). In the following LodTensor Y, the lod_level is 2. It means there are 2 sequence, the first sequence length is 2 (has 2 sub-sequences), the second one's length is 1. The first sequence's 2 sub-sequences have length 2 and 2, respectively. And the second sequence's 1 sub-sequence has length 3. y.lod = [[2 1], [2 2 3]] y.shape = [2+2+3, ...] Note: In above description, LoD is length-based. In Paddle internal implementation, lod is offset-based. Hence, internally, y.lod is represented as [[0, 2, 3], [0, 2, 4, 7]] (length-based equivlent would be [[2-0, 3-2], [2-0, 4-2, 7-4]]). Sometimes LoD is called recursive_sequence_length to be more self-explanatory. In this case, it must be length-based. Due to history reasons. when LoD is called lod in public API, it might be offset-based. Users should be careful about it. )DOC") .def_buffer( [](Tensor &self) -> py::buffer_info { return CastToPyBuffer(self); }) .def("__init__", [](LoDTensor &instance, const std::vector> &recursive_sequence_lengths) { LoD new_lod; new_lod.reserve(recursive_sequence_lengths.size()); std::copy(recursive_sequence_lengths.begin(), recursive_sequence_lengths.end(), std::back_inserter(new_lod)); LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod); PADDLE_ENFORCE( CheckLoD(new_offset_lod, -1), "the provided recursive_sequence_lengths info is invalid"); new (&instance) LoDTensor(new_offset_lod); }) .def("__init__", [](LoDTensor &instance) { new (&instance) LoDTensor(); }) // We implement offset based LOD in C++ while we use length based with // Python API. So we changed set_lod to set_recursive_sequence_lengths to // avoid misuse. // The discussion is here: // https://github.com/PaddlePaddle/Paddle/issues/10855 .def("set_lod", [](LoDTensor &self, const std::vector> &lod) { // the input lod is offset-based level-of-detail info LoD new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); PADDLE_ENFORCE(CheckLoD(new_lod, vectorize(self.dims()).front()), "the provided lod info is invalid"); self.set_lod(new_lod); }) .def("set_recursive_sequence_lengths", [](LoDTensor &self, const std::vector> &recursive_sequence_lengths) { // the input recursive_sequence_lengths is length-based // level-of-detail info LoD new_lod; new_lod.reserve(recursive_sequence_lengths.size()); std::copy(recursive_sequence_lengths.begin(), recursive_sequence_lengths.end(), std::back_inserter(new_lod)); LoD new_offset_lod = ConvertToOffsetBasedLoD(new_lod); PADDLE_ENFORCE( CheckLoD(new_offset_lod, vectorize(self.dims()).front()), "the provided recursive_sequence_lengths info is invalid"); self.set_lod(new_offset_lod); }) .def("lod", [](LoDTensor &self) -> std::vector> { // output the offset-based lod info LoD lod = self.lod(); std::vector> new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); return new_lod; }) // Set above comments of set_lod. .def("recursive_sequence_lengths", [](LoDTensor &self) -> std::vector> { // output the length-based lod info LoD lod = ConvertToLengthBasedLoD(self.lod()); std::vector> new_lod; new_lod.reserve(lod.size()); std::copy(lod.begin(), lod.end(), std::back_inserter(new_lod)); return new_lod; }) .def("has_valid_recursive_sequence_lengths", [](LoDTensor &self) -> bool { // Check that the lod info is valid and match the outermost // dimension of the LoDTensor data return CheckLoD(self.lod(), vectorize(self.dims()).front()); }); py::class_(m, "SelectedRows") .def("__init__", [](SelectedRows &instance) { new (&instance) SelectedRows(); }) .def("__init__", [](SelectedRows &instance, const std::vector rows, const int64_t &height) { new (&instance) SelectedRows(rows, height); }) .def("get_tensor", [](SelectedRows &self) { return self.mutable_value(); }, py::return_value_policy::reference) .def("set_height", &SelectedRows::set_height) .def("height", &SelectedRows::height) .def("set_rows", [](SelectedRows &self, std::vector rows) { #ifndef PADDLE_WITH_CUDA self.set_rows(rows); #else Vector new_rows(rows); self.set_rows(new_rows); #endif }) .def("sync_index", [](SelectedRows &instance) { instance.SyncIndex(); }) .def("rows", [](SelectedRows &self) { auto rows = self.rows(); std::vector new_rows; new_rows.reserve(rows.size()); std::copy(rows.begin(), rows.end(), std::back_inserter(new_rows)); return new_rows; }); py::class_(m, "Variable", R"DOC(Variable Class. All parameter, weight, gradient are variables in Paddle. )DOC") .def("is_int", [](const Variable &var) { return var.IsType(); }) .def("set_int", [](Variable &var, int val) -> void { *var.GetMutable() = val; }) .def("get_int", [](const Variable &var) -> int { return var.Get(); }) .def("is_float", [](const Variable &var) { return var.IsType(); }) .def("set_float", [](Variable &var, float val) -> void { *var.GetMutable() = val; }) .def("get_float", [](const Variable &var) -> float { return var.Get(); }) .def("get_tensor", [](Variable &self) -> LoDTensor * { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_lod_rank_table", [](Variable &self) { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_selected_rows", [](Variable &self) -> SelectedRows * { return self.GetMutable(); }, py::return_value_policy::reference) .def("get_lod_tensor_array", [](Variable &self) { return self.GetMutable(); }, py::return_value_policy::reference) #ifdef PADDLE_WITH_CUDA .def("get_communicator", [](Variable &self) -> platform::Communicator * { return self.GetMutable(); }, py::return_value_policy::reference) #endif .def("get_reader", [](Variable &self) -> framework::ReaderHolder * { PADDLE_ENFORCE(self.IsType()); return self.GetMutable(); }, py::return_value_policy::reference); py::class_(m, "Reader", "") .def("reset", &framework::ReaderHolder::ResetAll); using LoDTensorBlockingQueue = ::paddle::operators::reader::LoDTensorBlockingQueue; using LoDTensorBlockingQueueHolder = ::paddle::operators::reader::LoDTensorBlockingQueueHolder; py::class_>( m, "LoDTensorBlockingQueue", "") .def("push", [](LoDTensorBlockingQueue &self, const std::vector &lod_tensor_vec) { pybind11::gil_scoped_release release; return self.Push(lod_tensor_vec); }) .def("size", &LoDTensorBlockingQueue::Size) .def("capacity", &LoDTensorBlockingQueue::Cap) .def("close", &LoDTensorBlockingQueue::Close) .def("is_closed", &LoDTensorBlockingQueue::IsClosed); m.def("init_lod_tensor_blocking_queue", [](Variable &var, size_t capacity, const std::vector> &shapes) -> std::shared_ptr { std::vector dims(shapes.size()); std::transform(shapes.begin(), shapes.end(), dims.begin(), [](const std::vector &shape) { return make_ddim(shape); }); auto *holder = var.GetMutable(); holder->InitOnce(capacity, dims); return holder->GetQueue(); }, py::return_value_policy::copy); py::class_(m, "Scope", "") .def("var", [](Scope &self, const std::string &name) -> Variable * { return self.Var(name); }, py::return_value_policy::reference) .def("find_var", &Scope::FindVar, py::return_value_policy::reference) .def(py::init<>()) .def("new_scope", [](Scope &self) -> Scope * { return &self.NewScope(); }, py::return_value_policy::reference) .def("drop_kids", &Scope::DropKids); //! @note: Be careful! PyBind will return std::string as an unicode, not //! Python str. If you want a str object, you should cast them in Python. m.def("get_all_op_protos", []() -> std::vector { std::vector ret_values; for (auto &iter : OpInfoMap::Instance().map()) { auto &info = iter.second; if (info.HasOpProtoAndChecker()) { std::string str; PADDLE_ENFORCE( info.Proto().SerializeToString(&str), "Serialize OpProto Error. This could be a bug of Paddle."); ret_values.emplace_back(str); } } return ret_values; }); m.def( "get_grad_op_desc", [](const OpDesc &op_desc, const std::unordered_set &no_grad_set, const std::vector &grad_sub_block) { std::unordered_map grad_to_var; std::vector> grad_op_descs = framework::OpInfoMap::Instance() .Get(op_desc.Type()) .GradOpMaker()(op_desc, no_grad_set, &grad_to_var, grad_sub_block); std::vector grad_op_desc_ptrs(grad_op_descs.size()); std::transform(grad_op_descs.begin(), grad_op_descs.end(), grad_op_desc_ptrs.begin(), [](std::unique_ptr &p) { return p.release(); }); return std::make_pair(grad_op_desc_ptrs, grad_to_var); }); m.def("prune", [](const ProgramDesc &origin, const std::vector> &targets) { ProgramDesc prog_with_targets(origin); for (const auto &t : targets) { prog_with_targets.MutableBlock(t[0])->Op(t[1])->SetIsTarget(true); } proto::ProgramDesc pruned_desc; Prune(*prog_with_targets.Proto(), &pruned_desc); return new ProgramDesc(pruned_desc); }); m.def("empty_var_name", []() { return std::string(framework::kEmptyVarName); }); m.def("grad_var_suffix", []() { return std::string(framework::kGradVarSuffix); }); m.def_submodule( "var_names", "The module will return special predefined variable name in Paddle") .def("empty", []() { return kEmptyVarName; }) .def("temp", []() { return kTempVarName; }); // clang-format off py::class_(m, "DeviceContext") .def_static("create", [](paddle::platform::CPUPlace& place) -> paddle::platform::DeviceContext* { return new paddle::platform::CPUDeviceContext(); }) .def_static("create", [](paddle::platform::CUDAPlace& place) -> paddle::platform::DeviceContext* { #ifndef PADDLE_WITH_CUDA PADDLE_THROW("CUDAPlace is not supported in CPU device."); #else return new paddle::platform::CUDADeviceContext(place); #endif }) .def_static("create", [](paddle::platform::CUDAPinnedPlace& place) -> paddle::platform::DeviceContext* { #ifndef PADDLE_WITH_CUDA PADDLE_THROW( "CUDAPinnedPlace is not supported in CPU device."); #else return new paddle::platform::CUDAPinnedDeviceContext(place); #endif });; // clang-format on #ifdef PADDLE_WITH_CUDA py::class_(m, "Communicator").def(py::init<>()); #endif py::class_(m, "CUDAPlace") .def(py::init()) .def("__str__", string::to_string); py::class_(m, "CPUPlace") .def(py::init<>()) .def("__str__", string::to_string); py::class_(m, "CUDAPinnedPlace") .def(py::init<>()) .def("__str__", string::to_string); py::class_(m, "Place") .def(py::init<>()) .def("set_place", [](platform::Place &self, const platform::CPUPlace &cpu_place) { self = cpu_place; }) .def("set_place", [](platform::Place &self, const platform::CUDAPlace &gpu_place) { self = gpu_place; }) .def("set_place", [](platform::Place &self, const platform::CUDAPinnedPlace &cuda_pinned_place) { self = cuda_pinned_place; }); py::class_(m, "Operator") .def_static("create", [](py::bytes protobin) { proto::OpDesc desc; PADDLE_ENFORCE(desc.ParsePartialFromString(protobin), "Cannot parse user input to OpDesc"); PADDLE_ENFORCE(desc.IsInitialized(), "User OpDesc is not initialized, reason %s", desc.InitializationErrorString()); return OpRegistry::CreateOp(desc); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CPUPlace &place) { self.Run(scope, place); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CUDAPlace &place) { self.Run(scope, place); }) .def("run", [](OperatorBase &self, const Scope &scope, const platform::CUDAPinnedPlace &place) { self.Run(scope, place); }) .def("type", [](const OperatorBase &op) -> std::string { return op.Type(); }) .def("outputs", [](const OperatorBase &op) -> std::map> { return op.Outputs(); }) .def("output_vars", [](const OperatorBase &op) { return op.OutputVars(true); }) .def("inputs", [](const OperatorBase &op) { return op.Inputs(); }) .def("input_vars", [](const OperatorBase &op) { return op.InputVars(); }) .def("__str__", &OperatorBase::DebugString) .def("no_intermediate_outputs", [](const OperatorBase &op) { return op.OutputVars(false); }) .def("support_gpu", &OperatorBase::SupportGPU); py::class_(m, "Executor") .def(py::init()) .def("close", &Executor::Close) .def("run", [](Executor &self, const ProgramDesc &prog, Scope *scope, int block_id, bool create_local_scope, bool create_vars) { pybind11::gil_scoped_release release; self.Run(prog, scope, block_id, create_local_scope, create_vars); }); m.def("init_gflags", framework::InitGflags); m.def("init_glog", framework::InitGLOG); m.def("init_devices", [](bool init_p2p) { framework::InitDevices(init_p2p); }); m.def("is_compiled_with_cuda", IsCompiledWithCUDA); m.def("is_compiled_with_dist", IsCompiledWithDIST); #ifdef PADDLE_WITH_CUDA m.def("is_float16_supported", [](const platform::CUDAPlace &place) -> bool { // Only GPUs with Compute Capability >= 53 support float16 return platform::GetCUDAComputeCapability(place.device) >= 53; }); #endif m.def("set_feed_variable", framework::SetFeedVariable); m.def("get_fetch_variable", framework::GetFetchVariable); m.def("_is_program_version_supported", IsProgramVersionSupported); BindProgramDesc(&m); BindBlockDesc(&m); BindVarDsec(&m); BindOpDesc(&m); BindConstValue(&m); py::class_(m, "LodRankTable") .def("items", [](framework::LoDRankTable &table) { std::vector> res; for (auto &item : table.items()) { res.push_back({item.index, item.length}); } return res; }); py::class_(m, "LoDTensorArray") .def("__init__", [](LoDTensorArray &instance) { new (&instance) LoDTensorArray(); }) .def("__getitem__", [](LoDTensorArray &self, size_t i) { return &self.at(i); }, py::return_value_policy::reference) .def("__len__", [](LoDTensorArray &self) { return self.size(); }) .def("__setitem__", [](LoDTensorArray &self, size_t i, const LoDTensor &t) { PADDLE_ENFORCE_LT(i, self.size()); self[i].ShareDataWith(t); self[i].set_lod(t.lod()); }) .def("append", [](LoDTensorArray &self, const LoDTensor &t) { self.emplace_back(); self.back().ShareDataWith(t); self.back().set_lod(t.lod()); }); m.def("IsInplace", [](std::string op) -> bool { return operators::IsInplace(op); }); m.def("op_support_gpu", OpSupportGPU); #ifdef PADDLE_WITH_CUDA m.def("get_cuda_device_count", platform::GetCUDADeviceCount); m.def("nvprof_init", platform::CudaProfilerInit); m.def("nvprof_start", platform::CudaProfilerStart); m.def("nvprof_stop", platform::CudaProfilerStop); #endif py::enum_(m, "ProfilerState", py::arithmetic()) .value("kDisabled", platform::ProfilerState::kDisabled) .value("kCPU", platform::ProfilerState::kCPU) .value("kCUDA", platform::ProfilerState::kCUDA) .value("kAll", platform::ProfilerState::kAll) .export_values(); py::enum_(m, "EventSortingKey", py::arithmetic()) .value("kDefault", platform::EventSortingKey::kDefault) .value("kCalls", platform::EventSortingKey::kCalls) .value("kTotal", platform::EventSortingKey::kTotal) .value("kMin", platform::EventSortingKey::kMin) .value("kMax", platform::EventSortingKey::kMax) .value("kAve", platform::EventSortingKey::kAve) .export_values(); m.def("enable_profiler", platform::EnableProfiler); m.def("disable_profiler", platform::DisableProfiler); m.def("is_profiler_enabled", platform::IsProfileEnabled); m.def("reset_profiler", platform::ResetProfiler); py::class_> pass(m, "Pass"); pass.def(py::init()) .def("set_str", [](ir::Pass &self, const std::string &name, const std::string &attr) { self.Set(name, new std::string(attr)); }); py::class_> pb( m, "PassBuilder"); pb.def(py::init()) .def("append_pass", [](ir::PassBuilder &self, const std::string &pass_type) -> std::shared_ptr { return self.AppendPass(pass_type); }) .def("all_passes", [](ir::PassBuilder &self) { return self.AllPasses(); }) .def("insert_pass", [](ir::PassBuilder &self, size_t idx, const std::string &pass_type) { return self.InsertPass(idx, pass_type); }) .def("remove_pass", [](ir::PassBuilder &self, size_t idx) { self.RemovePass(idx); }); // -- python binds for parallel executor. py::class_ pe(m, "ParallelExecutor"); py::class_ exec_strategy(pe, "ExecutionStrategy", R"DOC( ExecutionStrategy allows the user to more preciously control how to run the program in ParallelExecutor by setting the property. The available properties include: use_cuda (bool): Whether to use CUDA or not. Default True. num_threads (int): The number of threads that used to run the operators in ParallelExecutor. If it is not set, it will be set in ParallelExecutor according to the device count. Default 0. allow_op_delay (bool): Whether to delay the communication operators to run. Default False. num_iteration_per_drop_scope (int): how many iterations between the two dropping local scopes. Default 100. )DOC"); exec_strategy.def(py::init()) .def_property( "num_threads", [](const ExecutionStrategy &self) { return self.num_threads_; }, [](ExecutionStrategy &self, size_t num_threads) { self.num_threads_ = num_threads; }) .def_property( "use_cuda", [](const ExecutionStrategy &self) { return self.use_cuda_; }, [](ExecutionStrategy &self, bool use_cuda) { self.use_cuda_ = use_cuda; }) .def_property( "allow_op_delay", [](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](ExecutionStrategy &self, bool allow_op_delay) { self.allow_op_delay_ = allow_op_delay; }) .def_property( "num_iteration_per_drop_scope", [](const ExecutionStrategy &self) { return self.num_iteration_per_drop_scope_; }, [](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) { self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope; }); exec_strategy.def_property( "use_experimental_executor", [](const ExecutionStrategy &self) { return self.type_ == ExecutionStrategy::kExperimental; }, [](ExecutionStrategy &self, bool experimental) { self.type_ = experimental ? ExecutionStrategy::kExperimental : ExecutionStrategy::kDefault; }); py::class_ build_strategy(pe, "BuildStrategy", R"DOC( BuildStrategy allows the user to more preciously control how to build the SSA Graph in ParallelExecutor by setting the property. The available properties include: reduce_strategy (str): There are two reduce strategies, 'AllReduce' and 'Reduce'. If you want that all parameters will be optimized on all devices, you can choose 'AllReduce'; if you choose 'Reduce', all parameters will be evenly allocated to different devices for optimization, and then broadcast the optimized parameter to other devices. Default 'AllReduce'. gradient_scale_strategy (str): There are two ways of defining loss@grad, 'CoeffNumDevice' and 'Customized'. By default, ParallelExecutor sets the loss@grad according to the number of devices. If you want to customize loss@grad, you can choose 'Customized'. Default 'CoeffNumDevice'. debug_graphviz_path (str): Whether to write the SSA Graph to file in the form of graphviz. It is useful for debugging. Default "". )DOC"); py::enum_(build_strategy, "ReduceStrategy") .value("Reduce", BuildStrategy::ReduceStrategy::kReduce) .value("AllReduce", BuildStrategy::ReduceStrategy::kAllReduce); py::enum_(build_strategy, "GradientScaleStrategy") .value("CoeffNumDevice", BuildStrategy::GradientScaleStrategy::kCoeffNumDevice) .value("One", BuildStrategy::GradientScaleStrategy::kOne) .value("Customized", BuildStrategy::GradientScaleStrategy::kCustomized); build_strategy.def(py::init()) .def_property( "reduce_strategy", [](const BuildStrategy &self) { return self.reduce_; }, [](BuildStrategy &self, BuildStrategy::ReduceStrategy strategy) { self.reduce_ = strategy; }) .def_property( "gradient_scale_strategy", [](const BuildStrategy &self) { return self.gradient_scale_; }, [](BuildStrategy &self, BuildStrategy::GradientScaleStrategy strategy) { self.gradient_scale_ = strategy; }) .def_property( "debug_graphviz_path", [](const BuildStrategy &self) { return self.debug_graphviz_path_; }, [](BuildStrategy &self, const std::string &path) { self.debug_graphviz_path_ = path; }) .def_property( "enable_data_balance", [](const BuildStrategy &self) { return self.enable_data_balance_; }, [](BuildStrategy &self, bool b) { self.enable_data_balance_ = b; }) .def_property("fuse_elewise_add_act_ops", [](const BuildStrategy &self) { return self.fuse_elewise_add_act_ops_; }, [](BuildStrategy &self, bool b) { self.fuse_elewise_add_act_ops_ = b; }) .def("_create_passes_from_strategy", [](BuildStrategy &self) -> std::shared_ptr { return self.CreatePassesFromStrategy(); }); pe.def(py::init &, const std::unordered_set &, const std::unordered_set &, const ProgramDesc &, const std::string &, Scope *, std::vector &, const ExecutionStrategy &, const BuildStrategy &, size_t, size_t>()) // NOTE: even we return a vec* to Python use reference policy. // We still cannot get local_scope from this vector, since the element // of vec will be freed by Python GC. We can only return Scope* // one by one and mark them as reference. .def("local_scopes", [](ParallelExecutor &self) -> std::vector * { return &self.GetLocalScopes(); }, py::return_value_policy::reference) .def("feed_tensors_into_local_scopes", &ParallelExecutor::FeedTensorsIntoLocalScopes) .def("feed_and_split_tensor_into_local_scopes", &ParallelExecutor::FeedAndSplitTensorIntoLocalScopes) .def("run", [](ParallelExecutor &self, const std::vector &fetch_tensors, const std::string &fetched_var_name) { pybind11::gil_scoped_release release; self.Run(fetch_tensors, fetched_var_name); }); BindRecordIOWriter(&m); return m.ptr(); } } // namespace pybind } // namespace paddle