提交 f8ed2c22 编写于 作者: S sneaxiy

try to fix ci error

test=develop
上级 072d95d8
......@@ -193,8 +193,15 @@ ExtractComputationOpFromLastLivedVar(VarHandle *var, size_t scope_idx,
return shrink_func(computation_op);
}
static bool CanPrecede(const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
/**
* Shrink op dependencies. If some ops do not Tensor buffer of any input,
* just remove the dependency of this op, i.e, decrease reference count.
*
* Returns whether the dependency count decreases to 0.
*/
static bool ShrinkNoNeedBufferVarOpDependency(
const std::string &var_name,
std::unordered_set<ComputationOpHandle *> *op_handles) {
std::vector<ComputationOpHandle *> skip_ops;
for (auto *op_handle : *op_handles) {
auto *op_base = op_handle->GetOp();
......@@ -303,8 +310,9 @@ std::unique_ptr<ir::Graph> ReferenceCountPass::ApplyImpl(
VLOG(10) << "Extract " << result.size() << " ops of var " << var_name;
size_t original_op_deps = result.size();
// If reference count can be calculated precedingly, just precede
if (CanPrecede(var_name, &result)) {
// If all ops do not need buffer of var_name, calculate reference count
// of the previous version of var_name.
if (ShrinkNoNeedBufferVarOpDependency(var_name, &result)) {
VLOG(10) << "Try to precede reference count computing at var "
<< var_name;
continue;
......
......@@ -64,7 +64,9 @@ struct OpInOutInfo {
}
private:
// A set to record unused buffer input vars of op
std::unordered_set<std::string> no_need_buffer_ins_;
// A set to record other args of op (including in, out)
std::unordered_set<std::string> other_args_set_;
bool is_built_{false};
};
......@@ -91,6 +93,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_var_list) {
UseGarbageCollectorGFlags();
std::unordered_set<std::string> skip_vars(skip_var_list.begin(),
skip_var_list.end());
......@@ -112,6 +115,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
}
if (info.IsInArgBufferNeeded(name)) {
// Update the last living op of variable to current op
var_op_idx_map[name] = i;
} else {
VLOG(10) << "Skip reference count computing of variable "
......@@ -124,6 +128,7 @@ std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
for (auto &name_pair : op->Outputs()) {
for (auto &name : name_pair.second) {
if (VarCanBeDeleted(name, block, skip_vars)) {
// Update the last living op of variable to current op
var_op_idx_map[name] = i;
}
}
......
......@@ -25,11 +25,13 @@
namespace paddle {
namespace framework {
// Result map: op -> variable names that can be deleted after op runs
std::unordered_map<OperatorBase *, std::vector<std::string>> GetUnusedVars(
const BlockDesc &block,
const std::vector<std::unique_ptr<OperatorBase>> &ops,
const std::vector<std::string> &skip_vars);
// Collect unused tensors after op runs
void DeleteUnusedTensors(
const Scope &scope, OperatorBase *op,
const std::unordered_map<OperatorBase *, std::vector<std::string>>
......
......@@ -13,6 +13,11 @@
// limitations under the License.
#include <algorithm>
#include <deque>
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#ifdef PADDLE_WITH_CUDA
#include "paddle/fluid/platform/cuda_device_guard.h"
#endif
......@@ -21,6 +26,15 @@
namespace paddle {
namespace framework {
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
GarbageCollector::GarbageCollector(const platform::Place &place,
size_t max_memory_size)
: max_memory_size_((std::max)(max_memory_size, static_cast<size_t>(1))) {
......@@ -85,5 +99,16 @@ void StreamGarbageCollector::ClearCallback(
callback_manager_->AddCallback(callback);
}
#endif
void UseGarbageCollectorGFlags() {}
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
} // namespace framework
} // namespace paddle
......@@ -18,6 +18,7 @@
#include <functional>
#include <memory>
#include <mutex> // NOLINT
#include <utility>
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
......@@ -126,5 +127,10 @@ void GarbageCollector::Add(Container &&objs, Callback &&callback) {
}
}
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
extern void UseGarbageCollectorGFlags();
} // namespace framework
} // namespace paddle
......@@ -29,15 +29,6 @@ DEFINE_bool(
"Delete local scope eagerly. It will reduce GPU memory usage but "
"slow down the destruction of variables.(around 1% performance harm)");
DEFINE_double(
eager_delete_tensor_gb, -1.0,
"Memory size threshold (GB) when the garbage collector clear tensors."
"Disabled when this value is less than 0");
DEFINE_bool(fast_eager_deletion_mode, true,
"Fast eager deletion mode. If enabled, memory would release "
"immediately without waiting GPU kernel ends.");
// When in inference scenario, the scopes will not be written by two threads in
// a mean time, but a scope may be read by multiple threads concurrently, and
// the mutex will cause serious performance issue.
......@@ -57,15 +48,6 @@ DEFINE_bool(fast_eager_deletion_mode, true,
namespace paddle {
namespace framework {
int64_t GetEagerDeletionThreshold() {
return FLAGS_eager_delete_tensor_gb < 0
? -1
: static_cast<int64_t>(FLAGS_eager_delete_tensor_gb *
(static_cast<int64_t>(1) << 30));
}
bool IsFastEagerDeletionModeEnabled() { return FLAGS_fast_eager_deletion_mode; }
Scope::~Scope() { DropKids(); }
Scope& Scope::NewScope() const {
......
......@@ -32,9 +32,6 @@ extern "C" {
namespace paddle {
namespace framework {
int64_t GetEagerDeletionThreshold();
bool IsFastEagerDeletionModeEnabled();
class Scope;
/**
......
......@@ -13,6 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/concat_op.h"
#include <memory>
#include <string>
#include <vector>
......@@ -120,11 +121,7 @@ Examples:
class ConcatOpGrad : public framework::OperatorWithKernel {
public:
ConcatOpGrad(const std::string &type,
const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext *ctx) const override {
auto in_x = "X";
......@@ -142,8 +139,19 @@ class ConcatOpGrad : public framework::OperatorWithKernel {
}
}
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.GetPlace());
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(ConcatOpGradNoNeedBufferVarInference,
"X");
} // namespace operators
} // namespace paddle
......@@ -151,7 +159,8 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(concat, ops::ConcatOp, ops::ConcatOpMaker,
paddle::framework::DefaultGradOpDescMaker<
false> /* set false to disable empty grad */);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad);
REGISTER_OPERATOR(concat_grad, ops::ConcatOpGrad,
ops::ConcatOpGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(
concat, ops::ConcatKernel<paddle::platform::CPUDeviceContext, double>,
ops::ConcatKernel<paddle::platform::CPUDeviceContext, float>,
......
......@@ -13,7 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/crop_op.h"
#include <boost/lexical_cast.hpp>
#include <memory>
#include <string>
#include <vector>
namespace paddle {
namespace operators {
......@@ -178,12 +180,31 @@ class CropOpGrad : public framework::OperatorWithKernel {
}
};
class CropGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("crop_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("X", Input("X"));
if (ForwardOp().Inputs().count("Offsets") > 0) {
op->SetInput("Offsets", Input("Offsets"));
}
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(crop, ops::CropOp, ops::CropOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
ops::CropGradOpDescMaker);
REGISTER_OPERATOR(crop_grad, ops::CropOpGrad);
REGISTER_OP_CPU_KERNEL(
crop, ops::CropKernel<paddle::platform::CPUDeviceContext, float>);
......
......@@ -14,6 +14,7 @@
#include <set>
#include <string>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/distributed/parameter_prefetch.h"
......@@ -218,7 +219,7 @@ void prefetch(const std::string& id_name, const std::string& out_name,
boost::get<platform::CUDAPlace>(id_tensor.place()),
id_tensor.data<int64_t>(), sizeof(int64_t) * id_tensor.numel(),
stream);
for (size_t i = 0; i < cpu_tensor.numel(); ++i) {
for (int64_t i = 0; i < cpu_tensor.numel(); ++i) {
ids_vector.push_back(cpu_tensor_data[i]);
}
#endif
......
......@@ -13,6 +13,9 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/gather_op.h"
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/ddim.h"
namespace paddle {
......@@ -59,8 +62,9 @@ class GatherGradOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return framework::OpKernelType(ctx.Input<Tensor>("X")->type(),
ctx.device_context());
return framework::OpKernelType(
ctx.Input<Tensor>(framework::GradVarName("Out"))->type(),
ctx.device_context());
}
};
......@@ -94,13 +98,34 @@ Out = [[3, 4],
)DOC");
}
};
class GatherGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType("gather_grad");
op->SetInput("Index", Input("Index"));
op->SetInput("X", Input("X"));
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
return op;
}
};
DECLARE_NO_NEED_BUFFER_VARS_INFERENCE(GatherGradNoNeedBufferVarInference, "X");
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp);
ops::GatherGradOpDescMaker);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInference);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
......
......@@ -21,6 +21,7 @@
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <memory>
#include <sstream>
#include <string>
#include <unordered_map>
......@@ -152,7 +153,7 @@ class CTRReader : public framework::FileReader {
queue_->ReOpen();
VLOG(3) << "reopen success";
VLOG(3) << "thread_num " << thread_num_;
for (int thread_id = 0; thread_id < thread_num_; thread_id++) {
for (size_t thread_id = 0; thread_id < thread_num_; thread_id++) {
read_threads_.emplace_back(new std::thread(std::bind(
&ReadThread, file_groups_[thread_id], data_desc_,
static_cast<int>(thread_id), &read_thread_status_, queue_)));
......
......@@ -24,6 +24,7 @@ limitations under the License. */
#include "paddle/fluid/framework/executor.h"
#include "paddle/fluid/framework/feed_fetch_method.h"
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/framework/garbage_collector.h"
#include "paddle/fluid/framework/ir/pass_builder.h"
#include "paddle/fluid/framework/lod_rank_table.h"
#include "paddle/fluid/framework/lod_tensor.h"
......@@ -133,6 +134,8 @@ PYBIND11_MODULE(core, m) {
paddle::platform::CpuTotalPhysicalMemory();
paddle::memory::allocation::UseAllocatorStrategyGFlag();
paddle::framework::UseGarbageCollectorGFlags();
m.doc() = "C++ core of PaddlePaddle";
// using framework in this function. Since it is inside a function, it will
......
......@@ -15,6 +15,9 @@
import os
import numpy as np
os.environ['FLAGS_eager_delete_tensor_gb'] = '0.0'
os.environ['FLAGS_fast_eager_deletion_mode'] = '1'
os.environ['FLAGS_use_ngraph'] = '0'
os.environ['FLAGS_use_mkldnn'] = '0'
os.environ['CPU_NUM'] = '4'
import paddle.fluid as fluid
......@@ -58,18 +61,24 @@ def get_persistables_and_non_persistables(prog, fetch_list):
class TestExecutor(unittest.TestCase):
def setUp(self):
self.place = fluid.CPUPlace()
def test_executor_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.executor_main()
def test_parallel_executor_main(self):
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
self.pe_main()
places = [fluid.CPUPlace()]
if fluid.core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.place = p
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
with fluid.unique_name.guard():
self.executor_main()
for p in places:
self.place = p
with fluid.program_guard(fluid.Program(), fluid.Program()):
with fluid.scope_guard(fluid.Scope()):
with fluid.unique_name.guard():
self.pe_main()
def prepare_feed(self, image, label, dev_cnt=1):
batch_size = 32 * dev_cnt
......@@ -83,25 +92,36 @@ class TestExecutor(unittest.TestCase):
return image_np, label_np
def assertScopeVar(self, scope, persitables, non_persistables):
outline_p_vars = []
for name in persitables:
var = scope.find_var(name)
self.assertTrue(var is not None)
t = var.get_tensor()
self.assertTrue(t._is_initialized())
if not t._is_initialized():
outline_p_vars.append(name)
outline_np_vars = []
for name in non_persistables:
var = scope.find_var(name)
self.assertTrue(var is not None)
t = var.get_tensor()
if t._is_initialized():
print('WARNING: Variable {} is alive'.format(name))
self.assertTrue(not t._is_initialized())
outline_np_vars.append(name)
print('Non-alive persistable vars {} in {}'.format(outline_p_vars,
persitables))
print('Alive non-persistable vars {} in {}'.format(outline_np_vars,
non_persistables))
self.assertEqual(len(outline_p_vars), 0)
self.assertEqual(len(outline_np_vars), 0)
def executor_main(self):
image, label, loss = simple_fc_net()
loss.persistable = False
persistables, non_persistables = get_persistables_and_non_persistables(
fluid.default_main_program(), [loss.name])
print('Non-persistable var number {}'.format(len(non_persistables)))
print(non_persistables)
exe = fluid.Executor(self.place)
exe.run(fluid.default_startup_program())
......@@ -135,6 +155,10 @@ class TestExecutor(unittest.TestCase):
exec_strategy = fluid.ExecutionStrategy()
exec_strategy.num_iteration_per_drop_scope = 100
build_strategy = fluid.BuildStrategy()
build_strategy.memory_optimize = False
build_strategy.enable_inplace = False
prog = fluid.CompiledProgram(fluid.default_main_program(
)).with_data_parallel(
loss_name=loss.name, exec_strategy=exec_strategy)
......@@ -155,11 +179,5 @@ class TestExecutor(unittest.TestCase):
self.assertScopeVar(kids[0], persistables, non_persistables)
class TestExecutor2(TestExecutor):
def setUp(self):
self.place = fluid.CPUPlace() if not fluid.core.is_compiled_with_cuda() \
else fluid.CUDAPlace(0)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册