diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index 2eaf08153e8ecc9beda4edd28b5dadf11c86cb2b..7cf8d55aeeb1d99acd2f501461f0563f87a25e78 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -22,7 +22,11 @@ enum Mode { HETER = 4; // support XPU and GPU computing server } -message RecomputeConfig { repeated string checkpoints = 1; } +message RecomputeConfig { + repeated string checkpoints = 1; + optional bool enable_offload = 2 [ default = false ]; + repeated int32 checkpoint_shape = 3; +} message ShardingConfig { optional float fuse_broadcast_MB = 1 [ default = 32.0 ]; diff --git a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc index c0420e6b5f3c212721b278ce04bf7ece090a5cc5..072fcd891e683b3b74082f2b5fa009cc689ec50e 100644 --- a/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc +++ b/paddle/fluid/framework/ir/squared_mat_sub_fuse_pass.cc @@ -394,5 +394,5 @@ REGISTER_PASS_CAPABILITY(squared_mat_sub_fuse_pass) .EQ("square", 0) .LE("elementwise_mul", 1) .LE("elementwise_sub", 1) - .EQ("fill_constant", 1) + .LE("fill_constant", 2) .EQ("fusion_squared_mat_sub", 0)); diff --git a/paddle/fluid/operators/fill_constant_op.cc b/paddle/fluid/operators/fill_constant_op.cc index aac0337fe307bc6451c012e7575ff2a1ad8df9d7..8a96d057cbe039d1577d4210c6df747d54796267 100644 --- a/paddle/fluid/operators/fill_constant_op.cc +++ b/paddle/fluid/operators/fill_constant_op.cc @@ -116,6 +116,15 @@ class FillConstantOpMaker : public framework::OpProtoAndCheckerMaker { "memory. Otherwise, fill output variable to the running " "device") .SetDefault(false); + AddAttr("place_type", + "(int, default -1) allow mamually setting place where the " + "variable should be hold. " + "-1: not set manually, determine the place by executor. " + "0: CPUPlace. " + "1: CUDAPlace. " + "2: CUDAPinnedPlace. " + "3: XPUPlace. ") + .SetDefault(-1); AddOutput("Out", "(Tensor) Tensor of specified shape will be filled " "with the specified value"); @@ -154,4 +163,11 @@ REGISTER_OP_VERSION(fill_constant) )ROC", paddle::framework::compatible::OpVersionDesc().NewInput( "ValueTensor", - "In order to support new feature tensor support of Value")); + "In order to support new feature tensor support of Value")) + .AddCheckpoint( + R"ROC( + Upgrade fill_constant to add a new attribute [place_type]. + )ROC", + paddle::framework::compatible::OpVersionDesc().NewAttr( + "place_type", + "In order to support tensor in CUDAPinnedPlace and XPUPlace", -1)); diff --git a/paddle/fluid/operators/fill_constant_op.h b/paddle/fluid/operators/fill_constant_op.h index cce28cae975001ff30040b0e81b2d90e82ed12e1..5d1f1fa781df2c1d9a9a9daaffdfa3add7285178 100644 --- a/paddle/fluid/operators/fill_constant_op.h +++ b/paddle/fluid/operators/fill_constant_op.h @@ -39,6 +39,7 @@ class FillConstantKernel : public framework::OpKernel { auto str_value = ctx.Attr("str_value"); auto float_value = ctx.Attr("value"); auto force_cpu = ctx.Attr("force_cpu"); + auto place_type = ctx.Attr("place_type"); framework::Tensor *tensor = nullptr; framework::Variable *out_var = ctx.OutputVar("Out"); @@ -101,29 +102,59 @@ class FillConstantKernel : public framework::OpKernel { platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); auto &dev_ctx = *pool.Get(ctx.GetPlace()); - bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); - if (cpu_place) { + int actual_place = place_type; + + if (actual_place == -1) { + bool cpu_place = force_cpu || ctx.GetPlace() == platform::CPUPlace(); + if (cpu_place) { + actual_place = 0; + } else if (platform::is_gpu_place(ctx.GetPlace())) { + actual_place = 1; + } else if (platform::is_xpu_place(ctx.GetPlace())) { + actual_place = 3; + } + } + + if (actual_place == 0) { tensor->mutable_data(platform::CPUPlace(), data_type); math::SetConstant functor; functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); - } + } else if (actual_place == 1) { #ifdef PADDLE_WITH_CUDA - if (!cpu_place) { tensor->mutable_data(ctx.GetPlace(), data_type); math::SetConstant functor; functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); - } +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); #endif + } else if (actual_place == 2) { +#ifdef PADDLE_WITH_CUDA + tensor->mutable_data(platform::CUDAPinnedPlace(), data_type); + math::SetConstant functor; + functor(reinterpret_cast(dev_ctx), + tensor, static_cast(value)); +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with GPU.")); +#endif + } else if (actual_place == 3) { #ifdef PADDLE_WITH_XPU - if (!cpu_place) { tensor->mutable_data(ctx.GetPlace(), data_type); math::SetConstant functor; functor(reinterpret_cast(dev_ctx), tensor, static_cast(value)); - } +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "PaddlePaddle should compile with XPU.")); #endif + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "Could NOT determine the place of variable, place_type = %d .", + actual_place)); + } } }; } // namespace operators diff --git a/paddle/fluid/operators/memcpy_op.cc b/paddle/fluid/operators/memcpy_op.cc new file mode 100644 index 0000000000000000000000000000000000000000..5e195d70e92899e1c2f2cffaafebff728b82d4b0 --- /dev/null +++ b/paddle/fluid/operators/memcpy_op.cc @@ -0,0 +1,146 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#include "paddle/fluid/operators/memcpy_op.h" + +#include + +namespace paddle { +namespace framework { +class OpDesc; +class Variable; +} // namespace framework +namespace imperative { +class OpBase; +} // namespace imperative +namespace platform { +struct CPUPlace; +struct CUDAPlace; +struct float16; +} // namespace platform +} // namespace paddle + +namespace paddle { +namespace operators { + +class MemcpyOp : public framework::OperatorWithKernel { + public: + MemcpyOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorWithKernel(type, inputs, outputs, attrs) {} + + void InferShape(framework::InferShapeContext *ctx) const override { + auto type = ctx->GetInputsVarType("X")[0]; + if (type == framework::proto::VarType::SELECTED_ROWS || + type == framework::proto::VarType::LOD_TENSOR) { + ctx->SetOutputDim("Out", ctx->GetInputDim("X")); + if (type == framework::proto::VarType::LOD_TENSOR) { + ctx->ShareLoD("X", /*->*/ "Out"); + } + } + } + + protected: + framework::OpKernelType GetKernelTypeForVar( + const std::string &var_name, const framework::Tensor &tensor, + const framework::OpKernelType &expected_kernel_type) const override { + return framework::OpKernelType(expected_kernel_type.data_type_, + expected_kernel_type.place_, + tensor.layout()); + } + + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + return framework::OpKernelType( + OperatorWithKernel::IndicateVarDataType(ctx, "X"), + ctx.device_context()); + } +}; + +class MemcpyInferVarType : public framework::VarTypeInference { + public: + void operator()(framework::InferVarTypeContext *ctx) const override { + ctx->SyncTypeAndDataType("X", "Out"); + } +}; + +class MemcpyKernel { + public: + void operator()(const framework::ExecutionContext &ctx) const { + auto *x = ctx.InputVar("X"); + if (x == nullptr) { + return; + } + PADDLE_ENFORCE_EQ( + ctx.HasOutput("Out"), true, + platform::errors::NotFound("Output(Out) of memcpy_op is not found.")); + auto *out = ctx.OutputVar("Out"); + platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance(); + auto &dev_ctx = *pool.Get(ctx.GetPlace()); + auto dst_place_type = ctx.Attr("dst_place_type"); + framework::VisitVarType(*x, MemcpyFunctor(out, dev_ctx, dst_place_type)); + } +}; + +class MemcpyOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput("X", "(LoDTensor) The input variable "); + AddOutput("Out", + "(LoDTensor) The type of output " + "is the same as input X."); + AddAttr("dst_place_type", + "Determine the dst place of tensor copy. " + "By Now it ONLY support CUDAPlace and CUDAPinnedPlace. Other " + "place type is Unimplemented and will cause ERROR." + "0: dst is on CPUPlace. " + "1: dst is on CUDAPlace. " + "2: dst is on CUDAPinnedPlace. " + "3: dst is on XPUPlace. "); + AddComment(R"DOC( + Memcpy Operator. + By now, it ONLY supports the memcopy between CUDAPinnedPlace and CUDAPlace, + and used as an internal op by Recompute-Offload. + You would have to update it if you want other more capacities. + +Out = X, when type in [LoDTensor] +raise error if the type is not listed above. +)DOC"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +namespace plat = paddle::platform; +REGISTER_OPERATOR( + memcpy, ops::MemcpyOp, ops::MemcpyOpProtoMaker, ops::MemcpyInferVarType, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); + +REGISTER_OP_CPU_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double, + ops::MemcpyKernel, int, ops::MemcpyKernel, + int64_t, ops::MemcpyKernel, bool, + ops::MemcpyKernel, plat::float16, + ops::MemcpyKernel); + +#ifdef PADDLE_WITH_CUDA +REGISTER_OP_CUDA_KERNEL_FUNCTOR(memcpy, float, ops::MemcpyKernel, double, + ops::MemcpyKernel, int, ops::MemcpyKernel, + int64_t, ops::MemcpyKernel, bool, + ops::MemcpyKernel, plat::float16, + ops::MemcpyKernel); +#endif diff --git a/paddle/fluid/operators/memcpy_op.h b/paddle/fluid/operators/memcpy_op.h new file mode 100755 index 0000000000000000000000000000000000000000..ac190312653b7a6b93acd473ea114741851bf5be --- /dev/null +++ b/paddle/fluid/operators/memcpy_op.h @@ -0,0 +1,75 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#pragma once + +#include "paddle/fluid/framework/data_type.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/platform/device_context.h" + +namespace paddle { +namespace framework { +class LoDTensor; +class Variable; +} // namespace framework +} // namespace paddle + +namespace paddle { +namespace operators { +class MemcpyFunctor { + public: + MemcpyFunctor(framework::Variable *out, + const platform::DeviceContext &dev_ctx, + const int dst_place_type) + : out_(out), dev_ctx_(dev_ctx), dst_place_type_(dst_place_type) {} + + void operator()(const framework::LoDTensor &lod_tensor) const { + auto &out_tensor = *out_->GetMutable(); + + if (dst_place_type_ == 3) { + framework::TensorCopy(lod_tensor, platform::CUDAPinnedPlace(), dev_ctx_, + &out_tensor); + } else if (dst_place_type_ == 2) { + framework::TensorCopy(lod_tensor, dev_ctx_.GetPlace(), dev_ctx_, + &out_tensor); + } else { + PADDLE_THROW(platform::errors::Unimplemented( + "memcpy dst_place_type: %d is not supported yet.", dst_place_type_)); + } + out_tensor.set_lod(lod_tensor.lod()); + } + + void operator()(const framework::SelectedRows &rows) const { + // (JZ-LIANG) to support SelectedRows + PADDLE_THROW(platform::errors::Unimplemented( + "Memcpy for SelectedRows is NOT support yet.")); + } + + template + void operator()(const T &v) const { + PADDLE_ENFORCE_EQ( + true, false, + platform::errors::PermissionDenied( + "Not support type for Memcpy op with type %s", typeid(T).name())); + } + + private: + framework::Variable *out_; + const platform::DeviceContext &dev_ctx_; + const int dst_place_type_; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/distributed/fleet/base/distributed_strategy.py b/python/paddle/distributed/fleet/base/distributed_strategy.py index 658143d0a22b8a15a30806c04e7a6bbce0a9118b..f7a28f15e9b70be3280ce29eb97487a238e78ce6 100755 --- a/python/paddle/distributed/fleet/base/distributed_strategy.py +++ b/python/paddle/distributed/fleet/base/distributed_strategy.py @@ -632,8 +632,20 @@ class DistributedStrategy(object): @property def recompute_configs(self): """ - Set recompute configurations. In general, the recompute strategy of current - implementation should have some manually assign checkpoints + Set recompute configurations. + + **Note**: + checkpoints(list): list of string name of checkpoints. In general, the recompute + strategy of current implementation should have some manually assign checkpoints. + + enable_offload(bool): enable recompute checkpoints offload feature. this feature + will offload the checkpoint to host memory to allow even larger batch size. since + the memcpy from host to device takes time, it is a trade off between larger batch + size and training speed. + + checkpoint_shape(list): list of int that specific the shape of checkpoint. so far + recompute-offload requires that all checkpoint to be same shape, and every dimension + specific here should be determined ("-1" is not allowed). Examples: @@ -642,7 +654,10 @@ class DistributedStrategy(object): import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.recompute = True - strategy.recompute_configs = {"checkpoints": ["x", "y"]} + strategy.recompute_configs = { + "checkpoints": ["x", "y"], + "enable_offload": True, + "checkpoint_shape": [100, 512, 1024] } """ return get_msg_dict(self.strategy.recompute_configs) @@ -692,6 +707,14 @@ class DistributedStrategy(object): This configuration will affect the communication speed in sharding training, and should be an empirical value decided by your model size and network topology. + hybrid_dp(bool): enable hybrid data parallelism above the sharding parallelism. + you are supposed to have at least double the number of gpu you have in normal sharding + training to enable this feature. + + sharding_group_size(int): attribute of hybrid_dp. specific the the number of gpus within + each sharding group; and therefore, the number of hybrid data parallelism ways will be equal + to (global_size / sharding_group_size). + Examples: .. code-block:: python @@ -699,7 +722,10 @@ class DistributedStrategy(object): import paddle.distributed.fleet as fleet strategy = fleet.DistributedStrategy() strategy.sharding = True - strategy.sharding_configs = {"fuse_broadcast_MB": 32} + strategy.sharding_configs = { + "fuse_broadcast_MB": 32, + "hybrid_dp": True, + "sharding_group_size": 8} """ return get_msg_dict(self.strategy.sharding_configs) diff --git a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py index ea2b67ac4bd1f647718cf454d85e8888141bdf83..3a784c306257b20929ed0bc1e080b104a638b928 100644 --- a/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py +++ b/python/paddle/distributed/fleet/meta_optimizers/recompute_optimizer.py @@ -39,9 +39,13 @@ class RecomputeOptimizer(MetaOptimizerBase): return configs = self.user_defined_strategy.recompute_configs - self.wrapped_opt = RO(self.inner_opt) self.wrapped_opt._set_checkpoints(list(configs["checkpoints"])) + if configs["enable_offload"]: + self.wrapped_opt._enable_offload() + # TODO(JZ-LIANG) might found a way to infer the checkpoint shape automatically + checkpoint_shapes = list(configs["checkpoint_shape"]) + self.wrapped_opt.checkpoint_shape = checkpoint_shapes def _can_apply(self): if not self.role_maker._is_collective: diff --git a/python/paddle/fluid/backward.py b/python/paddle/fluid/backward.py index 742949c59ee8b0a1fe03932f4a4f7e8621700966..33e2e387a82758ba9cd59dc40d41fb5ad05ee29b 100644 --- a/python/paddle/fluid/backward.py +++ b/python/paddle/fluid/backward.py @@ -99,8 +99,32 @@ class ProgramStats(object): max_op_idx = max(max_op_idx, idx) if min_op_idx >= max_op_idx: return False, min_op_idx, max_op_idx + return True, min_op_idx, max_op_idx + def _update_segment_start(self, min_idx, pre_segment_end_idx): + """ + persist vars of amp-related cast should be included in recompute segment + """ + + def is_amp_cast(op): + return op.desc.type() == 'cast' and self.block.var( + op.desc.input_arg_names()[0]).persistable + + idx_ = min_idx - 1 + updated_min_idx = min_idx + while idx_ > pre_segment_end_idx: + if is_amp_cast(self.ops[idx_]): + _logger.debug("found amp-cast op: {}, : {}".format(self.ops[ + idx_].desc.type(), self.ops[idx_].desc.input_arg_names()[ + 0])) + updated_min_idx = idx_ + idx_ -= 1 + else: + break + + return updated_min_idx + def build_stats(self): for i, op in enumerate(self.ops): self.op_deps[i] = {"in_ops": [], "out_ops": []} @@ -751,20 +775,29 @@ def _append_backward_ops_with_checkpoints_( if name not in program_stat.var_op_deps: break op_idx = program_stat.var_op_deps[name]["var_as_output_ops"] + # only count the last generate op for idx in op_idx: max_op_idx = max(max_op_idx, idx) if max_op_idx > 0: segments.append([0, max_op_idx + 1]) else: start_idx = 0 + pre_segment_end_idx = -1 while True: + _logger.debug("FW op range[0] - [{}]".format(len(ops))) if start_idx >= len(checkpoints_name) - 1: break + # min_idx: checkpoint_1' s input op + # max_idx: checkpoint_2' s output op flag, min_idx, max_idx = program_stat.is_subgraph( [checkpoints_name[start_idx]], [checkpoints_name[start_idx + 1]]) if flag: + # max_idx + 1 since the exact and used segment end idx is max_idx + min_idx = program_stat._update_segment_start( + min_idx, pre_segment_end_idx) segments.append([min_idx, max_idx + 1]) + start_idx += 1 if segments != [] and segments[0][0] != 0: @@ -772,12 +805,31 @@ def _append_backward_ops_with_checkpoints_( else: recompute_segments = segments + for i, (idx1, idx2) in enumerate(recompute_segments): + _logger.debug("recompute segment[{}]".format(i)) + _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + ), ops[idx1].desc.input_arg_names())) + _logger.debug("segment end op: [{}]: [{}]".format(ops[ + idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) + _logger.debug("recompute segment[{}]".format(i)) + _logger.debug("segment start op: [{}]: [{}]".format(ops[idx1].desc.type( + ), ops[idx1].desc.input_arg_names())) + _logger.debug("segment end op: [{}]: [{}]".format(ops[ + idx2 - 1].desc.type(), ops[idx2 - 1].desc.input_arg_names())) + # 2) go through all forward ops and induct all variables that will be hold in memory vars_should_be_hold = [] # a. variables that are used across segments will be held in memory for segment in recompute_segments: vars_should_be_hold.extend( program_stat.get_out_of_subgraph_vars(segment[0], segment[1])) + + cross_vars = set(vars_should_be_hold) - set(checkpoints_name) + _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + len(cross_vars), cross_vars)) + _logger.debug("found [{}] vars which cross recompute segment: [{}], better checkpoints might be set to reduce those vars".format( \ + len(cross_vars), cross_vars)) + # b. output of seed op should be kept in memory vars_should_be_hold.extend(program_stat.get_reserved_vars()) # c. input variables are checkpoints @@ -792,8 +844,6 @@ def _append_backward_ops_with_checkpoints_( max_calculated_op_position = len(ops) if recompute_segments == []: - # if there is no recompute segment, add backward ops like - # _append_backward_ops_ function gap_ops = ops[0:max_calculated_op_position] for op in reversed(gap_ops): if op.has_attr("sub_block"): @@ -807,7 +857,6 @@ def _append_backward_ops_with_checkpoints_( grad_to_var.update(op_grad_to_var) for i, segment in enumerate(recompute_segments[::-1]): - # add grad op for ops not in any segments gap_ops = ops[segment[1]:max_calculated_op_position] max_calculated_op_position = segment[0] for op in reversed(gap_ops): @@ -851,7 +900,7 @@ def _append_backward_ops_with_checkpoints_( # added_descs should be in grad_op_descs because it is backward op desc grad_op_descs.extend(buffer_descs) - # 3.c. add backward ops of current recomputation ops + # 3.c. add backward ops for all ops in current segment for op_desc in reversed(added_descs): grad_op_desc, op_grad_to_var = core.get_grad_op_desc( op_desc, cpt.to_text(no_grad_dict[block.idx]), []) @@ -1480,9 +1529,11 @@ def append_backward(loss, # TODO: support _append_backward_ops_with_checkpoints_ in # sub-block (control flow) + is_recompute = False if checkpoints != None and \ isinstance(checkpoints, list) and \ len(checkpoints) > 0: + is_recompute = True program_stat, checkpoint_names, \ vars_should_be_hold, \ recompute_segments = \ @@ -1577,7 +1628,10 @@ def append_backward(loss, attr_val.extend(g.op.attr(op_role_var_attr_name)) g.op._set_attr(op_role_var_attr_name, attr_val) - return params_and_grads + if is_recompute: + return params_and_grads, checkpoint_names + else: + return params_and_grads def _as_list(x): diff --git a/python/paddle/fluid/optimizer.py b/python/paddle/fluid/optimizer.py index a7d6ef871749874f2aaf50662ef2ba53158aa767..3c560689e1210fcb312a2311da72c720afb2fe0a 100644 --- a/python/paddle/fluid/optimizer.py +++ b/python/paddle/fluid/optimizer.py @@ -4600,6 +4600,7 @@ class RecomputeOptimizer(Optimizer): self._checkpoints = None self._learning_rate = self._optimizer._learning_rate self._learning_rate_map = self._optimizer._learning_rate_map + self.enable_offload = False def _set_checkpoints(self, checkpoints): """ @@ -4615,6 +4616,10 @@ class RecomputeOptimizer(Optimizer): ), "_checkpoints should be a list of Variable or a list of String" self._checkpoints = checkpoints + # should enable offload before calling backward + def _enable_offload(self): + self.enable_offload = True + @framework.deprecate_stat_dict def load(self, state_dict): """ @@ -4703,6 +4708,358 @@ class RecomputeOptimizer(Optimizer): return self._optimizer.apply_gradients(params_grads=params_grads) + def _creat_vars(self, varname): + pinned_var_name = unique_name.generate(varname + "@Pinned") + fetched_var_name = unique_name.generate(varname + "@Fetch") + + pinned_var = self._main_program.global_block().create_var( + name=pinned_var_name, + shape=self.checkpoint_shape, + dtype=self._main_program.global_block().var(varname).dtype, + persistable=False, + stop_gradient=True) + + fetch_var = self._main_program.global_block().create_var( + name=fetched_var_name, + shape=self.checkpoint_shape, + dtype=self._main_program.global_block().var(varname).dtype, + persistable=False, + stop_gradient=False) + + return pinned_var_name, fetched_var_name + + def _append_fill_constant_ops(self, startup_program): + """ + add fill_constant_ops to the end of the prog + + we should fill the pinned vars before runing the main_prog + to instantiate their tensor hold_, which could tell us whether + the host memory could hold all the checkpoints from all the + GPU devices in this node. + """ + op_role = 0 + block = startup_program.global_block() + fill_constant_vars = self.checkpoint_name2pinned_name.values() + OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() + for varname in fill_constant_vars: + var = self._main_program.global_block().var(varname) + # NOTE (JZ-LIANG) to pre-allocate the CUDAPinned MEM + pinned_var = block.create_var( + name=varname, + shape=self.checkpoint_shape, + dtype=self._main_program.global_block().var(var.name).dtype, + persistable=False, + stop_gradient=True) + block.append_op( + type='fill_constant', + outputs={'Out': varname}, + attrs={ + "shape": var.shape, + "dtype": var.dtype, + "value": 0.0, + "place_type": 2, + OP_ROLE_KEY: op_role, + }) + + return + + def _insert_async_memcpy_op(self, insert_idx, src_varname, dst_varname, + op_role, kind): + OP_ROLE_KEY = core.op_proto_and_checker_maker.kOpRoleAttrName() + self.block._insert_op_without_sync( + insert_idx, + type='memcpy', + inputs={'X': [self._main_program.global_block().var(src_varname)]}, + outputs={ + 'Out': [self._main_program.global_block().var(dst_varname)] + }, + attrs={"dst_place_type": int(kind), + OP_ROLE_KEY: op_role}) + + def _insert_fetch_op(self, idx, varname): + assert varname in self.checkpoint_name2pinned_name, "Try to fetch {} from Pinned Memory, but it is NOT a checkpoint".format( + varname) + + pinned_varname = self.checkpoint_name2pinned_name[varname] + fetch_varname = self.checkpoint_name2fetch_name[varname] + self._insert_async_memcpy_op(idx, pinned_varname, fetch_varname, 1, 2) + + def _insert_offload_op(self, idx, varname): + assert varname in self.checkpoint_name2pinned_name, "Try to offload {} to Pinned Memory, but it is NOT a checkpoint".format( + varname) + pinned_varname = self.checkpoint_name2pinned_name[varname] + self._insert_async_memcpy_op(idx, varname, pinned_varname, 0, 3) + + def _insert_sync_op(self, op_idx, checkpoint_name): + # single stream offload no need sync + pass + + def _record_fetch_op(self, idx): + assert len(self.un_fetch_checkpoint_names + ) > 0, "Could NOT found checkpoint to fetch" + checkpoint_name = self.un_fetch_checkpoint_names.pop(-1) + logging.debug("Record fetch [{}]".format(checkpoint_name)) + self.idx2insertions[idx] = ("fetch", checkpoint_name) + + return checkpoint_name + + def _record_offload_op(self, idx, checkpoint_name): + expected_checkpoint_name = self.un_offload_checkpoint_names.pop(0) + assert checkpoint_name == expected_checkpoint_name, "expected to offload [{}] but got [{}]".format( + expected_checkpoint_name, checkpoint_name) + logging.debug("Record offload [{}]".format(checkpoint_name)) + self.idx2insertions[idx] = ("offload", checkpoint_name) + + def _record_sync_op(self, idx, checkpoint_name): + assert checkpoint_name not in self.synced_checkpoints, "Try to sync the checkpoint [{}] twice".format( + checkpoint_name) + self.synced_checkpoints.add(checkpoint_name) + logging.debug("Record offload sync [{}]".format(checkpoint_name)) + self.idx2insertions[idx] = ("sync", checkpoint_name) + + def _parse_backward(self): + + self.idx2insertions = {} + # don't offload the last checkpoints, to favor throughput + self.un_fetch_checkpoint_names = self.sorted_checkpoint_names[:] + self.un_fetch_checkpoint_names.pop(-1) + need_fetch_checkpoint_names = self.un_fetch_checkpoint_names[:] + self.checkpoint_usage_count = {} + for checkpoint_name in self.un_fetch_checkpoint_names: + self.checkpoint_usage_count[checkpoint_name] = 0 + + self.bw_strart_op_idx = len(self.block.ops) + for idx, op in enumerate(self.block.ops): + if int(op.desc.attr("op_role")) == 1: + self.bw_strart_op_idx = idx + break + + assert self.bw_strart_op_idx < len( + self.block.ops), "Could NOT found backword op in prog" + + # fetch second to last checkpoint at the beginning of BW + fetched_checkpoint_varname = self._record_fetch_op( + self.bw_strart_op_idx) + last_last_fetch_checkpoint = None + + for i, op in enumerate(self.block.ops[self.bw_strart_op_idx:]): + idx = self.bw_strart_op_idx + i + input_vars = op.desc.input_arg_names() + + for input_var in input_vars: + if input_var in need_fetch_checkpoint_names: + if input_var not in self.un_fetch_checkpoint_names: + # fetch the offloade checkpoint when the first usage of its previous one + if self.checkpoint_usage_count[input_var] == 0: + # TODO (JZ-LIANG) sync memcpy_stream if extra stream for memcpy + second_to_last_fetch_checkpoint = fetched_checkpoint_varname + # there is NO fetch ahead the first checkpoint + if input_var != self.sorted_checkpoint_names[0]: + fetched_checkpoint_varname = self._record_fetch_op( + idx) + + # should check the current used checkpoint is ths last fetch one + assert second_to_last_fetch_checkpoint == input_var, "Current recompute segment should use [{}] BUT got [{}]".format( + second_to_last_fetch_checkpoint, input_var) + # rename + self.block.ops[idx]._rename_input( + input_var, + self.checkpoint_name2fetch_name[input_var]) + self.checkpoint_usage_count[input_var] += 1 + else: + raise ValueError( + "use checkpoint [{}] before fetch in BW".format( + input_var)) + + assert len(self.un_fetch_checkpoint_names + ) == 0, "{} checkpoints have NOT been Recorded".format( + self.un_fetch_checkpoint_names) + + def _update_backward(self): + if len(self.idx2insertions) == 0: + return + total_op = len(self.block.ops) + for op_idx in reversed(range(self.bw_strart_op_idx, total_op)): + if op_idx in self.idx2insertions: + operation, checkpoint_name = self.idx2insertions[op_idx] + if operation == "fetch": + self._insert_fetch_op(op_idx, checkpoint_name) + logging.debug("Insert [{}] fetch op.".format( + checkpoint_name)) + del self.idx2insertions[op_idx] + elif operation == "sync": + self._insert_sync_op(op_idx, checkpoint_name) + logging.debug("Sync [{}] fetch op.".format(checkpoint_name)) + self.block._sync_with_cpp() + assert len( + self.idx2insertions) == 0, "{} checkpoints left un-Fecthed".format( + [ele[1] for ele in self.idx2insertions.values()]) + + def _parse_forward(self): + + self.idx2insertions = {} + # don't offload the last checkpoints, faster, less memory saving + self.un_offload_checkpoint_names = self.sorted_checkpoint_names[:] + last_checkpoint = self.un_offload_checkpoint_names.pop(-1) + need_offload_checkpoint_names = self.un_offload_checkpoint_names[:] + self.checkpoint_usage_count_and_idx = {} + for checkpoint_name in self.un_offload_checkpoint_names: + self.checkpoint_usage_count_and_idx[checkpoint_name] = { + 'count': 0, + 'idx': -1 + } + self.synced_checkpoints = set() + self.fw_strart_op_idx = len(self.block.ops) + for idx, op in enumerate(self.block.ops): + if int(op.desc.attr("op_role")) == 0: + self.fw_strart_op_idx = idx + break + + assert self.fw_strart_op_idx < len( + self.block.ops), "Could NOT found Forward op in prog" + last_offload_checkpoint = None + + for i, op in enumerate(self.block.ops[self.fw_strart_op_idx: + self.bw_strart_op_idx]): + + idx = self.fw_strart_op_idx + i + output_vars = op.desc.output_arg_names() + input_vars = op.desc.input_arg_names() + + for output_var in output_vars: + if output_var in need_offload_checkpoint_names: + assert len( + output_vars + ) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( + output_var, op) + + if output_var in self.un_offload_checkpoint_names: + # insert sync op if last checkpoint has not been sync + if last_offload_checkpoint != None: + if self.checkpoint_usage_count_and_idx[ + last_offload_checkpoint]['count'] == 0: + self._record_sync_op(idx, + last_offload_checkpoint) + else: + last_usage_idx = self.checkpoint_usage_count_and_idx[ + last_offload_checkpoint]['idx'] + assert last_usage_idx > 0, "last_usage_idx of checkpoint [{}] should large than 0".format( + last_offload_checkpoint) + self._record_sync_op(last_usage_idx + 1, + last_offload_checkpoint) + # insert offload op after the checkpoint's generation op + self._record_offload_op(idx + 1, output_var) + last_offload_checkpoint = output_var + else: + raise ValueError( + "There should be just ONE op that output checkpoint [{}]". + format(output_var)) + # need to sync the last need to offload checkpoint before the last checkpoint as output op + if output_var == last_checkpoint: + assert len( + output_vars + ) == 1, "chekpoint should be the only Output of a certain op, but [{}] is from [{}]".format( + output_var, op) + assert last_offload_checkpoint == self.sorted_checkpoint_names[ + -2], "the last offload chekpoint before [{}] is suppose to be [{}], but got [{}]".format( + last_checkpoint, self.sorted_checkpoint_names[-2], + last_offload_checkpoint) + # sync if last checkpoint has not been sync + if self.checkpoint_usage_count_and_idx[ + last_offload_checkpoint]['idx'] == 0: + self._record_sync_op(idx, last_offload_checkpoint) + else: + last_usage_idx = self.checkpoint_usage_count_and_idx[ + last_offload_checkpoint]['idx'] + assert last_usage_idx > 0, "last_usage_idx of checkpoint [{}] should large than 0".format( + last_offload_checkpoint) + self._record_sync_op(last_usage_idx + 1, + last_offload_checkpoint) + # record checkpoint usage + for input_var in input_vars: + if input_var in need_offload_checkpoint_names: + assert input_var not in self.synced_checkpoints, "checkpoint [{}] used after sync".format( + input_var) + self.checkpoint_usage_count_and_idx[input_var]['count'] += 1 + self.checkpoint_usage_count_and_idx[input_var]['idx'] = idx + + assert len(self.un_offload_checkpoint_names + ) == 0, "{} checkpoints have NOT been Recorded".format( + self.un_fetch_checkpoint_names) + assert len(self.synced_checkpoints) == len( + need_offload_checkpoint_names + ), "{} checkpoints have NOT been Recorded".format( + set(need_offload_checkpoint_names) - set(self.synced_checkpoints)) + + def _update_forward(self): + if len(self.idx2insertions) == 0: + return + for op_idx in reversed( + range(self.fw_strart_op_idx, self.bw_strart_op_idx)): + if op_idx in self.idx2insertions: + operation, checkpoint_name = self.idx2insertions[op_idx] + if operation == "offload": + self._insert_offload_op(op_idx, checkpoint_name) + logging.debug("Insert [{}] offload op.".format( + checkpoint_name)) + del self.idx2insertions[op_idx] + elif operation == "sync": + self._insert_sync_op(op_idx, checkpoint_name) + logging.debug("Insert [{}] offload_sync op.".format( + checkpoint_name)) + del self.idx2insertions[op_idx] + + self.block._sync_with_cpp() + assert len(self.idx2insertions + ) == 0, "{} checkpoints left un-Offloaded".format( + [ele[1] for ele in self.idx2insertions.values()]) + + def _check_offload_fetch(self): + # TODO(JZ-LIANG) the single stream offload need no sync + pass + + def _offload(self, loss, startup_program=None): + """ + core steps for recompute offload + 1. create pinned vars and temp vars + 2. parse & update Forward pass: offload, sync + 3. parse & update Backward pass: rename, fetch, sync + 4. verify the correctness + """ + self._main_program = loss.block.program + self.block = loss.block + if startup_program == None: + startup_program = fluid.default_startup_program() + + with program_guard(self._main_program, startup_program): + assert len(self.checkpoint_shape) > 0, ( + "checkpoints shape {} should be an non empty list like: [12, 512, 1024]". + format(self.checkpoint_shape)) + assert all([ele > 0 for ele in self.checkpoint_shape]), ( + "all ele in checkpoints shape {} should be a determined integer larger than 0". + format(self.checkpoint_shape)) + self.checkpoint_name2pinned_name = dict() + self.checkpoint_name2fetch_name = dict() + for checkpoint_varname in self.sorted_checkpoint_names: + pinned_var_name, fetch_var_name = self._creat_vars( + checkpoint_varname) + self.checkpoint_name2pinned_name[ + checkpoint_varname] = pinned_var_name + self.checkpoint_name2fetch_name[ + checkpoint_varname] = fetch_var_name + self._append_fill_constant_ops(startup_program) + # TODO (JZ-LIANG) to provide two offload stragtegy in future + # step 2. parse & update FW: rename, offload, sync + self._parse_backward() + self._update_backward() + # step 3. parse & update BW: rename, offload, sync + self._parse_forward() + self._update_forward() + # step 4. verify the correctness + self._check_offload_fetch() + + return + def backward(self, loss, startup_program=None, @@ -4767,8 +5124,24 @@ class RecomputeOptimizer(Optimizer): else: checkpoint_vars.append(loss.block.var(ckpt)) - params_grads = append_backward( - loss, parameter_list, no_grad_set, checkpoints=checkpoint_vars) + # allow return to non-recompute when checkpoints is empty + if len(checkpoint_vars) > 0: + params_grads, sorted_checkpoint_names = append_backward( + loss, + parameter_list, + no_grad_set, + checkpoints=checkpoint_vars) + else: + params_grads = append_backward( + loss, + parameter_list, + no_grad_set, + checkpoints=checkpoint_vars) + + if self.enable_offload: + self.sorted_checkpoint_names = sorted_checkpoint_names + self._offload(loss, startup_program=startup_program) + return params_grads def apply_optimize(self, loss, startup_program, params_grads): diff --git a/python/paddle/fluid/tests/unittests/CMakeLists.txt b/python/paddle/fluid/tests/unittests/CMakeLists.txt index ab8256043b1c08682e311b95067a6ac3d076a9d0..2ec2ea2872894f44611399a6d21d6c95bf06f32a 100644 --- a/python/paddle/fluid/tests/unittests/CMakeLists.txt +++ b/python/paddle/fluid/tests/unittests/CMakeLists.txt @@ -83,6 +83,7 @@ if(NOT WITH_GPU OR WIN32) LIST(REMOVE_ITEM TEST_OPS test_collective_allreduce_api) LIST(REMOVE_ITEM TEST_OPS test_collective_broadcast_api) LIST(REMOVE_ITEM TEST_OPS test_collective_allgather_api) + LIST(REMOVE_ITEM TEST_OPS test_memcpy_op) endif() if(WIN32) diff --git a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py index b6ecc07fd9f89aa67eb76ea559e62c6037ab3c25..b5eacecd003be519772adf77213def21d3528c95 100755 --- a/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py +++ b/python/paddle/fluid/tests/unittests/fleet_meta_optimizer_base.py @@ -132,5 +132,12 @@ class TestFleetMetaOptimizer(unittest.TestCase): elif name == "sharding": strategy.sharding = True strategy.sharding_configs = {"fuse_broadcast_MB": 0.2} + elif name == "recompute-offload": + strategy.recompute = True + strategy.recompute_configs = { + "checkpoints": ["fc_0.tmp_2", "fc_1.tmp_2"], + "enable_offload": True, + "checkpoint_shape": [256] + } else: raise NotImplementedError() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py index 42b60cd3fad5a76aee851620c2348d2de2e024e3..790cd5f3efbb4e5b2afb556e2d0a477098397709 100644 --- a/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_recompute_meta_optimizer.py @@ -153,6 +153,20 @@ class TestFleetRecomputeMetaOptimizer(TestFleetMetaOptimizer): self.assertIn('subprog', ''.join(outs)) self.assertIn('lamb', ops) + def test_recompute_offload(self): + train_prog, startup_prog = fluid.Program(), fluid.Program() + avg_cost, strategy = self.net(train_prog, startup_prog) + self.set_strategy(strategy, 'recompute-offload') + self.optimizer(avg_cost, strategy, train_prog, startup_prog) + ops = [op.type for op in avg_cost.block.ops] + outs = [ + op.output('Out')[0] for op in avg_cost.block.ops + if op.type == 'memcpy' + ] + self.assertIn('memcpy', ops) + self.assertIn('@Pinned', ''.join(outs)) + self.assertIn('@Fetch', ''.join(outs)) + if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py index 01a7e25abb6d6161f745f5170b8d66cff7f1b6f3..5da7e627f8707d94cd0e01f17ff14484ac18f4a2 100755 --- a/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py +++ b/python/paddle/fluid/tests/unittests/test_fleet_sharding_meta_optimizer.py @@ -170,19 +170,19 @@ class TestFleetShardingMetaOptimizer(TestFleetMetaOptimizer): self.assertEqual(ops, [ 'cast', 'cast', 'cast', 'fill_constant', 'fill_constant', - 'fill_constant', 'fill_constant', 'fill_constant', + 'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant', 'c_sync_calc_stream', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_broadcast', - 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', 'cast', 'cast', - 'mul', 'cast', 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', - 'elementwise_add', 'cast', 'tanh', 'cast', 'mul', 'elementwise_add', - 'softmax', 'cast', 'cross_entropy2', 'mean', 'elementwise_mul', - 'fill_constant', 'scale', 'elementwise_mul_grad', 'mean_grad', - 'cross_entropy_grad2', 'cast', 'softmax_grad', - 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast', - 'elementwise_add', 'cast', 'tanh_grad', 'cast', - 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', 'cast', - 'elementwise_add', 'cast', 'tanh_grad', 'cast', + 'c_broadcast', 'c_broadcast', 'c_broadcast', 'c_sync_comm_stream', + 'cast', 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh', + 'cast', 'cast', 'mul', 'elementwise_add', 'cast', 'tanh', 'cast', + 'mul', 'elementwise_add', 'softmax', 'cast', 'cross_entropy2', + 'mean', 'elementwise_mul', 'fill_constant', 'scale', + 'elementwise_mul_grad', 'mean_grad', 'cross_entropy_grad2', 'cast', + 'softmax_grad', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', + 'cast', 'mul', 'cast', 'elementwise_add', 'cast', 'tanh_grad', + 'cast', 'elementwise_add_grad', 'mul_grad', 'cast', 'cast', 'mul', + 'cast', 'elementwise_add', 'cast', 'tanh_grad', 'cast', 'elementwise_add_grad', 'mul_grad', 'c_sync_calc_stream', 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', 'c_allreduce_sum', diff --git a/python/paddle/fluid/tests/unittests/test_memcpy_op.py b/python/paddle/fluid/tests/unittests/test_memcpy_op.py new file mode 100755 index 0000000000000000000000000000000000000000..c6ecbcebcabce839aa8485ed75e9cf48d599a683 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_memcpy_op.py @@ -0,0 +1,176 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import op_test +import numpy as np +import unittest +import paddle +import paddle.fluid.core as core +from paddle.fluid.op import Operator +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard +from paddle.fluid.backward import append_backward + + +class TestMemcpy_FillConstant(unittest.TestCase): + def get_prog(self): + paddle.enable_static() + main_program = Program() + with program_guard(main_program): + pinned_var_name = "tensor@Pinned" + gpu_var_name = "tensor@GPU" + pinned_var = main_program.global_block().create_var( + name=pinned_var_name, + shape=[10, 10], + dtype='float32', + persistable=False, + stop_gradient=True) + gpu_var = main_program.global_block().create_var( + name=gpu_var_name, + shape=[10, 10], + dtype='float32', + persistable=False, + stop_gradient=True) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": gpu_var_name}, + attrs={ + "shape": [10, 10], + "dtype": gpu_var.dtype, + "value": 1.0, + "place_type": 1 + }) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": pinned_var_name}, + attrs={ + "shape": [10, 10], + "dtype": gpu_var.dtype, + "value": 0.0, + "place_type": 2 + }) + return main_program, gpu_var, pinned_var + + def test_gpu_cpoy_to_pinned(self): + main_program, gpu_var, pinned_var = self.get_prog() + main_program.global_block().append_op( + type='memcpy', + inputs={'X': gpu_var}, + outputs={'Out': pinned_var}, + attrs={'dst_place_type': 3}) + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + gpu_, pinned_ = exe.run(main_program, + feed={}, + fetch_list=[gpu_var.name, pinned_var.name]) + self.assertTrue(np.allclose(gpu_, pinned_)) + self.assertTrue(np.allclose(pinned_, np.ones((10, 10)))) + + def test_pinned_cpoy_gpu(self): + main_program, gpu_var, pinned_var = self.get_prog() + main_program.global_block().append_op( + type='memcpy', + inputs={'X': pinned_var}, + outputs={'Out': gpu_var}, + attrs={'dst_place_type': 2}) + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + gpu_, pinned_ = exe.run(main_program, + feed={}, + fetch_list=[gpu_var.name, pinned_var.name]) + self.assertTrue(np.allclose(gpu_, pinned_)) + self.assertTrue(np.allclose(gpu_, np.zeros((10, 10)))) + + +class TestMemcpyOPError(unittest.TestCase): + def get_prog(self): + paddle.enable_static() + main_program = Program() + with program_guard(main_program): + pinned_var = main_program.global_block().create_var( + name="tensor@Pinned_0", + shape=[10, 10], + dtype='float32', + persistable=False, + stop_gradient=True) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": "tensor@Pinned_0"}, + attrs={ + "shape": [10, 10], + "dtype": pinned_var.dtype, + "value": 0.0, + "place_type": 2 + }) + return main_program, pinned_var + + def test_SELECTED_ROWS(self): + main_program, pinned_var = self.get_prog() + selected_row_var = main_program.global_block().create_var( \ + name="selected_row_0", dtype="float32", persistable=False, \ + type=fluid.core.VarDesc.VarType.SELECTED_ROWS, stop_gradient=True) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": selected_row_var}, + attrs={ + "shape": selected_row_var.shape, + "dtype": selected_row_var.dtype, + "value": 1.0, + "place_type": 1 + }) + main_program.global_block().append_op( + type='memcpy', + inputs={'X': selected_row_var}, + outputs={'Out': pinned_var}, + attrs={'dst_place_type': 3}) + with self.assertRaises(NotImplementedError): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + selected_row_var_, pinned_ = exe.run( + main_program, + feed={}, + fetch_list=[selected_row_var.name, pinned_var.name]) + + def test_OTHER_PLACE_NotImplementedError(self): + main_program, pinned_var = self.get_prog() + lod_tensor_var = main_program.global_block().create_var( \ + name="lod_tensor_0", dtype="float32", persistable=False, stop_gradient=True) + main_program.global_block().append_op( + type="fill_constant", + outputs={"Out": lod_tensor_var}, + attrs={ + "shape": lod_tensor_var.shape, + "dtype": lod_tensor_var.dtype, + "value": 1.0, + "place_type": 0 + }) + main_program.global_block().append_op( + type='memcpy', + inputs={'X': pinned_var}, + outputs={'Out': lod_tensor_var}, + attrs={'dst_place_type': 0, }) + with self.assertRaises(NotImplementedError): + place = fluid.CUDAPlace(0) + exe = fluid.Executor(place) + lod_tensor_var_, pinned_ = exe.run( + main_program, + feed={}, + fetch_list=[lod_tensor_var.name, pinned_var.name]) + + +if __name__ == '__main__': + paddle.enable_static() + unittest.main()