未验证 提交 de49a4b7 编写于 作者: C chentianyu03 提交者: GitHub

exchange assign and assign_raw kernel name (#41625)

* exchange assign and assign_raw kernel name

* fix register error
上级 0835de79
......@@ -23,14 +23,14 @@
namespace phi {
template <typename Context>
void AssignRawKernel(const Context& dev_ctx,
void AssignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
Copy<Context>(dev_ctx, x, x.place(), false, out);
}
template <typename Context>
void AssignKernel(const Context& dev_ctx,
void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out) {
if (x) {
......@@ -38,7 +38,7 @@ void AssignKernel(const Context& dev_ctx,
return;
}
auto& x_tensor = *x.get_ptr();
AssignRawKernel<Context>(dev_ctx, x_tensor, out);
AssignKernel<Context>(dev_ctx, x_tensor, out);
}
}
......@@ -111,14 +111,14 @@ void AssignValueKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_raw,
CPU,
ALL_LAYOUT,
phi::AssignRawKernel<phi::CPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
assign, CPU, ALL_LAYOUT, phi::AssignKernel<phi::CPUContext>, ALL_DTYPE) {
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
......@@ -136,13 +136,13 @@ PD_REGISTER_KERNEL(assign_value,
int64_t) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(assign_raw,
GPU,
ALL_LAYOUT,
phi::AssignRawKernel<phi::GPUContext>,
ALL_DTYPE) {}
PD_REGISTER_GENERAL_KERNEL(
assign, GPU, ALL_LAYOUT, phi::AssignKernel<phi::GPUContext>, ALL_DTYPE) {
ALL_DTYPE) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
PD_REGISTER_GENERAL_KERNEL(assign_array,
......
......@@ -22,7 +22,7 @@
namespace phi {
template <typename Context>
void AssignRawKernel(const Context& dev_ctx,
void AssignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out);
......@@ -30,7 +30,7 @@ void AssignRawKernel(const Context& dev_ctx,
// assign op maker, the input parameter here needs to be dispensable, but
// this looks weird
template <typename Context>
void AssignKernel(const Context& dev_ctx,
void AssignRawKernel(const Context& dev_ctx,
paddle::optional<const DenseTensor&> x,
DenseTensor* out);
......
......@@ -23,10 +23,10 @@ KernelSignature AssignOpArgumentMapping(const ArgumentMappingContext& ctx) {
} else if (ctx.IsSelectedRowsInput("X")) {
return KernelSignature("assign_sr", {"X"}, {}, {"Out"});
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
}
} else {
return KernelSignature("assign", {"X"}, {}, {"Out"});
return KernelSignature("assign_raw", {"X"}, {}, {"Out"});
}
}
......
......@@ -174,7 +174,7 @@
infer_meta :
func : UnchangedInferMeta
kernel :
func : assign_raw
func : assign
backward : assign_grad
# atan
......
......@@ -120,7 +120,7 @@
func : UnchangedInferMeta
param : [out_grad]
kernel :
func : assign_raw
func : assign
- backward_api : atan2_grad
forward : atan2 (Tensor x, Tensor y) -> Tensor(out)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册