diff --git a/paddle/fluid/operators/concat_op_xpu.cc b/paddle/fluid/operators/concat_op_xpu.cc deleted file mode 100644 index c4e84dcbb5c104ba6f745002562ad62556c2293b..0000000000000000000000000000000000000000 --- a/paddle/fluid/operators/concat_op_xpu.cc +++ /dev/null @@ -1,235 +0,0 @@ -/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ -#ifdef PADDLE_WITH_XPU -#include -#include -#include - -#include "paddle/fluid/operators/concat_op.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" -#include "paddle/phi/core/lod_utils.h" - -namespace paddle { -namespace operators { -using Tensor = framework::Tensor; - -template -class ConcatXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const override { - auto ins = ctx.MultiInput("X"); - framework::LoDTensor* out = ctx.Output("Out"); - int axis = ctx.Attr("axis"); - PADDLE_ENFORCE_NE( - ins[0], - nullptr, - platform::errors::InvalidArgument("The input should not be null.")); - PADDLE_ENFORCE_NE(ctx.HasInput("AxisTensor"), - true, - platform::errors::InvalidArgument( - "XPU donot surpport AxisTensor for now")); - axis = ComputeAxis(static_cast(axis), - static_cast(ins[0]->dims().size())); - PADDLE_ENFORCE_GE(axis, - 0, - platform::errors::InvalidArgument( - "concat: axis should be larger than or " - "equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT(axis, - ins[0]->dims().size(), - platform::errors::InvalidArgument( - "concat: axis should be less than ins[0]->dims()!" - "But received axis is %d, while ins[0]->dims()" - "size is %d.", - axis, - ins[0]->dims().size())); - - // If axis is 0, the lod of the output is not the same as inputs. - if (axis == 0 && ins[0]->lod().size() > 0) { - size_t lod_size_0 = ins[0]->lod().size(); - size_t lod_size = lod_size_0; - for (size_t i = 1; i < ins.size(); ++i) { - if (ins[i]->lod().size() > 0) { - PADDLE_ENFORCE_EQ( - ins[i]->lod().size(), - lod_size_0, - platform::errors::Unimplemented( - "The lod level of all input LoDTensors should be same. " - "Maybe different lod level of input LoDTensors can concat," - "it is not supported currently. The lod level of %dth input " - "is %d and first input is %d.", - i, - ins[i]->lod().size(), - lod_size_0)); - } else { - lod_size = 0; - break; - } - } - if (lod_size) { - auto* out_lod = out->mutable_lod(); - for (size_t i = 1; i < ins.size(); ++i) { - auto in_lod = phi::ConvertToLengthBasedLoD(ins[i]->lod()); - phi::AppendLoD(out_lod, in_lod); - } - } - } - auto place = ctx.GetPlace(); - out->mutable_data(place); - std::vector> xdims_list; - std::vector ptrs; - for (unsigned int i = 0; i < ins.size(); ++i) { - if (ins[i] && ins[i]->numel() > 0) { - ptrs.push_back(reinterpret_cast(ins[i]->data())); - int size = ins[i]->dims().size(); - std::vector tmp_dims(size); - for (int j = 0; j < size; ++j) { - tmp_dims[j] = ins[i]->dims()[j]; - } - xdims_list.push_back(tmp_dims); - } - } - - PADDLE_ENFORCE_GT( - xdims_list.size(), - 0, - platform::errors::InvalidArgument("No tensor need concat")); - auto& dev_ctx = ctx.template device_context(); - - int r = xpu::concat(dev_ctx.x_context(), - ptrs, - reinterpret_cast(out->data()), - xdims_list, - axis); - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU concat kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -template -class ConcatGradXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - void Compute(const framework::ExecutionContext& ctx) const { - auto* out_grad = - ctx.Input(framework::GradVarName("Out")); - auto ins = ctx.MultiInput("X"); - auto out_var_names = ctx.OutputNames(framework::GradVarName("X")); - auto outs = - ctx.MultiOutput(framework::GradVarName("X")); - { - auto dx = outs; - auto x = ins; - for (size_t i = 0; i < dx.size(); ++i) { - if (dx[i] != nullptr) { - dx[i]->set_lod(x[i]->lod()); - } - } - } - PADDLE_ENFORCE_NE( - ins[0], - nullptr, - platform::errors::InvalidArgument("The input should not be null.")); - auto axis = ctx.Attr("axis"); - if (ctx.HasInput("AxisTensor")) { - auto* axis_tensor = ctx.Input("AxisTensor"); - axis = GetDataFromTensor(axis_tensor)[0]; - } - axis = ComputeAxis(static_cast(axis), - static_cast(ins[0]->dims().size())); - // get output tensor that the name is not kEmptyVarName - std::vector ptrs(outs.size()); - for (size_t j = 0; j < outs.size(); ++j) { - if (out_var_names[j] != framework::kEmptyVarName && - outs[j]->numel() != 0UL) { - outs[j]->mutable_data(ctx.GetPlace()); - ptrs[j] = reinterpret_cast(outs[j]->data()); - } else { - ptrs[j] = nullptr; - } - } - PADDLE_ENFORCE_GE(axis, - 0, - platform::errors::InvalidArgument( - "concat_grad: axis should be larger than or " - "equal to 0, but received axis is %d.", - axis)); - PADDLE_ENFORCE_LT( - axis, - out_grad->dims().size(), - platform::errors::InvalidArgument( - "concat_grad: axis should be less than ins[0]->dims()!" - "But received axis is %d, while ins[0]->dims()" - "size is %d.", - axis, - out_grad->dims().size())); - - auto input_dims = ins[0]->dims(); - std::vector split_list(ins.size()); - std::vector xdims_list(input_dims.size()); - int total_length = 0; - for (size_t i = 0; i < ins.size(); ++i) { - split_list[i] = ins[i]->dims()[axis]; - total_length += ins[i]->dims()[axis]; - } - for (int i = 0; i < input_dims.size(); ++i) { - if (i == axis) { - continue; - } - xdims_list[i] = input_dims[i]; - } - xdims_list[axis] = total_length; - - auto& dev_ctx = ctx.template device_context(); - int r = xpu::split( - dev_ctx.x_context(), - reinterpret_cast(out_grad->data()), - ptrs, - xdims_list, - split_list, - axis); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External( - "XPU API return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - r)); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; -REGISTER_OP_XPU_KERNEL( - concat, - ops::ConcatXPUKernel, - ops::ConcatXPUKernel); -REGISTER_OP_XPU_KERNEL( - concat_grad, - ops::ConcatGradXPUKernel, - ops::ConcatGradXPUKernel); - -#endif diff --git a/paddle/phi/kernels/xpu/concat_grad_kernel.cc b/paddle/phi/kernels/xpu/concat_grad_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..9f2053ddae8545cac9ba4604346f85ebd2f85753 --- /dev/null +++ b/paddle/phi/kernels/xpu/concat_grad_kernel.cc @@ -0,0 +1,105 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_grad_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" + +namespace phi { + +template +void ConcatGradKernel(const Context& dev_ctx, + const std::vector& x, + const DenseTensor& out_grad, + const Scalar& axis_scalar, + std::vector x_grad) { + using XPUType = typename XPUTypeTrait::Type; + auto outs = x_grad; + { + auto dx = outs; + for (size_t i = 0; i < dx.size(); ++i) { + if (dx[i] != nullptr) { + dx[i]->set_lod(x[i]->lod()); + } + } + } + PADDLE_ENFORCE_NE( + x[0], + nullptr, + phi::errors::InvalidArgument("The input should not be null.")); + auto axis = axis_scalar.to(); + axis = phi::funcs::ComputeAxis(static_cast(axis), + static_cast(x[0]->dims().size())); + // get output tensor that the name is not kEmptyVarName + std::vector ptrs(outs.size()); + for (size_t j = 0; j < outs.size(); ++j) { + if (outs[j] && outs[j]->numel() != 0UL) { + dev_ctx.template Alloc(outs[j]); + ptrs[j] = reinterpret_cast(outs[j]->data()); + } else { + ptrs[j] = nullptr; + } + } + PADDLE_ENFORCE_GE( + axis, + 0, + phi::errors::InvalidArgument("concat_grad: axis should be larger than or " + "equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + out_grad.dims().size(), + phi::errors::InvalidArgument( + "concat_grad: axis should be less than x[0]->dims()!" + "But received axis is %d, while x[0]->dims()" + "size is %d.", + axis, + out_grad.dims().size())); + + auto input_dims = x[0]->dims(); + std::vector split_list(x.size()); + std::vector xdims_list(input_dims.size()); + int total_length = 0; + for (size_t i = 0; i < x.size(); ++i) { + split_list[i] = x[i]->dims()[axis]; + total_length += x[i]->dims()[axis]; + } + for (int i = 0; i < input_dims.size(); ++i) { + if (i == axis) { + continue; + } + xdims_list[i] = input_dims[i]; + } + xdims_list[axis] = total_length; + + int r = + xpu::split(dev_ctx.x_context(), + reinterpret_cast(out_grad.data()), + ptrs, + xdims_list, + split_list, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat_grad"); +} + +} // namespace phi + +PD_REGISTER_KERNEL(concat_grad, + XPU, + ALL_LAYOUT, + phi::ConcatGradKernel, + float, + phi::dtype::float16) {} diff --git a/paddle/phi/kernels/xpu/concat_kernel.cc b/paddle/phi/kernels/xpu/concat_kernel.cc new file mode 100644 index 0000000000000000000000000000000000000000..50b323429b0672b6d70853615962fedc587d9f08 --- /dev/null +++ b/paddle/phi/kernels/xpu/concat_kernel.cc @@ -0,0 +1,111 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/concat_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/core/kernel_registry.h" +#include "paddle/phi/core/lod_utils.h" +#include "paddle/phi/kernels/funcs/axis_utils.h" +#include "paddle/phi/kernels/funcs/concat_funcs.h" + +namespace phi { + +template +void ConcatKernel(const Context& dev_ctx, + const std::vector& x, + const Scalar& axis_scalar, + DenseTensor* out) { + using XPUType = typename XPUTypeTrait::Type; + int64_t axis = axis_scalar.to(); + PADDLE_ENFORCE_NE( + x[0], + nullptr, + phi::errors::InvalidArgument("The input should not be null.")); + axis = phi::funcs::ComputeAxis(axis, x[0]->dims().size()); + PADDLE_ENFORCE_GE( + axis, + 0, + phi::errors::InvalidArgument("concat: axis should be larger than or " + "equal to 0, but received axis is %d.", + axis)); + PADDLE_ENFORCE_LT(axis, + x[0]->dims().size(), + phi::errors::InvalidArgument( + "concat: axis should be less than x[0]->dims()!" + "But received axis is %d, while x[0]->dims()" + "size is %d.", + axis, + x[0]->dims().size())); + + // If axis is 0, the lod of the output is not the same as inputs. + if (axis == 0 && x[0]->lod().size() > 0) { + size_t lod_size_0 = x[0]->lod().size(); + size_t lod_size = lod_size_0; + for (size_t i = 1; i < x.size(); ++i) { + if (x[i]->lod().size() > 0) { + PADDLE_ENFORCE_EQ( + x[i]->lod().size(), + lod_size_0, + phi::errors::Unimplemented( + "The lod level of all input LoDTensors should be same. " + "Maybe different lod level of input LoDTensors can concat," + "it is not supported currently. The lod level of %dth input " + "is %d and first input is %d.", + i, + x[i]->lod().size(), + lod_size_0)); + } else { + lod_size = 0; + break; + } + } + if (lod_size) { + auto* out_lod = out->mutable_lod(); + for (size_t i = 1; i < x.size(); ++i) { + auto in_lod = phi::ConvertToLengthBasedLoD(x[i]->lod()); + phi::AppendLoD(out_lod, in_lod); + } + } + } + dev_ctx.template Alloc(out); + std::vector> xdims_list; + std::vector ptrs; + for (unsigned int i = 0; i < x.size(); ++i) { + if (x[i] && x[i]->numel() > 0) { + ptrs.push_back(reinterpret_cast(x[i]->data())); + int size = x[i]->dims().size(); + std::vector tmp_dims(size); + for (int j = 0; j < size; ++j) { + tmp_dims[j] = x[i]->dims()[j]; + } + xdims_list.push_back(tmp_dims); + } + } + + PADDLE_ENFORCE_GT(xdims_list.size(), + 0, + phi::errors::InvalidArgument("No tensor need concat")); + int r = xpu::concat(dev_ctx.x_context(), + ptrs, + reinterpret_cast(out->data()), + xdims_list, + axis); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "concat"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + concat, XPU, ALL_LAYOUT, phi::ConcatKernel, float, phi::dtype::float16) {}