未验证 提交 75936d83 编写于 作者: J JZ-LIANG 提交者: GitHub

Recompute Offload (#30233)

上级 2e808577
......@@ -22,7 +22,11 @@ enum Mode {
HETER = 4; // support XPU and GPU computing server
}
message RecomputeConfig { repeated string checkpoints = 1; }
message RecomputeConfig {
repeated string checkpoints = 1;
optional bool enable_offload = 2 [ default = false ];
repeated int32 checkpoint_shape = 3;
}
message ShardingConfig {
optional float fuse_broadcast_MB = 1 [ default = 32.0 ];
......
......@@ -394,5 +394,5 @@ REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass)
.EQ("square", 0)
.LE("elementwise_mul", 1)
.LE("elementwise_sub", 1)
.EQ("fill_constant", 1)
.LE("fill_constant", 2)
.EQ("fusion_squared_mat_sub", 0));
......@@ -116,6 +116,15 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker {
"memory. Otherwise, fill output variable to the running "
"device")
.SetDefault(false);
AddAttr<int>("place_type",
"(int, default -1) allow mamually setting place where the "
"variable should be hold. "
"-1: not set manually, determine the place by executor. "
"0: CPUPlace. "
"1: CUDAPlace. "
"2: CUDAPinnedPlace. "
"3: XPUPlace. ")
.SetDefault(-1);
AddOutput("Out",
"(Tensor) Tensor of specified shape will be filled "
"with the specified value");
......@@ -154,4 +163,11 @@ REGISTER_OP_VERSION(fill_constant)
)ROC",
paddle::framework::compatible::OpVersionDesc().NewInput(
"ValueTensor",
"In order to support new feature tensor support of Value"));
"In order to support new feature tensor support of Value"))
.AddCheckpoint(
R"ROC(
Upgrade fill_constant to add a new attribute [place_type].
)ROC",
paddle::framework::compatible::OpVersionDesc().NewAttr(
"place_type",
"In order to support tensor in CUDAPinnedPlace and XPUPlace", -1));
......@@ -39,6 +39,7 @@ class FillConstantKernel : public framework::OpKernel<T> {
auto str_value = ctx.Attr<std::string>("str_value");
auto float_value = ctx.Attr<float>("value");
auto force_cpu = ctx.Attr<bool>("force_cpu");
auto place_type = ctx.Attr<int>("place_type");
framework::Tensor *tensor = nullptr;
framework::Variable *out_var = ctx.OutputVar("Out");
......@@ -101,29 +102,59 @@ class FillConstantKernel : public framework::OpKernel<T> {
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
int actual_place = place_type;
if (actual_place == -1) {
bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace();
if (cpu_place) {
actual_place = 0;
} else if (platform::is_gpu_place(ctx.GetPlace())) {
actual_place = 1;
} else if (platform::is_xpu_place(ctx.GetPlace())) {
actual_place = 3;
}
}
if (actual_place == 0) {
tensor->mutable_data(platform::CPUPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
} else if (actual_place == 1) {
#ifdef PADDLE_WITH_CUDA
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::CUDADeviceContext, T> functor;
functor(reinterpret_cast<const platform::CUDADeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} else if (actual_place == 2) {
#ifdef PADDLE_WITH_CUDA
tensor->mutable_data(platform::CUDAPinnedPlace(), data_type);
math::SetConstant<platform::CPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::CPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with GPU."));
#endif
} else if (actual_place == 3) {
#ifdef PADDLE_WITH_XPU
if (!cpu_place) {
tensor->mutable_data(ctx.GetPlace(), data_type);
math::SetConstant<platform::XPUDeviceContext, T> functor;
functor(reinterpret_cast<const platform::XPUDeviceContext &>(dev_ctx),
tensor, static_cast<T>(value));
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"PaddlePaddle should compile with XPU."));
#endif
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Could NOT determine the place of variable, place_type = %d .",
actual_place));
}
}
};
} // namespace operators
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/memcpy_op.h"
#include <string>
namespace paddle {
namespace framework {
class OpDesc;
class Variable;
} // namespace framework
namespace imperative {
class OpBase;
} // namespace imperative
namespace platform {
struct CPUPlace;
struct CUDAPlace;
struct float16;
} // namespace platform
} // namespace paddle
namespace paddle {
namespace operators {
class MemcpyOp : public framework::OperatorWithKernel {
public:
MemcpyOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorWithKernel(type, inputs, outputs, attrs) {}
void InferShape(framework::InferShapeContext *ctx) const override {
auto type = ctx->GetInputsVarType("X")[0];
if (type == framework::proto::VarType::SELECTED_ROWS ||
type == framework::proto::VarType::LOD_TENSOR) {
ctx->SetOutputDim("Out", ctx->GetInputDim("X"));
if (type == framework::proto::VarType::LOD_TENSOR) {
ctx->ShareLoD("X", /*->*/ "Out");
}
}
}
protected:
framework::OpKernelType GetKernelTypeForVar(
const std::string &var_name, const framework::Tensor &tensor,
const framework::OpKernelType &expected_kernel_type) const override {
return framework::OpKernelType(expected_kernel_type.data_type_,
expected_kernel_type.place_,
tensor.layout());
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
return framework::OpKernelType(
OperatorWithKernel::IndicateVarDataType(ctx, "X"),
ctx.device_context());
}
};
class MemcpyInferVarType : public framework::VarTypeInference {
public:
void operator()(framework::InferVarTypeContext *ctx) const override {
ctx->SyncTypeAndDataType("X", "Out");
}
};
class MemcpyKernel {
public:
void operator()(const framework::ExecutionContext &ctx) const {
auto *x = ctx.InputVar("X");
if (x == nullptr) {
return;
}
PADDLE_ENFORCE_EQ(
ctx.HasOutput("Out"), true,
platform::errors::NotFound("Output(Out) of memcpy_op is not found."));
auto *out = ctx.OutputVar("Out");
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(ctx.GetPlace());
auto dst_place_type = ctx.Attr<int>("dst_place_type");
framework::VisitVarType(*x, MemcpyFunctor(out, dev_ctx, dst_place_type));
}
};
class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput("X", "(LoDTensor) The input variable ");
AddOutput("Out",
"(LoDTensor) The type of output "
"is the same as input X.");
AddAttr<int>("dst_place_type",
"Determine the dst place of tensor copy. "
"By Now it ONLY support CUDAPlace and CUDAPinnedPlace. Other "
"place type is Unimplemented and will cause ERROR."
"0: dst is on CPUPlace. "
"1: dst is on CUDAPlace. "
"2: dst is on CUDAPinnedPlace. "
"3: dst is on XPUPlace. ");
AddComment(R"DOC(
Memcpy Operator.
By now, it ONLY supports the memcopy between CUDAPinnedPlace and CUDAPlace,
and used as an internal op by Recompute-Offload.
You would have to update it if you want other more capacities.
Out = X, when type in [LoDTensor]
raise error if the type is not listed above.
)DOC");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OPERATOR(
memcpy, ops::MemcpyOp, ops::MemcpyOpProtoMaker, ops::MemcpyInferVarType,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool,
ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel);
#ifdef PADDLE_WITH_CUDA
REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double,
ops::MemcpyKernel, int, ops::MemcpyKernel,
int64_t, ops::MemcpyKernel, bool,
ops::MemcpyKernel, plat::float16,
ops::MemcpyKernel);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace framework {
class LoDTensor;
class Variable;
} // namespace framework
} // namespace paddle
namespace paddle {
namespace operators {
class MemcpyFunctor {
public:
MemcpyFunctor(framework::Variable *out,
const platform::DeviceContext &dev_ctx,
const int dst_place_type)
: out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {}
void operator()(const framework::LoDTensor &lod_tensor) const {
auto &out_tensor = *out_->GetMutable<framework::LoDTensor>();
if (dst_place_type_ == 3) {
framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_,
&out_tensor);
} else if (dst_place_type_ == 2) {
framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_,
&out_tensor);
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"memcpy dst_place_type: %d is not supported yet.", dst_place_type_));
}
out_tensor.set_lod(lod_tensor.lod());
}
void operator()(const framework::SelectedRows &rows) const {
// (JZ-LIANG) to support SelectedRows
PADDLE_THROW(platform::errors::Unimplemented(
"Memcpy for SelectedRows is NOT support yet."));
}
template <typename T>
void operator()(const T &v) const {
PADDLE_ENFORCE_EQ(
true, false,
platform::errors::PermissionDenied(
"Not support type for Memcpy op with type %s", typeid(T).name()));
}
private:
framework::Variable *out_;
const platform::DeviceContext &dev_ctx_;
const int dst_place_type_;
};
} // namespace operators
} // namespace paddle
......@@ -632,8 +632,20 @@ class DistributedStrategy(object):
@property
def recompute_configs(self):
"""
Set recompute configurations. In general, the recompute strategy of current
implementation should have some manually assign checkpoints
Set recompute configurations.
**Note**:
checkpoints(list): list of string name of checkpoints. In general, the recompute
strategy of current implementation should have some manually assign checkpoints.
enable_offload(bool): enable recompute checkpoints offload feature. this feature
will offload the checkpoint to host memory to allow even larger batch size. since
the memcpy from host to device takes time, it is a trade off between larger batch
size and training speed.
checkpoint_shape(list): list of int that specific the shape of checkpoint. so far
recompute-offload requires that all checkpoint to be same shape, and every dimension
specific here should be determined ("-1" is not allowed).
Examples:
......@@ -642,7 +654,10 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.recompute = True
strategy.recompute_configs = {"checkpoints": ["x", "y"]}
strategy.recompute_configs = {
"checkpoints": ["x", "y"],
"enable_offload": True,
"checkpoint_shape": [100, 512, 1024] }
"""
return get_msg_dict(self.strategy.recompute_configs)
......@@ -692,6 +707,14 @@ class DistributedStrategy(object):
This configuration will affect the communication speed in sharding training,
and should be an empirical value decided by your model size and network topology.
hybrid_dp(bool): enable hybrid data parallelism above the sharding parallelism.
you are supposed to have at least double the number of gpu you have in normal sharding
training to enable this feature.
sharding_group_size(int): attribute of hybrid_dp. specific the the number of gpus within
each sharding group; and therefore, the number of hybrid data parallelism ways will be equal
to (global_size / sharding_group_size).
Examples:
.. code-block:: python
......@@ -699,7 +722,10 @@ class DistributedStrategy(object):
import paddle.distributed.fleet as fleet
strategy = fleet.DistributedStrategy()
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 32}
strategy.sharding_configs = {
"fuse_broadcast_MB": 32,
"hybrid_dp": True,
"sharding_group_size": 8}
"""
return get_msg_dict(self.strategy.sharding_configs)
......
......@@ -39,9 +39,13 @@ class RecomputeOptimizer(MetaOptimizerBase):
return
configs = self.user_defined_strategy.recompute_configs
self.wrapped_opt = RO(self.inner_opt)
self.wrapped_opt._set_checkpoints(list(configs["checkpoints"]))
if configs["enable_offload"]:
self.wrapped_opt._enable_offload()
# TODO(JZ-LIANG) might found a way to infer the checkpoint shape automatically
checkpoint_shapes = list(configs["checkpoint_shape"])
self.wrapped_opt.checkpoint_shape = checkpoint_shapes
def _can_apply(self):
if not self.role_maker._is_collective:
......
......@@ -99,8 +99,32 @@ class ProgramStats(object):
max_op_idx = max(max_op_idx, idx)
if min_op_idx >= max_op_idx:
return False, min_op_idx, max_op_idx
return True, min_op_idx, max_op_idx
def _update_segment_start(self, min_idx, pre_segment_end_idx):
"""
persist vars of amp-related cast should be included in recompute segment
"""
def is_amp_cast(op):
return op.desc.type() == 'cast' and self.block.var(
op.desc.input_arg_names()[0]).persistable
idx_ = min_idx - 1
updated_min_idx = min_idx
while idx_ > pre_segment_end_idx:
if is_amp_cast(self.ops[idx_]):
_logger.debug("found amp-cast op: {}, : {}".format(self.ops[
idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[
0]))
updated_min_idx = idx_
idx_ -= 1
else:
break
return updated_min_idx
def build_stats(self):
for i, op in enumerate(self.ops):
self.op_deps[i] = {"in_ops": [], "out_ops": []}
......@@ -751,20 +775,29 @@ def _append_backward_ops_with_checkpoints_(
if name not in program_stat.var_op_deps:
break
op_idx = program_stat.var_op_deps[name]["var_as_output_ops"]
# only count the last generate op
for idx in op_idx:
max_op_idx = max(max_op_idx, idx)
if max_op_idx > 0:
segments.append([0, max_op_idx + 1])
else:
start_idx = 0
pre_segment_end_idx = -1
while True:
_logger.debug("FW op range[0] - [{}]".format(len(ops)))
if start_idx >= len(checkpoints_name) - 1:
break
# min_idx: checkpoint_1' s input op
# max_idx: checkpoint_2' s output op
flag, min_idx, max_idx = program_stat.is_subgraph(
[checkpoints_name[start_idx]],
[checkpoints_name[start_idx + 1]])
if flag:
# max_idx + 1 since the exact and used segment end idx is max_idx
min_idx = program_stat._update_segment_start(
min_idx, pre_segment_end_idx)
segments.append([min_idx, max_idx + 1])
start_idx += 1
if segments != [] and segments[0][0] != 0:
......@@ -772,12 +805,31 @@ def _append_backward_ops_with_checkpoints_(
else:
recompute_segments = segments
for i, (idx1, idx2) in enumerate(recompute_segments):
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
_logger.debug("recompute segment[{}]".format(i))
_logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type(
), ops[idx1].desc.input_arg_names()))
_logger.debug("segment end op: [{}]: [{}]".format(ops[
idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names()))
# 2) go through all forward ops and induct all variables that will be hold in memory
vars_should_be_hold = []
# a. variables that are used across segments will be held in memory
for segment in recompute_segments:
vars_should_be_hold.extend(
program_stat.get_out_of_subgraph_vars(segment[0], segment[1]))
cross_vars = set(vars_should_be_hold) - set(checkpoints_name)
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
_logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \
len(cross_vars), cross_vars))
# b. output of seed op should be kept in memory
vars_should_be_hold.extend(program_stat.get_reserved_vars())
# c. input variables are checkpoints
......@@ -792,8 +844,6 @@ def _append_backward_ops_with_checkpoints_(
max_calculated_op_position = len(ops)
if recompute_segments == []:
# if there is no recompute segment, add backward ops like
# _append_backward_ops_ function
gap_ops = ops[0:max_calculated_op_position]
for op in reversed(gap_ops):
if op.has_attr("sub_block"):
......@@ -807,7 +857,6 @@ def _append_backward_ops_with_checkpoints_(
grad_to_var.update(op_grad_to_var)
for i, segment in enumerate(recompute_segments[::-1]):
# add grad op for ops not in any segments
gap_ops = ops[segment[1]:max_calculated_op_position]
max_calculated_op_position = segment[0]
for op in reversed(gap_ops):
......@@ -851,7 +900,7 @@ def _append_backward_ops_with_checkpoints_(
# added_descs should be in grad_op_descs because it is backward op desc
grad_op_descs.extend(buffer_descs)
# 3.c. add backward ops of current recomputation ops
# 3.c. add backward ops for all ops in current segment
for op_desc in reversed(added_descs):
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
op_desc, cpt.to_text(no_grad_dict[block.idx]), [])
......@@ -1480,9 +1529,11 @@ def append_backward(loss,
# TODO: support _append_backward_ops_with_checkpoints_ in
# sub-block (control flow)
is_recompute = False
if checkpoints != None and \
isinstance(checkpoints, list) and \
len(checkpoints) > 0:
is_recompute = True
program_stat, checkpoint_names, \
vars_should_be_hold, \
recompute_segments = \
......@@ -1577,6 +1628,9 @@ def append_backward(loss,
attr_val.extend(g.op.attr(op_role_var_attr_name))
g.op._set_attr(op_role_var_attr_name, attr_val)
if is_recompute:
return params_and_grads, checkpoint_names
else:
return params_and_grads
......
......@@ -4600,6 +4600,7 @@ class RecomputeOptimizer(Optimizer):
self._checkpoints = None
self._learning_rate = self._optimizer._learning_rate
self._learning_rate_map = self._optimizer._learning_rate_map
self.enable_offload = False
def _set_checkpoints(self, checkpoints):
"""
......@@ -4615,6 +4616,10 @@ class RecomputeOptimizer(Optimizer):
), "_checkpoints should be a list of Variable or a list of String"
self._checkpoints = checkpoints
# should enable offload before calling backward
def _enable_offload(self):
self.enable_offload = True
@framework.deprecate_stat_dict
def load(self, state_dict):
"""
......@@ -4703,6 +4708,358 @@ class RecomputeOptimizer(Optimizer):
return self._optimizer.apply_gradients(params_grads=params_grads)
def _creat_vars(self, varname):
pinned_var_name = unique_name.generate(varname + "@Pinned")
fetched_var_name = unique_name.generate(varname + "@Fetch")
pinned_var = self._main_program.global_block().create_var(
name=pinned_var_name,
shape=self.checkpoint_shape,
dtype=self._main_program.global_block().var(varname).dtype,
persistable=False,
stop_gradient=True)
fetch_var = self._main_program.global_block().create_var(
name=fetched_var_name,
shape=self.checkpoint_shape,
dtype=self._main_program.global_block().var(varname).dtype,
persistable=False,
stop_gradient=False)
return pinned_var_name, fetched_var_name
def _append_fill_constant_ops(self, startup_program):
"""
add fill_constant_ops to the end of the prog
we should fill the pinned vars before runing the main_prog
to instantiate their tensor hold_, which could tell us whether
the host memory could hold all the checkpoints from all the
GPU devices in this node.
"""
op_role = 0
block = startup_program.global_block()
fill_constant_vars = self.checkpoint_name2pinned_name.values()
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
for varname in fill_constant_vars:
var = self._main_program.global_block().var(varname)
# NOTE (JZ-LIANG) to pre-allocate the CUDAPinned MEM
pinned_var = block.create_var(
name=varname,
shape=self.checkpoint_shape,
dtype=self._main_program.global_block().var(var.name).dtype,
persistable=False,
stop_gradient=True)
block.append_op(
type='fill_constant',
outputs={'Out': varname},
attrs={
"shape": var.shape,
"dtype": var.dtype,
"value": 0.0,
"place_type": 2,
OP_ROLE_KEY: op_role,
})
return
def _insert_async_memcpy_op(self, insert_idx, src_varname, dst_varname,
op_role, kind):
OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName()
self.block._insert_op_without_sync(
insert_idx,
type='memcpy',
inputs={'X': [self._main_program.global_block().var(src_varname)]},
outputs={
'Out': [self._main_program.global_block().var(dst_varname)]
},
attrs={"dst_place_type": int(kind),
OP_ROLE_KEY: op_role})
def _insert_fetch_op(self, idx, varname):
assert varname in self.checkpoint_name2pinned_name, "Try to fetch {} from Pinned Memory, but it is NOT a checkpoint".format(
varname)
pinned_varname = self.checkpoint_name2pinned_name[varname]
fetch_varname = self.checkpoint_name2fetch_name[varname]
self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 2)
def _insert_offload_op(self, idx, varname):
assert varname in self.checkpoint_name2pinned_name, "Try to offload {} to Pinned Memory, but it is NOT a checkpoint".format(
varname)
pinned_varname = self.checkpoint_name2pinned_name[varname]
self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 3)
def _insert_sync_op(self, op_idx, checkpoint_name):
# single stream offload no need sync
pass
def _record_fetch_op(self, idx):
assert len(self.un_fetch_checkpoint_names
) > 0, "Could NOT found checkpoint to fetch"
checkpoint_name = self.un_fetch_checkpoint_names.pop(-1)
logging.debug("Record fetch [{}]".format(checkpoint_name))
self.idx2insertions[idx] = ("fetch", checkpoint_name)
return checkpoint_name
def _record_offload_op(self, idx, checkpoint_name):
expected_checkpoint_name = self.un_offload_checkpoint_names.pop(0)
assert checkpoint_name == expected_checkpoint_name, "expected to offload [{}] but got [{}]".format(
expected_checkpoint_name, checkpoint_name)
logging.debug("Record offload [{}]".format(checkpoint_name))
self.idx2insertions[idx] = ("offload", checkpoint_name)
def _record_sync_op(self, idx, checkpoint_name):
assert checkpoint_name not in self.synced_checkpoints, "Try to sync the checkpoint [{}] twice".format(
checkpoint_name)
self.synced_checkpoints.add(checkpoint_name)
logging.debug("Record offload sync [{}]".format(checkpoint_name))
self.idx2insertions[idx] = ("sync", checkpoint_name)
def _parse_backward(self):
self.idx2insertions = {}
# don't offload the last checkpoints, to favor throughput
self.un_fetch_checkpoint_names = self.sorted_checkpoint_names[:]
self.un_fetch_checkpoint_names.pop(-1)
need_fetch_checkpoint_names = self.un_fetch_checkpoint_names[:]
self.checkpoint_usage_count = {}
for checkpoint_name in self.un_fetch_checkpoint_names:
self.checkpoint_usage_count[checkpoint_name] = 0
self.bw_strart_op_idx = len(self.block.ops)
for idx, op in enumerate(self.block.ops):
if int(op.desc.attr("op_role")) == 1:
self.bw_strart_op_idx = idx
break
assert self.bw_strart_op_idx < len(
self.block.ops), "Could NOT found backword op in prog"
# fetch second to last checkpoint at the beginning of BW
fetched_checkpoint_varname = self._record_fetch_op(
self.bw_strart_op_idx)
last_last_fetch_checkpoint = None
for i, op in enumerate(self.block.ops[self.bw_strart_op_idx:]):
idx = self.bw_strart_op_idx + i
input_vars = op.desc.input_arg_names()
for input_var in input_vars:
if input_var in need_fetch_checkpoint_names:
if input_var not in self.un_fetch_checkpoint_names:
# fetch the offloade checkpoint when the first usage of its previous one
if self.checkpoint_usage_count[input_var] == 0:
# TODO (JZ-LIANG) sync memcpy_stream if extra stream for memcpy
second_to_last_fetch_checkpoint = fetched_checkpoint_varname
# there is NO fetch ahead the first checkpoint
if input_var != self.sorted_checkpoint_names[0]:
fetched_checkpoint_varname = self._record_fetch_op(
idx)
# should check the current used checkpoint is ths last fetch one
assert second_to_last_fetch_checkpoint == input_var, "Current recompute segment should use [{}] BUT got [{}]".format(
second_to_last_fetch_checkpoint, input_var)
# rename
self.block.ops[idx]._rename_input(
input_var,
self.checkpoint_name2fetch_name[input_var])
self.checkpoint_usage_count[input_var] += 1
else:
raise ValueError(
"use checkpoint [{}] before fetch in BW".format(
input_var))
assert len(self.un_fetch_checkpoint_names
) == 0, "{} checkpoints have NOT been Recorded".format(
self.un_fetch_checkpoint_names)
def _update_backward(self):
if len(self.idx2insertions) == 0:
return
total_op = len(self.block.ops)
for op_idx in reversed(range(self.bw_strart_op_idx, total_op)):
if op_idx in self.idx2insertions:
operation, checkpoint_name = self.idx2insertions[op_idx]
if operation == "fetch":
self._insert_fetch_op(op_idx, checkpoint_name)
logging.debug("Insert [{}] fetch op.".format(
checkpoint_name))
del self.idx2insertions[op_idx]
elif operation == "sync":
self._insert_sync_op(op_idx, checkpoint_name)
logging.debug("Sync [{}] fetch op.".format(checkpoint_name))
self.block._sync_with_cpp()
assert len(
self.idx2insertions) == 0, "{} checkpoints left un-Fecthed".format(
[ele[1] for ele in self.idx2insertions.values()])
def _parse_forward(self):
self.idx2insertions = {}
# don't offload the last checkpoints, faster, less memory saving
self.un_offload_checkpoint_names = self.sorted_checkpoint_names[:]
last_checkpoint = self.un_offload_checkpoint_names.pop(-1)
need_offload_checkpoint_names = self.un_offload_checkpoint_names[:]
self.checkpoint_usage_count_and_idx = {}
for checkpoint_name in self.un_offload_checkpoint_names:
self.checkpoint_usage_count_and_idx[checkpoint_name] = {
'count': 0,
'idx': -1
}
self.synced_checkpoints = set()
self.fw_strart_op_idx = len(self.block.ops)
for idx, op in enumerate(self.block.ops):
if int(op.desc.attr("op_role")) == 0:
self.fw_strart_op_idx = idx
break
assert self.fw_strart_op_idx < len(
self.block.ops), "Could NOT found Forward op in prog"
last_offload_checkpoint = None
for i, op in enumerate(self.block.ops[self.fw_strart_op_idx:
self.bw_strart_op_idx]):
idx = self.fw_strart_op_idx + i
output_vars = op.desc.output_arg_names()
input_vars = op.desc.input_arg_names()
for output_var in output_vars:
if output_var in need_offload_checkpoint_names:
assert len(
output_vars
) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
output_var, op)
if output_var in self.un_offload_checkpoint_names:
# insert sync op if last checkpoint has not been sync
if last_offload_checkpoint != None:
if self.checkpoint_usage_count_and_idx[
last_offload_checkpoint]['count'] == 0:
self._record_sync_op(idx,
last_offload_checkpoint)
else:
last_usage_idx = self.checkpoint_usage_count_and_idx[
last_offload_checkpoint]['idx']
assert last_usage_idx > 0, "last_usage_idx of checkpoint [{}] should large than 0".format(
last_offload_checkpoint)
self._record_sync_op(last_usage_idx + 1,
last_offload_checkpoint)
# insert offload op after the checkpoint's generation op
self._record_offload_op(idx + 1, output_var)
last_offload_checkpoint = output_var
else:
raise ValueError(
"There should be just ONE op that output checkpoint [{}]".
format(output_var))
# need to sync the last need to offload checkpoint before the last checkpoint as output op
if output_var == last_checkpoint:
assert len(
output_vars
) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format(
output_var, op)
assert last_offload_checkpoint == self.sorted_checkpoint_names[
-2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format(
last_checkpoint, self.sorted_checkpoint_names[-2],
last_offload_checkpoint)
# sync if last checkpoint has not been sync
if self.checkpoint_usage_count_and_idx[
last_offload_checkpoint]['idx'] == 0:
self._record_sync_op(idx, last_offload_checkpoint)
else:
last_usage_idx = self.checkpoint_usage_count_and_idx[
last_offload_checkpoint]['idx']
assert last_usage_idx > 0, "last_usage_idx of checkpoint [{}] should large than 0".format(
last_offload_checkpoint)
self._record_sync_op(last_usage_idx + 1,
last_offload_checkpoint)
# record checkpoint usage
for input_var in input_vars:
if input_var in need_offload_checkpoint_names:
assert input_var not in self.synced_checkpoints, "checkpoint [{}] used after sync".format(
input_var)
self.checkpoint_usage_count_and_idx[input_var]['count'] += 1
self.checkpoint_usage_count_and_idx[input_var]['idx'] = idx
assert len(self.un_offload_checkpoint_names
) == 0, "{} checkpoints have NOT been Recorded".format(
self.un_fetch_checkpoint_names)
assert len(self.synced_checkpoints) == len(
need_offload_checkpoint_names
), "{} checkpoints have NOT been Recorded".format(
set(need_offload_checkpoint_names) - set(self.synced_checkpoints))
def _update_forward(self):
if len(self.idx2insertions) == 0:
return
for op_idx in reversed(
range(self.fw_strart_op_idx, self.bw_strart_op_idx)):
if op_idx in self.idx2insertions:
operation, checkpoint_name = self.idx2insertions[op_idx]
if operation == "offload":
self._insert_offload_op(op_idx, checkpoint_name)
logging.debug("Insert [{}] offload op.".format(
checkpoint_name))
del self.idx2insertions[op_idx]
elif operation == "sync":
self._insert_sync_op(op_idx, checkpoint_name)
logging.debug("Insert [{}] offload_sync op.".format(
checkpoint_name))
del self.idx2insertions[op_idx]
self.block._sync_with_cpp()
assert len(self.idx2insertions
) == 0, "{} checkpoints left un-Offloaded".format(
[ele[1] for ele in self.idx2insertions.values()])
def _check_offload_fetch(self):
# TODO(JZ-LIANG) the single stream offload need no sync
pass
def _offload(self, loss, startup_program=None):
"""
core steps for recompute offload
1. create pinned vars and temp vars
2. parse & update Forward pass: offload, sync
3. parse & update Backward pass: rename, fetch, sync
4. verify the correctness
"""
self._main_program = loss.block.program
self.block = loss.block
if startup_program == None:
startup_program = fluid.default_startup_program()
with program_guard(self._main_program, startup_program):
assert len(self.checkpoint_shape) > 0, (
"checkpoints shape {} should be an non empty list like: [12, 512, 1024]".
format(self.checkpoint_shape))
assert all([ele > 0 for ele in self.checkpoint_shape]), (
"all ele in checkpoints shape {} should be a determined integer larger than 0".
format(self.checkpoint_shape))
self.checkpoint_name2pinned_name = dict()
self.checkpoint_name2fetch_name = dict()
for checkpoint_varname in self.sorted_checkpoint_names:
pinned_var_name, fetch_var_name = self._creat_vars(
checkpoint_varname)
self.checkpoint_name2pinned_name[
checkpoint_varname] = pinned_var_name
self.checkpoint_name2fetch_name[
checkpoint_varname] = fetch_var_name
self._append_fill_constant_ops(startup_program)
# TODO (JZ-LIANG) to provide two offload stragtegy in future
# step 2. parse & update FW: rename, offload, sync
self._parse_backward()
self._update_backward()
# step 3. parse & update BW: rename, offload, sync
self._parse_forward()
self._update_forward()
# step 4. verify the correctness
self._check_offload_fetch()
return
def backward(self,
loss,
startup_program=None,
......@@ -4767,8 +5124,24 @@ class RecomputeOptimizer(Optimizer):
else:
checkpoint_vars.append(loss.block.var(ckpt))
# allow return to non-recompute when checkpoints is empty
if len(checkpoint_vars) > 0:
params_grads, sorted_checkpoint_names = append_backward(
loss,
parameter_list,
no_grad_set,
checkpoints=checkpoint_vars)
else:
params_grads = append_backward(
loss, parameter_list, no_grad_set, checkpoints=checkpoint_vars)
loss,
parameter_list,
no_grad_set,
checkpoints=checkpoint_vars)
if self.enable_offload:
self.sorted_checkpoint_names = sorted_checkpoint_names
self._offload(loss, startup_program=startup_program)
return params_grads
def apply_optimize(self, loss, startup_program, params_grads):
......
......@@ -83,6 +83,7 @@ if(NOT WITH_GPU OR WIN32)
LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api)
LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api)
LIST(REMOVE_ITEM TEST_OPS test_memcpy_op)
endif()
if(WIN32)
......
......@@ -132,5 +132,12 @@ class TestFleetMetaOptimizer(unittest.TestCase):
elif name == "sharding":
strategy.sharding = True
strategy.sharding_configs = {"fuse_broadcast_MB": 0.2}
elif name == "recompute-offload":
strategy.recompute = True
strategy.recompute_configs = {
"checkpoints": ["fc_0.tmp_2", "fc_1.tmp_2"],
"enable_offload": True,
"checkpoint_shape": [256]
}
else:
raise NotImplementedError()
......@@ -153,6 +153,20 @@ class TestFleetRecomputeMetaOptimizer(TestFleetMetaOptimizer):
self.assertIn('subprog', ''.join(outs))
self.assertIn('lamb', ops)
def test_recompute_offload(self):
train_prog, startup_prog = fluid.Program(), fluid.Program()
avg_cost, strategy = self.net(train_prog, startup_prog)
self.set_strategy(strategy, 'recompute-offload')
self.optimizer(avg_cost, strategy, train_prog, startup_prog)
ops = [op.type for op in avg_cost.block.ops]
outs = [
op.output('Out')[0] for op in avg_cost.block.ops
if op.type == 'memcpy'
]
self.assertIn('memcpy', ops)
self.assertIn('@Pinned', ''.join(outs))
self.assertIn('@Fetch', ''.join(outs))
if __name__ == "__main__":
unittest.main()
......@@ -170,19 +170,19 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer):
self.assertEqual(ops, [
'cast', 'cast', 'cast', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant',
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast',
'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'cast', 'cast',
'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul',
'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add',
'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul',
'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad',
'cross_entropy_grad2', 'cast', 'softmax_grad',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast',
'elementwise_add', 'cast', 'tanh_grad', 'cast',
'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream',
'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh',
'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast',
'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2',
'mean', 'elementwise_mul', 'fill_constant', 'scale',
'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast',
'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast',
'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh_grad',
'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul',
'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast',
'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum',
......
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import op_test
import numpy as np
import unittest
import paddle
import paddle.fluid.core as core
from paddle.fluid.op import Operator
import paddle.fluid as fluid
from paddle.fluid import compiler, Program, program_guard
from paddle.fluid.backward import append_backward
class TestMemcpy_FillConstant(unittest.TestCase):
def get_prog(self):
paddle.enable_static()
main_program = Program()
with program_guard(main_program):
pinned_var_name = "tensor@Pinned"
gpu_var_name = "tensor@GPU"
pinned_var = main_program.global_block().create_var(
name=pinned_var_name,
shape=[10, 10],
dtype='float32',
persistable=False,
stop_gradient=True)
gpu_var = main_program.global_block().create_var(
name=gpu_var_name,
shape=[10, 10],
dtype='float32',
persistable=False,
stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": gpu_var_name},
attrs={
"shape": [10, 10],
"dtype": gpu_var.dtype,
"value": 1.0,
"place_type": 1
})
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": pinned_var_name},
attrs={
"shape": [10, 10],
"dtype": gpu_var.dtype,
"value": 0.0,
"place_type": 2
})
return main_program, gpu_var, pinned_var
def test_gpu_cpoy_to_pinned(self):
main_program, gpu_var, pinned_var = self.get_prog()
main_program.global_block().append_op(
type='memcpy',
inputs={'X': gpu_var},
outputs={'Out': pinned_var},
attrs={'dst_place_type': 3})
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program,
feed={},
fetch_list=[gpu_var.name, pinned_var.name])
self.assertTrue(np.allclose(gpu_, pinned_))
self.assertTrue(np.allclose(pinned_, np.ones((10, 10))))
def test_pinned_cpoy_gpu(self):
main_program, gpu_var, pinned_var = self.get_prog()
main_program.global_block().append_op(
type='memcpy',
inputs={'X': pinned_var},
outputs={'Out': gpu_var},
attrs={'dst_place_type': 2})
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
gpu_, pinned_ = exe.run(main_program,
feed={},
fetch_list=[gpu_var.name, pinned_var.name])
self.assertTrue(np.allclose(gpu_, pinned_))
self.assertTrue(np.allclose(gpu_, np.zeros((10, 10))))
class TestMemcpyOPError(unittest.TestCase):
def get_prog(self):
paddle.enable_static()
main_program = Program()
with program_guard(main_program):
pinned_var = main_program.global_block().create_var(
name="tensor@Pinned_0",
shape=[10, 10],
dtype='float32',
persistable=False,
stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": "tensor@Pinned_0"},
attrs={
"shape": [10, 10],
"dtype": pinned_var.dtype,
"value": 0.0,
"place_type": 2
})
return main_program, pinned_var
def test_SELECTED_ROWS(self):
main_program, pinned_var = self.get_prog()
selected_row_var = main_program.global_block().create_var( \
name="selected_row_0", dtype="float32", persistable=False, \
type=fluid.core.VarDesc.VarType.SELECTED_ROWS, stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": selected_row_var},
attrs={
"shape": selected_row_var.shape,
"dtype": selected_row_var.dtype,
"value": 1.0,
"place_type": 1
})
main_program.global_block().append_op(
type='memcpy',
inputs={'X': selected_row_var},
outputs={'Out': pinned_var},
attrs={'dst_place_type': 3})
with self.assertRaises(NotImplementedError):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
selected_row_var_, pinned_ = exe.run(
main_program,
feed={},
fetch_list=[selected_row_var.name, pinned_var.name])
def test_OTHER_PLACE_NotImplementedError(self):
main_program, pinned_var = self.get_prog()
lod_tensor_var = main_program.global_block().create_var( \
name="lod_tensor_0", dtype="float32", persistable=False, stop_gradient=True)
main_program.global_block().append_op(
type="fill_constant",
outputs={"Out": lod_tensor_var},
attrs={
"shape": lod_tensor_var.shape,
"dtype": lod_tensor_var.dtype,
"value": 1.0,
"place_type": 0
})
main_program.global_block().append_op(
type='memcpy',
inputs={'X': pinned_var},
outputs={'Out': lod_tensor_var},
attrs={'dst_place_type': 0, })
with self.assertRaises(NotImplementedError):
place = fluid.CUDAPlace(0)
exe = fluid.Executor(place)
lod_tensor_var_, pinned_ = exe.run(
main_program,
feed={},
fetch_list=[lod_tensor_var.name, pinned_var.name])
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册