From f1f74e9e6cd5cdd3141f8712f9149363eae8a9da Mon Sep 17 00:00:00 2001 From: Chen Weihang Date: Mon, 7 Feb 2022 09:54:10 +0800 Subject: [PATCH] [CustomOp] Support output as input argument of kernel func (#39353) * refactor custom op kernel func and utils * add output sync * adapte tensor* in utils * fix windows symbol error --- paddle/fluid/framework/custom_operator.cc | 125 +++++++---- paddle/pten/api/ext/op_meta_info.h | 210 ++++++++++++------ paddle/pten/api/lib/op_meta_info.cc | 92 ++++++++ .../fluid/tests/custom_op/custom_relu_op.cc | 60 +++++ .../fluid/tests/custom_op/custom_relu_op.cu | 28 +++ .../custom_op/test_custom_relu_op_jit.py | 3 +- 6 files changed, 409 insertions(+), 109 deletions(-) diff --git a/paddle/fluid/framework/custom_operator.cc b/paddle/fluid/framework/custom_operator.cc index 35b6b918931..68445e7976e 100644 --- a/paddle/fluid/framework/custom_operator.cc +++ b/paddle/fluid/framework/custom_operator.cc @@ -110,8 +110,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, const std::vector& outputs, const std::vector& attrs) { VLOG(3) << "Custom Operator: Start run KernelFunc."; - std::vector custom_ins; - std::vector> custom_vec_ins; + // prepare CustomOpKernelContext + paddle::CustomOpKernelContext kernel_ctx; for (auto& in_name : inputs) { VLOG(3) << "Custom Operator: input name - " << in_name; if (detail::IsDuplicableVar(in_name)) { @@ -136,7 +136,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, custom_t.set_impl(std::make_shared(*x)); custom_vec_in.emplace_back(custom_t); } - custom_vec_ins.emplace_back(custom_vec_in); + kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in)); } else { auto* x = ctx.Input(in_name); PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound( @@ -146,33 +146,32 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, "Input tensor (%s) is not initialized.", in_name)); paddle::experimental::Tensor custom_in; custom_in.set_impl(std::make_shared(*x)); - custom_ins.emplace_back(custom_in); + kernel_ctx.EmplaceBackInput(std::move(custom_in)); } } - std::vector custom_attrs; for (auto& attr_str : attrs) { auto attr_name_and_type = detail::ParseAttrStr(attr_str); auto attr_name = attr_name_and_type[0]; auto attr_type_str = attr_name_and_type[1]; if (attr_type_str == "bool") { - custom_attrs.emplace_back(ctx.Attr(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr(attr_name)); } else if (attr_type_str == "int") { - custom_attrs.emplace_back(ctx.Attr(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr(attr_name)); } else if (attr_type_str == "float") { - custom_attrs.emplace_back(ctx.Attr(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr(attr_name)); } else if (attr_type_str == "int64_t") { - custom_attrs.emplace_back(ctx.Attr(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr(attr_name)); } else if (attr_type_str == "std::string") { - custom_attrs.emplace_back(ctx.Attr(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr(attr_name)); } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back(ctx.Attr>(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr>(attr_name)); } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back(ctx.Attr>(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr>(attr_name)); } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back(ctx.Attr>(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr>(attr_name)); } else if (attr_type_str == "std::vector") { - custom_attrs.emplace_back(ctx.Attr>(attr_name)); + kernel_ctx.EmplaceBackAttr(ctx.Attr>(attr_name)); } else { PADDLE_THROW(platform::errors::Unimplemented( "Unsupported `%s` type value as custom attribute now. " @@ -185,35 +184,75 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx, } } - VLOG(3) << "Custom Operator: Run ComputeFunc."; - try { - auto outs = func(custom_ins, custom_vec_ins, custom_attrs); + VLOG(3) << "Custom Operator: push outputs into CustomOpKernelContext."; + // cache the target tensor pointers + std::vector true_out_ptrs; + for (size_t i = 0; i < outputs.size(); ++i) { + auto out_name = outputs[i]; + if (detail::IsDuplicableVar(out_name)) { + PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL, + platform::errors::PreconditionNotMet( + "If custom operator's outputs contains `paddle::Vec(" + ")` type, " + "it only can hold one output.")); + auto vec_out = ctx.MultiOutput(out_name); + PADDLE_ENFORCE_NE(vec_out.empty(), true, + platform::errors::NotFound( + "Output vector (%s) is empty.", out_name)); + std::vector custom_vec_out; + for (size_t j = 0; j < vec_out.size(); ++j) { + auto* out = vec_out[j]; + PADDLE_ENFORCE_NOT_NULL( + out, + platform::errors::NotFound( + "The %d-th tensor in output vector (%s) is nullptr.", j, + out_name)); + true_out_ptrs.emplace_back(out); + paddle::experimental::Tensor custom_t; + // here only can copy the output tensor into context + custom_t.set_impl(std::make_shared(*out)); + custom_vec_out.emplace_back(custom_t); + } + kernel_ctx.EmplaceBackOutputs(std::move(custom_vec_out)); + } else { + auto* out = ctx.Output(out_name); + PADDLE_ENFORCE_NOT_NULL( + out, platform::errors::NotFound("Output tensor (%s) is nullptr.", + out_name)); + true_out_ptrs.emplace_back(out); + paddle::experimental::Tensor custom_out; + // here only can copy the output tensor into context + custom_out.set_impl(std::make_shared(*out)); + kernel_ctx.EmplaceBackOutput(std::move(custom_out)); + } + } - VLOG(3) << "Custom Operator: Share outputs into ExecutionContext."; - for (size_t i = 0; i < outputs.size(); ++i) { - auto out_name = outputs[i]; - if (detail::IsDuplicableVar(out_name)) { - PADDLE_ENFORCE(i == 0UL && outputs.size() == 1UL, - platform::errors::PreconditionNotMet( - "If custom operator's outputs contains `paddle::Vec(" - ")` type, " - "it only can hold one output.")); - auto vec_true_outs = ctx.MultiOutput(out_name); - PADDLE_ENFORCE_EQ( - vec_true_outs.size(), outs.size(), - platform::errors::InvalidArgument( - "The number of element in custom operator outputs is wrong, " - "expected contains %d Tensors, but actually contains %d " - "Tensors.", - vec_true_outs.size(), outs.size())); - for (size_t j = 0; j < vec_true_outs.size(); ++j) { - *vec_true_outs.at(j) = - *std::dynamic_pointer_cast(outs.at(j).impl()); - } - } else { - auto* true_out = ctx.Output(out_name); - *true_out = - *std::dynamic_pointer_cast(outs.at(i).impl()); + try { + VLOG(3) << "Custom Operator: Run ComputeFunc."; + func(&kernel_ctx); + + // sync output tensor data into original output + auto* calc_outs = kernel_ctx.AllMutableOutput(); + PADDLE_ENFORCE_EQ( + true_out_ptrs.size(), calc_outs->size(), + platform::errors::InvalidArgument( + "The number of element in custom operator outputs is wrong, " + "expected contains %d Tensors, but actually contains %d " + "Tensors.", + true_out_ptrs.size(), calc_outs->size())); + for (size_t i = 0; i < true_out_ptrs.size(); ++i) { + auto* true_out = true_out_ptrs.at(i); + auto calc_out = + std::dynamic_pointer_cast(calc_outs->at(i).impl()); + // assgin meta info + auto* true_out_meta = pten::DenseTensorUtils::GetMutableMeta(true_out); + true_out_meta->dims = calc_out->dims(); + true_out_meta->dtype = calc_out->dtype(); + true_out_meta->layout = calc_out->layout(); + // lod and offset no need to be reset + // reset holder if needed + if (true_out->Holder() != calc_out->Holder()) { + true_out->ResetHolder(calc_out->Holder()); } } } catch (platform::EnforceNotMet& exception) { @@ -609,7 +648,7 @@ void RegisterOperatorWithMetaInfo( auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta); if (OpInfoMap::Instance().Has(op_name)) { - LOG(WARNING) << "Operator (" << op_name << ")has been registered."; + LOG(WARNING) << "Operator (" << op_name << ") has been registered."; return; } diff --git a/paddle/pten/api/ext/op_meta_info.h b/paddle/pten/api/ext/op_meta_info.h index ac37d724698..a8f12bad187 100644 --- a/paddle/pten/api/ext/op_meta_info.h +++ b/paddle/pten/api/ext/op_meta_info.h @@ -17,6 +17,7 @@ limitations under the License. */ #include #include #include +#include #include #include "paddle/pten/api/ext/dll_decl.h" @@ -76,37 +77,66 @@ inline std::string Vec(const std::string& t_name) { return result; } +PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst); + +////////////////////// Kernel Context //////////////////////// + +class PADDLE_API CustomOpKernelContext { + public: + CustomOpKernelContext() = default; + + void EmplaceBackInput(Tensor&& input); + void EmplaceBackInputs(std::vector&& inputs); + void EmplaceBackOutput(Tensor&& output); + void EmplaceBackOutputs(std::vector&& outputs); + void EmplaceBackAttr(paddle::any attr); + + const std::pair& InputRangeAt(size_t idx) const; + const std::pair& OutputRangeAt(size_t idx) const; + + const Tensor& InputAt(size_t idx) const; + std::vector InputsBetween(size_t start, size_t end) const; + + Tensor* MutableOutputAt(size_t idx); + std::vector MutableOutputBetweeen(size_t start, size_t end); + std::vector* AllMutableOutput(); + + template + AttrType AttrAt(size_t idx) const { + try { + return paddle::any_cast(attrs_.at(idx)); + } catch (paddle::bad_any_cast&) { + PD_THROW("Attribute cast error in Custom Op Kernel Context."); + } + } + + private: + // TODO(chenweihang): replaced be SmallVector + std::vector inputs_; + std::vector outputs_; + std::vector attrs_; + + std::vector> input_range_; + std::vector> output_range_; +}; + ////////////////////// Kernel Function (PD_KERNEL) //////////////////////// // Record Op kernel core function -using KernelFunc = - std::vector (*)(const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs); - -#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \ - template \ - struct ComputeCallHelper { \ - template \ - static Return Compute(const std::vector& inputs, \ - const std::vector>& vec_inputs, \ - const std::vector& attrs, \ - const PreviousArgs&... pargs) { \ - try { \ - attr_type arg = paddle::any_cast(attrs[attr_idx]); \ - return ComputeCallHelper::template Compute( \ - inputs, vec_inputs, attrs, pargs..., arg); \ - } catch (paddle::bad_any_cast&) { \ - PD_THROW( \ - "Attribute cast error in custom operator. Expected " #attr_type \ - " value."); \ - } \ - } \ +using KernelFunc = void (*)(CustomOpKernelContext*); + +#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \ + template \ + struct ComputeCallHelper { \ + template \ + static void Compute(CustomOpKernelContext* ctx, \ + const PreviousArgs&... pargs) { \ + attr_type arg = ctx->AttrAt(attr_idx); \ + ComputeCallHelper< \ + Tail...>::template Compute(ctx, \ + pargs..., \ + arg); \ + } \ } template @@ -117,11 +147,8 @@ struct KernelFuncImpl; template struct KernelFuncImpl { - static Return Compute(const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs) { - return ComputeCallHelper>::template Compute<0, 0, 0>( - inputs, vec_inputs, attrs); + static void Compute(CustomOpKernelContext* ctx) { + ComputeCallHelper>::template Compute<0, 0, 0>(ctx); } private: @@ -130,37 +157,29 @@ struct KernelFuncImpl { template struct ComputeCallHelper { - template - static Return Compute(const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs, - const PreviousArgs&... pargs) { - const Tensor& arg = inputs[in_idx]; - return ComputeCallHelper::template Compute( - inputs, vec_inputs, attrs, pargs..., arg); + template + static void Compute(CustomOpKernelContext* ctx, + const PreviousArgs&... pargs) { + auto& range = ctx->InputRangeAt(in_idx); + auto& arg = ctx->InputAt(range.first); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); } }; template struct ComputeCallHelper&, Tail...> { - template - static Return Compute(const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs, - const PreviousArgs&... pargs) { - const std::vector& arg = vec_inputs[vec_in_idx]; - return ComputeCallHelper::template Compute( - inputs, vec_inputs, attrs, pargs..., arg); + template + static void Compute(CustomOpKernelContext* ctx, + const PreviousArgs&... pargs) { + auto& range = ctx->InputRangeAt(in_idx); + auto arg = ctx->InputsBetween(range.first, range.second); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); } }; @@ -194,15 +213,76 @@ struct KernelFuncImpl { PD_SPECIALIZE_ComputeCallHelper(std::vector); PD_SPECIALIZE_ComputeCallHelper(std::vector); + template + struct ComputeCallHelper { + template + static void Compute(CustomOpKernelContext* ctx, + const PreviousArgs&... pargs) { + auto& range = ctx->OutputRangeAt(out_idx); + auto* arg = ctx->MutableOutputAt(range.first); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); + } + }; + + // TODO(chenweihang): What is the appropriate output form? + // std::vector*? or std::vector? or std::vector* + template + struct ComputeCallHelper, Tail...> { + template + static void Compute(CustomOpKernelContext* ctx, + const PreviousArgs&... pargs) { + auto& range = ctx->OutputRangeAt(out_idx); + auto arg = ctx->MutableOutputBetweeen(range.first, range.second); + ComputeCallHelper< + Tail...>::template Compute(ctx, + pargs..., + arg); + } + }; + + template + struct ComputeReturnHelper; + + // For compatibility with the original custom op form + template + struct ComputeReturnHelper> { + static void Compute(CustomOpKernelContext* ctx, const Args&... args) { + static_assert(out_idx == 0, + "If return std::vector in Custom OpKernel, " + "you cannot pass output by kernel funciton argument."); + auto outs = impl_fn(args...); + auto* orig_outs = ctx->AllMutableOutput(); + PD_CHECK(orig_outs->size() == outs.size(), + "The number of element in custom operator outputs is wrong, " + "expected contains ", + orig_outs->size(), + " Tensors, but actually contains ", + outs.size(), + " Tensors."); + for (size_t i = 0; i < outs.size(); ++i) { + AssignTensorImpl(outs.at(i), &(orig_outs->at(i))); + } + } + }; + + template + struct ComputeReturnHelper { + static void Compute(CustomOpKernelContext* ctx, const Args&... args) { + static_assert(out_idx > 0, "Custom OpKernel has no output."); + impl_fn(args...); + } + }; + // end: base template template struct ComputeCallHelper> { - template - static Return Compute(const std::vector& inputs, - const std::vector>& vec_inputs, - const std::vector& attrs, - const Args&... args) { - return impl_fn(args...); + template + static void Compute(CustomOpKernelContext* ctx, + const PreviousArgs&... pargs) { + ComputeReturnHelper::Compute(ctx, pargs...); } }; }; diff --git a/paddle/pten/api/lib/op_meta_info.cc b/paddle/pten/api/lib/op_meta_info.cc index 82d465b4c21..649960e1e1c 100644 --- a/paddle/pten/api/lib/op_meta_info.cc +++ b/paddle/pten/api/lib/op_meta_info.cc @@ -19,10 +19,102 @@ limitations under the License. */ #include #include "paddle/fluid/framework/custom_operator.h" +#include "paddle/pten/core/dense_tensor.h" #include "paddle/pten/core/enforce.h" namespace paddle { +PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) { + PADDLE_ENFORCE_EQ(src.is_dense_tensor() && dst->is_dense_tensor(), + true, + pten::errors::Unavailable( + "Now only supported DenseTensor in Custom Operator.")); + PADDLE_ENFORCE_EQ( + src.initialized(), + true, + pten::errors::Unavailable( + "The Custom OpKernel calculate output is not initialized.")); + PADDLE_ENFORCE_EQ(dst->defined(), + true, + pten::errors::Unavailable( + "The Custom OpKernel origin output is not defined.")); + auto& dense_src = static_cast(*src.impl()); + auto* dense_dst = static_cast(dst->impl().get()); + *dense_dst = dense_src; +} + +////////////////////// Kernel Context ////////////////////// + +void CustomOpKernelContext::EmplaceBackInput(Tensor&& input) { + size_t index = inputs_.size(); + inputs_.emplace_back(input); + input_range_.emplace_back(std::make_pair(index, index + 1)); +} + +void CustomOpKernelContext::EmplaceBackInputs(std::vector&& inputs) { + size_t index = inputs_.size(); + input_range_.emplace_back(std::make_pair(index, index + inputs.size())); + inputs_.insert(inputs_.end(), + std::make_move_iterator(inputs.begin()), + std::make_move_iterator(inputs.end())); +} + +void CustomOpKernelContext::EmplaceBackOutput(Tensor&& output) { + size_t index = outputs_.size(); + outputs_.emplace_back(output); + output_range_.emplace_back(std::make_pair(index, index + 1)); +} + +void CustomOpKernelContext::EmplaceBackOutputs(std::vector&& outputs) { + size_t index = outputs_.size(); + output_range_.emplace_back(std::make_pair(index, index + outputs.size())); + outputs_.insert(outputs_.end(), + std::make_move_iterator(outputs.begin()), + std::make_move_iterator(outputs.end())); +} + +void CustomOpKernelContext::EmplaceBackAttr(paddle::any attr) { + attrs_.emplace_back(std::move(attr)); +} + +const Tensor& CustomOpKernelContext::InputAt(size_t idx) const { + return inputs_.at(idx); +} + +std::vector CustomOpKernelContext::InputsBetween(size_t start, + size_t end) const { + std::vector rlt; + for (size_t i = start; i < end; ++i) { + rlt.emplace_back(inputs_.at(i)); + } + return rlt; +} + +Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) { + return &(outputs_.at(idx)); +} +std::vector CustomOpKernelContext::MutableOutputBetweeen(size_t start, + size_t end) { + std::vector rlt; + for (size_t i = start; i < end; ++i) { + rlt.emplace_back(&(outputs_.at(i))); + } + return rlt; +} + +std::vector* CustomOpKernelContext::AllMutableOutput() { + return &outputs_; +} + +const std::pair& CustomOpKernelContext::InputRangeAt( + size_t idx) const { + return input_range_.at(idx); +} +const std::pair& CustomOpKernelContext::OutputRangeAt( + size_t idx) const { + return output_range_.at(idx); +} + ////////////////////// Op Meta Info ////////////////////// OpMetaInfo& OpMetaInfo::Inputs(std::vector&& inputs) { diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc index c5ec3191c1b..c89990be34c 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cc +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cc @@ -151,3 +151,63 @@ PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward) .Outputs({paddle::Grad("X")}) .SetKernelFn(PD_KERNEL(ReluBackwardWithoutX)) .SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape)); + +void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { + PD_DISPATCH_FLOATING_TYPES( + x.type(), "relu_cpu_forward", ([&] { + relu_cpu_forward_kernel( + x.data(), out->mutable_data(x.place()), x.size()); + })); +} + +void relu_cpu_backward_out(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x) { + PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] { + relu_cpu_backward_kernel( + grad_out.data(), + out.data(), + grad_x->mutable_data(x.place()), + out.size()); + })); +} + +void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out); +void relu_cuda_backward_out(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x); + +void ReluForwardOut(const paddle::Tensor& x, paddle::Tensor* out) { + if (x.place() == paddle::PlaceType::kCPU) { + return relu_cpu_forward_out(x, out); + } else if (x.place() == paddle::PlaceType::kGPU) { + return relu_cuda_forward_out(x, out); + } else { + PD_THROW("Not implemented."); + } +} + +void ReluBackwardOut(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x) { + if (x.place() == paddle::PlaceType::kCPU) { + return relu_cpu_backward_out(x, out, grad_out, grad_x); + } else if (x.place() == paddle::PlaceType::kGPU) { + return relu_cuda_backward_out(x, out, grad_out, grad_x); + } else { + PD_THROW("Not implemented."); + } +} + +PD_BUILD_OP(custom_relu_out) + .Inputs({"X"}) + .Outputs({"Out"}) + .SetKernelFn(PD_KERNEL(ReluForwardOut)); + +PD_BUILD_GRAD_OP(custom_relu_out) + .Inputs({"X", "Out", paddle::Grad("Out")}) + .Outputs({paddle::Grad("X")}) + .SetKernelFn(PD_KERNEL(ReluBackwardOut)); diff --git a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu index 637deeb9056..33c5ede299b 100644 --- a/python/paddle/fluid/tests/custom_op/custom_relu_op.cu +++ b/python/paddle/fluid/tests/custom_op/custom_relu_op.cu @@ -89,3 +89,31 @@ std::vector relu_cuda_backward_without_x( return {grad_x}; } + +void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) { + int numel = x.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + x.type(), "relu_cuda_forward_kernel", ([&] { + relu_cuda_forward_kernel<<>>( + x.data(), out->mutable_data(x.place()), numel); + })); +} + +void relu_cuda_backward_out(const paddle::Tensor& x, + const paddle::Tensor& out, + const paddle::Tensor& grad_out, + paddle::Tensor* grad_x) { + int numel = out.size(); + int block = 512; + int grid = (numel + block - 1) / block; + PD_DISPATCH_FLOATING_AND_HALF_TYPES( + out.type(), "relu_cuda_backward_kernel", ([&] { + relu_cuda_backward_kernel<<>>( + grad_out.data(), + out.data(), + grad_x->mutable_data(x.place()), + numel); + })); +} diff --git a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py index 16458841f44..407eb342ba9 100644 --- a/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py +++ b/python/paddle/fluid/tests/custom_op/test_custom_relu_op_jit.py @@ -50,7 +50,8 @@ class TestJITLoad(unittest.TestCase): def setUp(self): self.custom_ops = [ custom_module.custom_relu, custom_module.custom_relu_dup, - custom_module.custom_relu_no_x_in_backward + custom_module.custom_relu_no_x_in_backward, + custom_module.custom_relu_out ] self.dtypes = ['float32', 'float64'] if paddle.is_compiled_with_cuda(): -- GitLab