未验证 提交 f1f74e9e 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp] Support output as input argument of kernel func (#39353)

* refactor custom op kernel func and utils

* add output sync

* adapte tensor* in utils

* fix windows symbol error
上级 a821c4a9
......@@ -110,8 +110,8 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs) {
VLOG(3) << "Custom Operator: Start run KernelFunc.";
std::vector<paddle::experimental::Tensor> custom_ins;
std::vector<std::vector<paddle::experimental::Tensor>> custom_vec_ins;
// prepare CustomOpKernelContext
paddle::CustomOpKernelContext kernel_ctx;
for (auto& in_name : inputs) {
VLOG(3) << "Custom Operator: input name - " << in_name;
if (detail::IsDuplicableVar(in_name)) {
......@@ -136,7 +136,7 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
custom_t.set_impl(std::make_shared<pten::DenseTensor>(*x));
custom_vec_in.emplace_back(custom_t);
}
custom_vec_ins.emplace_back(custom_vec_in);
kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in));
} else {
auto* x = ctx.Input<Tensor>(in_name);
PADDLE_ENFORCE_NOT_NULL(x, platform::errors::NotFound(
......@@ -146,33 +146,32 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"Input tensor (%s) is not initialized.", in_name));
paddle::experimental::Tensor custom_in;
custom_in.set_impl(std::make_shared<pten::DenseTensor>(*x));
custom_ins.emplace_back(custom_in);
kernel_ctx.EmplaceBackInput(std::move(custom_in));
}
}
std::vector<paddle::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = detail::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(ctx.Attr<bool>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<bool>(attr_name));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(ctx.Attr<int>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<int>(attr_name));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(ctx.Attr<float>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<float>(attr_name));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(ctx.Attr<int64_t>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<int64_t>(attr_name));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(ctx.Attr<std::string>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::string>(attr_name));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<int>>(attr_name));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<float>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<float>>(attr_name));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<int64_t>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<int64_t>>(attr_name));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(ctx.Attr<std::vector<std::string>>(attr_name));
kernel_ctx.EmplaceBackAttr(ctx.Attr<std::vector<std::string>>(attr_name));
} else {
PADDLE_THROW(platform::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
......@@ -185,11 +184,9 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
}
}
VLOG(3) << "Custom Operator: Run ComputeFunc.";
try {
auto outs = func(custom_ins, custom_vec_ins, custom_attrs);
VLOG(3) << "Custom Operator: Share outputs into ExecutionContext.";
VLOG(3) << "Custom Operator: push outputs into CustomOpKernelContext.";
// cache the target tensor pointers
std::vector<Tensor*> true_out_ptrs;
for (size_t i = 0; i < outputs.size(); ++i) {
auto out_name = outputs[i];
if (detail::IsDuplicableVar(out_name)) {
......@@ -198,22 +195,64 @@ static void RunKernelFunc(const framework::ExecutionContext& ctx,
"If custom operator's outputs contains `paddle::Vec("
")` type, "
"it only can hold one output."));
auto vec_true_outs = ctx.MultiOutput<Tensor>(out_name);
auto vec_out = ctx.MultiOutput<Tensor>(out_name);
PADDLE_ENFORCE_NE(vec_out.empty(), true,
platform::errors::NotFound(
"Output vector<tensor> (%s) is empty.", out_name));
std::vector<paddle::experimental::Tensor> custom_vec_out;
for (size_t j = 0; j < vec_out.size(); ++j) {
auto* out = vec_out[j];
PADDLE_ENFORCE_NOT_NULL(
out,
platform::errors::NotFound(
"The %d-th tensor in output vector<tensor> (%s) is nullptr.", j,
out_name));
true_out_ptrs.emplace_back(out);
paddle::experimental::Tensor custom_t;
// here only can copy the output tensor into context
custom_t.set_impl(std::make_shared<pten::DenseTensor>(*out));
custom_vec_out.emplace_back(custom_t);
}
kernel_ctx.EmplaceBackOutputs(std::move(custom_vec_out));
} else {
auto* out = ctx.Output<Tensor>(out_name);
PADDLE_ENFORCE_NOT_NULL(
out, platform::errors::NotFound("Output tensor (%s) is nullptr.",
out_name));
true_out_ptrs.emplace_back(out);
paddle::experimental::Tensor custom_out;
// here only can copy the output tensor into context
custom_out.set_impl(std::make_shared<pten::DenseTensor>(*out));
kernel_ctx.EmplaceBackOutput(std::move(custom_out));
}
}
try {
VLOG(3) << "Custom Operator: Run ComputeFunc.";
func(&kernel_ctx);
// sync output tensor data into original output
auto* calc_outs = kernel_ctx.AllMutableOutput();
PADDLE_ENFORCE_EQ(
vec_true_outs.size(), outs.size(),
true_out_ptrs.size(), calc_outs->size(),
platform::errors::InvalidArgument(
"The number of element in custom operator outputs is wrong, "
"expected contains %d Tensors, but actually contains %d "
"Tensors.",
vec_true_outs.size(), outs.size()));
for (size_t j = 0; j < vec_true_outs.size(); ++j) {
*vec_true_outs.at(j) =
*std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(j).impl());
}
} else {
auto* true_out = ctx.Output<Tensor>(out_name);
*true_out =
*std::dynamic_pointer_cast<pten::DenseTensor>(outs.at(i).impl());
true_out_ptrs.size(), calc_outs->size()));
for (size_t i = 0; i < true_out_ptrs.size(); ++i) {
auto* true_out = true_out_ptrs.at(i);
auto calc_out =
std::dynamic_pointer_cast<pten::DenseTensor>(calc_outs->at(i).impl());
// assgin meta info
auto* true_out_meta = pten::DenseTensorUtils::GetMutableMeta(true_out);
true_out_meta->dims = calc_out->dims();
true_out_meta->dtype = calc_out->dtype();
true_out_meta->layout = calc_out->layout();
// lod and offset no need to be reset
// reset holder if needed
if (true_out->Holder() != calc_out->Holder()) {
true_out->ResetHolder(calc_out->Holder());
}
}
} catch (platform::EnforceNotMet& exception) {
......@@ -609,7 +648,7 @@ void RegisterOperatorWithMetaInfo(
auto op_name = OpMetaInfoHelper::GetOpName(base_op_meta);
if (OpInfoMap::Instance().Has(op_name)) {
LOG(WARNING) << "Operator (" << op_name << ")has been registered.";
LOG(WARNING) << "Operator (" << op_name << ") has been registered.";
return;
}
......
......@@ -17,6 +17,7 @@ limitations under the License. */
#include <iostream>
#include <string>
#include <unordered_map>
#include <utility>
#include <vector>
#include "paddle/pten/api/ext/dll_decl.h"
......@@ -76,36 +77,65 @@ inline std::string Vec(const std::string& t_name) {
return result;
}
PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst);
////////////////////// Kernel Context ////////////////////////
class PADDLE_API CustomOpKernelContext {
public:
CustomOpKernelContext() = default;
void EmplaceBackInput(Tensor&& input);
void EmplaceBackInputs(std::vector<Tensor>&& inputs);
void EmplaceBackOutput(Tensor&& output);
void EmplaceBackOutputs(std::vector<Tensor>&& outputs);
void EmplaceBackAttr(paddle::any attr);
const std::pair<size_t, size_t>& InputRangeAt(size_t idx) const;
const std::pair<size_t, size_t>& OutputRangeAt(size_t idx) const;
const Tensor& InputAt(size_t idx) const;
std::vector<Tensor> InputsBetween(size_t start, size_t end) const;
Tensor* MutableOutputAt(size_t idx);
std::vector<Tensor*> MutableOutputBetweeen(size_t start, size_t end);
std::vector<Tensor>* AllMutableOutput();
template <typename AttrType>
AttrType AttrAt(size_t idx) const {
try {
return paddle::any_cast<AttrType>(attrs_.at(idx));
} catch (paddle::bad_any_cast&) {
PD_THROW("Attribute cast error in Custom Op Kernel Context.");
}
}
private:
// TODO(chenweihang): replaced be SmallVector
std::vector<Tensor> inputs_;
std::vector<Tensor> outputs_;
std::vector<paddle::any> attrs_;
std::vector<std::pair<size_t, size_t>> input_range_;
std::vector<std::pair<size_t, size_t>> output_range_;
};
////////////////////// Kernel Function (PD_KERNEL) ////////////////////////
// Record Op kernel core function
using KernelFunc =
std::vector<Tensor> (*)(const std::vector<Tensor>& inputs,
const std::vector<std::vector<Tensor>>& vec_inputs,
const std::vector<paddle::any>& attrs);
using KernelFunc = void (*)(CustomOpKernelContext*);
#define PD_SPECIALIZE_ComputeCallHelper(attr_type) \
template <typename... Tail> \
struct ComputeCallHelper<attr_type, Tail...> { \
template <int in_idx, \
int vec_in_idx, \
int attr_idx, \
typename... PreviousArgs> \
static Return Compute(const std::vector<Tensor>& inputs, \
const std::vector<std::vector<Tensor>>& vec_inputs, \
const std::vector<paddle::any>& attrs, \
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs> \
static void Compute(CustomOpKernelContext* ctx, \
const PreviousArgs&... pargs) { \
try { \
attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]); \
return ComputeCallHelper<Tail...>::template Compute<in_idx, \
vec_in_idx, \
attr_idx + 1>( \
inputs, vec_inputs, attrs, pargs..., arg); \
} catch (paddle::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator. Expected " #attr_type \
" value."); \
} \
attr_type arg = ctx->AttrAt<attr_type>(attr_idx); \
ComputeCallHelper< \
Tail...>::template Compute<in_idx, attr_idx + 1, out_idx>(ctx, \
pargs..., \
arg); \
} \
}
......@@ -117,11 +147,8 @@ struct KernelFuncImpl;
template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
static Return Compute(const std::vector<Tensor>& inputs,
const std::vector<std::vector<Tensor>>& vec_inputs,
const std::vector<paddle::any>& attrs) {
return ComputeCallHelper<Args..., TypeTag<int>>::template Compute<0, 0, 0>(
inputs, vec_inputs, attrs);
static void Compute(CustomOpKernelContext* ctx) {
ComputeCallHelper<Args..., TypeTag<int>>::template Compute<0, 0, 0>(ctx);
}
private:
......@@ -130,37 +157,29 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
template <typename... Tail>
struct ComputeCallHelper<const Tensor&, Tail...> {
template <int in_idx,
int vec_in_idx,
int attr_idx,
typename... PreviousArgs>
static Return Compute(const std::vector<Tensor>& inputs,
const std::vector<std::vector<Tensor>>& vec_inputs,
const std::vector<paddle::any>& attrs,
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx,
const PreviousArgs&... pargs) {
const Tensor& arg = inputs[in_idx];
return ComputeCallHelper<Tail...>::template Compute<in_idx + 1,
vec_in_idx,
attr_idx>(
inputs, vec_inputs, attrs, pargs..., arg);
auto& range = ctx->InputRangeAt(in_idx);
auto& arg = ctx->InputAt(range.first);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
template <typename... Tail>
struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> {
template <int in_idx,
int vec_in_idx,
int attr_idx,
typename... PreviousArgs>
static Return Compute(const std::vector<Tensor>& inputs,
const std::vector<std::vector<Tensor>>& vec_inputs,
const std::vector<paddle::any>& attrs,
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx,
const PreviousArgs&... pargs) {
const std::vector<Tensor>& arg = vec_inputs[vec_in_idx];
return ComputeCallHelper<Tail...>::template Compute<in_idx,
vec_in_idx + 1,
attr_idx>(
inputs, vec_inputs, attrs, pargs..., arg);
auto& range = ctx->InputRangeAt(in_idx);
auto arg = ctx->InputsBetween(range.first, range.second);
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
};
......@@ -194,15 +213,76 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_ComputeCallHelper(std::vector<int64_t>);
PD_SPECIALIZE_ComputeCallHelper(std::vector<std::string>);
template <typename... Tail>
struct ComputeCallHelper<Tensor*, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx,
const PreviousArgs&... pargs) {
auto& range = ctx->OutputRangeAt(out_idx);
auto* arg = ctx->MutableOutputAt(range.first);
ComputeCallHelper<
Tail...>::template Compute<in_idx, attr_idx, out_idx + 1>(ctx,
pargs...,
arg);
}
};
// TODO(chenweihang): What is the appropriate output form?
// std::vector<Tensor>*? or std::vector<Tensor*>? or std::vector<Tensor*>*
template <typename... Tail>
struct ComputeCallHelper<std::vector<Tensor*>, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx,
const PreviousArgs&... pargs) {
auto& range = ctx->OutputRangeAt(out_idx);
auto arg = ctx->MutableOutputBetweeen(range.first, range.second);
ComputeCallHelper<
Tail...>::template Compute<in_idx, attr_idx, out_idx + 1>(ctx,
pargs...,
arg);
}
};
template <int out_idx, typename T>
struct ComputeReturnHelper;
// For compatibility with the original custom op form
template <int out_idx>
struct ComputeReturnHelper<out_idx, std::vector<Tensor>> {
static void Compute(CustomOpKernelContext* ctx, const Args&... args) {
static_assert(out_idx == 0,
"If return std::vector<Tensor> in Custom OpKernel, "
"you cannot pass output by kernel funciton argument.");
auto outs = impl_fn(args...);
auto* orig_outs = ctx->AllMutableOutput();
PD_CHECK(orig_outs->size() == outs.size(),
"The number of element in custom operator outputs is wrong, "
"expected contains ",
orig_outs->size(),
" Tensors, but actually contains ",
outs.size(),
" Tensors.");
for (size_t i = 0; i < outs.size(); ++i) {
AssignTensorImpl(outs.at(i), &(orig_outs->at(i)));
}
}
};
template <int out_idx>
struct ComputeReturnHelper<out_idx, void> {
static void Compute(CustomOpKernelContext* ctx, const Args&... args) {
static_assert(out_idx > 0, "Custom OpKernel has no output.");
impl_fn(args...);
}
};
// end: base template
template <typename T>
struct ComputeCallHelper<TypeTag<T>> {
template <int in_idx, int vec_in_idx, int attr_idx>
static Return Compute(const std::vector<Tensor>& inputs,
const std::vector<std::vector<Tensor>>& vec_inputs,
const std::vector<paddle::any>& attrs,
const Args&... args) {
return impl_fn(args...);
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx,
const PreviousArgs&... pargs) {
ComputeReturnHelper<out_idx, Return>::Compute(ctx, pargs...);
}
};
};
......
......@@ -19,10 +19,102 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/framework/custom_operator.h"
#include "paddle/pten/core/dense_tensor.h"
#include "paddle/pten/core/enforce.h"
namespace paddle {
PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst) {
PADDLE_ENFORCE_EQ(src.is_dense_tensor() && dst->is_dense_tensor(),
true,
pten::errors::Unavailable(
"Now only supported DenseTensor in Custom Operator."));
PADDLE_ENFORCE_EQ(
src.initialized(),
true,
pten::errors::Unavailable(
"The Custom OpKernel calculate output is not initialized."));
PADDLE_ENFORCE_EQ(dst->defined(),
true,
pten::errors::Unavailable(
"The Custom OpKernel origin output is not defined."));
auto& dense_src = static_cast<const pten::DenseTensor&>(*src.impl());
auto* dense_dst = static_cast<pten::DenseTensor*>(dst->impl().get());
*dense_dst = dense_src;
}
////////////////////// Kernel Context //////////////////////
void CustomOpKernelContext::EmplaceBackInput(Tensor&& input) {
size_t index = inputs_.size();
inputs_.emplace_back(input);
input_range_.emplace_back(std::make_pair(index, index + 1));
}
void CustomOpKernelContext::EmplaceBackInputs(std::vector<Tensor>&& inputs) {
size_t index = inputs_.size();
input_range_.emplace_back(std::make_pair(index, index + inputs.size()));
inputs_.insert(inputs_.end(),
std::make_move_iterator(inputs.begin()),
std::make_move_iterator(inputs.end()));
}
void CustomOpKernelContext::EmplaceBackOutput(Tensor&& output) {
size_t index = outputs_.size();
outputs_.emplace_back(output);
output_range_.emplace_back(std::make_pair(index, index + 1));
}
void CustomOpKernelContext::EmplaceBackOutputs(std::vector<Tensor>&& outputs) {
size_t index = outputs_.size();
output_range_.emplace_back(std::make_pair(index, index + outputs.size()));
outputs_.insert(outputs_.end(),
std::make_move_iterator(outputs.begin()),
std::make_move_iterator(outputs.end()));
}
void CustomOpKernelContext::EmplaceBackAttr(paddle::any attr) {
attrs_.emplace_back(std::move(attr));
}
const Tensor& CustomOpKernelContext::InputAt(size_t idx) const {
return inputs_.at(idx);
}
std::vector<Tensor> CustomOpKernelContext::InputsBetween(size_t start,
size_t end) const {
std::vector<Tensor> rlt;
for (size_t i = start; i < end; ++i) {
rlt.emplace_back(inputs_.at(i));
}
return rlt;
}
Tensor* CustomOpKernelContext::MutableOutputAt(size_t idx) {
return &(outputs_.at(idx));
}
std::vector<Tensor*> CustomOpKernelContext::MutableOutputBetweeen(size_t start,
size_t end) {
std::vector<Tensor*> rlt;
for (size_t i = start; i < end; ++i) {
rlt.emplace_back(&(outputs_.at(i)));
}
return rlt;
}
std::vector<Tensor>* CustomOpKernelContext::AllMutableOutput() {
return &outputs_;
}
const std::pair<size_t, size_t>& CustomOpKernelContext::InputRangeAt(
size_t idx) const {
return input_range_.at(idx);
}
const std::pair<size_t, size_t>& CustomOpKernelContext::OutputRangeAt(
size_t idx) const {
return output_range_.at(idx);
}
////////////////////// Op Meta Info //////////////////////
OpMetaInfo& OpMetaInfo::Inputs(std::vector<std::string>&& inputs) {
......
......@@ -151,3 +151,63 @@ PD_BUILD_GRAD_OP(custom_relu_no_x_in_backward)
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackwardWithoutX))
.SetInferShapeFn(PD_INFER_SHAPE(ReluBackwardWithoutXInferShape));
void relu_cpu_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
PD_DISPATCH_FLOATING_TYPES(
x.type(), "relu_cpu_forward", ([&] {
relu_cpu_forward_kernel<data_t>(
x.data<data_t>(), out->mutable_data<data_t>(x.place()), x.size());
}));
}
void relu_cpu_backward_out(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out,
paddle::Tensor* grad_x) {
PD_DISPATCH_FLOATING_TYPES(out.type(), "relu_cpu_backward", ([&] {
relu_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x->mutable_data<data_t>(x.place()),
out.size());
}));
}
void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out);
void relu_cuda_backward_out(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out,
paddle::Tensor* grad_x);
void ReluForwardOut(const paddle::Tensor& x, paddle::Tensor* out) {
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_forward_out(x, out);
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_forward_out(x, out);
} else {
PD_THROW("Not implemented.");
}
}
void ReluBackwardOut(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out,
paddle::Tensor* grad_x) {
if (x.place() == paddle::PlaceType::kCPU) {
return relu_cpu_backward_out(x, out, grad_out, grad_x);
} else if (x.place() == paddle::PlaceType::kGPU) {
return relu_cuda_backward_out(x, out, grad_out, grad_x);
} else {
PD_THROW("Not implemented.");
}
}
PD_BUILD_OP(custom_relu_out)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(ReluForwardOut));
PD_BUILD_GRAD_OP(custom_relu_out)
.Inputs({"X", "Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(ReluBackwardOut));
......@@ -89,3 +89,31 @@ std::vector<paddle::Tensor> relu_cuda_backward_without_x(
return {grad_x};
}
void relu_cuda_forward_out(const paddle::Tensor& x, paddle::Tensor* out) {
int numel = x.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
x.type(), "relu_cuda_forward_kernel", ([&] {
relu_cuda_forward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
x.data<data_t>(), out->mutable_data<data_t>(x.place()), numel);
}));
}
void relu_cuda_backward_out(const paddle::Tensor& x,
const paddle::Tensor& out,
const paddle::Tensor& grad_out,
paddle::Tensor* grad_x) {
int numel = out.size();
int block = 512;
int grid = (numel + block - 1) / block;
PD_DISPATCH_FLOATING_AND_HALF_TYPES(
out.type(), "relu_cuda_backward_kernel", ([&] {
relu_cuda_backward_kernel<data_t><<<grid, block, 0, x.stream()>>>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x->mutable_data<data_t>(x.place()),
numel);
}));
}
......@@ -50,7 +50,8 @@ class TestJITLoad(unittest.TestCase):
def setUp(self):
self.custom_ops = [
custom_module.custom_relu, custom_module.custom_relu_dup,
custom_module.custom_relu_no_x_in_backward
custom_module.custom_relu_no_x_in_backward,
custom_module.custom_relu_out
]
self.dtypes = ['float32', 'float64']
if paddle.is_compiled_with_cuda():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册