提交 bebb9870 编写于 作者: T typhoonzero

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into fix_mac_whl_packaging

...@@ -31,7 +31,7 @@ script: ...@@ -31,7 +31,7 @@ script:
if [[ "$JOB" != "doc" ]]; then exit 0; fi; if [[ "$JOB" != "doc" ]]; then exit 0; fi;
# For document only # For document only
if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi; if [[ "$TRAVIS_PULL_REQUEST" != "false" ]]; then exit 0; fi;
if [[ "$TRAVIS_BRANCH" != "develop" && ! "$TRAVIS_BRANCH" =~ ^v[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then exit 0; fi; if [[ "$TRAVIS_BRANCH" != "develop" && ! "$TRAVIS_BRANCH" =~ ^v|release/[[:digit:]]+\.[[:digit:]]+(\.[[:digit:]]+)?(-\S*)?$ ]]; then exit 0; fi;
export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/master/scripts/deploy/deploy_docs.sh export DEPLOY_DOCS_SH=https://raw.githubusercontent.com/PaddlePaddle/PaddlePaddle.org/master/scripts/deploy/deploy_docs.sh
export DOCS_DIR=`pwd` export DOCS_DIR=`pwd`
cd .. cd ..
......
# Design Doc: Distributed Lookup Table Operator # Design Doc: Distributed Lookup Table Operator
A lookup table operator in PaddlePaddle where the table could be out A distribute lookup table operator in PaddlePaddle where the table could be out
of the memory of a computer. of the memory of a computer.
## Background ## Background
...@@ -24,14 +24,14 @@ memory, so we'd need a distributed storage service, which supports the ...@@ -24,14 +24,14 @@ memory, so we'd need a distributed storage service, which supports the
lookup of rows. lookup of rows.
The following figure illustrates the multiplication of x with two The following figure illustrates the multiplication of x with two
non-zero elements, or say, two symbols, and a lookup table W: non-zero elements, or say two symbols, and a lookup table W:
![lookup table](./src/lookup_table.png) ![lookup table](./src/lookup_table.png)
### The Backward Algorithm ### The Backward Algorithm
The backward algorithm computes W'(x) using W(x). W'(x) has the same The backward algorithm computes W'(x) using W(x). W'(x) has the same
scale of size as W(x) and is much smaller than W. the scale of size as W(x) and is much smaller than W.
To optimize W given W', we can do simple SGD update: To optimize W given W', we can do simple SGD update:
...@@ -44,111 +44,46 @@ $$W = f(W, W')$$ ...@@ -44,111 +44,46 @@ $$W = f(W, W')$$
The following figure illustrates the backward pass of the lookup The following figure illustrates the backward pass of the lookup
operator: ![lookup table training](./src/lookup_table_training.png) operator: ![lookup table training](./src/lookup_table_training.png)
## Distributed Storage Service ## Distributed Lookup Table
### Problem 1: The lookup table may be very large.
The forward algorithm requires a distributed storage service for W.
The backward algorithm prefers that the storage system can apply the In the condition like the search engine and recommendation system, the number of feature Id may be very large, say 100,000,000,000, then for a float value lookup table of size 8, the total size of the table is:
optimization algorithm on W. The following two sections describe two
solutions -- the former doesn't require that the storage service can ```
do optimization, the latter does. 100,000,000,000 * 8 * 4(Bytes) = 2980.23 GB
```
### Storage Service Doesn't Optimize
### Solution: Distributed storage
In this design, we use highly-optimized distributed storage, e.g.,
memcached, as the storage service, and we run the optimization 1. Paddle use [SelectedRows](https://github.com/PaddlePaddle/Paddle/blob/develop/doc/fluid/design/modules/selected_rows.md) as the storage format for the lookup table, the lookup table parameter will be split to multi-machine according to the hash of the feature ID, and data will also be split and send to the same machine to prefetch the parameter.
algorithm on parameter servers of PaddlePaddle. The following figure
illustrates the training process. 1. For common parameters, the trainer will get the whole parameter for training, but for the big lookup table, the trainer can not store the whole parameter. Because the input data feature is very sparse, every time we only need a few parameters for training, so we use `prefetch_op` to only prefetch the parameter needed to trainer.
<!-- ### Problem 2. The Id in the lookup table is not sure before training.
Note: please update the following URL when update this digraph.
<img src='https://g.gravizo.com/svg? The feature Id is calculated by the hash function because the feature data source is so large, we can not get all the Id before training. So we can not initialize the table before training.
digraph G {
rankdir="LR"; ### Solution: Id auto growth
subgraph cluster1 {
P1 [label="pserver 1"]; At the beginning of training, paddle only malloc the memory for the lookup table at parameter server side, the Id and it's value will not be initialized. During training, when a parameter server received an Id, if it is already in the lookup table, it will return the existing parameter, if the Id does not exist, paddle will add it into the lookup table and initialize the value for it.
P2 [label="pserver 2"];
T1 [label="trainer 1"]; ### Problem 3: parameter load and save
T2 [label="trainer 2"];
T3 [label="trainer 3"]; For common parameters, paddle use trainer to save and load them. But for distributed lookup table, trainer cannot do this because it's large size.
}
KV [label="memcached"]; ### Solution: Parameter server side save and load
T1 -> P1;
T1 -> P2; Paddle support parameter server side save and load for distribute lookup table. Each machine of parameter servers will only save and load part of the whole table.
T2 -> P1;
T2 -> P2; ## Architecture
T3 -> P1; The whole architecture of the distribute lookup table is as below:
T3 -> P2;
P1 -> KV [color=gray, weight=0.1]; ### Training steps:
KV -> P1 [color=gray, weight=0.1]; 1. Read a batch of data, the data is feature ids.
P2 -> KV [color=gray, weight=0.1]; 1. The input ids will be split by `split_ids_op` with the same hash function of the lookup table.
KV -> P2 [color=gray, weight=0.1]; 1. The `prefetch_op` use the split result to prefetch parameters back from the lookup table.
KV -> T1 [color=gray, weight=0.1]; 1. Run forward-backward to get the gradient of the lookup table.
KV -> T2 [color=gray, weight=0.1]; 1. `split_ids_op` split the gradient and then use `send_op` to the parameter server.
KV -> T3 [color=gray, weight=0.1]; 1. parameter server update the table with the received gradient.
}
) ![distribute lookup table](./src/distributed_lookup_table.jpeg)
'/>
-->
<img src='https://g.gravizo.com/svg?%20digraph%20G%20{%20rankdir=%22LR%22;%20subgraph%20cluster1%20{%20P1%20[label=%22pserver%201%22];%20P2%20[label=%22pserver%202%22];%20T1%20[label=%22trainer%201%22];%20T2%20[label=%22trainer%202%22];%20T3%20[label=%22trainer%203%22];%20}%20KV%20[label=%22memcached%22];%20T1%20-%3E%20P1;%20T1%20-%3E%20P2;%20T2%20-%3E%20P1;%20T2%20-%3E%20P2;%20T3%20-%3E%20P1;%20T3%20-%3E%20P2;%20P1%20-%3E%20KV%20[color=gray,%20weight=0.1];%20KV%20-%3E%20P1%20[color=gray,%20weight=0.1];%20P2%20-%3E%20KV%20[color=gray,%20weight=0.1];%20KV%20-%3E%20P2%20[color=gray,%20weight=0.1];%20KV%20-%3E%20T1%20[color=gray,%20weight=0.1];%20KV%20-%3E%20T2%20[color=gray,%20weight=0.1];%20KV%20-%3E%20T3%20[color=gray,%20weight=0.1];%20}'/>
Each trainer runs the forward and backward passes using their local
data:
1. In the forward pass, when a trainer runs the forward algorithm of a
lookup operator, it retrieves W(x) from the storage service.
1. The trainer computes W'(x) in the backward pass using W(x).
During the global update process:
1. Each trainer uploads its W'(x) to parameter servers.
1. The parameter server runs the optimization algorithm, e.g., the
Adam optimization algorithm, which requires that
1. The parameter server retrieves W(x) from memcached, and
1. The parameter server pushes $\Delta W(x)=f(W(x), lambda \sum_j
W'(x))$ to memcached, where $f$ denotes the optimization
algorithm.
### Storage Service Does Optimize
This design is very similar to the above one, except that the
optimization algorithm $f$ runs on the storage service.
- Pro: parameter servers do not retrieve W(x) from the storage
service, thus saves half network communication.
- Con: the storage service needs to be able to run the optimization
algorithm.
## Distributed Sparse Table in Fluid
For another design, we can implement a distributed sparse table in Fluid,
and don't need to maintain an external storage component while training.
You may need to read Fluid [Distributed Training Architecture](./distributed_architecture.md)
and [Parameter Server](./parameter_server.md) before going on.
![fluid lookup remote table](./src/fluid_lookup_remote_table.png)
Partition a large table into multiple pserver instances
1. `DistributeTranspiler` would split the table partitioned into some small
table blocks with some partitioned algorithms such as
[RoundRobin](https://en.wikipedia.org/wiki/Round-robin_scheduling),
[Hash](https://en.wikipedia.org/wiki/Hash) and etc...
1. For some cases, the range of input `Ids` is very wide and unpredictable, so the sparse
table would be able to fill a new value for the id that didn't appear before with
zero, uniform random or Gaussian distribution.
For each Trainer's training process:
1. In the forward pass, we use `pre-fetch` op to pre-fetch parameter blocks according to the
input `Ids` from PServers instead of the local `lookup_table` op, and then merge the blocks
into a parameter `W`.
1. Compute `GRAD@W'` in the backward pass using the pre-fetched `W` and send it to PServer to
execute the optimize pass.
## Conclusion
Let us do the "storage service does not optimize" solution first, as a
baseline at least, because it is easier to use a well-optimized
distributed storage service like memcached. We can do the "storage
service does optimize" solution later or at the same time, which, if
implemented carefully, should have better performance than the former.
...@@ -46,9 +46,14 @@ cc_library(paddle_inference_api ...@@ -46,9 +46,14 @@ cc_library(paddle_inference_api
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc SRCS paddle_inference_api.cc paddle_inference_api_impl.cc
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB})
# Here the shared library doesn't depend on other fluid libraries, or double free will occur.
cc_library(paddle_inference_api_shared SHARED cc_library(paddle_inference_api_shared SHARED
SRCS paddle_inference_api.cc paddle_inference_api_impl.cc SRCS paddle_inference_api.cc paddle_inference_api_impl.cc)
DEPS ${FLUID_CORE_MODULES} ${GLOB_OP_LIB}) set_target_properties(paddle_inference_api_shared PROPERTIES OUTPUT_NAME paddle_inference_api)
if(NOT APPLE)
set(LINK_FLAGS "-fPIC -fvisibility=hidden")
set_target_properties(paddle_inference_api_shared PROPERTIES LINK_FLAGS "${LINK_FLAGS}")
endif()
cc_test(test_paddle_inference_api cc_test(test_paddle_inference_api
SRCS test_paddle_inference_api.cc SRCS test_paddle_inference_api.cc
......
...@@ -23,7 +23,6 @@ int PaddleDtypeSize(PaddleDType dtype) { ...@@ -23,7 +23,6 @@ int PaddleDtypeSize(PaddleDType dtype) {
case PaddleDType::INT64: case PaddleDType::INT64:
return sizeof(int64_t); return sizeof(int64_t);
default: default:
//
assert(false); assert(false);
return -1; return -1;
} }
......
...@@ -34,7 +34,7 @@ struct BuildStrategy { ...@@ -34,7 +34,7 @@ struct BuildStrategy {
std::string debug_graphviz_path_{""}; std::string debug_graphviz_path_{""};
bool enable_data_balance_{true}; bool enable_data_balance_{false};
}; };
} // namespace details } // namespace details
......
...@@ -86,9 +86,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan( ...@@ -86,9 +86,9 @@ std::vector<std::array<int, 3>> DataBalanceOpHandle::GetBalancePlan(
} }
void DataBalanceOpHandle::RunImpl() { void DataBalanceOpHandle::RunImpl() {
if (places_.size() == 1) { PADDLE_ENFORCE_GT(places_.size(), 1,
return; "Data balance can only be enabled when the number of "
} "places to run larger than 1.");
auto in_var_handles = DynamicCast<VarHandle>(inputs_); auto in_var_handles = DynamicCast<VarHandle>(inputs_);
auto out_var_handles = DynamicCast<VarHandle>(outputs_); auto out_var_handles = DynamicCast<VarHandle>(outputs_);
PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0); PADDLE_ENFORCE(in_var_handles.size() % places_.size() == 0);
......
...@@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder( ...@@ -59,6 +59,11 @@ MultiDevSSAGraphBuilder::MultiDevSSAGraphBuilder(
grad_names_.insert(GradVarName(p)); grad_names_.insert(GradVarName(p));
} }
balance_vars_.resize(places_.size(), 0); balance_vars_.resize(places_.size(), 0);
if (strategy_.enable_data_balance_ && places_.size() == 1) {
LOG(WARNING) << "It is no need to enable data balance when there is only "
"one place. enable_data_balance is set to False.";
strategy_.enable_data_balance_ = false;
}
} }
void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result, void MultiDevSSAGraphBuilder::CreateOpHandleIOs(SSAGraph *result,
......
...@@ -182,21 +182,15 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I, ...@@ -182,21 +182,15 @@ struct OpKernelRegistrarFunctorEx<PlaceType, false, I,
VarTypeInference VarTypeInference
InferShapeBase InferShapeBase
*/ */
#define REGISTER_OPERATOR(op_type, op_class, ...) \ #define REGISTER_OPERATOR(op_type, op_class, ...) \
STATIC_ASSERT_GLOBAL_NAMESPACE( \ STATIC_ASSERT_GLOBAL_NAMESPACE( \
__reg_op__##op_type, \ __reg_op__##op_type, \
"REGISTER_OPERATOR must be called in global namespace"); \ "REGISTER_OPERATOR must be called in global namespace"); \
class _OpClass_##op_type##_ : public op_class { \ static ::paddle::framework::OperatorRegistrar<op_class, ##__VA_ARGS__> \
public: \ __op_registrar_##op_type##__(#op_type); \
DEFINE_OP_CLONE_METHOD(_OpClass_##op_type##_); \ int TouchOpRegistrar_##op_type() { \
DEFINE_OP_CONSTRUCTOR(_OpClass_##op_type##_, op_class); \ __op_registrar_##op_type##__.Touch(); \
}; \ return 0; \
static ::paddle::framework::OperatorRegistrar<_OpClass_##op_type##_, \
##__VA_ARGS__> \
__op_registrar_##op_type##__(#op_type); \
int TouchOpRegistrar_##op_type() { \
__op_registrar_##op_type##__.Touch(); \
return 0; \
} }
#define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \ #define REGISTER_OP_WITHOUT_GRADIENT(op_type, op_class, op_maker_class) \
......
...@@ -193,15 +193,10 @@ TEST(OpRegistry, CustomChecker) { ...@@ -193,15 +193,10 @@ TEST(OpRegistry, CustomChecker) {
ASSERT_EQ(test_attr, 4); ASSERT_EQ(test_attr, 4);
} }
class CosineOpComplete : public paddle::framework::CosineOp {
public:
DEFINE_OP_CONSTRUCTOR(CosineOpComplete, paddle::framework::CosineOp);
DEFINE_OP_CLONE_METHOD(CosineOpComplete);
};
TEST(OperatorRegistrar, Test) { TEST(OperatorRegistrar, Test) {
paddle::framework::OperatorRegistrar< paddle::framework::OperatorRegistrar<
CosineOpComplete, paddle::framework::CosineOpProtoAndCheckerMaker> paddle::framework::CosineOp,
paddle::framework::CosineOpProtoAndCheckerMaker>
reg("cos"); reg("cos");
} }
......
...@@ -633,6 +633,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope, ...@@ -633,6 +633,16 @@ void OperatorWithKernel::RunImpl(const Scope& scope,
VLOG(3) << "expected_kernel_key:" << expected_kernel_key; VLOG(3) << "expected_kernel_key:" << expected_kernel_key;
auto kernel_iter = kernels.find(expected_kernel_key); auto kernel_iter = kernels.find(expected_kernel_key);
#ifdef PADDLE_WITH_MKLDNN
// workaround for missing MKLDNN kernel when FLAGS_use_mkldnn env var is set
if (kernel_iter == kernels.end() &&
expected_kernel_key.library_type_ == LibraryType::kMKLDNN) {
VLOG(3) << "missing MKLDNN kernel: fallbacking to PLAIN one";
expected_kernel_key.library_type_ = LibraryType::kPlain;
expected_kernel_key.data_layout_ = DataLayout::kAnyLayout;
kernel_iter = kernels.find(expected_kernel_key);
}
#endif
if (kernel_iter == kernels.end()) { if (kernel_iter == kernels.end()) {
PADDLE_THROW("op %s does not have kernel for %s", type_, PADDLE_THROW("op %s does not have kernel for %s", type_,
KernelTypeToString(expected_kernel_key)); KernelTypeToString(expected_kernel_key));
......
...@@ -121,10 +121,6 @@ class OperatorBase { ...@@ -121,10 +121,6 @@ class OperatorBase {
//! Get all outputs variable names //! Get all outputs variable names
virtual std::vector<std::string> OutputVars(bool has_intermediate) const; virtual std::vector<std::string> OutputVars(bool has_intermediate) const;
// Return a new operator instance, which is as same as this.
// Use unique_ptr to prevent caller forget to delete this pointer.
virtual std::unique_ptr<OperatorBase> Clone() const = 0;
protected: protected:
std::string type_; std::string type_;
// NOTE: in case of OpGrad, inputs_ contains: // NOTE: in case of OpGrad, inputs_ contains:
...@@ -145,37 +141,6 @@ class OperatorBase { ...@@ -145,37 +141,6 @@ class OperatorBase {
const platform::Place& place) const = 0; const platform::Place& place) const = 0;
}; };
// Macro for define a clone method.
// If you are writing an kernel operator, `Clone` will be defined when you
// register it. i.e. `Clone` method is not needed to define by yourself.
#define DEFINE_OP_CLONE_METHOD(cls) \
std::unique_ptr<::paddle::framework::OperatorBase> Clone() const final { \
return std::unique_ptr<::paddle::framework::OperatorBase>(new cls(*this)); \
}
// Macro for define a default constructor for Operator.
// You can also use
// using PARENT_CLASS::PARENT_CLASS;
// to use parent's constructor.
#define DEFINE_OP_CONSTRUCTOR(cls, parent_cls) \
cls(const std::string& type, \
const ::paddle::framework::VariableNameMap& inputs, \
const ::paddle::framework::VariableNameMap& outputs, \
const paddle::framework::AttributeMap& attrs) \
: parent_cls(type, inputs, outputs, attrs) {}
class NOP : public OperatorBase {
public:
using OperatorBase::OperatorBase;
std::unique_ptr<OperatorBase> Clone() const override {
return std::unique_ptr<OperatorBase>(new NOP(*this));
}
private:
void RunImpl(const Scope& scope,
const platform::Place& place) const override {}
};
class ExecutionContext { class ExecutionContext {
public: public:
ExecutionContext(const OperatorBase& op, const Scope& scope, ExecutionContext(const OperatorBase& op, const Scope& scope,
......
...@@ -247,26 +247,3 @@ TEST(OpKernel, multi_inputs) { ...@@ -247,26 +247,3 @@ TEST(OpKernel, multi_inputs) {
auto op = paddle::framework::OpRegistry::CreateOp(op_desc); auto op = paddle::framework::OpRegistry::CreateOp(op_desc);
op->Run(scope, cpu_place); op->Run(scope, cpu_place);
} }
class OperatorClone : public paddle::framework::OperatorBase {
public:
DEFINE_OP_CLONE_METHOD(OperatorClone);
OperatorClone(const std::string& type,
const paddle::framework::VariableNameMap& inputs,
const paddle::framework::VariableNameMap& outputs,
const paddle::framework::AttributeMap& attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const paddle::framework::Scope& scope,
const paddle::platform::Place& place) const override {}
};
TEST(Operator, Clone) {
paddle::framework::InitDevices(true);
OperatorClone a("ABC", paddle::framework::VariableNameMap{},
paddle::framework::VariableNameMap{},
paddle::framework::AttributeMap{});
auto b = a.Clone();
ASSERT_EQ(a.Type(), b->Type());
}
...@@ -22,6 +22,17 @@ limitations under the License. */ ...@@ -22,6 +22,17 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace framework { namespace framework {
class NOP : public OperatorBase {
public:
NOP(const std::string &type, const VariableNameMap &inputs,
const VariableNameMap &outputs, const AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const Scope &scope,
const platform::Place &place) const override {}
};
class SumOpMaker : public OpProtoAndCheckerMaker { class SumOpMaker : public OpProtoAndCheckerMaker {
public: public:
void Make() { void Make() {
......
...@@ -85,7 +85,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> { ...@@ -85,7 +85,7 @@ class EltwiseAddMKLDNNKernel : public framework::OpKernel<T> {
"Wrong layout/format set for X tensor"); "Wrong layout/format set for X tensor");
PADDLE_ENFORCE(y->layout() == DataLayout::kMKLDNN && PADDLE_ENFORCE(y->layout() == DataLayout::kMKLDNN &&
y->format() != memory::format::format_undef, y->format() != memory::format::format_undef,
"Wrong layout/format set for X tensor"); "Wrong layout/format set for Y tensor");
std::vector<int> src_x_tz = framework::vectorize2int(x_dims); std::vector<int> src_x_tz = framework::vectorize2int(x_dims);
std::vector<int> src_y_tz = framework::vectorize2int(y_dims); std::vector<int> src_y_tz = framework::vectorize2int(y_dims);
......
...@@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -92,9 +92,13 @@ class ReadOpMaker : public framework::OpProtoAndCheckerMaker {
void Make() override { void Make() override {
AddInput("Reader", "(ReaderHolder) The executed reader."); AddInput("Reader", "(ReaderHolder) The executed reader.");
AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable(); AddOutput("Out", "(LoDTensor) The output data.").AsDuplicable();
AddAttr<bool>("throw_eof_exp", AddAttr<bool>(
"If set true, an exception will be thrown when the Reader " "throw_eof_exp",
"yields empty (which means there is no next data).") "If set true, an exception will be thrown when the Reader "
"yields empty (which means there is no next data).\n"
"NOTES: This flag must be true always. It will be set to false"
" only when the data-balance is enabled in ParallelExecutor"
" and it is set by ParallelExecutor instance, not users.")
.SetDefault(true); .SetDefault(true);
AddComment(R"DOC( AddComment(R"DOC(
Read Operator Read Operator
......
...@@ -532,6 +532,7 @@ void TrainerThread::computeThread() { ...@@ -532,6 +532,7 @@ void TrainerThread::computeThread() {
break; break;
} }
} }
hl_fini();
} }
void TrainerThread::prefetch() { void TrainerThread::prefetch() {
...@@ -651,6 +652,7 @@ void TrainerThread::copyGradToBufferThread() { ...@@ -651,6 +652,7 @@ void TrainerThread::copyGradToBufferThread() {
} }
partnerThread->notifyGradientCollect(pid); partnerThread->notifyGradientCollect(pid);
} }
hl_fini();
} }
void TrainerThread::gradCollectThread() { void TrainerThread::gradCollectThread() {
...@@ -693,6 +695,7 @@ void TrainerThread::gradCollectThread() { ...@@ -693,6 +695,7 @@ void TrainerThread::gradCollectThread() {
notifyCopyGradToBuffer(pid); notifyCopyGradToBuffer(pid);
} }
} }
hl_fini();
} }
void TrainerThread::doCallback(int pid) { void TrainerThread::doCallback(int pid) {
...@@ -741,6 +744,7 @@ void TrainerThread::valueDispatchThread() { ...@@ -741,6 +744,7 @@ void TrainerThread::valueDispatchThread() {
thread->notifyValueReady(pid); thread->notifyValueReady(pid);
} }
hl_fini();
} }
void TrainerThread::notifyValueReady(int paramId) { void TrainerThread::notifyValueReady(int paramId) {
......
...@@ -197,6 +197,7 @@ void ParallelThread::computeThread() { ...@@ -197,6 +197,7 @@ void ParallelThread::computeThread() {
job_work.layer_->markAllInputGrad(); job_work.layer_->markAllInputGrad();
} }
} }
hl_fini();
} }
void ParallelThread::start() { void ParallelThread::start() {
......
...@@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase): ...@@ -103,8 +103,12 @@ class TestDataBalance(unittest.TestCase):
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_data_balance = True
parallel_exe = fluid.ParallelExecutor( parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, main_program=main_prog) use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size): if (parallel_exe.device_count > self.batch_size):
print("WARNING: Unittest TestDataBalance skipped. \ print("WARNING: Unittest TestDataBalance skipped. \
...@@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase): ...@@ -145,9 +149,12 @@ class TestDataBalance(unittest.TestCase):
place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace() place = fluid.CUDAPlace(0) if self.use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
build_strategy = fluid.BuildStrategy()
build_strategy.enable_data_balance = True
parallel_exe = fluid.ParallelExecutor( parallel_exe = fluid.ParallelExecutor(
use_cuda=self.use_cuda, main_program=main_prog) use_cuda=self.use_cuda,
main_program=main_prog,
build_strategy=build_strategy)
if (parallel_exe.device_count > self.batch_size): if (parallel_exe.device_count > self.batch_size):
print("WARNING: Unittest TestDataBalance skipped. \ print("WARNING: Unittest TestDataBalance skipped. \
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册