From 1349584e907c49eae1a357991bd300eb49c3861b Mon Sep 17 00:00:00 2001 From: Leo Guo <58431564+ZibinGuo@users.noreply.github.com> Date: Wed, 14 Sep 2022 14:24:34 +0800 Subject: [PATCH] Migrate scale and scatter to phi, and modify the code style for roi_align_kernel. test=kunlun (#45938) --- paddle/fluid/operators/scale_op_xpu.cc | 70 ----------- paddle/fluid/operators/scatter_op_xpu.cc | 136 --------------------- paddle/phi/kernels/xpu/roi_align_kernel.cc | 2 +- paddle/phi/kernels/xpu/scale_kernel.cc | 10 +- paddle/phi/kernels/xpu/scatter_kernel.cc | 110 +++++++++++++++++ 5 files changed, 114 insertions(+), 214 deletions(-) delete mode 100644 paddle/fluid/operators/scale_op_xpu.cc delete mode 100644 paddle/fluid/operators/scatter_op_xpu.cc create mode 100644 paddle/phi/kernels/xpu/scatter_kernel.cc diff --git a/paddle/fluid/operators/scale_op_xpu.cc b/paddle/fluid/operators/scale_op_xpu.cc deleted file mode 100644 index 6d7982327a..0000000000 --- a/paddle/fluid/operators/scale_op_xpu.cc +++ /dev/null @@ -1,70 +0,0 @@ -/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU - -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/phi/kernels/scale_kernel.h" - -namespace paddle { -namespace operators { -template -class ScaleXPUKernel : public framework::OpKernel { - using XPUType = typename XPUTypeTrait::Type; - - public: - virtual void Compute(const framework::ExecutionContext& ctx) const { - auto* in_var = ctx.InputVar("X"); - auto* in = framework::GetLoDTensorOrSelectedRowsValueFromVar(*in_var); - auto scale = static_cast(ctx.Attr("scale")); - auto bias = static_cast(ctx.Attr("bias")); - auto bias_after_scale = ctx.Attr("bias_after_scale"); - auto* out_var = ctx.OutputVar("Out"); - if (in_var->IsType() && in_var != out_var) { - auto& in_slr = in_var->Get(); - auto* out_slr = out_var->GetMutable(); - out_slr->set_rows(in_slr.rows()); - out_slr->set_height(in_slr.height()); - } - auto* out = - framework::GetMutableLoDTensorOrSelectedRowsValueFromVar(out_var); - out->mutable_data(in->place()); - auto& dev_ctx = ctx.template device_context(); - // call phi kernel - phi::ScaleKernel( - static_cast::TYPE&>(dev_ctx), - *in, - scale, - bias, - bias_after_scale, - out); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL( - scale, - ops::ScaleXPUKernel, - ops::ScaleXPUKernel, - ops::ScaleXPUKernel); - -#endif diff --git a/paddle/fluid/operators/scatter_op_xpu.cc b/paddle/fluid/operators/scatter_op_xpu.cc deleted file mode 100644 index 5933368332..0000000000 --- a/paddle/fluid/operators/scatter_op_xpu.cc +++ /dev/null @@ -1,136 +0,0 @@ -/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#ifdef PADDLE_WITH_XPU -#include -#include - -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/operator.h" -#include "paddle/fluid/framework/tensor.h" -#include "paddle/fluid/platform/device_context.h" - -namespace paddle { -namespace operators { - -using Tensor = framework::Tensor; - -template -class ScatterOpXPUKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext &ctx) const override { - auto *x = ctx.Input("X"); - auto *index = ctx.Input("Ids"); - auto *updates = ctx.Input("Updates"); - auto *out = ctx.Output("Out"); - bool overwrite = ctx.Attr("overwrite"); - - // In place output: Out = X, Out[ids] = Updates - framework::TensorCopy(*x, ctx.GetPlace(), out); - // Apply ScatterUpdate: Out[index] = Updates[:] - const auto &index_type = framework::TransToProtoVarType(index->dtype()); - bool index_type_match = index_type == framework::proto::VarType::INT32 || - index_type == framework::proto::VarType::INT64; - PADDLE_ENFORCE_EQ(index_type_match, - true, - platform::errors::InvalidArgument( - "Index holds the wrong type, it holds [%s]," - "but desires to be [%s] or [%s].", - paddle::framework::DataTypeToString(index_type), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT32), - paddle::framework::DataTypeToString( - framework::proto::VarType::INT64))); - - // check index of shape 1-D - PADDLE_ENFORCE_EQ( - index->dims().size() == 1 || - (index->dims().size() == 2 && index->dims()[1] == 1), - true, - platform::errors::InvalidArgument( - "index's shape is error, " - "expect index'dims shape is 1 or 2 and index.dims[1] is 1" - "but got index'dims shape is %d", - index->dims().size())); - - int index_size = static_cast(index->dims()[0]); - auto x_dims = x->dims(); - auto update_dims = updates->dims(); - for (int i = 1; i < x_dims.size(); i++) - PADDLE_ENFORCE_EQ( - x_dims[i], - update_dims[i], - platform::errors::InvalidArgument( - "The dimensions of the source tensor and target tensor should" - " match, but received source tensor's %d-th dimension is %d," - "target tensor's %d-th dimension is %d.", - i, - x_dims[i], - i, - update_dims[i])); - - int dim0 = static_cast(x->dims()[0]); - int dim1 = static_cast( - phi::product(phi::slice_ddim(x_dims, 1, x_dims.size()))); - T *out_data = out->data(); - const T *updates_data = updates->data(); - - auto &dev_ctx = - ctx.template device_context(); - int r = XPU_SUCCESS; - - Tensor indices_cpu(index->type()); - framework::TensorCopy(*index, platform::CPUPlace(), &indices_cpu); - - if (index_type == framework::proto::VarType::INT32) { - auto index_data = const_cast(index->data()); - xpu::VectorParam indices{ - indices_cpu.data(), index_size, index_data}; - r = xpu::scatter(dev_ctx.x_context(), - updates_data, - out_data, - indices, - dim0, - dim1, - overwrite); - } else { - auto index_data = const_cast(index->data()); - xpu::VectorParam indices{ - indices_cpu.data(), index_size, index_data}; - r = xpu::scatter(dev_ctx.x_context(), - updates_data, - out_data, - indices, - dim0, - dim1, - overwrite); - } - PADDLE_ENFORCE_EQ(r, - XPU_SUCCESS, - platform::errors::External( - "XPU scatter kernel return wrong value[%d %s]", - r, - XPUAPIErrorMsg[r])); - } -}; - -} // namespace operators -} // namespace paddle - -namespace ops = paddle::operators; - -REGISTER_OP_XPU_KERNEL(scatter, - ops::ScatterOpXPUKernel, - ops::ScatterOpXPUKernel); -#endif diff --git a/paddle/phi/kernels/xpu/roi_align_kernel.cc b/paddle/phi/kernels/xpu/roi_align_kernel.cc index dacb676693..895a235da5 100644 --- a/paddle/phi/kernels/xpu/roi_align_kernel.cc +++ b/paddle/phi/kernels/xpu/roi_align_kernel.cc @@ -137,7 +137,7 @@ void RoiAlignKernel(const Context& dev_ctx, sampling_ratio, true, aligned); - PADDLE_ENFORCE_XDNN_SUCCESS(r, "roi_align_grad"); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "roi_align"); if (dev_ctx.x_context()->xpu_stream) { dev_ctx.Wait(); } diff --git a/paddle/phi/kernels/xpu/scale_kernel.cc b/paddle/phi/kernels/xpu/scale_kernel.cc index b5a07a7a14..a478dfddf1 100644 --- a/paddle/phi/kernels/xpu/scale_kernel.cc +++ b/paddle/phi/kernels/xpu/scale_kernel.cc @@ -14,7 +14,7 @@ #include "paddle/phi/kernels/scale_kernel.h" -#include "paddle/fluid/platform/device/xpu/xpu_header.h" +#include "paddle/phi/backends/xpu/enforce_xpu.h" #include "paddle/phi/backends/xpu/xpu_context.h" #include "paddle/phi/common/data_type.h" #include "paddle/phi/common/float16.h" @@ -30,7 +30,7 @@ void ScaleKernel(const Context& dev_ctx, float bias, bool bias_after_scale, DenseTensor* out) { - out->mutable_data(dev_ctx.GetPlace()); + dev_ctx.template Alloc(out); PADDLE_ENFORCE_EQ( x.dims(), @@ -47,11 +47,7 @@ void ScaleKernel(const Context& dev_ctx, bias_after_scale, scale.to(), bias); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - phi::errors::External( - "XPU scale kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scale"); } } // namespace phi diff --git a/paddle/phi/kernels/xpu/scatter_kernel.cc b/paddle/phi/kernels/xpu/scatter_kernel.cc new file mode 100644 index 0000000000..21a4c638a8 --- /dev/null +++ b/paddle/phi/kernels/xpu/scatter_kernel.cc @@ -0,0 +1,110 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/phi/kernels/scatter_kernel.h" + +#include "paddle/phi/backends/xpu/enforce_xpu.h" +#include "paddle/phi/backends/xpu/xpu_context.h" +#include "paddle/phi/core/kernel_registry.h" + +namespace phi { + +template +void ScatterKernel(const Context &ctx, + const DenseTensor &x, + const DenseTensor &index, + const DenseTensor &updates, + bool overwrite, + DenseTensor *out) { + phi::Copy(ctx, x, ctx.GetPlace(), false, out); + // Apply ScatterUpdate: Out[index] = Updates[:] + const auto &index_type = index.dtype(); + bool index_type_match = + index_type == phi::DataType::INT32 || index_type == phi::DataType::INT64; + PADDLE_ENFORCE_EQ( + index_type_match, + true, + phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s]," + "but desires to be [%s] or [%s].", + index_type, + phi::DataType::INT32, + phi::DataType::INT64)); + + // check index of shape 1-D + PADDLE_ENFORCE_EQ( + index.dims().size() == 1 || + (index.dims().size() == 2 && index.dims()[1] == 1), + true, + phi::errors::InvalidArgument( + "index's shape is error, " + "expect index'dims shape is 1 or 2 and index.dims[1] is 1" + "but got index'dims shape is %d", + index.dims().size())); + + int index_size = static_cast(index.dims()[0]); + auto x_dims = x.dims(); + auto update_dims = updates.dims(); + for (int i = 1; i < x_dims.size(); i++) + PADDLE_ENFORCE_EQ( + x_dims[i], + update_dims[i], + phi::errors::InvalidArgument( + "The dimensions of the source tensor and target tensor should" + " match, but received source tensor's %d-th dimension is %d," + "target tensor's %d-th dimension is %d.", + i, + x_dims[i], + i, + update_dims[i])); + + int dim0 = static_cast(x.dims()[0]); + int dim1 = + static_cast(phi::product(phi::slice_ddim(x_dims, 1, x_dims.size()))); + T *out_data = out->data(); + const T *updates_data = updates.data(); + + DenseTensor indices_cpu(index.type()); + phi::Copy(ctx, index, phi::CPUPlace(), false, &indices_cpu); + + int r = 0; + if (index_type == phi::DataType::INT32) { + auto index_data = const_cast(index.data()); + xpu::VectorParam indices{ + indices_cpu.data(), index_size, index_data}; + r = xpu::scatter(ctx.x_context(), + updates_data, + out_data, + indices, + dim0, + dim1, + overwrite); + } else { + auto index_data = const_cast(index.data()); + xpu::VectorParam indices{ + indices_cpu.data(), index_size, index_data}; + r = xpu::scatter(ctx.x_context(), + updates_data, + out_data, + indices, + dim0, + dim1, + overwrite); + } + PADDLE_ENFORCE_XDNN_SUCCESS(r, "scatter"); +} + +} // namespace phi + +PD_REGISTER_KERNEL( + scatter, XPU, ALL_LAYOUT, phi::ScatterKernel, float, int, int64_t) {} -- GitLab