提交 f8ed2c22 编写于 作者: S sneaxiy

try to fix ci error

test=develop
上级 072d95d8
...@@ -193,8 +193,15 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx, ...@@ -193,8 +193,15 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return shrink_func(computation_op); return shrink_func(computation_op);
} }
static bool CanPrecede(const std::string &var_name, /**
std::unordered_set<ComputationOpHandle *> *op_handles) { * Shrink op dependencies. If some ops do not Tensor buffer of any input,
* just remove the dependency of this op, i.e, decrease reference count.
*
* Returns whether the dependency count decreases to 0.
*/
static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops; std::vector<ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) { for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp(); auto *op_base = op_handle->GetOp();
...@@ -303,8 +310,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl( ...@@ -303,8 +310,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name; VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
size_t original_op_deps = result.size(); size_t original_op_deps = result.size();
// If reference count can be calculated precedingly, just precede // If all ops do not need buffer of var_name, calculate reference count
if (CanPrecede(var_name, &result)) { // of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
VLOG(10) << "Try to precede reference count computing at var " VLOG(10) << "Try to precede reference count computing at var "
<< var_name; << var_name;
continue; continue;
......
...@@ -64,7 +64,9 @@ struct OpInOutInfo { ...@@ -64,7 +64,9 @@ struct OpInOutInfo {
} }
private: private:
// A set to record unused buffer input vars of op
std::unordered_set<std::string> no_need_buffer_ins_; std::unordered_set<std::string> no_need_buffer_ins_;
// A set to record other args of op (including in, out)
std::unordered_set<std::string> other_args_set_; std::unordered_set<std::string> other_args_set_;
bool is_built_{false}; bool is_built_{false};
}; };
...@@ -91,6 +93,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars( ...@@ -91,6 +93,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
const BlockDesc &block, const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops, const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_var_list) { const std::vector<std::string> &skip_var_list) {
UseGarbageCollectorGFlags();
std::unordered_set<std::string> skip_vars(skip_var_list.begin(), std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end()); skip_var_list.end());
...@@ -112,6 +115,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars( ...@@ -112,6 +115,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
} }
if (info.IsInArgBufferNeeded(name)) { if (info.IsInArgBufferNeeded(name)) {
// Update the last living op of variable to current op
var_op_idx_map[name] = i; var_op_idx_map[name] = i;
} else { } else {
VLOG(10) << "Skip reference count computing of variable " VLOG(10) << "Skip reference count computing of variable "
...@@ -124,6 +128,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars( ...@@ -124,6 +128,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
for (auto &name_pair : op->Outputs()) { for (auto &name_pair : op->Outputs()) {
for (auto &name : name_pair.second) { for (auto &name : name_pair.second) {
if (VarCanBeDeleted(name, block, skip_vars)) { if (VarCanBeDeleted(name, block, skip_vars)) {
// Update the last living op of variable to current op
var_op_idx_map[name] = i; var_op_idx_map[name] = i;
} }
} }
......
...@@ -25,11 +25,13 @@ ...@@ -25,11 +25,13 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
// Result map: op -> variable names that can be deleted after op runs
std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars( std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
const BlockDesc &block, const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops, const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_vars); const std::vector<std::string> &skip_vars);
// Collect unused tensors after op runs
void DeleteUnusedTensors( void DeleteUnusedTensors(
const Scope &scope, OperatorBase *op, const Scope &scope, OperatorBase *op,
const std::unordered_map<OperatorBase *, std::vector<std::string>> const std::unordered_map<OperatorBase *, std::vector<std::string>>
......
...@@ -13,6 +13,11 @@ ...@@ -13,6 +13,11 @@
// limitations under the License. // limitations under the License.
#include <algorithm> #include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h" #include "paddle/fluid/platform/cuda_device_guard.h"
#endif #endif
...@@ -21,6 +26,15 @@ ...@@ -21,6 +26,15 @@
namespace paddle { namespace paddle {
namespace framework { namespace framework {
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
GarbageCollector::GarbageCollector(const platform::Place &place, GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size) size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) { : max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
...@@ -85,5 +99,16 @@ void StreamGarbageCollector::ClearCallback( ...@@ -85,5 +99,16 @@ void StreamGarbageCollector::ClearCallback(
callback_manager_->AddCallback(callback); callback_manager_->AddCallback(callback);
} }
#endif #endif
void UseGarbageCollectorGFlags() {}
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -18,6 +18,7 @@ ...@@ -18,6 +18,7 @@
#include <functional> #include <functional>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <utility>
#include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/device_context.h"
namespace paddle { namespace paddle {
...@@ -126,5 +127,10 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) { ...@@ -126,5 +127,10 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) {
} }
} }
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
extern void UseGarbageCollectorGFlags();
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
...@@ -29,15 +29,6 @@ DEFINE_bool( ...@@ -29,15 +29,6 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but " "Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)"); "slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
// When in inference scenario, the scopes will not be written by two threads in // When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and // a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue. // the mutex will cause serious performance issue.
...@@ -57,15 +48,6 @@ DEFINE_bool(fast_eager_deletion_mode, true, ...@@ -57,15 +48,6 @@ DEFINE_bool(fast_eager_deletion_mode, true,
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); } Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const { Scope& Scope::NewScope() const {
......
...@@ -32,9 +32,6 @@ extern "C" { ...@@ -32,9 +32,6 @@ extern "C" {
namespace paddle { namespace paddle {
namespace framework { namespace framework {
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
class Scope; class Scope;
/** /**
......
...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/concat_op.h" #include "paddle/fluid/operators/concat_op.h"
#include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
...@@ -120,11 +121,7 @@ Examples: ...@@ -120,11 +121,7 @@ Examples:
class ConcatOpGrad : public framework::OperatorWithKernel { class ConcatOpGrad : public framework::OperatorWithKernel {
public: public:
ConcatOpGrad(const std::string &type, using framework::OperatorWithKernel::OperatorWithKernel;
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override { void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X"; auto in_x = "X";
...@@ -142,8 +139,19 @@ class ConcatOpGrad : public framework::OperatorWithKernel { ...@@ -142,8 +139,19 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
} }
} }
} }
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
}; };
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
"X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -151,7 +159,8 @@ namespace ops = paddle::operators; ...@@ -151,7 +159,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker, REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
paddle::framework::DefaultGradOpDescMaker< paddle::framework::DefaultGradOpDescMaker<
false> /* set false to disable empty grad */); false> /* set false to disable empty grad */);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad); REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatOpGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>, concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>, ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
......
...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/crop_op.h" #include "paddle/fluid/operators/crop_op.h"
#include <boost/lexical_cast.hpp> #include <memory>
#include <string>
#include <vector>
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -178,12 +180,31 @@ class CropOpGrad : public framework::OperatorWithKernel { ...@@ -178,12 +180,31 @@ class CropOpGrad : public framework::OperatorWithKernel {
} }
}; };
class CropGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("crop_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("Offsets") > 0) {
op->SetInput("Offsets", Input("Offsets"));
}
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker, REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::CropGradOpDescMaker);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad); REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>); crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>);
......
...@@ -14,6 +14,7 @@ ...@@ -14,6 +14,7 @@
#include <set> #include <set>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/operators/distributed/parameter_prefetch.h" #include "paddle/fluid/operators/distributed/parameter_prefetch.h"
...@@ -218,7 +219,7 @@ void prefetch(const std::string& id_name, const std::string& out_name, ...@@ -218,7 +219,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
boost::get<platform::CUDAPlace>(id_tensor.place()), boost::get<platform::CUDAPlace>(id_tensor.place()),
id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(), id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(),
stream); stream);
for (size_t i = 0; i < cpu_tensor.numel(); ++i) { for (int64_t i = 0; i < cpu_tensor.numel(); ++i) {
ids_vector.push_back(cpu_tensor_data[i]); ids_vector.push_back(cpu_tensor_data[i]);
} }
#endif #endif
......
...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and ...@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_op.h" #include "paddle/fluid/operators/gather_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h" #include "paddle/fluid/framework/ddim.h"
namespace paddle { namespace paddle {
...@@ -59,8 +62,9 @@ class GatherGradOp : public framework::OperatorWithKernel { ...@@ -59,8 +62,9 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(), return framework::OpKernelType(
ctx.device_context()); ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
} }
}; };
...@@ -94,13 +98,34 @@ Out = [[3, 4], ...@@ -94,13 +98,34 @@ Out = [[3, 4],
)DOC"); )DOC");
} }
}; };
class GatherGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("gather_grad");
op->SetInput("Index", Input("Index"));
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GatherGradNoNeedBufferVarInference, "X");
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker, REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); ops::GatherGradOpDescMaker);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp); REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>, REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>, ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>, ops::GatherOpKernel<uint8_t>,
......
...@@ -21,6 +21,7 @@ ...@@ -21,6 +21,7 @@
#include <cstdlib> #include <cstdlib>
#include <fstream> #include <fstream>
#include <iostream> #include <iostream>
#include <memory>
#include <sstream> #include <sstream>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
...@@ -152,7 +153,7 @@ class CTRReader : public framework::FileReader { ...@@ -152,7 +153,7 @@ class CTRReader : public framework::FileReader {
queue_->ReOpen(); queue_->ReOpen();
VLOG(3) << "reopen success"; VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_; VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) { for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(std::bind( read_threads_.emplace_back(new std::thread(std::bind(
&ReadThread, file_groups_[thread_id], data_desc_, &ReadThread, file_groups_[thread_id], data_desc_,
static_cast<int>(thread_id), &read_thread_status_, queue_))); static_cast<int>(thread_id), &read_thread_status_, queue_)));
......
...@@ -24,6 +24,7 @@ limitations under the License. */ ...@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h" #include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h" #include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h" #include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/pass_builder.h" #include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h" #include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h" #include "paddle/fluid/framework/lod_tensor.h"
...@@ -133,6 +134,8 @@ PYBIND11_MODULE(core, m) { ...@@ -133,6 +134,8 @@ PYBIND11_MODULE(core, m) {
paddle::platform::CpuTotalPhysicalMemory(); paddle::platform::CpuTotalPhysicalMemory();
paddle::memory::allocation::UseAllocatorStrategyGFlag(); paddle::memory::allocation::UseAllocatorStrategyGFlag();
paddle::framework::UseGarbageCollectorGFlags();
m.doc() = "C++ core of PaddlePaddle"; m.doc() = "C++ core of PaddlePaddle";
// using framework in this function. Since it is inside a function, it will // using framework in this function. Since it is inside a function, it will
......
...@@ -15,6 +15,9 @@ ...@@ -15,6 +15,9 @@
import os import os
import numpy as np import numpy as np
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0' os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
os.environ['FLAGS_fast_eager_deletion_mode'] = '1'
os.environ['FLAGS_use_ngraph'] = '0'
os.environ['FLAGS_use_mkldnn'] = '0'
os.environ['CPU_NUM'] = '4' os.environ['CPU_NUM'] = '4'
import paddle.fluid as fluid import paddle.fluid as fluid
...@@ -58,18 +61,24 @@ def get_persistables_and_non_persistables(prog, fetch_list): ...@@ -58,18 +61,24 @@ def get_persistables_and_non_persistables(prog, fetch_list):
class TestExecutor(unittest.TestCase): class TestExecutor(unittest.TestCase):
def setUp(self):
self.place = fluid.CPUPlace()
def test_executor_main(self): def test_executor_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()): places = [fluid.CPUPlace()]
with fluid.scope_guard(fluid.Scope()): if fluid.core.is_compiled_with_cuda():
self.executor_main() places.append(fluid.CUDAPlace(0))
def test_parallel_executor_main(self): for p in places:
with fluid.program_guard(fluid.Program(), fluid.Program()): self.place = p
with fluid.scope_guard(fluid.Scope()): with fluid.program_guard(fluid.Program(), fluid.Program()):
self.pe_main() with fluid.scope_guard(fluid.Scope()):
with fluid.unique_name.guard():
self.executor_main()
for p in places:
self.place = p
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
with fluid.unique_name.guard():
self.pe_main()
def prepare_feed(self, image, label, dev_cnt=1): def prepare_feed(self, image, label, dev_cnt=1):
batch_size = 32 * dev_cnt batch_size = 32 * dev_cnt
...@@ -83,25 +92,36 @@ class TestExecutor(unittest.TestCase): ...@@ -83,25 +92,36 @@ class TestExecutor(unittest.TestCase):
return image_np, label_np return image_np, label_np
def assertScopeVar(self, scope, persitables, non_persistables): def assertScopeVar(self, scope, persitables, non_persistables):
outline_p_vars = []
for name in persitables: for name in persitables:
var = scope.find_var(name) var = scope.find_var(name)
self.assertTrue(var is not None) self.assertTrue(var is not None)
t = var.get_tensor() t = var.get_tensor()
self.assertTrue(t._is_initialized()) if not t._is_initialized():
outline_p_vars.append(name)
outline_np_vars = []
for name in non_persistables: for name in non_persistables:
var = scope.find_var(name) var = scope.find_var(name)
self.assertTrue(var is not None) self.assertTrue(var is not None)
t = var.get_tensor() t = var.get_tensor()
if t._is_initialized(): if t._is_initialized():
print('WARNING: Variable {} is alive'.format(name)) outline_np_vars.append(name)
self.assertTrue(not t._is_initialized())
print('Non-alive persistable vars {} in {}'.format(outline_p_vars,
persitables))
print('Alive non-persistable vars {} in {}'.format(outline_np_vars,
non_persistables))
self.assertEqual(len(outline_p_vars), 0)
self.assertEqual(len(outline_np_vars), 0)
def executor_main(self): def executor_main(self):
image, label, loss = simple_fc_net() image, label, loss = simple_fc_net()
loss.persistable = False loss.persistable = False
persistables, non_persistables = get_persistables_and_non_persistables( persistables, non_persistables = get_persistables_and_non_persistables(
fluid.default_main_program(), [loss.name]) fluid.default_main_program(), [loss.name])
print('Non-persistable var number {}'.format(len(non_persistables)))
print(non_persistables)
exe = fluid.Executor(self.place) exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
...@@ -135,6 +155,10 @@ class TestExecutor(unittest.TestCase): ...@@ -135,6 +155,10 @@ class TestExecutor(unittest.TestCase):
exec_strategy = fluid.ExecutionStrategy() exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 100 exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
prog = fluid.CompiledProgram(fluid.default_main_program( prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel( )).with_data_parallel(
loss_name=loss.name, exec_strategy=exec_strategy) loss_name=loss.name, exec_strategy=exec_strategy)
...@@ -155,11 +179,5 @@ class TestExecutor(unittest.TestCase): ...@@ -155,11 +179,5 @@ class TestExecutor(unittest.TestCase):
self.assertScopeVar(kids[0], persistables, non_persistables) self.assertScopeVar(kids[0], persistables, non_persistables)
class TestExecutor2(TestExecutor):
def setUp(self):
self.place = fluid.CPUPlace() if not fluid.core.is_compiled_with_cuda() \
else fluid.CUDAPlace(0)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册