未验证 提交 de49a4b7 编写于 作者: C chentianyu03 提交者: GitHub

exchange assign and assign_raw kernel name (#41625)

* exchange assign and assign_raw kernel name

* fix register error
上级 0835de79
...@@ -23,22 +23,22 @@ ...@@ -23,22 +23,22 @@
namespace phi { namespace phi {
template <typename Context> template <typename Context>
void AssignRawKernel(const Context& dev_ctx, void AssignKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out) { DenseTensor* out) {
Copy<Context>(dev_ctx, x, x.place(), false, out); Copy<Context>(dev_ctx, x, x.place(), false, out);
} }
template <typename Context> template <typename Context>
void AssignKernel(const Context& dev_ctx, void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x, paddle::optional<const DenseTensor&> x,
DenseTensor* out) { DenseTensor* out) {
if (x) { if (x) {
if (!x->IsInitialized()) { if (!x->IsInitialized()) {
return; return;
} }
auto& x_tensor = *x.get_ptr(); auto& x_tensor = *x.get_ptr();
AssignRawKernel<Context>(dev_ctx, x_tensor, out); AssignKernel<Context>(dev_ctx, x_tensor, out);
} }
} }
...@@ -111,14 +111,14 @@ void AssignValueKernel(const Context& dev_ctx, ...@@ -111,14 +111,14 @@ void AssignValueKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_raw, PD_REGISTER_GENERAL_KERNEL(assign_raw,
CPU, CPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::AssignRawKernel<phi::CPUContext>, phi::AssignRawKernel<phi::CPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
PD_REGISTER_GENERAL_KERNEL(assign_array, PD_REGISTER_GENERAL_KERNEL(assign_array,
...@@ -136,13 +136,13 @@ PD_REGISTER_KERNEL(assign_value, ...@@ -136,13 +136,13 @@ PD_REGISTER_KERNEL(assign_value,
int64_t) {} int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_raw, PD_REGISTER_GENERAL_KERNEL(assign_raw,
GPU, GPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::AssignRawKernel<phi::GPUContext>, phi::AssignRawKernel<phi::GPUContext>,
ALL_DTYPE) {} ALL_DTYPE) {
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND); kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
} }
PD_REGISTER_GENERAL_KERNEL(assign_array, PD_REGISTER_GENERAL_KERNEL(assign_array,
......
...@@ -22,17 +22,17 @@ ...@@ -22,17 +22,17 @@
namespace phi { namespace phi {
template <typename Context> template <typename Context>
void AssignRawKernel(const Context& dev_ctx, void AssignKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
DenseTensor* out); DenseTensor* out);
// In order to be compatible with the `AsDispensable` input in the original // In order to be compatible with the `AsDispensable` input in the original
// assign op maker, the input parameter here needs to be dispensable, but // assign op maker, the input parameter here needs to be dispensable, but
// this looks weird // this looks weird
template <typename Context> template <typename Context>
void AssignKernel(const Context& dev_ctx, void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x, paddle::optional<const DenseTensor&> x,
DenseTensor* out); DenseTensor* out);
template <typename Context> template <typename Context>
void AssignArrayKernel(const Context& dev_ctx, void AssignArrayKernel(const Context& dev_ctx,
......
...@@ -23,10 +23,10 @@ KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) { ...@@ -23,10 +23,10 @@ KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) {
} else if (ctx.IsSelectedRowsInput("X")) { } else if (ctx.IsSelectedRowsInput("X")) {
return KernelSignature("assign_sr", {"X"}, {}, {"Out"}); return KernelSignature("assign_sr", {"X"}, {}, {"Out"});
} else { } else {
return KernelSignature("assign", {"X"}, {}, {"Out"}); return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
} }
} else { } else {
return KernelSignature("assign", {"X"}, {}, {"Out"}); return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
} }
} }
......
...@@ -174,7 +174,7 @@ ...@@ -174,7 +174,7 @@
infer_meta : infer_meta :
func : UnchangedInferMeta func : UnchangedInferMeta
kernel : kernel :
func : assign_raw func : assign
backward : assign_grad backward : assign_grad
# atan # atan
......
...@@ -120,7 +120,7 @@ ...@@ -120,7 +120,7 @@
func : UnchangedInferMeta func : UnchangedInferMeta
param : [out_grad] param : [out_grad]
kernel : kernel :
func : assign_raw func : assign
- backward_api : atan2_grad - backward_api : atan2_grad
forward : atan2 (Tensor x, Tensor y) -> Tensor(out) forward : atan2 (Tensor x, Tensor y) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册