未验证 提交 f824bc0d 编写于 作者: H HongyuJia 提交者: GitHub

[Custom Operator] Custom op support inplace mechanism (#51620)

* init unit test commit, contains register thinking

* support inplace

* get inplaced x.grad

* Try support inplace and hook at the same time

* Support inplace, need debug

* Support inplace successfully

* Inplace use Tensor&, consistent with Tensor*

* fix MapPlainOutputs bug

* fix double grad inplace error
上级 0b778bdc
...@@ -174,6 +174,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>, ...@@ -174,6 +174,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs( auto grad_outputs_names = paddle::framework::OpMetaInfoHelper::GetOutputs(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]); egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
const auto& grad_inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(
egr::Controller::Instance().GetOpMetaInfoMap().at(op_type_)[1]);
auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_);
auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap();
...@@ -205,6 +208,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>, ...@@ -205,6 +208,9 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
} }
VLOG(6) << "Prepare Grad attrs"; VLOG(6) << "Prepare Grad attrs";
ctx.EmplaceBackAttrs(attrs_); ctx.EmplaceBackAttrs(attrs_);
// NOTE(HongyuJia): grad_outputs_names.size() <= OutputMeta().size():
// OutputMeta().size() indicates input size of forward op,
// grad_outputs_names.size() indicates output size of backward op.
paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize> outs( paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize> outs(
OutputMeta().size()); OutputMeta().size());
paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize> paddle::small_vector<std::vector<paddle::Tensor>, kSlotSmallVectorSize>
...@@ -234,8 +240,10 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>, ...@@ -234,8 +240,10 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
} }
VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad"; VLOG(7) << "Run Kernel of Grad Custom Op: " << op_type_ << "_grad";
ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn( (*paddle::framework::OpMetaInfoHelper::GetKernelFn(
kernel_map.at(op_type_)[1]))(&ctx); kernel_map.at(op_type_)[1]))(&ctx);
ctx.AssignInplaceOutputs();
VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op";
std::vector<std::vector<egr::AutogradMeta*>> ins_auto_grad_metas; std::vector<std::vector<egr::AutogradMeta*>> ins_auto_grad_metas;
...@@ -353,6 +361,8 @@ RunCustomOpDoubleGradNode::operator()( ...@@ -353,6 +361,8 @@ RunCustomOpDoubleGradNode::operator()(
paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[2]); paddle::framework::OpMetaInfoHelper::GetInputs(vec_map[2]);
auto grad_outputs_names = auto grad_outputs_names =
paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[2]); paddle::framework::OpMetaInfoHelper::GetOutputs(vec_map[2]);
const auto& grad_inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(vec_map[2]);
auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_); auto map = egr::Controller::Instance().GetCustomEdgesSlotMap().at(op_type_);
auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap(); auto kernel_map = egr::Controller::Instance().GetOpMetaInfoMap();
...@@ -419,8 +429,10 @@ RunCustomOpDoubleGradNode::operator()( ...@@ -419,8 +429,10 @@ RunCustomOpDoubleGradNode::operator()(
} }
VLOG(7) << "Run Kernel of Grad Custom Op: " << name(); VLOG(7) << "Run Kernel of Grad Custom Op: " << name();
ctx.MapPlainOutputs(grad_inputs_name, grad_outputs_names, grad_inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn( (*paddle::framework::OpMetaInfoHelper::GetKernelFn(
kernel_map.at(op_type_)[2]))(&ctx); kernel_map.at(op_type_)[2]))(&ctx);
ctx.AssignInplaceOutputs();
return outs; return outs;
} }
......
...@@ -130,11 +130,13 @@ static std::vector<std::string> ParseAttrStr(const std::string& attr) { ...@@ -130,11 +130,13 @@ static std::vector<std::string> ParseAttrStr(const std::string& attr) {
////////////////// Kernel Define //////////////////// ////////////////// Kernel Define ////////////////////
// custom op kernel call function define // custom op kernel call function define
static void RunKernelFunc(const framework::ExecutionContext& ctx, static void RunKernelFunc(
const paddle::KernelFunc& func, const framework::ExecutionContext& ctx,
const std::vector<std::string>& inputs, const paddle::KernelFunc& func,
const std::vector<std::string>& outputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& attrs) { const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs,
const std::unordered_map<std::string, std::string>& inplace_map) {
VLOG(3) << "Custom Operator: Start run KernelFunc."; VLOG(3) << "Custom Operator: Start run KernelFunc.";
// prepare CustomOpKernelContext // prepare CustomOpKernelContext
paddle::CustomOpKernelContext kernel_ctx; paddle::CustomOpKernelContext kernel_ctx;
...@@ -283,7 +285,10 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, ...@@ -283,7 +285,10 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
VLOG(4) << "Initialize phi tensor operants successfully"; VLOG(4) << "Initialize phi tensor operants successfully";
} }
// handle inplace case
kernel_ctx.MapPlainOutputs(inputs, outputs, inplace_map);
func(&kernel_ctx); func(&kernel_ctx);
kernel_ctx.AssignInplaceOutputs();
// sync output tensor data into original output // sync output tensor data into original output
auto* calc_outs = kernel_ctx.AllMutableOutput(); auto* calc_outs = kernel_ctx.AllMutableOutput();
...@@ -686,12 +691,14 @@ static void RegisterOperatorKernelWithPlace( ...@@ -686,12 +691,14 @@ static void RegisterOperatorKernelWithPlace(
OperatorWithKernel::AllOpKernels()[name][key] = op_kernel_func; OperatorWithKernel::AllOpKernels()[name][key] = op_kernel_func;
} }
static void RegisterOperatorKernel(const std::string& name, static void RegisterOperatorKernel(
const paddle::KernelFunc& kernel_func, const std::string& name,
const std::vector<std::string>& inputs, const paddle::KernelFunc& kernel_func,
const std::vector<std::string>& outputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& attrs, const std::vector<std::string>& outputs,
void* dso_handle) { const std::vector<std::string>& attrs,
const std::unordered_map<std::string, std::string>& inplace_map,
void* dso_handle) {
VLOG(3) << "Custom Operator: op name in kernel: " << name; VLOG(3) << "Custom Operator: op name in kernel: " << name;
// NOTE [ Dummy Op Kernel Key ] // NOTE [ Dummy Op Kernel Key ]
// TODO(chenweihang): Because execute engine need get device context based // TODO(chenweihang): Because execute engine need get device context based
...@@ -701,10 +708,10 @@ static void RegisterOperatorKernel(const std::string& name, ...@@ -701,10 +708,10 @@ static void RegisterOperatorKernel(const std::string& name,
OperatorWithKernel::OpKernelFunc op_kernel_func; OperatorWithKernel::OpKernelFunc op_kernel_func;
if (kernel_func) { if (kernel_func) {
VLOG(3) << "Register custom operator " << name << " with kernel func"; VLOG(3) << "Register custom operator " << name << " with kernel func";
op_kernel_func = [kernel_func, inputs, outputs, attrs]( op_kernel_func = [kernel_func, inputs, outputs, attrs, inplace_map](
const framework::ExecutionContext& ctx) { const framework::ExecutionContext& ctx) {
VLOG(3) << "Custom Operator: run custom kernel func in lambda."; VLOG(3) << "Custom Operator: run custom kernel func in lambda.";
RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs); RunKernelFunc(ctx, kernel_func, inputs, outputs, attrs, inplace_map);
}; };
} else { } else {
VLOG(3) << "Register custom operator " << name VLOG(3) << "Register custom operator " << name
...@@ -760,6 +767,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -760,6 +767,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta); auto& op_inputs = OpMetaInfoHelper::GetInputs(base_op_meta);
auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta); auto& op_outputs = OpMetaInfoHelper::GetOutputs(base_op_meta);
auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta); auto& op_attrs = OpMetaInfoHelper::GetAttrs(base_op_meta);
auto& op_inplace_map = OpMetaInfoHelper::GetInplaceMap(base_op_meta);
auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta); auto& kernel_fn = OpMetaInfoHelper::GetKernelFn(base_op_meta);
auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta); auto& infer_shape_func = OpMetaInfoHelper::GetInferShapeFn(base_op_meta);
auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta); auto& infer_dtype_func = OpMetaInfoHelper::GetInferDtypeFn(base_op_meta);
...@@ -771,6 +779,12 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -771,6 +779,12 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
<< string::join_strings(op_outputs, ','); << string::join_strings(op_outputs, ',');
VLOG(3) << "Custom Operator: forward, op attrs: " VLOG(3) << "Custom Operator: forward, op attrs: "
<< string::join_strings(op_attrs, ','); << string::join_strings(op_attrs, ',');
if (!op_inplace_map.empty()) {
VLOG(3) << "Custom Operator: forward, op inplace_map: "
<< string::join_strings(op_inplace_map, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}
// Op // Op
info.creator_ = [](const std::string& op_name, info.creator_ = [](const std::string& op_name,
...@@ -795,6 +809,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -795,6 +809,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
op_name, op_name,
info.proto_->InitializationErrorString())); info.proto_->InitializationErrorString()));
// Inplace
if (!op_inplace_map.empty()) {
info.infer_inplace_ = [op_inplace_map](bool use_cuda) {
return op_inplace_map;
};
}
// InferShape // InferShape
if (infer_shape_func == nullptr) { if (infer_shape_func == nullptr) {
// use default InferShape // use default InferShape
...@@ -908,8 +929,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -908,8 +929,13 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
} }
// Kernel func // Kernel func
RegisterOperatorKernel( RegisterOperatorKernel(op_name,
op_name, kernel_fn, op_inputs, op_outputs, op_attrs, dso_handle); kernel_fn,
op_inputs,
op_outputs,
op_attrs,
op_inplace_map,
dso_handle);
// If grad op or double grad op exists // If grad op or double grad op exists
std::string cur_op_name = op_name; std::string cur_op_name = op_name;
...@@ -920,6 +946,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -920,6 +946,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op); auto& grad_op_inputs = OpMetaInfoHelper::GetInputs(cur_grad_op);
auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op); auto& grad_op_outputs = OpMetaInfoHelper::GetOutputs(cur_grad_op);
auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op); auto& grad_op_attrs = OpMetaInfoHelper::GetAttrs(cur_grad_op);
auto& grad_op_inplace_map = OpMetaInfoHelper::GetInplaceMap(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op); auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op);
...@@ -928,6 +955,14 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -928,6 +955,14 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
<< string::join_strings(grad_op_inputs, ','); << string::join_strings(grad_op_inputs, ',');
VLOG(3) << "Custom Operator: backward, op outputs: " VLOG(3) << "Custom Operator: backward, op outputs: "
<< string::join_strings(grad_op_outputs, ','); << string::join_strings(grad_op_outputs, ',');
VLOG(3) << "Custom Operator: backward, op attrs: "
<< string::join_strings(grad_op_attrs, ',');
if (!op_inplace_map.empty()) {
VLOG(3) << "Custom Operator: backward, op inplace_map: "
<< string::join_strings(grad_op_inplace_map, ',', [](auto& pair) {
return pair.first + ": " + pair.second;
});
}
bool is_double_grad = (i == 2); bool is_double_grad = (i == 2);
...@@ -1040,6 +1075,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1040,6 +1075,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
grad_op_inputs, grad_op_inputs,
grad_op_outputs, grad_op_outputs,
grad_op_attrs, grad_op_attrs,
grad_op_inplace_map,
dso_handle); dso_handle);
// update current info // update current info
......
...@@ -39,6 +39,10 @@ class OpMetaInfoHelper { ...@@ -39,6 +39,10 @@ class OpMetaInfoHelper {
const paddle::OpMetaInfo& info) { const paddle::OpMetaInfo& info) {
return info.attrs_; return info.attrs_;
} }
static const std::unordered_map<std::string, std::string>& GetInplaceMap(
const paddle::OpMetaInfo& info) {
return info.inplace_map_;
}
static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) { static const KernelFunc& GetKernelFn(const paddle::OpMetaInfo& info) {
return info.kernel_fn_; return info.kernel_fn_;
} }
......
...@@ -531,7 +531,18 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -531,7 +531,18 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
meta_info_map.at(op_type)[0])); meta_info_map.at(op_type)[0]));
ctx.EmplaceBackAttrs(res_attrs); ctx.EmplaceBackAttrs(res_attrs);
const auto& vec_map = meta_info_map.at(op_type); const auto& vec_map = meta_info_map.at(op_type);
// handle inplace case
const auto& inputs = paddle::framework::OpMetaInfoHelper::GetInputs(
meta_info_map.at(op_type)[0]);
const auto& outputs = paddle::framework::OpMetaInfoHelper::GetOutputs(
meta_info_map.at(op_type)[0]);
const auto& inplace_map =
paddle::framework::OpMetaInfoHelper::GetInplaceMap(
meta_info_map.at(op_type)[0]);
ctx.MapPlainOutputs(inputs, outputs, inplace_map);
(*paddle::framework::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx); (*paddle::framework::OpMetaInfoHelper::GetKernelFn(vec_map[0]))(&ctx);
ctx.AssignInplaceOutputs();
VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op"; VLOG(7) << "Get AutogradMeta for inputs and outputs for Custom Op";
std::vector<std::vector<egr::AutogradMeta*>> ins_auto_grad_metas; std::vector<std::vector<egr::AutogradMeta*>> ins_auto_grad_metas;
...@@ -557,12 +568,43 @@ static PyObject* eager_api_run_custom_op(PyObject* self, ...@@ -557,12 +568,43 @@ static PyObject* eager_api_run_custom_op(PyObject* self,
require_any_grad || egr::EagerUtils::ComputeRequireGrad( require_any_grad || egr::EagerUtils::ComputeRequireGrad(
trace_backward, &(ins_auto_grad_metas[i])); trace_backward, &(ins_auto_grad_metas[i]));
} }
// handle inplace case
for (size_t i = 0; i < ctx.InputRange().size(); i++) {
if (inplace_map.find(inputs[i]) != inplace_map.end()) {
size_t input_size =
ctx.InputRangeAt(i).second - ctx.InputRangeAt(i).first;
size_t start_idx = ctx.InputRangeAt(i).first;
for (size_t j = 0; j < input_size; j++) {
egr::EagerUtils::CheckInplace(ctx.InputAt(start_idx + j),
ins_auto_grad_metas[i][j],
require_any_grad);
// Bump Inplace Version
ctx.MutableInputAt(start_idx + j).bump_inplace_version();
VLOG(3) << "Custom operator: Tensor("
<< ctx.InputAt(start_idx + j).name()
<< ") uses Inplace Strategy.";
}
}
}
if (require_any_grad && (vec_map.size() > 1)) { if (require_any_grad && (vec_map.size() > 1)) {
VLOG(6) << " Construct Grad for Custom Op: " << op_type; VLOG(6) << " Construct Grad for Custom Op: " << op_type;
ConstructFwdAndBwdMap(vec_map, op_type); ConstructFwdAndBwdMap(vec_map, op_type);
for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) { for (size_t i = 0; i < outs_auto_grad_metas.size(); i++) {
egr::EagerUtils::PassStopGradient(false, &(outs_auto_grad_metas[i])); egr::EagerUtils::PassStopGradient(false, &(outs_auto_grad_metas[i]));
} }
// Note(HongyuJia): In dygraph eager mode, CheckInplace makes sure leaf
// nodes set stop_gradient=True. However, dygraph mode can also outputs
// lead nodes' gradients (For example, we can get x.grad after x.add_(y)).
// To be consistent with dygraph mode, we have to PassStopGradient for all
// inplaced ins_auto_grad_metas.
std::unordered_map<size_t, size_t> inplace_tensor_map =
ctx.GetInplaceTensorMap();
for (auto pair : inplace_tensor_map) {
egr::EagerUtils::PassStopGradient(false,
&(ins_auto_grad_metas[pair.first]));
}
auto grad_node = std::make_shared<egr::RunCustomOpNode>( auto grad_node = std::make_shared<egr::RunCustomOpNode>(
outs_auto_grad_metas.size(), ins_auto_grad_metas.size(), op_type); outs_auto_grad_metas.size(), ins_auto_grad_metas.size(), op_type);
auto slot_map = auto slot_map =
......
...@@ -609,8 +609,7 @@ paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj, ...@@ -609,8 +609,7 @@ paddle::CustomOpKernelContext CastPyArg2CustomOpKernelContext(PyObject* obj,
return ::pybind11::handle(obj).cast<paddle::CustomOpKernelContext>(); return ::pybind11::handle(obj).cast<paddle::CustomOpKernelContext>();
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"argument (position %d) must be " "argument (position %d) must be CustomOpKernelContext, "
"one of(Place,CUDAPlace,CPUPlace,XPUPlace,NPUPlace,CUDAPinnedPlace), "
"but got %s", "but got %s",
arg_pos + 1, arg_pos + 1,
reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name)); reinterpret_cast<PyTypeObject*>(obj->ob_type)->tp_name));
......
...@@ -108,6 +108,7 @@ class PADDLE_API CustomOpKernelContext { ...@@ -108,6 +108,7 @@ class PADDLE_API CustomOpKernelContext {
const Tensor& InputAt(size_t idx) const; const Tensor& InputAt(size_t idx) const;
std::vector<Tensor> InputsBetween(size_t start, size_t end) const; std::vector<Tensor> InputsBetween(size_t start, size_t end) const;
Tensor& MutableInputAt(size_t idx);
const std::vector<paddle::any>& Attrs() const { return attrs_; } const std::vector<paddle::any>& Attrs() const { return attrs_; }
const std::vector<std::pair<size_t, size_t>>& InputRange() { const std::vector<std::pair<size_t, size_t>>& InputRange() {
return input_range_; return input_range_;
...@@ -129,11 +130,23 @@ class PADDLE_API CustomOpKernelContext { ...@@ -129,11 +130,23 @@ class PADDLE_API CustomOpKernelContext {
} }
} }
// handle inplace case
void MapPlainOutputs(
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map);
void AssignInplaceOutputs();
std::vector<Tensor*>* AllMutablePlainOutput();
std::unordered_map<size_t, size_t> GetInplaceTensorMap();
private: private:
// TODO(chenweihang): replaced be SmallVector // TODO(chenweihang): replaced be SmallVector
std::vector<Tensor> inputs_; std::vector<Tensor> inputs_;
std::vector<Tensor> outputs_; std::vector<Tensor> outputs_;
std::vector<paddle::any> attrs_; std::vector<paddle::any> attrs_;
// handle inplace case
std::vector<Tensor*> plain_outputs_;
std::unordered_map<size_t, size_t> inplace_tensor_map_;
std::vector<std::pair<size_t, size_t>> input_range_; std::vector<std::pair<size_t, size_t>> input_range_;
std::vector<std::pair<size_t, size_t>> output_range_; std::vector<std::pair<size_t, size_t>> output_range_;
...@@ -148,8 +161,7 @@ using KernelFunc = void (*)(CustomOpKernelContext*); ...@@ -148,8 +161,7 @@ using KernelFunc = void (*)(CustomOpKernelContext*);
template <typename... Tail> \ template <typename... Tail> \
struct ComputeCallHelper<attr_type, Tail...> { \ struct ComputeCallHelper<attr_type, Tail...> { \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \ template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Compute(CustomOpKernelContext* ctx, \ static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) { \
const PreviousArgs&... pargs) { \
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \ attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
ComputeCallHelper< \ ComputeCallHelper< \
Tail...>::template Compute<in_idx, attr_idx + 1, out_idx>(ctx, \ Tail...>::template Compute<in_idx, attr_idx + 1, out_idx>(ctx, \
...@@ -177,10 +189,9 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -177,10 +189,9 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<const Tensor&, Tail...> { struct ComputeCallHelper<const Tensor&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
const PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx); auto& range = ctx->InputRangeAt(in_idx);
auto& arg = ctx->InputAt(range.first); auto& arg = ctx->MutableInputAt(range.first);
ComputeCallHelper< ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx, Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs..., pargs...,
...@@ -191,8 +202,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -191,8 +202,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> { struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
const PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx); auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->InputsBetween(range.first, range.second); auto arg = ctx->InputsBetween(range.first, range.second);
ComputeCallHelper< ComputeCallHelper<
...@@ -232,11 +242,12 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -232,11 +242,12 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>); PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>); PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
// Used to be compatible with 2.3 released internal inplace interface, not
// recommended
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<Tensor*, Tail...> { struct ComputeCallHelper<Tensor*, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
const PreviousArgs&... pargs) {
auto& range = ctx->OutputRangeAt(out_idx); auto& range = ctx->OutputRangeAt(out_idx);
auto* arg = ctx->MutableOutputAt(range.first); auto* arg = ctx->MutableOutputAt(range.first);
ComputeCallHelper< ComputeCallHelper<
...@@ -246,13 +257,14 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -246,13 +257,14 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Used to be compatible with 2.3 released internal inplace interface, not
// recommended
// TODO(chenweihang): What is the appropriate output form? // TODO(chenweihang): What is the appropriate output form?
// std::vector<Tensor>*? or std::vector<Tensor*>? or std::vector<Tensor*>* // std::vector<Tensor>*? or std::vector<Tensor*>? or std::vector<Tensor*>*
template <typename... Tail> template <typename... Tail>
struct ComputeCallHelper<std::vector<Tensor*>, Tail...> { struct ComputeCallHelper<std::vector<Tensor*>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
const PreviousArgs&... pargs) {
auto& range = ctx->OutputRangeAt(out_idx); auto& range = ctx->OutputRangeAt(out_idx);
auto arg = ctx->MutableOutputBetweeen(range.first, range.second); auto arg = ctx->MutableOutputBetweeen(range.first, range.second);
ComputeCallHelper< ComputeCallHelper<
...@@ -262,18 +274,32 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -262,18 +274,32 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
} }
}; };
// Handle Tensor& for inplace case
template <typename... Tail>
struct ComputeCallHelper<Tensor&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto& arg = ctx->MutableInputAt(range.first);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
template <int out_idx, typename T> template <int out_idx, typename T>
struct ComputeReturnHelper; struct ComputeReturnHelper;
// For compatibility with the original custom op form // For compatibility with the original custom op form
template <int out_idx> template <int out_idx>
struct ComputeReturnHelper<out_idx, std::vector<Tensor>> { struct ComputeReturnHelper<out_idx, std::vector<Tensor>> {
static void Compute(CustomOpKernelContext* ctx, const Args&... args) { static void Compute(CustomOpKernelContext* ctx, Args&... args) {
static_assert(out_idx == 0, static_assert(out_idx == 0,
"If return std::vector<Tensor> in Custom OpKernel, " "If return std::vector<Tensor> in Custom OpKernel, "
"you cannot pass output by kernel function argument."); "you cannot pass output by kernel function argument.");
auto outs = impl_fn(args...); auto outs = impl_fn(args...);
auto* orig_outs = ctx->AllMutableOutput(); auto* orig_outs = ctx->AllMutablePlainOutput();
PD_CHECK(orig_outs->size() == outs.size(), PD_CHECK(orig_outs->size() == outs.size(),
"The number of element in custom operator outputs is wrong, " "The number of element in custom operator outputs is wrong, "
"expected contains ", "expected contains ",
...@@ -282,15 +308,14 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -282,15 +308,14 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
outs.size(), outs.size(),
" Tensors."); " Tensors.");
for (size_t i = 0; i < outs.size(); ++i) { for (size_t i = 0; i < outs.size(); ++i) {
AssignTensorImpl(outs.at(i), &(orig_outs->at(i))); AssignTensorImpl(outs.at(i), orig_outs->at(i));
} }
} }
}; };
template <int out_idx> template <int out_idx>
struct ComputeReturnHelper<out_idx, void> { struct ComputeReturnHelper<out_idx, void> {
static void Compute(CustomOpKernelContext* ctx, const Args&... args) { static void Compute(CustomOpKernelContext* ctx, Args&... args) {
static_assert(out_idx > 0, "Custom OpKernel has no output.");
impl_fn(args...); impl_fn(args...);
} }
}; };
...@@ -299,8 +324,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -299,8 +324,7 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
template <typename T> template <typename T>
struct ComputeCallHelper<TypeTag<T>> { struct ComputeCallHelper<TypeTag<T>> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
const PreviousArgs&... pargs) {
ComputeReturnHelper<out_idx, Return>::Compute(ctx, pargs...); ComputeReturnHelper<out_idx, Return>::Compute(ctx, pargs...);
} }
}; };
...@@ -547,9 +571,14 @@ class PADDLE_API OpMetaInfo { ...@@ -547,9 +571,14 @@ class PADDLE_API OpMetaInfo {
// format: {"<name1>", "<name2>", ...} // format: {"<name1>", "<name2>", ...}
OpMetaInfo& Outputs(std::vector<std::string>&& outputs); OpMetaInfo& Outputs(std::vector<std::string>&& outputs);
// format: {"<name1>:<type1>", "<name1>:<type1>", ...} // format: {"<name1>:<type1>", "<name2>:<type2>", ...}
OpMetaInfo& Attrs(std::vector<std::string>&& attrs); OpMetaInfo& Attrs(std::vector<std::string>&& attrs);
// format: {"<input_name1>:<output_name1>",
// "<input_name2>:<output_name2>",...}
OpMetaInfo& Inplace(
std::unordered_map<std::string, std::string>&& inplace_map);
// format: PD_KERNEL(...) // format: PD_KERNEL(...)
OpMetaInfo& SetKernelFn(KernelFunc&& func); OpMetaInfo& SetKernelFn(KernelFunc&& func);
...@@ -567,6 +596,7 @@ class PADDLE_API OpMetaInfo { ...@@ -567,6 +596,7 @@ class PADDLE_API OpMetaInfo {
std::vector<std::string> inputs_; std::vector<std::string> inputs_;
std::vector<std::string> outputs_; std::vector<std::string> outputs_;
std::vector<std::string> attrs_; std::vector<std::string> attrs_;
std::unordered_map<std::string, std::string> inplace_map_;
// 2. func info // 2. func info
KernelFunc kernel_fn_{nullptr}; KernelFunc kernel_fn_{nullptr};
InferShapeFunc infer_shape_fn_{nullptr}; InferShapeFunc infer_shape_fn_{nullptr};
...@@ -605,6 +635,8 @@ class PADDLE_API OpMetaInfoBuilder { ...@@ -605,6 +635,8 @@ class PADDLE_API OpMetaInfoBuilder {
OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs); OpMetaInfoBuilder& Inputs(std::vector<std::string>&& inputs);
OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs); OpMetaInfoBuilder& Outputs(std::vector<std::string>&& outputs);
OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs); OpMetaInfoBuilder& Attrs(std::vector<std::string>&& attrs);
OpMetaInfoBuilder& Inplace(
std::unordered_map<std::string, std::string>&& inplace_map);
OpMetaInfoBuilder& SetKernelFn(KernelFunc func); OpMetaInfoBuilder& SetKernelFn(KernelFunc func);
OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func); OpMetaInfoBuilder& SetInferShapeFn(InferShapeFunc func);
OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func); OpMetaInfoBuilder& SetInferDtypeFn(InferDtypeFunc func);
......
...@@ -94,6 +94,10 @@ std::vector<Tensor> CustomOpKernelContext::InputsBetween(size_t start, ...@@ -94,6 +94,10 @@ std::vector<Tensor> CustomOpKernelContext::InputsBetween(size_t start,
return rlt; return rlt;
} }
Tensor& CustomOpKernelContext::MutableInputAt(size_t idx) {
return inputs_.at(idx);
}
Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) { Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) {
return &(outputs_.at(idx)); return &(outputs_.at(idx));
} }
...@@ -128,6 +132,71 @@ const std::pair<size_t, size_t>& CustomOpKernelContext::OutputRangeAt( ...@@ -128,6 +132,71 @@ const std::pair<size_t, size_t>& CustomOpKernelContext::OutputRangeAt(
return output_range_.at(idx); return output_range_.at(idx);
} }
// handle inplace mechanism
// Find out non-inplace output tensors.
void CustomOpKernelContext::MapPlainOutputs(
const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs,
const std::unordered_map<std::string, std::string>& inplace_map) {
for (size_t in_idx = 0; in_idx < inputs.size(); ++in_idx) {
auto& input = inputs[in_idx];
if (inplace_map.find(input) == inplace_map.end()) {
continue;
}
auto out_iter = find(outputs.begin(), outputs.end(), inplace_map.at(input));
PADDLE_ENFORCE(
out_iter != outputs.end(),
phi::errors::NotFound("Can't find the mapped value of %s, please check "
"the input of `Inplace` again and make "
"sure you registered your op accurately. ",
input));
inplace_tensor_map_[in_idx] = distance(outputs.begin(), out_iter);
}
for (size_t i = 0; i < outputs.size(); ++i) {
if (std::any_of(
inplace_tensor_map_.begin(),
inplace_tensor_map_.end(),
[i](std::unordered_map<size_t, size_t>::const_reference pair) {
return pair.second == i;
})) {
continue;
}
size_t output_start_idx = output_range_[i].first;
size_t output_end_idx = output_range_[i].second;
for (size_t idx = output_start_idx; idx < output_end_idx; ++idx) {
plain_outputs_.push_back(&outputs_[idx]);
}
}
VLOG(4) << "Custom opertor update inplace input-output map successfully.";
}
// Assign input tensor to inplace output tensors.
void CustomOpKernelContext::AssignInplaceOutputs() {
for (auto pair : inplace_tensor_map_) {
size_t in_start_idx = input_range_[pair.first].first;
size_t in_end_idx = input_range_[pair.first].second;
size_t out_start_idx = output_range_[pair.second].first;
size_t out_end_idx = output_range_[pair.second].second;
size_t assign_tensor_size = in_end_idx - in_start_idx;
PADDLE_ENFORCE(
assign_tensor_size == out_end_idx - out_start_idx,
phi::errors::OutOfRange("When assigning inplaced tensor, Input vector "
"size %d mismatch output vector size %d",
in_end_idx - in_start_idx,
out_end_idx - out_start_idx));
for (size_t i = 0; i < assign_tensor_size; ++i) {
AssignTensorImpl(inputs_[in_start_idx + i], &outputs_[out_start_idx + i]);
}
VLOG(4)
<< "Custom opertor update inplace input-output tensor successfully.";
}
}
std::vector<Tensor*>* CustomOpKernelContext::AllMutablePlainOutput() {
return &plain_outputs_;
}
std::unordered_map<size_t, size_t>
CustomOpKernelContext::GetInplaceTensorMap() {
return inplace_tensor_map_;
}
////////////////////// Op Meta Info ////////////////////// ////////////////////// Op Meta Info //////////////////////
OpMetaInfo& OpMetaInfo::Inputs(std::vector<std::string>&& inputs) { OpMetaInfo& OpMetaInfo::Inputs(std::vector<std::string>&& inputs) {
...@@ -142,6 +211,12 @@ OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) { ...@@ -142,6 +211,12 @@ OpMetaInfo& OpMetaInfo::Attrs(std::vector<std::string>&& attrs) {
attrs_ = std::forward<std::vector<std::string>>(attrs); attrs_ = std::forward<std::vector<std::string>>(attrs);
return *this; return *this;
} }
OpMetaInfo& OpMetaInfo::Inplace(
std::unordered_map<std::string, std::string>&& inplace_map) {
inplace_map_ =
std::forward<std::unordered_map<std::string, std::string>>(inplace_map);
return *this;
}
OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) { OpMetaInfo& OpMetaInfo::SetKernelFn(KernelFunc&& func) {
kernel_fn_ = std::forward<KernelFunc>(func); kernel_fn_ = std::forward<KernelFunc>(func);
return *this; return *this;
...@@ -222,6 +297,13 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) { ...@@ -222,6 +297,13 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::Attrs(std::vector<std::string>&& attrs) {
return *this; return *this;
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::Inplace(
std::unordered_map<std::string, std::string>&& inplace_map) {
info_ptr_->Inplace(
std::forward<std::unordered_map<std::string, std::string>>(inplace_map));
return *this;
}
OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) { OpMetaInfoBuilder& OpMetaInfoBuilder::SetKernelFn(KernelFunc func) {
info_ptr_->SetKernelFn(std::forward<KernelFunc>(func)); info_ptr_->SetKernelFn(std::forward<KernelFunc>(func));
return *this; return *this;
......
...@@ -50,6 +50,7 @@ py_test(test_custom_conj SRCS test_custom_conj.py) ...@@ -50,6 +50,7 @@ py_test(test_custom_conj SRCS test_custom_conj.py)
py_test(test_custom_linear SRCS test_custom_linear.py) py_test(test_custom_linear SRCS test_custom_linear.py)
py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py) py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py)
py_test(test_custom_tanh_double_grad SRCS test_custom_tanh_double_grad.py) py_test(test_custom_tanh_double_grad SRCS test_custom_tanh_double_grad.py)
py_test(test_custom_inplace SRCS test_custom_inplace.py)
# other tests # other tests
py_test(test_sysconfig SRCS test_sysconfig.py) py_test(test_sysconfig SRCS test_sysconfig.py)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either
// express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
template <typename data_t>
void add_forward_kernel(data_t* x_data, const data_t* y_data, int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
x_data[i] += y_data[i];
}
}
template <typename data_t>
void add_backward_kernel(data_t* y_grad_data,
const data_t* out_grad_data,
int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
y_grad_data[i] = out_grad_data[i];
}
}
template <typename data_t>
void relu_forward_kernel(data_t* x_data, int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
x_data[i] = x_data[i] > 0 ? x_data[i] : 0;
}
}
template <typename data_t>
void relu_backward_kernel(const data_t* out_data,
data_t* grad_out_data,
int64_t out_numel) {
for (int64_t i = 0; i < out_numel; ++i) {
grad_out_data[i] =
grad_out_data[i] * (out_data[i] > static_cast<data_t>(0) ? 1. : 0.);
}
}
void AddForward(paddle::Tensor& x, const paddle::Tensor& y) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
PD_DISPATCH_FLOATING_TYPES(x.type(), "AddForward", ([&] {
add_forward_kernel<data_t>(x.data<data_t>(),
y.data<data_t>(),
x.size());
}));
}
std::vector<paddle::DataType> AddInferDtype(const paddle::DataType& x_dtype,
const paddle::DataType& y_dtype) {
return {x_dtype};
}
std::vector<std::vector<int64_t>> AddInferShape(
const std::vector<int64_t>& x_shape, const std::vector<int64_t>& y_shape) {
return {x_shape};
}
std::vector<paddle::Tensor> AddBackward(const paddle::Tensor& x,
const paddle::Tensor& y,
paddle::Tensor& out_grad) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
PD_CHECK(y.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor y_grad = paddle::empty(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
out_grad.type(), "AddBackward", ([&] {
add_backward_kernel<data_t>(
y_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
}));
return {y_grad};
}
PD_BUILD_OP(custom_add)
.Inputs({"X", "Y"})
.Outputs({"Out"})
.Inplace({{"X", "Out"}})
.SetKernelFn(PD_KERNEL(AddForward))
.SetInferShapeFn(PD_INFER_SHAPE(AddInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype));
PD_BUILD_GRAD_OP(custom_add)
.Inputs({"X", "Y", paddle::Grad("Out")})
.Outputs({paddle::Grad("X"), paddle::Grad("Y")})
.Inplace({{paddle::Grad("Out"), paddle::Grad("X")}})
.SetKernelFn(PD_KERNEL(AddBackward));
void ReluForwardInplace(paddle::Tensor& x) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
PD_DISPATCH_FLOATING_TYPES(x.type(), "ReluForward", ([&] {
relu_forward_kernel<data_t>(x.data<data_t>(),
x.size());
}));
}
void ReluBackwardInplace(const paddle::Tensor& x,
const paddle::Tensor& out,
paddle::Tensor& grad_out) { // NOLINT
PD_CHECK(out.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
PD_DISPATCH_FLOATING_TYPES(
grad_out.type(), "ReluBackward", ([&] {
relu_backward_kernel<data_t>(
out.data<data_t>(), grad_out.data<data_t>(), grad_out.size());
}));
}
PD_BUILD_OP(custom_relu_inplace)
.Inputs({"X"})
.Outputs({"Out"})
.Inplace({{"X", "Out"}})
.SetKernelFn(PD_KERNEL(ReluForwardInplace));
PD_BUILD_GRAD_OP(custom_relu_inplace)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.Inplace({{paddle::Grad("Out"), paddle::Grad("X")}})
.SetKernelFn(PD_KERNEL(ReluBackwardInplace));
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from utils import extra_cc_args, extra_nvcc_args, paddle_includes
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_inplace\\custom_inplace.pyd'.format(get_build_directory())
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
# Compile and load custom op Just-In-Time.
custom_inplace = load(
name='custom_inplace',
sources=['custom_inplace.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cflags
extra_cuda_cflags=extra_nvcc_args, # test for cflags
verbose=True,
)
def inplace_dynamic_add(phi_func, device, dtype, np_x, np_y):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=True)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
if phi_func:
out = custom_inplace.custom_add(x, y)
else:
out = x.add_(y)
out.backward()
return x.numpy(), y.numpy(), out.numpy(), x.grad.numpy(), y.grad.numpy()
def inplace_static_add(func, device, dtype, np_x, np_y):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype)
x.stop_gradient = False
y.stop_gradient = False
out = func(x, y)
mean_out = paddle.mean(out)
static.append_backward(mean_out)
exe = static.Executor()
exe.run(static.default_startup_program())
x_v, out_v, x_grad_v, y_grad_v, out_grad_v = exe.run(
static.default_main_program(),
feed={
"x": np_x.astype(dtype),
"y": np_y.astype(dtype),
},
fetch_list=[
x.name,
out.name,
x.name + "@GRAD",
y.name + "@GRAD",
out.name + "@GRAD",
],
)
paddle.disable_static()
return x_v, out_v, x_grad_v, y_grad_v, out_grad_v
def inplace_dynamic_relu(phi_func, device, dtype, np_x, np_y, np_z):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
z = paddle.to_tensor(np_z, dtype=dtype, stop_gradient=False)
out_xy = x + y
if phi_func:
out_xy = custom_inplace.custom_relu_inplace(out_xy)
out_xyz = out_xy + z
out = custom_inplace.custom_relu_inplace(out_xyz)
else:
out_xy = paddle.nn.functional.relu_(out_xy)
out_xyz = out_xy + z
out = paddle.nn.functional.relu_(out_xyz)
out.backward()
return x.numpy(), y.numpy(), out.numpy(), x.grad.numpy(), y.grad.numpy()
def inplace_static_relu(func, device, dtype, np_x, np_y, np_z):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
y = static.data(name="y", shape=[None, np_y.shape[1]], dtype=dtype)
z = static.data(name="z", shape=[None, np_z.shape[1]], dtype=dtype)
x.stop_gradient = False
y.stop_gradient = False
z.stop_gradient = False
out_xy = x + y
out_xy = func(out_xy)
out_xyz = out_xy + z
out = func(out_xyz)
mean_out = paddle.mean(out)
static.append_backward(mean_out)
exe = static.Executor()
exe.run(static.default_startup_program())
x_v, y_v, out_v, x_grad_v, y_grad_v = exe.run(
static.default_main_program(),
feed={
"x": np_x.astype(dtype),
"y": np_y.astype(dtype),
"z": np_z.astype(dtype),
},
fetch_list=[
x.name,
y.name,
out.name,
x.name + "@GRAD",
y.name + "@GRAD",
],
)
paddle.disable_static()
return x_v, y_v, out_v, x_grad_v, y_grad_v
class TestCustomInplaceJit(unittest.TestCase):
def setUp(self):
self.dtypes = ['float32', 'float64']
self.devices = ['cpu']
self.np_x = np.random.random((3, 2)).astype("float32")
self.np_y = np.random.random((3, 2)).astype("float32")
self.np_z = np.random.random((3, 2)).astype("float32")
def check_output(self, out, pd_out, name):
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op {}: {},\n paddle api {}: {}'.format(
name, out, name, pd_out
),
)
def check_output_allclose(self, out, pd_out, name):
np.testing.assert_allclose(
out,
pd_out,
rtol=5e-5,
atol=1e-2,
err_msg='custom op {}: {},\n paddle api {}: {}'.format(
name, out, name, pd_out
),
)
def test_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
(
pd_x,
pd_out,
pd_x_grad,
pd_y_grad,
pd_out_grad,
) = inplace_static_add(
paddle.add,
device,
dtype,
self.np_x,
self.np_y,
)
(
phi_x,
phi_out,
phi_x_grad,
phi_y_grad,
phi_out_grad,
) = inplace_static_add(
custom_inplace.custom_add,
device,
dtype,
self.np_x,
self.np_y,
)
self.check_output(phi_x, phi_out, "inplace_phi_x")
self.check_output(
phi_x_grad, phi_out_grad, "inplace_phi_x_grad"
)
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
self.check_output(phi_out_grad, pd_out_grad, "out_grad")
def test_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
(
pd_x,
pd_y,
pd_out,
pd_x_grad,
pd_y_grad,
) = inplace_dynamic_add(
False,
device,
dtype,
self.np_x,
self.np_y,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
) = inplace_dynamic_add(
True,
device,
dtype,
self.np_x,
self.np_y,
)
self.check_output(phi_x, phi_out, "inplace_phi_x")
self.check_output(pd_x, pd_out, "inplace_pd_x")
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
def test_static_multiple_inplace_relu(self):
for device in self.devices:
for dtype in self.dtypes:
(
pd_x,
pd_y,
pd_out,
pd_x_grad,
pd_y_grad,
) = inplace_static_relu(
paddle.nn.functional.relu,
device,
dtype,
self.np_x,
self.np_y,
self.np_z,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
) = inplace_static_relu(
custom_inplace.custom_relu_inplace,
device,
dtype,
self.np_x,
self.np_y,
self.np_z,
)
self.check_output_allclose(phi_x, pd_x, "x")
self.check_output_allclose(phi_y, pd_y, "y")
self.check_output_allclose(phi_out, pd_out, "out")
self.check_output_allclose(phi_x_grad, pd_x_grad, "x_grad")
self.check_output_allclose(phi_y_grad, pd_y_grad, "y_grad")
def test_dynamic_multiple_inplace_relu(self):
for device in self.devices:
for dtype in self.dtypes:
(
pd_x,
pd_y,
pd_out,
pd_x_grad,
pd_y_grad,
) = inplace_dynamic_relu(
False,
device,
dtype,
self.np_x,
self.np_y,
self.np_z,
)
(
phi_x,
phi_y,
phi_out,
phi_x_grad,
phi_y_grad,
) = inplace_dynamic_relu(
True,
device,
dtype,
self.np_x,
self.np_y,
self.np_z,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_y, pd_y, "y")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
self.check_output(phi_y_grad, pd_y_grad, "y_grad")
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册