diff --git a/paddle/fluid/eager/custom_operator/custom_operator_node.cc b/paddle/fluid/eager/custom_operator/custom_operator_node.cc index df95975981644617f767ceb26dbe81edd2bdd85c..c99f524b246a7c74bb2fef904c295f1f91942471 100644 --- a/paddle/fluid/eager/custom_operator/custom_operator_node.cc +++ b/paddle/fluid/eager/custom_operator/custom_operator_node.cc @@ -174,6 +174,9 @@ RunCustomOpNode::operator()(paddle::small_vector, egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs( egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); + const auto& grad_inplace_map = + paddle::framework::OpMetaInfoHelper::GetInplaceMap( + egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); @@ -205,6 +208,9 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(6) << "Prepare Grad attrs"; ctx.EmplaceBackAttrs(attrs_); + // NOTE(HongyuJia): grad_outputs_names.size() <= OutputMeta().size(): + // OutputMeta().size() indicates input size of forward op, + // grad_outputs_names.size() indicates output size of backward op. paddle::small_vector, kSlotSmallVectorSize> outs( OutputMeta().size()); paddle::small_vector, kSlotSmallVectorSize> @@ -234,8 +240,10 @@ RunCustomOpNode::operator()(paddle::small_vector, } VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; + ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map); (*paddle::framework::OpMetaInfoHelper::GetKernelFn( kernel_map.at(op_type_)[1]))(&ctx); + ctx.AssignInplaceOutputs(); VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; std::vector> ins_auto_grad_metas; @@ -353,6 +361,8 @@ RunCustomOpDoubleGradNode::operator()( paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[2]); auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[2]); + const auto& grad_inplace_map = + paddle::framework::OpMetaInfoHelper::GetInplaceMap(vec_map[2]); auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); @@ -419,8 +429,10 @@ RunCustomOpDoubleGradNode::operator()( } VLOG(7) << "Run Kernel of Grad Custom Op: " << name(); + ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map); (*paddle::framework::OpMetaInfoHelper::GetKernelFn( kernel_map.at(op_type_)[2]))(&ctx); + ctx.AssignInplaceOutputs(); return outs; } diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 9a834e583239f8173d3869d573202560994445a6..2ddbf738787906cd449e116cb3a0371bf801683f 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -130,11 +130,13 @@ static std::vector ParseAttrStr(const std::string& attr) { ////////////////// Kernel Define //////////////////// // custom op kernel call function define -static void RunKernelFunc(const framework::ExecutionContext& ctx, - const paddle::KernelFunc& func, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& attrs) { +static void RunKernelFunc( + const framework::ExecutionContext& ctx, + const paddle::KernelFunc& func, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& attrs, + const std::unordered_map& inplace_map) { VLOG(3) << "Custom Operator: Start run KernelFunc."; // prepare CustomOpKernelContext paddle::CustomOpKernelContext kernel_ctx; @@ -283,7 +285,10 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, VLOG(4) << "Initialize phi tensor operants successfully"; } + // handle inplace case + kernel_ctx.MapPlainOutputs(inputs, outputs, inplace_map); func(&kernel_ctx); + kernel_ctx.AssignInplaceOutputs(); // sync output tensor data into original output auto* calc_outs = kernel_ctx.AllMutableOutput(); @@ -686,12 +691,14 @@ static void RegisterOperatorKernelWithPlace( OperatorWithKernel::AllOpKernels()[name][key] = op_kernel_func; } -static void RegisterOperatorKernel(const std::string& name, - const paddle::KernelFunc& kernel_func, - const std::vector& inputs, - const std::vector& outputs, - const std::vector& attrs, - void* dso_handle) { +static void RegisterOperatorKernel( + const std::string& name, + const paddle::KernelFunc& kernel_func, + const std::vector& inputs, + const std::vector& outputs, + const std::vector& attrs, + const std::unordered_map& inplace_map, + void* dso_handle) { VLOG(3) << "Custom Operator: op name in kernel: " << name; // NOTE [ Dummy Op Kernel Key ] // TODO(chenweihang): Because execute engine need get device context based @@ -701,10 +708,10 @@ static void RegisterOperatorKernel(const std::string& name, OperatorWithKernel::OpKernelFunc op_kernel_func; if (kernel_func) { VLOG(3) << "Register custom operator " << name << " with kernel func"; - op_kernel_func = [kernel_func, inputs, outputs, attrs]( + op_kernel_func = [kernel_func, inputs, outputs, attrs, inplace_map]( const framework::ExecutionContext& ctx) { VLOG(3) << "Custom Operator: run custom kernel func in lambda."; - RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); + RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs, inplace_map); }; } else { VLOG(3) << "Register custom operator " << name @@ -760,6 +767,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); + auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(base_op_meta); auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta); auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta); auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta); @@ -771,6 +779,12 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, << string::join_strings(op_outputs, ','); VLOG(3) << "Custom Operator: forward, op attrs: " << string::join_strings(op_attrs, ','); + if (!op_inplace_map.empty()) { + VLOG(3) << "Custom Operator: forward, op inplace_map: " + << string::join_strings(op_inplace_map, ',', [](auto& pair) { + return pair.first + ": " + pair.second; + }); + } // Op info.creator_ = [](const std::string& op_name, @@ -795,6 +809,13 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, op_name, info.proto_->InitializationErrorString())); + // Inplace + if (!op_inplace_map.empty()) { + info.infer_inplace_ = [op_inplace_map](bool use_cuda) { + return op_inplace_map; + }; + } + // InferShape if (infer_shape_func == nullptr) { // use default InferShape @@ -908,8 +929,13 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, } // Kernel func - RegisterOperatorKernel( - op_name, kernel_fn, op_inputs, op_outputs, op_attrs, dso_handle); + RegisterOperatorKernel(op_name, + kernel_fn, + op_inputs, + op_outputs, + op_attrs, + op_inplace_map, + dso_handle); // If grad op or double grad op exists std::string cur_op_name = op_name; @@ -920,6 +946,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op); auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); + auto& grad_op_inplace_map = OpMetaInfoHelper::GetInplaceMap(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op); @@ -928,6 +955,14 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, << string::join_strings(grad_op_inputs, ','); VLOG(3) << "Custom Operator: backward, op outputs: " << string::join_strings(grad_op_outputs, ','); + VLOG(3) << "Custom Operator: backward, op attrs: " + << string::join_strings(grad_op_attrs, ','); + if (!op_inplace_map.empty()) { + VLOG(3) << "Custom Operator: backward, op inplace_map: " + << string::join_strings(grad_op_inplace_map, ',', [](auto& pair) { + return pair.first + ": " + pair.second; + }); + } bool is_double_grad = (i == 2); @@ -1040,6 +1075,7 @@ void RegisterOperatorWithMetaInfo(const std::vector& op_meta_infos, grad_op_inputs, grad_op_outputs, grad_op_attrs, + grad_op_inplace_map, dso_handle); // update current info diff --git a/paddle/fluid/framework/op_meta_info_helper.h b/paddle/fluid/framework/op_meta_info_helper.h index b93e0ab0f551183924a7dfcb79601a6e94345089..20154e1ee385b494fe3b49429791a1170900d213 100644 --- a/paddle/fluid/framework/op_meta_info_helper.h +++ b/paddle/fluid/framework/op_meta_info_helper.h @@ -39,6 +39,10 @@ class OpMetaInfoHelper { const paddle::OpMetaInfo& info) { return info.attrs_; } + static const std::unordered_map& GetInplaceMap( + const paddle::OpMetaInfo& info) { + return info.inplace_map_; + } static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) { return info.kernel_fn_; } diff --git a/paddle/fluid/pybind/eager_functions.cc b/paddle/fluid/pybind/eager_functions.cc index 9a3229f3210173b72a12d816c14dc9298853b226..50508d1db5cd708e71fb5cccfa8df70538495dad 100644 --- a/paddle/fluid/pybind/eager_functions.cc +++ b/paddle/fluid/pybind/eager_functions.cc @@ -531,7 +531,18 @@ static PyObject* eager_api_run_custom_op(PyObject* self, meta_info_map.at(op_type)[0])); ctx.EmplaceBackAttrs(res_attrs); const auto& vec_map = meta_info_map.at(op_type); + + // handle inplace case + const auto& inputs = paddle::framework::OpMetaInfoHelper::GetInputs( + meta_info_map.at(op_type)[0]); + const auto& outputs = paddle::framework::OpMetaInfoHelper::GetOutputs( + meta_info_map.at(op_type)[0]); + const auto& inplace_map = + paddle::framework::OpMetaInfoHelper::GetInplaceMap( + meta_info_map.at(op_type)[0]); + ctx.MapPlainOutputs(inputs, outputs, inplace_map); (*paddle::framework::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); + ctx.AssignInplaceOutputs(); VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; std::vector> ins_auto_grad_metas; @@ -557,12 +568,43 @@ static PyObject* eager_api_run_custom_op(PyObject* self, require_any_grad || egr::EagerUtils::ComputeRequireGrad( trace_backward, &(ins_auto_grad_metas[i])); } + + // handle inplace case + for (size_t i = 0; i < ctx.InputRange().size(); i++) { + if (inplace_map.find(inputs[i]) != inplace_map.end()) { + size_t input_size = + ctx.InputRangeAt(i).second - ctx.InputRangeAt(i).first; + size_t start_idx = ctx.InputRangeAt(i).first; + for (size_t j = 0; j < input_size; j++) { + egr::EagerUtils::CheckInplace(ctx.InputAt(start_idx + j), + ins_auto_grad_metas[i][j], + require_any_grad); + // Bump Inplace Version + ctx.MutableInputAt(start_idx + j).bump_inplace_version(); + VLOG(3) << "Custom operator: Tensor(" + << ctx.InputAt(start_idx + j).name() + << ") uses Inplace Strategy."; + } + } + } + if (require_any_grad && (vec_map.size() > 1)) { VLOG(6) << " Construct Grad for Custom Op: " << op_type; ConstructFwdAndBwdMap(vec_map, op_type); for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) { egr::EagerUtils::PassStopGradient(false, &(outs_auto_grad_metas[i])); } + // Note(HongyuJia): In dygraph eager mode, CheckInplace makes sure leaf + // nodes set stop_gradient=True. However, dygraph mode can also outputs + // lead nodes' gradients (For example, we can get x.grad after x.add_(y)). + // To be consistent with dygraph mode, we have to PassStopGradient for all + // inplaced ins_auto_grad_metas. + std::unordered_map inplace_tensor_map = + ctx.GetInplaceTensorMap(); + for (auto pair : inplace_tensor_map) { + egr::EagerUtils::PassStopGradient(false, + &(ins_auto_grad_metas[pair.first])); + } auto grad_node = std::make_shared( outs_auto_grad_metas.size(), ins_auto_grad_metas.size(), op_type); auto slot_map = diff --git a/paddle/fluid/pybind/eager_utils.cc b/paddle/fluid/pybind/eager_utils.cc index 8637b12f8aae9747c973551fb3075ffc73da3c21..fc7d359afd7d3b05dcfbf4010d958cee1cd479a4 100644 --- a/paddle/fluid/pybind/eager_utils.cc +++ b/paddle/fluid/pybind/eager_utils.cc @@ -609,8 +609,7 @@ paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj, return ::pybind11::handle(obj).cast(); } else { PADDLE_THROW(platform::errors::InvalidArgument( - "argument (position %d) must be " - "one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), " + "argument (position %d) must be CustomOpKernelContext, " "but got %s", arg_pos + 1, reinterpret_cast(obj->ob_type)->tp_name)); diff --git a/paddle/phi/api/ext/op_meta_info.h b/paddle/phi/api/ext/op_meta_info.h index 95eddb9745596c5aa6788a52f4b3d128a8ddcaf5..77ec8c417da33cd06d5e9c2c72c5fcd0d04ab150 100644 --- a/paddle/phi/api/ext/op_meta_info.h +++ b/paddle/phi/api/ext/op_meta_info.h @@ -108,6 +108,7 @@ class PADDLE_API CustomOpKernelContext { const Tensor& InputAt(size_t idx) const; std::vector InputsBetween(size_t start, size_t end) const; + Tensor& MutableInputAt(size_t idx); const std::vector& Attrs() const { return attrs_; } const std::vector>& InputRange() { return input_range_; @@ -129,11 +130,23 @@ class PADDLE_API CustomOpKernelContext { } } + // handle inplace case + void MapPlainOutputs( + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map); + void AssignInplaceOutputs(); + std::vector* AllMutablePlainOutput(); + std::unordered_map GetInplaceTensorMap(); + private: // TODO(chenweihang): replaced be SmallVector std::vector inputs_; std::vector outputs_; std::vector attrs_; + // handle inplace case + std::vector plain_outputs_; + std::unordered_map inplace_tensor_map_; std::vector> input_range_; std::vector> output_range_; @@ -148,8 +161,7 @@ using KernelFunc = void (*)(CustomOpKernelContext*); template \ struct ComputeCallHelper { \ template \ - static void Compute(CustomOpKernelContext* ctx, \ - const PreviousArgs&... pargs) { \ + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { \ attr_type arg = ctx->AttrAt(attr_idx); \ ComputeCallHelper< \ Tail...>::template Compute(ctx, \ @@ -177,10 +189,9 @@ struct KernelFuncImpl { template struct ComputeCallHelper { template - static void Compute(CustomOpKernelContext* ctx, - const PreviousArgs&... pargs) { + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { auto& range = ctx->InputRangeAt(in_idx); - auto& arg = ctx->InputAt(range.first); + auto& arg = ctx->MutableInputAt(range.first); ComputeCallHelper< Tail...>::template Compute(ctx, pargs..., @@ -191,8 +202,7 @@ struct KernelFuncImpl { template struct ComputeCallHelper&, Tail...> { template - static void Compute(CustomOpKernelContext* ctx, - const PreviousArgs&... pargs) { + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { auto& range = ctx->InputRangeAt(in_idx); auto arg = ctx->InputsBetween(range.first, range.second); ComputeCallHelper< @@ -232,11 +242,12 @@ struct KernelFuncImpl { PD_SPECIALIZE_ComputeCallHelper(std::vector); PD_SPECIALIZE_ComputeCallHelper(std::vector); + // Used to be compatible with 2.3 released internal inplace interface, not + // recommended template struct ComputeCallHelper { template - static void Compute(CustomOpKernelContext* ctx, - const PreviousArgs&... pargs) { + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { auto& range = ctx->OutputRangeAt(out_idx); auto* arg = ctx->MutableOutputAt(range.first); ComputeCallHelper< @@ -246,13 +257,14 @@ struct KernelFuncImpl { } }; + // Used to be compatible with 2.3 released internal inplace interface, not + // recommended // TODO(chenweihang): What is the appropriate output form? // std::vector*? or std::vector? or std::vector* template struct ComputeCallHelper, Tail...> { template - static void Compute(CustomOpKernelContext* ctx, - const PreviousArgs&... pargs) { + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { auto& range = ctx->OutputRangeAt(out_idx); auto arg = ctx->MutableOutputBetweeen(range.first, range.second); ComputeCallHelper< @@ -262,18 +274,32 @@ struct KernelFuncImpl { } }; + // Handle Tensor& for inplace case + template + struct ComputeCallHelper { + template + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { + auto& range = ctx->InputRangeAt(in_idx); + auto& arg = ctx->MutableInputAt(range.first); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); + } + }; + template struct ComputeReturnHelper; // For compatibility with the original custom op form template struct ComputeReturnHelper> { - static void Compute(CustomOpKernelContext* ctx, const Args&... args) { + static void Compute(CustomOpKernelContext* ctx, Args&... args) { static_assert(out_idx == 0, "If return std::vector in Custom OpKernel, " "you cannot pass output by kernel function argument."); auto outs = impl_fn(args...); - auto* orig_outs = ctx->AllMutableOutput(); + auto* orig_outs = ctx->AllMutablePlainOutput(); PD_CHECK(orig_outs->size() == outs.size(), "The number of element in custom operator outputs is wrong, " "expected contains ", @@ -282,15 +308,14 @@ struct KernelFuncImpl { outs.size(), " Tensors."); for (size_t i = 0; i < outs.size(); ++i) { - AssignTensorImpl(outs.at(i), &(orig_outs->at(i))); + AssignTensorImpl(outs.at(i), orig_outs->at(i)); } } }; template struct ComputeReturnHelper { - static void Compute(CustomOpKernelContext* ctx, const Args&... args) { - static_assert(out_idx > 0, "Custom OpKernel has no output."); + static void Compute(CustomOpKernelContext* ctx, Args&... args) { impl_fn(args...); } }; @@ -299,8 +324,7 @@ struct KernelFuncImpl { template struct ComputeCallHelper> { template - static void Compute(CustomOpKernelContext* ctx, - const PreviousArgs&... pargs) { + static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { ComputeReturnHelper::Compute(ctx, pargs...); } }; @@ -547,9 +571,14 @@ class PADDLE_API OpMetaInfo { // format: {"", "", ...} OpMetaInfo& Outputs(std::vector&& outputs); - // format: {":", ":", ...} + // format: {":", ":", ...} OpMetaInfo& Attrs(std::vector&& attrs); + // format: {":", + // ":",...} + OpMetaInfo& Inplace( + std::unordered_map&& inplace_map); + // format: PD_KERNEL(...) OpMetaInfo& SetKernelFn(KernelFunc&& func); @@ -567,6 +596,7 @@ class PADDLE_API OpMetaInfo { std::vector inputs_; std::vector outputs_; std::vector attrs_; + std::unordered_map inplace_map_; // 2. func info KernelFunc kernel_fn_{nullptr}; InferShapeFunc infer_shape_fn_{nullptr}; @@ -605,6 +635,8 @@ class PADDLE_API OpMetaInfoBuilder { OpMetaInfoBuilder& Inputs(std::vector&& inputs); OpMetaInfoBuilder& Outputs(std::vector&& outputs); OpMetaInfoBuilder& Attrs(std::vector&& attrs); + OpMetaInfoBuilder& Inplace( + std::unordered_map&& inplace_map); OpMetaInfoBuilder& SetKernelFn(KernelFunc func); OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); diff --git a/paddle/phi/api/lib/op_meta_info.cc b/paddle/phi/api/lib/op_meta_info.cc index a6b7921c30c61cd142b72cbda3f59b7db005fb08..487308ea56823ee9b756b8186a26df14f32b9674 100644 --- a/paddle/phi/api/lib/op_meta_info.cc +++ b/paddle/phi/api/lib/op_meta_info.cc @@ -94,6 +94,10 @@ std::vector CustomOpKernelContext::InputsBetween(size_t start, return rlt; } +Tensor& CustomOpKernelContext::MutableInputAt(size_t idx) { + return inputs_.at(idx); +} + Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) { return &(outputs_.at(idx)); } @@ -128,6 +132,71 @@ const std::pair& CustomOpKernelContext::OutputRangeAt( return output_range_.at(idx); } +// handle inplace mechanism +// Find out non-inplace output tensors. +void CustomOpKernelContext::MapPlainOutputs( + const std::vector& inputs, + const std::vector& outputs, + const std::unordered_map& inplace_map) { + for (size_t in_idx = 0; in_idx < inputs.size(); ++in_idx) { + auto& input = inputs[in_idx]; + if (inplace_map.find(input) == inplace_map.end()) { + continue; + } + auto out_iter = find(outputs.begin(), outputs.end(), inplace_map.at(input)); + PADDLE_ENFORCE( + out_iter != outputs.end(), + phi::errors::NotFound("Can't find the mapped value of %s, please check " + "the input of `Inplace` again and make " + "sure you registered your op accurately. ", + input)); + inplace_tensor_map_[in_idx] = distance(outputs.begin(), out_iter); + } + for (size_t i = 0; i < outputs.size(); ++i) { + if (std::any_of( + inplace_tensor_map_.begin(), + inplace_tensor_map_.end(), + [i](std::unordered_map::const_reference pair) { + return pair.second == i; + })) { + continue; + } + size_t output_start_idx = output_range_[i].first; + size_t output_end_idx = output_range_[i].second; + for (size_t idx = output_start_idx; idx < output_end_idx; ++idx) { + plain_outputs_.push_back(&outputs_[idx]); + } + } + VLOG(4) << "Custom opertor update inplace input-output map successfully."; +} +// Assign input tensor to inplace output tensors. +void CustomOpKernelContext::AssignInplaceOutputs() { + for (auto pair : inplace_tensor_map_) { + size_t in_start_idx = input_range_[pair.first].first; + size_t in_end_idx = input_range_[pair.first].second; + size_t out_start_idx = output_range_[pair.second].first; + size_t out_end_idx = output_range_[pair.second].second; + size_t assign_tensor_size = in_end_idx - in_start_idx; + PADDLE_ENFORCE( + assign_tensor_size == out_end_idx - out_start_idx, + phi::errors::OutOfRange("When assigning inplaced tensor, Input vector " + "size %d mismatch output vector size %d", + in_end_idx - in_start_idx, + out_end_idx - out_start_idx)); + for (size_t i = 0; i < assign_tensor_size; ++i) { + AssignTensorImpl(inputs_[in_start_idx + i], &outputs_[out_start_idx + i]); + } + VLOG(4) + << "Custom opertor update inplace input-output tensor successfully."; + } +} +std::vector* CustomOpKernelContext::AllMutablePlainOutput() { + return &plain_outputs_; +} +std::unordered_map +CustomOpKernelContext::GetInplaceTensorMap() { + return inplace_tensor_map_; +} ////////////////////// Op Meta Info ////////////////////// OpMetaInfo& OpMetaInfo::Inputs(std::vector&& inputs) { @@ -142,6 +211,12 @@ OpMetaInfo& OpMetaInfo::Attrs(std::vector&& attrs) { attrs_ = std::forward>(attrs); return *this; } +OpMetaInfo& OpMetaInfo::Inplace( + std::unordered_map&& inplace_map) { + inplace_map_ = + std::forward>(inplace_map); + return *this; +} OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) { kernel_fn_ = std::forward(func); return *this; @@ -222,6 +297,13 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector&& attrs) { return *this; } +OpMetaInfoBuilder& OpMetaInfoBuilder::Inplace( + std::unordered_map&& inplace_map) { + info_ptr_->Inplace( + std::forward>(inplace_map)); + return *this; +} + OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { info_ptr_->SetKernelFn(std::forward(func)); return *this; diff --git a/python/paddle/fluid/tests/custom_op/CMakeLists.txt b/python/paddle/fluid/tests/custom_op/CMakeLists.txt index 17dc7468853880880ab68185bc5d9a8c5a788e49..7fc26aed21ddb07d82b6cd1eeafeba9274cd0738 100644 --- a/python/paddle/fluid/tests/custom_op/CMakeLists.txt +++ b/python/paddle/fluid/tests/custom_op/CMakeLists.txt @@ -50,6 +50,7 @@ py_test(test_custom_conj SRCS test_custom_conj.py) py_test(test_custom_linear SRCS test_custom_linear.py) py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py) py_test(test_custom_tanh_double_grad SRCS test_custom_tanh_double_grad.py) +py_test(test_custom_inplace SRCS test_custom_inplace.py) # other tests py_test(test_sysconfig SRCS test_sysconfig.py) diff --git a/python/paddle/fluid/tests/custom_op/custom_inplace.cc b/python/paddle/fluid/tests/custom_op/custom_inplace.cc new file mode 100644 index 0000000000000000000000000000000000000000..7b57a632ca64821d8da84427077204eed427a11b --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/custom_inplace.cc @@ -0,0 +1,136 @@ +// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either +// express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include +#include + +#include "paddle/extension.h" + +template +void add_forward_kernel(data_t* x_data, const data_t* y_data, int64_t numel) { + for (size_t i = 0; i < numel; ++i) { + x_data[i] += y_data[i]; + } +} + +template +void add_backward_kernel(data_t* y_grad_data, + const data_t* out_grad_data, + int64_t numel) { + for (size_t i = 0; i < numel; ++i) { + y_grad_data[i] = out_grad_data[i]; + } +} + +template +void relu_forward_kernel(data_t* x_data, int64_t numel) { + for (size_t i = 0; i < numel; ++i) { + x_data[i] = x_data[i] > 0 ? x_data[i] : 0; + } +} + +template +void relu_backward_kernel(const data_t* out_data, + data_t* grad_out_data, + int64_t out_numel) { + for (int64_t i = 0; i < out_numel; ++i) { + grad_out_data[i] = + grad_out_data[i] * (out_data[i] > static_cast(0) ? 1. : 0.); + } +} + +void AddForward(paddle::Tensor& x, const paddle::Tensor& y) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + PD_DISPATCH_FLOATING_TYPES(x.type(), "AddForward", ([&] { + add_forward_kernel(x.data(), + y.data(), + x.size()); + })); +} + +std::vector AddInferDtype(const paddle::DataType& x_dtype, + const paddle::DataType& y_dtype) { + return {x_dtype}; +} + +std::vector> AddInferShape( + const std::vector& x_shape, const std::vector& y_shape) { + return {x_shape}; +} + +std::vector AddBackward(const paddle::Tensor& x, + const paddle::Tensor& y, + paddle::Tensor& out_grad) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + PD_CHECK(y.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + paddle::Tensor y_grad = paddle::empty(x.shape(), x.dtype(), x.place()); + + PD_DISPATCH_FLOATING_TYPES( + out_grad.type(), "AddBackward", ([&] { + add_backward_kernel( + y_grad.data(), out_grad.data(), out_grad.size()); + })); + + return {y_grad}; +} + +PD_BUILD_OP(custom_add) + .Inputs({"X", "Y"}) + .Outputs({"Out"}) + .Inplace({{"X", "Out"}}) + .SetKernelFn(PD_KERNEL(AddForward)) + .SetInferShapeFn(PD_INFER_SHAPE(AddInferShape)) + .SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype)); + +PD_BUILD_GRAD_OP(custom_add) + .Inputs({"X", "Y", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X"), paddle::Grad("Y")}) + .Inplace({{paddle::Grad("Out"), paddle::Grad("X")}}) + .SetKernelFn(PD_KERNEL(AddBackward)); + +void ReluForwardInplace(paddle::Tensor& x) { // NOLINT + PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + PD_DISPATCH_FLOATING_TYPES(x.type(), "ReluForward", ([&] { + relu_forward_kernel(x.data(), + x.size()); + })); +} + +void ReluBackwardInplace(const paddle::Tensor& x, + const paddle::Tensor& out, + paddle::Tensor& grad_out) { // NOLINT + PD_CHECK(out.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor."); + + PD_DISPATCH_FLOATING_TYPES( + grad_out.type(), "ReluBackward", ([&] { + relu_backward_kernel( + out.data(), grad_out.data(), grad_out.size()); + })); +} + +PD_BUILD_OP(custom_relu_inplace) + .Inputs({"X"}) + .Outputs({"Out"}) + .Inplace({{"X", "Out"}}) + .SetKernelFn(PD_KERNEL(ReluForwardInplace)); + +PD_BUILD_GRAD_OP(custom_relu_inplace) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .Inplace({{paddle::Grad("Out"), paddle::Grad("X")}}) + .SetKernelFn(PD_KERNEL(ReluBackwardInplace)); diff --git a/python/paddle/fluid/tests/custom_op/test_custom_inplace.py b/python/paddle/fluid/tests/custom_op/test_custom_inplace.py new file mode 100644 index 0000000000000000000000000000000000000000..d3a8959410139d1b10ad63b78dffe33d2e625e67 --- /dev/null +++ b/python/paddle/fluid/tests/custom_op/test_custom_inplace.py @@ -0,0 +1,333 @@ +# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import os +import unittest + +import numpy as np +from utils import extra_cc_args, extra_nvcc_args, paddle_includes + +import paddle +import paddle.static as static +from paddle.utils.cpp_extension import get_build_directory, load +from paddle.utils.cpp_extension.extension_utils import run_cmd + +# Because Windows don't use docker, the shared lib already exists in the +# cache dir, it will not be compiled again unless the shared lib is removed. +file = '{}\\custom_inplace\\custom_inplace.pyd'.format(get_build_directory()) +if os.name == 'nt' and os.path.isfile(file): + cmd = 'del {}'.format(file) + run_cmd(cmd, True) + +# Compile and load custom op Just-In-Time. +custom_inplace = load( + name='custom_inplace', + sources=['custom_inplace.cc'], + extra_include_paths=paddle_includes, # add for Coverage CI + extra_cxx_cflags=extra_cc_args, # test for cflags + extra_cuda_cflags=extra_nvcc_args, # test for cflags + verbose=True, +) + + +def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y): + paddle.set_device(device) + x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True) + y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) + if phi_func: + out = custom_inplace.custom_add(x, y) + else: + out = x.add_(y) + + out.backward() + return x.numpy(), y.numpy(), out.numpy(), x.grad.numpy(), y.grad.numpy() + + +def inplace_static_add(func, device, dtype, np_x, np_y): + paddle.enable_static() + paddle.set_device(device) + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype) + y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype) + x.stop_gradient = False + y.stop_gradient = False + out = func(x, y) + mean_out = paddle.mean(out) + static.append_backward(mean_out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + x_v, out_v, x_grad_v, y_grad_v, out_grad_v = exe.run( + static.default_main_program(), + feed={ + "x": np_x.astype(dtype), + "y": np_y.astype(dtype), + }, + fetch_list=[ + x.name, + out.name, + x.name + "@GRAD", + y.name + "@GRAD", + out.name + "@GRAD", + ], + ) + paddle.disable_static() + return x_v, out_v, x_grad_v, y_grad_v, out_grad_v + + +def inplace_dynamic_relu(phi_func, device, dtype, np_x, np_y, np_z): + paddle.set_device(device) + x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False) + y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False) + z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False) + out_xy = x + y + if phi_func: + out_xy = custom_inplace.custom_relu_inplace(out_xy) + out_xyz = out_xy + z + out = custom_inplace.custom_relu_inplace(out_xyz) + else: + out_xy = paddle.nn.functional.relu_(out_xy) + out_xyz = out_xy + z + out = paddle.nn.functional.relu_(out_xyz) + + out.backward() + return x.numpy(), y.numpy(), out.numpy(), x.grad.numpy(), y.grad.numpy() + + +def inplace_static_relu(func, device, dtype, np_x, np_y, np_z): + paddle.enable_static() + paddle.set_device(device) + with static.scope_guard(static.Scope()): + with static.program_guard(static.Program()): + x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype) + y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype) + z = static.data(name="z", shape=[None, np_z.shape[1]], dtype=dtype) + x.stop_gradient = False + y.stop_gradient = False + z.stop_gradient = False + out_xy = x + y + out_xy = func(out_xy) + out_xyz = out_xy + z + out = func(out_xyz) + mean_out = paddle.mean(out) + static.append_backward(mean_out) + + exe = static.Executor() + exe.run(static.default_startup_program()) + + x_v, y_v, out_v, x_grad_v, y_grad_v = exe.run( + static.default_main_program(), + feed={ + "x": np_x.astype(dtype), + "y": np_y.astype(dtype), + "z": np_z.astype(dtype), + }, + fetch_list=[ + x.name, + y.name, + out.name, + x.name + "@GRAD", + y.name + "@GRAD", + ], + ) + paddle.disable_static() + return x_v, y_v, out_v, x_grad_v, y_grad_v + + +class TestCustomInplaceJit(unittest.TestCase): + def setUp(self): + self.dtypes = ['float32', 'float64'] + self.devices = ['cpu'] + self.np_x = np.random.random((3, 2)).astype("float32") + self.np_y = np.random.random((3, 2)).astype("float32") + self.np_z = np.random.random((3, 2)).astype("float32") + + def check_output(self, out, pd_out, name): + np.testing.assert_array_equal( + out, + pd_out, + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out, name, pd_out + ), + ) + + def check_output_allclose(self, out, pd_out, name): + np.testing.assert_allclose( + out, + pd_out, + rtol=5e-5, + atol=1e-2, + err_msg='custom op {}: {},\n paddle api {}: {}'.format( + name, out, name, pd_out + ), + ) + + def test_static_add(self): + for device in self.devices: + for dtype in self.dtypes: + ( + pd_x, + pd_out, + pd_x_grad, + pd_y_grad, + pd_out_grad, + ) = inplace_static_add( + paddle.add, + device, + dtype, + self.np_x, + self.np_y, + ) + ( + phi_x, + phi_out, + phi_x_grad, + phi_y_grad, + phi_out_grad, + ) = inplace_static_add( + custom_inplace.custom_add, + device, + dtype, + self.np_x, + self.np_y, + ) + self.check_output(phi_x, phi_out, "inplace_phi_x") + self.check_output( + phi_x_grad, phi_out_grad, "inplace_phi_x_grad" + ) + + self.check_output(phi_out, pd_out, "out") + self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(phi_y_grad, pd_y_grad, "y_grad") + self.check_output(phi_out_grad, pd_out_grad, "out_grad") + + def test_dynamic_add(self): + for device in self.devices: + for dtype in self.dtypes: + ( + pd_x, + pd_y, + pd_out, + pd_x_grad, + pd_y_grad, + ) = inplace_dynamic_add( + False, + device, + dtype, + self.np_x, + self.np_y, + ) + ( + phi_x, + phi_y, + phi_out, + phi_x_grad, + phi_y_grad, + ) = inplace_dynamic_add( + True, + device, + dtype, + self.np_x, + self.np_y, + ) + + self.check_output(phi_x, phi_out, "inplace_phi_x") + self.check_output(pd_x, pd_out, "inplace_pd_x") + + self.check_output(phi_x, pd_x, "x") + self.check_output(phi_y, pd_y, "y") + self.check_output(phi_out, pd_out, "out") + self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(phi_y_grad, pd_y_grad, "y_grad") + + def test_static_multiple_inplace_relu(self): + for device in self.devices: + for dtype in self.dtypes: + ( + pd_x, + pd_y, + pd_out, + pd_x_grad, + pd_y_grad, + ) = inplace_static_relu( + paddle.nn.functional.relu, + device, + dtype, + self.np_x, + self.np_y, + self.np_z, + ) + ( + phi_x, + phi_y, + phi_out, + phi_x_grad, + phi_y_grad, + ) = inplace_static_relu( + custom_inplace.custom_relu_inplace, + device, + dtype, + self.np_x, + self.np_y, + self.np_z, + ) + self.check_output_allclose(phi_x, pd_x, "x") + self.check_output_allclose(phi_y, pd_y, "y") + self.check_output_allclose(phi_out, pd_out, "out") + self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad") + self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad") + + def test_dynamic_multiple_inplace_relu(self): + for device in self.devices: + for dtype in self.dtypes: + ( + pd_x, + pd_y, + pd_out, + pd_x_grad, + pd_y_grad, + ) = inplace_dynamic_relu( + False, + device, + dtype, + self.np_x, + self.np_y, + self.np_z, + ) + ( + phi_x, + phi_y, + phi_out, + phi_x_grad, + phi_y_grad, + ) = inplace_dynamic_relu( + True, + device, + dtype, + self.np_x, + self.np_y, + self.np_z, + ) + + self.check_output(phi_x, pd_x, "x") + self.check_output(phi_y, pd_y, "y") + self.check_output(phi_out, pd_out, "out") + self.check_output(phi_x_grad, pd_x_grad, "x_grad") + self.check_output(phi_y_grad, pd_y_grad, "y_grad") + + +if __name__ == "__main__": + unittest.main()