diff --git a/paddle/fluid/operators/stack_op_xpu.cc b/paddle/fluid/operators/stack_op_xpu.cc deleted file mode 100644 index 4c65b21488264c8a074de1be7042e8f2f2ef18e5..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/stack_op_xpu.cc +++ /dev/null @@ -1,114 +0,0 @@ -// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -#ifdef PADDLE_WITH_XPU -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/concat_op.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" - -namespace paddle { -namespace operators { - -using framework::Tensor; -template -class StackXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto x = ctx.MultiInput("X"); - auto* y = ctx.Output("Y"); - int axis = ctx.Attr("axis"); - if (axis < 0) { - axis += x[0]->dims().size() + 1; - } - auto* y_data = y->mutable_data(ctx.GetPlace()); - - auto& dim = x[0]->dims(); - std::vector xdims; - for (auto i = 0; i < dim.size(); ++i) { - xdims.push_back(dim[i]); - } - xdims.push_back(1); - std::vector> xdims_list; - int n = static_cast(x.size()); - for (int i = 0; i < n; i++) { - xdims_list.push_back(xdims); - } - - std::vector x_list; - for (int i = 0; i < n; i++) { - x_list.push_back(x[i]->data()); - } - - auto& dev_ctx = ctx.template device_context(); - int r = - xpu::concat(dev_ctx.x_context(), x_list, y_data, xdims_list, axis); - PADDLE_ENFORCE_EQ(r, - xpu::Error_t::SUCCESS, - platform::errors::External( - "The stack XPU API return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -template -class StackGradXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto* dy = ctx.Input(framework::GradVarName("Y")); - auto dx = ctx.MultiOutput(framework::GradVarName("X")); - auto axis = ctx.Attr("axis"); - auto& dev_ctx = ctx.template device_context(); - auto dy_dims = dy->dims(); - - if (axis < 0) axis += dy_dims.size() + 1; - auto dy_shape = phi::vectorize(dy_dims); - - std::vector dx_dims_list(dx.size(), 1); - std::vector dx_lists; - for (auto out : dx) { - dx_lists.push_back(out->mutable_data(ctx.GetPlace())); - } - - int r = xpu::split(dev_ctx.x_context(), - dy->data(), - dx_lists, - dy_shape, - dx_dims_list, - axis); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "The stack_grad XPU kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace plat = paddle::platform; -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL(stack, - ops::StackXPUKernel, - ops::StackXPUKernel, - ops::StackXPUKernel); -REGISTER_OP_XPU_KERNEL(stack_grad, - ops::StackGradXPUKernel, - ops::StackGradXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/stack_grad_kernel.cc b/paddle/phi/kernels/xpu/stack_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..59319d3e7624ce3ca4fc29531e8f21e3414a6162 --- /dev/null +++ b/paddle/phi/kernels/xpu/stack_grad_kernel.cc @@ -0,0 +1,53 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/stack_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void StackGradKernel(const Context& dev_ctx, + const DenseTensor& out, + int axis, + std::vector x_grad) { + using XPUType = typename XPUTypeTrait::Type; + auto outs = x_grad; + auto dy_dims = out.dims(); + + if (axis < 0) axis += dy_dims.size() + 1; + auto dy_shape = phi::vectorize(dy_dims); + + std::vector dx_dims_list(x_grad.size(), 1); + std::vector dx_lists; + for (size_t j = 0; j < outs.size(); ++j) { + dev_ctx.template Alloc(outs[j]); + dx_lists.push_back(reinterpret_cast(outs[j]->data())); + } + + int r = xpu::split(dev_ctx.x_context(), + reinterpret_cast(out.data()), + dx_lists, + dy_shape, + dx_dims_list, + axis); + + PADDLE_ENFORCE_XDNN_SUCCESS(r, "split in stack_grad op"); +} +} // namespace phi + +PD_REGISTER_KERNEL( + stack_grad, XPU, ALL_LAYOUT, phi::StackGradKernel, float, int) {} diff --git a/paddle/phi/kernels/xpu/stack_kernel.cc b/paddle/phi/kernels/xpu/stack_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..b908a6a080e785264a449072b479d6ec39451d1b --- /dev/null +++ b/paddle/phi/kernels/xpu/stack_kernel.cc @@ -0,0 +1,59 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/stack_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void StackKernel(const Context& dev_ctx, + const std::vector& x, + int axis, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + if (axis < 0) { + axis += x[0]->dims().size() + 1; + } + dev_ctx.template Alloc(out); + auto& dim = x[0]->dims(); + std::vector xdims; + for (auto i = 0; i < dim.size(); ++i) { + xdims.push_back(dim[i]); + } + xdims.push_back(1); + std::vector> xdims_list; + int n = static_cast(x.size()); + for (int i = 0; i < n; i++) { + xdims_list.push_back(xdims); + } + + std::vector x_list; + for (int i = 0; i < n; i++) { + x_list.push_back(reinterpret_cast(x[i]->data())); + } + + int r = xpu::concat(dev_ctx.x_context(), + x_list, + reinterpret_cast(out->data()), + xdims_list, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat in stack op"); +} +} // namespace phi + +PD_REGISTER_KERNEL( + stack, XPU, ALL_LAYOUT, phi::StackKernel, float, int, int64_t) {}