未验证 提交 0c703fe7 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Move gather op kernel into phi (#40500)

* add phi gather kernel

* update year

* remove original gather opkernel

* add gather grad phi kernels

* remove origin gather grad kernel

* fix failed npu and xpu

* fix xpu compile failed
上级 dde9cec0
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
...@@ -198,17 +198,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker, ...@@ -198,17 +198,7 @@ REGISTER_OPERATOR(gather, ops::GatherOp, ops::GatherOpMaker,
ops::GatherGradOpMaker<paddle::imperative::OpBase>); ops::GatherGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(gather_grad, ops::GatherGradOp, REGISTER_OPERATOR(gather_grad, ops::GatherGradOp,
ops::GatherGradNoNeedBufferVarInferer); ops::GatherGradNoNeedBufferVarInferer);
REGISTER_OP_CPU_KERNEL(gather, ops::GatherOpKernel<float>,
ops::GatherOpKernel<double>, ops::GatherOpKernel<int>,
ops::GatherOpKernel<uint8_t>,
ops::GatherOpKernel<int64_t>,
ops::GatherOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_CPU_KERNEL(gather_grad, ops::GatherGradientOpKernel<float>,
ops::GatherGradientOpKernel<double>,
ops::GatherGradientOpKernel<int>,
ops::GatherGradientOpKernel<uint8_t>,
ops::GatherGradientOpKernel<int64_t>,
ops::GatherGradientOpKernel<phi::dtype::bfloat16>);
REGISTER_OP_VERSION(gather) REGISTER_OP_VERSION(gather)
.AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC", .AddCheckpoint(R"ROC(upgrad gather, add a new input [Axis])ROC",
paddle::framework::compatible::OpVersionDesc().NewInput( paddle::framework::compatible::OpVersionDesc().NewInput(
......
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/operators/gather_op.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type =
framework::TransToProtoVarType(axis_tensor->dtype());
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT16) {
axis = static_cast<int>(cpu_axis.data<int16_t>()[0]);
}
}
const auto &place = ctx.GetPlace();
const auto &index_type = framework::TransToProtoVarType(index->dtype());
const auto &dev_ctx = ctx.cuda_device_context();
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GatherV2CUDAFunction<T, int32_t>(x, index, axis, output,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GatherV2CUDAFunction<T, int64_t>(x, index, axis, output,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GatherV2CUDAFunction<T, int16_t>(x, index, axis, output,
dev_ctx);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GPUGather<T, int64_t>(dev_ctx, *x, *index, output);
} else if (index_type == framework::proto::VarType::INT16) {
phi::funcs::GPUGather<T, int16_t>(dev_ctx, *x, *index, output);
}
}
};
template <typename T>
class GatherGradOpCUDAKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet(
"This kernel only runs on GPU device."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
Tensor cpu_axis;
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type =
framework::TransToProtoVarType(axis_tensor->dtype());
if (axis_type == framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (axis_type == framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
const auto &dev_ctx = ctx.cuda_device_context();
const auto &index_type = framework::TransToProtoVarType(index->dtype());
if (axis != 0) {
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GatherV2GradCUDAFunction<T, int32_t>(dO, index, axis, dX,
dev_ctx);
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GatherV2GradCUDAFunction<T, int64_t>(dO, index, axis, dX,
dev_ctx);
}
return;
}
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *ctx.template device_context<platform::CUDADeviceContext>()
.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
if (index_type == framework::proto::VarType::INT32) {
phi::funcs::GPUScatterAssign<T, int>(dev_ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
} else if (index_type == framework::proto::VarType::INT64) {
phi::funcs::GPUScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX,
ctx.Attr<bool>("overwrite"));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(gather, ops::GatherOpCUDAKernel<float>,
ops::GatherOpCUDAKernel<double>,
ops::GatherOpCUDAKernel<int64_t>,
ops::GatherOpCUDAKernel<int>,
ops::GatherOpCUDAKernel<int16_t>,
ops::GatherOpCUDAKernel<plat::float16>,
ops::GatherOpCUDAKernel<plat::bfloat16>);
REGISTER_OP_CUDA_KERNEL(gather_grad, ops::GatherGradOpCUDAKernel<float>,
ops::GatherGradOpCUDAKernel<double>,
ops::GatherGradOpCUDAKernel<int64_t>,
ops::GatherGradOpCUDAKernel<int>,
ops::GatherGradOpCUDAKernel<plat::float16>,
ops::GatherGradOpCUDAKernel<plat::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/eigen.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/scatter.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class GatherOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
// get axis from tensor
if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis_tensor->dtype();
if (axis_type == phi::DataType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
}
}
const auto &index_type = index->dtype();
auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (axis != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2Function<T, int32_t>(dev_ctx, x, index, axis,
output);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2Function<T, int64_t>(dev_ctx, x, index, axis,
output);
}
return;
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
if (index_type == phi::DataType::INT32) {
phi::funcs::CPUGather<T, int>(dev_ctx, *x, *index, output);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::CPUGather<T, int64_t>(dev_ctx, *x, *index, output);
}
}
};
template <typename T>
class GatherGradientOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_cpu_place(ctx.GetPlace()), true,
platform::errors::PreconditionNotMet("This kernel only runs on CPU."));
auto *index = ctx.Input<Tensor>("Index");
auto *dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dO = ctx.Input<Tensor>(framework::GradVarName("Out"));
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
const auto &axis_type = axis_tensor->dtype();
if (axis_type == phi::DataType::INT32) {
axis = static_cast<int>(axis_tensor->data<int32_t>()[0]);
} else if (axis_type == phi::DataType::INT64) {
axis = static_cast<int>(axis_tensor->data<int64_t>()[0]);
}
}
const auto &index_type = index->dtype();
auto &dev_ctx = ctx.template device_context<phi::CPUContext>();
if (axis != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2GradFunction<T, int32_t>(dev_ctx, dO, index, axis,
dX);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2GradFunction<T, int64_t>(dev_ctx, dO, index, axis,
dX);
}
return;
}
dX->mutable_data<T>(ctx.GetPlace());
auto dxt = framework::EigenVector<T>::Flatten(*dX);
auto &place = *dev_ctx.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (dO->numel() == 0) return;
bool overwrite = ctx.Attr<bool>("overwrite");
if (index_type == phi::DataType::INT32) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, *dO, *index, dX);
} else {
phi::funcs::ScatterAssignAdd<T, int32_t>(dev_ctx, *dO, *index, dX);
}
} else if (index_type == phi::DataType::INT64) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, *dO, *index, dX);
} else {
phi::funcs::ScatterAssignAdd<T, int64_t>(dev_ctx, *dO, *index, dX);
}
}
}
};
} // namespace operators
} // namespace paddle
...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,11 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/gather_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/platform/device/npu/npu_info.h" #include "paddle/fluid/platform/device/npu/npu_info.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h" #include "paddle/fluid/platform/device/npu/npu_op_runner.h"
......
...@@ -24,16 +24,15 @@ limitations under the License. */ ...@@ -24,16 +24,15 @@ limitations under the License. */
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h" #include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/gather_op.h"
#include "paddle/fluid/string/printf.h" #include "paddle/fluid/string/printf.h"
#include "paddle/phi/kernels/funcs/math_function.h" #include "paddle/phi/kernels/funcs/math_function.h"
namespace f = paddle::framework; namespace f = paddle::framework;
namespace p = paddle::platform; namespace p = paddle::platform;
USE_OP(gather); USE_OP_ITSELF(gather);
USE_OP_DEVICE_KERNEL(gather, NPU); USE_OP_DEVICE_KERNEL(gather, NPU);
USE_OP(gather_grad); USE_OP_ITSELF(gather_grad);
USE_OP_DEVICE_KERNEL(gather_grad, NPU); USE_OP_DEVICE_KERNEL(gather_grad, NPU);
template <typename T> template <typename T>
......
...@@ -13,15 +13,18 @@ See the License for the specific language governing permissions and ...@@ -13,15 +13,18 @@ See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/gather_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <vector> #include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h" #include "paddle/phi/core/ddim.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
template <typename T> template <typename T>
class GatherOpXPUKernel : public framework::OpKernel<T> { class GatherOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type; using XPUType = typename XPUTypeTrait<T>::Type;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_grad_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/gather.h"
#include "paddle/phi/kernels/funcs/scatter.h"
namespace phi {
template <typename T, typename Context>
void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) {
const auto& index_type = index.dtype();
auto axis_v = axis.to<int>();
if (axis_v != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2GradFunction<T, int32_t>(
dev_ctx, &out_grad, &index, axis_v, x_grad);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2GradFunction<T, int64_t>(
dev_ctx, &out_grad, &index, axis_v, x_grad);
}
return;
}
dev_ctx.template Alloc<T>(x_grad);
auto dxt = EigenVector<T>::Flatten(*x_grad);
auto& place = *dev_ctx.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (x_grad->numel() == 0) return;
if (index_type == phi::DataType::INT32) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int32_t>(dev_ctx, out_grad, index, x_grad);
} else {
phi::funcs::ScatterAssignAdd<T, int32_t>(
dev_ctx, out_grad, index, x_grad);
}
} else if (index_type == phi::DataType::INT64) {
if (overwrite) {
phi::funcs::ScatterAssign<T, int64_t>(dev_ctx, out_grad, index, x_grad);
} else {
phi::funcs::ScatterAssignAdd<T, int64_t>(
dev_ctx, out_grad, index, x_grad);
}
}
}
} // namespace phi
PD_REGISTER_KERNEL(gather_grad,
CPU,
ALL_LAYOUT,
phi::GatherGradKernel,
float,
double,
int,
uint8_t,
int64_t,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather.h"
namespace phi {
template <typename T, typename Context>
void GatherKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const Scalar& axis,
DenseTensor* out) {
const auto& index_type = index.dtype();
auto axis_v = axis.to<int>();
if (axis_v != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2Function<T, int32_t>(
dev_ctx, &x, &index, axis_v, out);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2Function<T, int64_t>(
dev_ctx, &x, &index, axis_v, out);
}
return;
}
dev_ctx.template Alloc<T>(out);
if (x.numel() == 0) {
return;
}
if (index_type == phi::DataType::INT32) {
phi::funcs::CPUGather<T, int>(dev_ctx, x, index, out);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::CPUGather<T, int64_t>(dev_ctx, x, index, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(gather,
CPU,
ALL_LAYOUT,
phi::GatherKernel,
float,
double,
int,
uint8_t,
int64_t,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include "paddle/phi/common/scalar.h"
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void GatherKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const Scalar& axis,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/eigen/common.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
#include "paddle/phi/kernels/funcs/scatter.cu.h"
namespace phi {
template <typename T, typename Context>
void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) {
const auto& index_type = index.dtype();
auto axis_v = axis.to<int>();
if (axis_v != 0) {
if (index_type == DataType::INT32) {
phi::funcs::GatherV2GradCUDAFunction<T, int32_t>(
&out_grad, &index, axis_v, x_grad, dev_ctx);
} else if (index_type == DataType::INT64) {
phi::funcs::GatherV2GradCUDAFunction<T, int64_t>(
&out_grad, &index, axis_v, x_grad, dev_ctx);
}
return;
}
dev_ctx.template Alloc<T>(x_grad);
auto dxt = EigenVector<T>::Flatten(*x_grad);
auto& place = *dev_ctx.eigen_device();
dxt.device(place) = dxt.constant(static_cast<T>(0));
if (out_grad.numel() == 0) return;
if (index_type == DataType::INT32) {
phi::funcs::GPUScatterAssign<T, int>(
dev_ctx, out_grad, index, x_grad, overwrite);
} else if (index_type == DataType::INT64) {
phi::funcs::GPUScatterAssign<T, int64_t>(
dev_ctx, out_grad, index, x_grad, overwrite);
}
}
} // namespace phi
PD_REGISTER_KERNEL(gather_grad,
GPU,
ALL_LAYOUT,
phi::GatherGradKernel,
float,
double,
int64_t,
int,
phi::dtype::float16,
phi::dtype::bfloat16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/common/bfloat16.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/gather.cu.h"
namespace phi {
template <typename T, typename Context>
void GatherKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const Scalar& axis,
DenseTensor* out) {
const auto& index_type = index.dtype();
auto axis_v = axis.to<int>();
if (axis_v != 0) {
if (index_type == phi::DataType::INT32) {
phi::funcs::GatherV2CUDAFunction<T, int32_t>(
&x, &index, axis_v, out, dev_ctx);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GatherV2CUDAFunction<T, int64_t>(
&x, &index, axis_v, out, dev_ctx);
} else if (index_type == phi::DataType::INT16) {
phi::funcs::GatherV2CUDAFunction<T, int16_t>(
&x, &index, axis_v, out, dev_ctx);
}
return;
}
dev_ctx.template Alloc<T>(out);
if (x.numel() == 0) return;
if (index_type == phi::DataType::INT32) {
phi::funcs::GPUGather<T, int>(dev_ctx, x, index, out);
} else if (index_type == phi::DataType::INT64) {
phi::funcs::GPUGather<T, int64_t>(dev_ctx, x, index, out);
} else if (index_type == phi::DataType::INT16) {
phi::funcs::GPUGather<T, int16_t>(dev_ctx, x, index, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(gather,
GPU,
ALL_LAYOUT,
phi::GatherKernel,
float,
double,
int64_t,
int,
int16_t,
phi::dtype::float16,
phi::dtype::bfloat16) {}
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature GatherOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature("gather", {"X", "Index"}, {"Axis"}, {"Out"});
} else {
return KernelSignature("gather", {"X", "Index"}, {"axis"}, {"Out"});
}
}
KernelSignature GatherGradOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.HasInput("Axis")) {
return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")},
{"Axis", "overwrite"},
{GradVarName("X")});
} else {
return KernelSignature("gather_grad",
{"X", "Index", GradVarName("Out")},
{"axis", "overwrite"},
{GradVarName("X")});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(gather, phi::GatherOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(gather_grad, phi::GatherGradOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册