未验证 提交 c36dd3b3 编写于 作者: Y Yu Yang 提交者: GitHub

Merge pull request #11114 from reyoung/feature/yep

Try to speed up parallel executor
...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope ...@@ -87,7 +87,7 @@ cc_library(executor SRCS executor.cc DEPS op_registry device_context scope
framework_proto glog lod_rank_table feed_fetch_method) framework_proto glog lod_rank_table feed_fetch_method)
cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor) cc_library(parallel_executor SRCS parallel_executor.cc DEPS multi_devices_graph_builder threaded_ssa_graph_executor scope_buffered_ssa_graph_executor)
cc_library(prune SRCS prune.cc DEPS framework_proto) cc_library(prune SRCS prune.cc DEPS framework_proto)
cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context) cc_test(prune_test SRCS prune_test.cc DEPS op_info prune recurrent_op device_context)
......
...@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha ...@@ -36,5 +36,6 @@ cc_test(broadcast_op_test SRCS broadcast_op_handle_test.cc DEPS var_handle op_ha
device_context broadcast_op_handle) device_context broadcast_op_handle)
cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory cc_test(gather_op_test SRCS gather_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
device_context gather_op_handle) device_context gather_op_handle)
cc_library(scope_buffered_ssa_graph_executor SRCS scope_buffered_ssa_graph_executor.cc DEPS ssa_graph_executor)
#cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory #cc_test(reduce_op_handle_test SRCS reduce_op_handle_test.cc DEPS var_handle op_handle_base scope ddim memory
# device_context reduce_op_handle ) # device_context reduce_op_handle )
...@@ -22,6 +22,7 @@ struct ExecutionStrategy { ...@@ -22,6 +22,7 @@ struct ExecutionStrategy {
size_t num_threads_{0}; size_t num_threads_{0};
bool use_event_{true}; bool use_event_{true};
bool allow_op_delay_{false}; bool allow_op_delay_{false};
size_t num_iteration_per_drop_scope_{100};
}; };
} // namespace details } // namespace details
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include <string>
#include <vector>
#include "paddle/fluid/framework/executor.h"
namespace paddle {
namespace framework {
namespace details {
ScopeBufferedSSAGraphExecutor::ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope *> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor> &&underlying_executor)
: strategy_(std::move(strategy)),
underlying_executor_(std::move(underlying_executor)),
local_scopes_(std::move(local_scopes)),
var_infos_(std::move(var_infos)),
places_(std::move(places)) {}
FeedFetchList ScopeBufferedSSAGraphExecutor::Run(
const std::vector<std::string> &fetch_tensors) {
if (drop_scope_counter_ == 0) {
// Create local scopes.
for (auto it = local_scopes_.rbegin(); it != local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &info : var_infos_) {
if (scope->FindVar(info.name_) != nullptr) {
continue;
}
if (info.persistable_) { // Persistable
InitializeVariable(scope->Var(info.name_), info.type_);
} else {
InitializeVariable(local_scope.Var(info.name_), info.type_);
}
}
}
}
auto fetch_data = underlying_executor_->Run(fetch_tensors);
drop_scope_counter_ += 1;
if (!fetch_tensors.empty() ||
drop_scope_counter_ == strategy_.num_iteration_per_drop_scope_) {
drop_scope_counter_ = 0;
// Wait All computational streams
for (auto p : places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
}
return fetch_data;
}
} // namespace details
} // namespace framework
} // namespace paddle
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/details/execution_strategy.h"
#include "paddle/fluid/framework/details/ssa_graph_executor.h"
#include "paddle/fluid/framework/scope.h"
#include "paddle/fluid/platform/place.h"
namespace paddle {
namespace framework {
namespace details {
struct VariableInfo {
std::string name_;
proto::VarType::Type type_;
bool persistable_;
};
class ScopeBufferedSSAGraphExecutor : public SSAGraphExecutor {
public:
ScopeBufferedSSAGraphExecutor(
ExecutionStrategy strategy, std::vector<Scope*> local_scopes,
std::vector<VariableInfo> var_infos, std::vector<platform::Place> places,
std::unique_ptr<SSAGraphExecutor>&& underlying_executor);
FeedFetchList Run(const std::vector<std::string>& fetch_tensors) override;
private:
size_t drop_scope_counter_{0};
ExecutionStrategy strategy_;
std::unique_ptr<SSAGraphExecutor> underlying_executor_;
std::vector<Scope*> local_scopes_;
std::vector<VariableInfo> var_infos_;
std::vector<platform::Place> places_;
};
} // namespace details
} // namespace framework
} // namespace paddle
...@@ -17,10 +17,6 @@ ...@@ -17,10 +17,6 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace details { namespace details {
SSAGraphExecutor::SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph)
: graph_(std::move(graph)) {}
SSAGraphExecutor::~SSAGraphExecutor() {} SSAGraphExecutor::~SSAGraphExecutor() {}
} // namespace details } // namespace details
......
...@@ -28,15 +28,11 @@ class SSAGraphExecutor { ...@@ -28,15 +28,11 @@ class SSAGraphExecutor {
DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor); DISABLE_COPY_AND_ASSIGN(SSAGraphExecutor);
public: public:
// Steal graph inside SSAGraphExecutor() {}
explicit SSAGraphExecutor(std::unique_ptr<SSAGraph> &&graph);
virtual ~SSAGraphExecutor(); virtual ~SSAGraphExecutor();
virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0; virtual FeedFetchList Run(const std::vector<std::string> &fetch_tensors) = 0;
protected:
std::unique_ptr<SSAGraph> graph_;
}; };
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
...@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor( ...@@ -21,7 +21,7 @@ ThreadedSSAGraphExecutor::ThreadedSSAGraphExecutor(
const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes, const ExecutionStrategy &strategy, const std::vector<Scope *> &local_scopes,
const std::vector<platform::Place> &places, const std::vector<platform::Place> &places,
std::unique_ptr<SSAGraph> &&graph) std::unique_ptr<SSAGraph> &&graph)
: SSAGraphExecutor(std::move(graph)), : graph_(std::move(graph)),
pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_) pool_(strategy.num_threads_ >= 2 ? new ::ThreadPool(strategy.num_threads_)
: nullptr), : nullptr),
local_scopes_(local_scopes), local_scopes_(local_scopes),
...@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp( ...@@ -189,7 +189,9 @@ void ThreadedSSAGraphExecutor::RunOp(
BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) { BlockingQueue<VarHandleBase *> *ready_var_q, details::OpHandleBase *op) {
auto op_run = [ready_var_q, op, this] { auto op_run = [ready_var_q, op, this] {
try { try {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString(); if (VLOG_IS_ON(10)) {
VLOG(10) << op << " " << op->Name() << " : " << op->DebugString();
}
op->Run(strategy_.use_event_); op->Run(strategy_.use_event_);
VLOG(10) << op << " " << op->Name() << " Done "; VLOG(10) << op << " " << op->Name() << " Done ";
running_ops_--; running_ops_--;
......
...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor { ...@@ -51,6 +51,7 @@ class ThreadedSSAGraphExecutor : public SSAGraphExecutor {
details::OpHandleBase *op); details::OpHandleBase *op);
private: private:
std::unique_ptr<SSAGraph> graph_;
std::unique_ptr<::ThreadPool> pool_; std::unique_ptr<::ThreadPool> pool_;
std::vector<Scope *> local_scopes_; std::vector<Scope *> local_scopes_;
std::vector<platform::Place> places_; std::vector<platform::Place> places_;
......
...@@ -23,6 +23,7 @@ limitations under the License. */ ...@@ -23,6 +23,7 @@ limitations under the License. */
#endif #endif
#include "paddle/fluid/framework/details/multi_devices_graph_builder.h" #include "paddle/fluid/framework/details/multi_devices_graph_builder.h"
#include "paddle/fluid/framework/details/scope_buffered_ssa_graph_executor.h"
#include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h" #include "paddle/fluid/framework/details/threaded_ssa_graph_executor.h"
#include "paddle/fluid/platform/profiler.h" #include "paddle/fluid/platform/profiler.h"
...@@ -42,8 +43,6 @@ class ParallelExecutorPrivate { ...@@ -42,8 +43,6 @@ class ParallelExecutorPrivate {
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_; std::unique_ptr<platform::NCCLContextMap> nccl_ctxs_;
#endif #endif
std::vector<std::tuple<std::string, proto::VarType::Type, bool>> var_types_;
bool own_local_scope; bool own_local_scope;
}; };
...@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor( ...@@ -92,9 +91,18 @@ ParallelExecutor::ParallelExecutor(
local_scopes.empty()) { // Is CUDA local_scopes.empty()) { // Is CUDA
BCastParamsToGPUs(bcast_vars); BCastParamsToGPUs(bcast_vars);
} }
// Startup Program has been run. All local scopes has correct parameters. // Startup Program has been run. All local scopes has correct parameters.
// Step 2. Create vars in each scope;
std::vector<details::VariableInfo> var_infos;
for (auto *var : main_program.Block(0).AllVars()) {
var_infos.emplace_back();
var_infos.back().name_ = var->Name();
var_infos.back().type_ = var->GetType();
var_infos.back().persistable_ = var->Persistable();
}
// Step 2. Convert main_program to SSA form and dependency graph. Also, insert // Step 3. Convert main_program to SSA form and dependency graph. Also, insert
// ncclOp // ncclOp
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
details::MultiDevSSAGraphBuilder builder( details::MultiDevSSAGraphBuilder builder(
...@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor( ...@@ -105,16 +113,15 @@ ParallelExecutor::ParallelExecutor(
params, member_->local_scopes_, params, member_->local_scopes_,
build_strategy); build_strategy);
#endif #endif
auto graph = builder.Build(main_program); auto graph = builder.Build(main_program);
member_->executor_.reset(new details::ThreadedSSAGraphExecutor( member_->executor_.reset(new details::ThreadedSSAGraphExecutor(
exec_strategy, member_->local_scopes_, places, std::move(graph))); exec_strategy, member_->local_scopes_, places, std::move(graph)));
// Step 3. Create vars in each scope; member_->executor_.reset(new details::ScopeBufferedSSAGraphExecutor(
for (auto *var : main_program.Block(0).AllVars()) { exec_strategy, member_->local_scopes_, std::move(var_infos),
member_->var_types_.emplace_back(var->Name(), var->GetType(), member_->places_, std::move(member_->executor_)));
var->Persistable());
}
} }
void ParallelExecutor::BCastParamsToGPUs( void ParallelExecutor::BCastParamsToGPUs(
...@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs( ...@@ -169,42 +176,9 @@ void ParallelExecutor::BCastParamsToGPUs(
void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors, void ParallelExecutor::Run(const std::vector<std::string> &fetch_tensors,
const std::string &fetched_var_name) { const std::string &fetched_var_name) {
platform::RecordBlock b(0); platform::RecordBlock b(0);
// Create local scopes.
for (auto it = member_->local_scopes_.rbegin();
it != member_->local_scopes_.rend(); ++it) {
auto &scope = *it;
Scope &local_scope = scope->NewScope();
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>() =
&local_scope;
for (auto &name_type_pair : member_->var_types_) {
if (scope->FindVar(std::get<0>(name_type_pair)) != nullptr) {
continue;
}
if (std::get<2>(name_type_pair)) { // Persistable
InitializeVariable(scope->Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
} else {
InitializeVariable(local_scope.Var(std::get<0>(name_type_pair)),
std::get<1>(name_type_pair));
}
}
}
auto fetch_data = member_->executor_->Run(fetch_tensors); auto fetch_data = member_->executor_->Run(fetch_tensors);
*member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() = *member_->global_scope_->Var(fetched_var_name)->GetMutable<FeedFetchList>() =
fetch_data; fetch_data;
// Wait All computational streams
for (auto p : member_->places_) {
platform::DeviceContextPool::Instance().Get(p)->Wait();
}
for (auto &scope : member_->local_scopes_) {
auto &local_scope =
*scope->Var(details::kLocalExecScopeName)->GetMutable<Scope *>();
scope->DeleteScope(local_scope);
}
} }
void ParallelExecutor::FeedTensorsIntoLocalScopes( void ParallelExecutor::FeedTensorsIntoLocalScopes(
......
...@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() { ...@@ -175,7 +175,6 @@ CUDADeviceContext::~CUDADeviceContext() {
Place CUDADeviceContext::GetPlace() const { return place_; } Place CUDADeviceContext::GetPlace() const { return place_; }
void CUDADeviceContext::Wait() const { void CUDADeviceContext::Wait() const {
std::lock_guard<std::recursive_mutex> guard(mutex_);
PADDLE_ENFORCE(cudaStreamSynchronize(stream_)); PADDLE_ENFORCE(cudaStreamSynchronize(stream_));
PADDLE_ENFORCE(cudaGetLastError()); PADDLE_ENFORCE(cudaGetLastError());
} }
......
...@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -100,7 +100,6 @@ class CUDADeviceContext : public DeviceContext {
template <typename Callback> template <typename Callback>
void RecordEvent(cudaEvent_t ev, Callback callback) { void RecordEvent(cudaEvent_t ev, Callback callback) {
std::lock_guard<std::recursive_mutex> guard(mutex_);
callback(); callback();
PADDLE_ENFORCE(cudaEventRecord(ev, stream_)); PADDLE_ENFORCE(cudaEventRecord(ev, stream_));
} }
...@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext { ...@@ -110,8 +109,6 @@ class CUDADeviceContext : public DeviceContext {
std::unique_ptr<Eigen::GpuDevice> eigen_device_; std::unique_ptr<Eigen::GpuDevice> eigen_device_;
std::unique_ptr<EigenCudaStreamDevice> eigen_stream_; std::unique_ptr<EigenCudaStreamDevice> eigen_stream_;
mutable std::recursive_mutex mutex_;
cudaStream_t stream_; cudaStream_t stream_;
cudnnHandle_t cudnn_handle_; cudnnHandle_t cudnn_handle_;
cublasHandle_t cublas_handle_; cublasHandle_t cublas_handle_;
......
...@@ -45,7 +45,7 @@ extern void *cublas_dso_handle; ...@@ -45,7 +45,7 @@ extern void *cublas_dso_handle;
std::call_once(cublas_dso_flag, []() { \ std::call_once(cublas_dso_flag, []() { \
cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \ cublas_dso_handle = paddle::platform::dynload::GetCublasDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(cublas_dso_handle, #__name); \ static void *p_##__name = dlsym(cublas_dso_handle, #__name); \
return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \ return reinterpret_cast<FUNC_TYPE>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name); ...@@ -39,7 +39,7 @@ extern void EnforceCUDNNLoaded(const char* fn_name);
cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \ cudnn_dso_handle = paddle::platform::dynload::GetCUDNNDsoHandle(); \
}); \ }); \
EnforceCUDNNLoaded(#__name); \ EnforceCUDNNLoaded(#__name); \
void* p_##__name = dlsym(cudnn_dso_handle, #__name); \ static void* p_##__name = dlsym(cudnn_dso_handle, #__name); \
return reinterpret_cast<cudnn_func>(p_##__name)(args...); \ return reinterpret_cast<cudnn_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -45,7 +45,7 @@ extern void *cupti_dso_handle; ...@@ -45,7 +45,7 @@ extern void *cupti_dso_handle;
std::call_once(cupti_dso_flag, []() { \ std::call_once(cupti_dso_flag, []() { \
cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \ cupti_dso_handle = paddle::platform::dynload::GetCUPTIDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(cupti_dso_handle, #__name); \ static void *p_##__name = dlsym(cupti_dso_handle, #__name); \
return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \ return reinterpret_cast<cuptiFunc>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -34,7 +34,7 @@ extern void *curand_dso_handle; ...@@ -34,7 +34,7 @@ extern void *curand_dso_handle;
std::call_once(curand_dso_flag, []() { \ std::call_once(curand_dso_flag, []() { \
curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \ curand_dso_handle = paddle::platform::dynload::GetCurandDsoHandle(); \
}); \ }); \
void *p_##__name = dlsym(curand_dso_handle, #__name); \ static void *p_##__name = dlsym(curand_dso_handle, #__name); \
return reinterpret_cast<curandFunc>(p_##__name)(args...); \ return reinterpret_cast<curandFunc>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -37,7 +37,7 @@ extern void* nccl_dso_handle; ...@@ -37,7 +37,7 @@ extern void* nccl_dso_handle;
std::call_once(nccl_dso_flag, []() { \ std::call_once(nccl_dso_flag, []() { \
nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \ nccl_dso_handle = paddle::platform::dynload::GetNCCLDsoHandle(); \
}); \ }); \
void* p_##__name = dlsym(nccl_dso_handle, #__name); \ static void* p_##__name = dlsym(nccl_dso_handle, #__name); \
return reinterpret_cast<nccl_func>(p_##__name)(args...); \ return reinterpret_cast<nccl_func>(p_##__name)(args...); \
} \ } \
}; \ }; \
......
...@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle; ...@@ -40,7 +40,7 @@ extern void* tensorrt_dso_handle;
paddle::platform::dynload::GetTensorRtDsoHandle(); \ paddle::platform::dynload::GetTensorRtDsoHandle(); \
PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \ PADDLE_ENFORCE(tensorrt_dso_handle, "load tensorrt so failed"); \
}); \ }); \
void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \ static void* p_##__name = dlsym(tensorrt_dso_handle, #__name); \
PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \ PADDLE_ENFORCE(p_##__name, "load %s failed", #__name); \
return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \ return reinterpret_cast<tensorrt_func>(p_##__name)(args...); \
} \ } \
......
...@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle; ...@@ -40,7 +40,7 @@ extern void* warpctc_dso_handle;
std::call_once(warpctc_dso_flag, []() { \ std::call_once(warpctc_dso_flag, []() { \
warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \ warpctc_dso_handle = paddle::platform::dynload::GetWarpCTCDsoHandle(); \
}); \ }); \
void* p_##_name = dlsym(warpctc_dso_handle, #__name); \ static void* p_##_name = dlsym(warpctc_dso_handle, #__name); \
return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \ return reinterpret_cast<warpctcFunc>(p_##_name)(args...); \
} \ } \
}; \ }; \
......
...@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle. ...@@ -519,6 +519,14 @@ All parameter, weight, gradient are variables in Paddle.
[](const ExecutionStrategy &self) { return self.allow_op_delay_; }, [](const ExecutionStrategy &self) { return self.allow_op_delay_; },
[](ExecutionStrategy &self, bool allow_op_delay) { [](ExecutionStrategy &self, bool allow_op_delay) {
self.allow_op_delay_ = allow_op_delay; self.allow_op_delay_ = allow_op_delay;
})
.def_property(
"num_iteration_per_drop_scope",
[](const ExecutionStrategy &self) {
return self.num_iteration_per_drop_scope_;
},
[](ExecutionStrategy &self, size_t num_iteration_per_drop_scope) {
self.num_iteration_per_drop_scope_ = num_iteration_per_drop_scope;
}); });
py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy"); py::class_<BuildStrategy> build_strategy(pe, "BuildStrategy");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册