未验证 提交 60e1eccb 编写于 作者: W wanghuancoder 提交者: GitHub

[Phi] gather gather_grad gather_nd gaussian_random xpu to Phi (#45465)

* gather gather_grad gather_nd gaussian_random xpu to phi
上级 ca5567e1
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/device_context.h"
namespace paddle {
namespace operators {
template <typename T>
class GatherNdXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto *x = ctx.Input<framework::Tensor>("X");
auto *index = ctx.Input<framework::Tensor>("Index");
auto *out = ctx.Output<framework::Tensor>("Out");
out->template mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
if (index->numel() == 0) {
framework::TensorCopy(*x, ctx.GetPlace(), ctx.device_context(), out);
return;
}
const auto &index_type = framework::TransToProtoVarType(index->dtype());
bool index_type_match = index_type == framework::proto::VarType::INT32 ||
index_type == framework::proto::VarType::INT64;
PADDLE_ENFORCE_EQ(index_type_match,
true,
platform::errors::InvalidArgument(
"Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
paddle::framework::DataTypeToString(index_type),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT32),
paddle::framework::DataTypeToString(
framework::proto::VarType::INT64)));
auto x_shape = phi::vectorize<int>(x->dims());
auto index_shape = phi::vectorize<int>(index->dims());
if (index_shape.size() == 1) {
index_shape.insert(index_shape.begin(), 1);
}
xpu::VectorParam<int> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
auto &dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = XPU_SUCCESS;
if (index_type == framework::proto::VarType::INT32) {
ret = xpu::gather_nd<T, int>(dev_ctx.x_context(),
x->data<T>(),
index->data<int>(),
out->data<T>(),
x_vec,
index_shape);
} else {
ret = xpu::gather_nd<T, int64_t>(dev_ctx.x_context(),
x->data<T>(),
index->data<int64_t>(),
out->data<T>(),
x_vec,
index_shape);
}
PADDLE_ENFORCE_EQ(ret,
XPU_SUCCESS,
platform::errors::External(
"XPU gather_nd kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(gather_nd,
ops::GatherNdXPUKernel<int>,
ops::GatherNdXPUKernel<int64_t>,
ops::GatherNdXPUKernel<float>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/ddim.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename T>
class GatherOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_xpu_place(ctx.GetPlace()),
true,
platform::errors::PreconditionNotMet("This kernel only runs on XPU."));
auto *x = ctx.Input<Tensor>("X");
auto *index = ctx.Input<Tensor>("Index");
auto *output = ctx.Output<Tensor>("Out");
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->dtype();
if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
output->mutable_data<T>(ctx.GetPlace());
if (x->numel() == 0) return;
const auto index_dims = index->dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1],
1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(x->dims().size());
for (int i = 0; i < x->dims().size(); ++i) {
xshape[i] = x->dims()[i];
}
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
if (framework::TransToProtoVarType(index->dtype()) ==
framework::proto::VarType::INT32) {
r = xpu::gather<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(x->data<T>()),
index->data<int>(),
reinterpret_cast<XPUType *>(output->data<T>()),
xshape,
index->dims()[0],
axis);
} else {
r = xpu::gather<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(x->data<T>()),
index->data<int64_t>(),
reinterpret_cast<XPUType *>(output->data<T>()),
xshape,
index->dims()[0],
axis);
}
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU gather kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
};
template <typename T>
class GatherGradOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
platform::is_xpu_place(ctx.GetPlace()),
true,
platform::errors::PreconditionNotMet("This kernel only runs on XPU."));
auto *index = ctx.Input<Tensor>("Index");
auto *dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto *dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int axis = ctx.Attr<int>("axis");
if (ctx.HasInput("Axis")) {
Tensor cpu_axis;
const Tensor *axis_tensor = ctx.Input<Tensor>("Axis");
framework::TensorCopy(*axis_tensor, platform::CPUPlace(), &cpu_axis);
const auto &axis_type = axis_tensor->dtype();
if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT32) {
axis = static_cast<int>(cpu_axis.data<int32_t>()[0]);
} else if (framework::TransToProtoVarType(axis_type) ==
framework::proto::VarType::INT64) {
axis = static_cast<int>(cpu_axis.data<int64_t>()[0]);
}
}
if (dout->numel() == 0) {
return;
}
bool overwrite = ctx.Attr<bool>("overwrite");
const auto index_dims = index->dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1],
1,
platform::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
platform::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(dx->dims().size());
for (int i = 0; i < dx->dims().size(); ++i) {
xshape[i] = dx->dims()[i];
}
dx->mutable_data<T>(ctx.GetPlace());
int r = XPU_SUCCESS;
if (framework::TransToProtoVarType(index->dtype()) ==
framework::proto::VarType::INT32) {
r = xpu::gather_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(dout->data<T>()),
index->data<int>(),
reinterpret_cast<XPUType *>(dx->data<T>()),
xshape,
index->dims()[0],
axis,
overwrite);
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int *index_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(index->numel());
r = xpu::cast_v2<int64_t, int32_t>(dev_ctx.x_context(),
index->data<int64_t>(),
index_int_ptr_l3,
index->numel());
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External("XPU API(cast_v2) return wrong "
"value[%d %s]",
r,
XPUAPIErrorMsg[r]));
r = xpu::gather_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(dout->data<T>()),
index_int_ptr_l3,
reinterpret_cast<XPUType *>(dx->data<T>()),
xshape,
index->dims()[0],
axis,
overwrite);
}
PADDLE_ENFORCE_EQ(r,
xpu::Error_t::SUCCESS,
platform::errors::External(
"XPU gather grad kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(gather,
ops::GatherOpXPUKernel<float>,
ops::GatherOpXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(gather_grad,
ops::GatherGradOpXPUKernel<float>,
ops::GatherGradOpXPUKernel<paddle::platform::float16>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <random>
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
template <typename T>
class XPUGaussianRandomKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
float mean = context.Attr<float>("mean");
float std = context.Attr<float>("std");
auto* tensor = context.Output<framework::Tensor>("Out");
std::normal_distribution<T> dist(mean, std);
int64_t size = tensor->numel();
T* data = tensor->mutable_data<T>(context.GetPlace());
unsigned int seed = static_cast<unsigned int>(context.Attr<int>("seed"));
// TODO(pangyoki): implement GetXPURandomEngine to set different seeds on
// corresponding XPU device.
auto engine = framework::GetCPURandomEngine(seed);
std::unique_ptr<T[]> data_cpu(new T[size]);
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = dist(*engine);
}
memory::Copy(context.GetPlace(),
data,
platform::CPUPlace(),
reinterpret_cast<void*>(data_cpu.get()),
size * sizeof(T));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(gaussian_random, ops::XPUGaussianRandomKernel<float>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GatherGradKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const DenseTensor& out_grad,
const Scalar& axis,
bool overwrite,
DenseTensor* x_grad) {
auto axis_v = axis.to<int>();
const auto& index_type = index.dtype();
if (out_grad.numel() == 0) {
return;
}
const auto index_dims = index.dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(x_grad->dims().size());
for (int i = 0; i < x_grad->dims().size(); ++i) {
xshape[i] = x_grad->dims()[i];
}
dev_ctx.template Alloc<T>(x_grad);
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
if (index_type == DataType::INT32) {
r = xpu::gather_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape,
index.dims()[0],
axis_v,
overwrite);
} else {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int* index_int_ptr_l3 = RAII_GUARD.alloc_l3_or_gm<int32_t>(index.numel());
r = xpu::cast_v2<int64_t, int32_t>(dev_ctx.x_context(),
index.data<int64_t>(),
index_int_ptr_l3,
index.numel());
PADDLE_ENFORCE_EQ(r,
XPU_SUCCESS,
phi::errors::External("XPU API(cast_v2) return wrong "
"value[%d %s]",
r,
XPUAPIErrorMsg[r]));
r = xpu::gather_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
index_int_ptr_l3,
reinterpret_cast<XPUType*>(x_grad->data<T>()),
xshape,
index.dims()[0],
axis_v,
overwrite);
}
PADDLE_ENFORCE_EQ(
r,
xpu::Error_t::SUCCESS,
phi::errors::External("XPU gather grad kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
} // namespace phi
PD_REGISTER_KERNEL(gather_grad,
XPU,
ALL_LAYOUT,
phi::GatherGradKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GatherKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& index,
const Scalar& axis,
DenseTensor* out) {
auto axis_v = axis.to<int>();
const auto& index_type = index.dtype();
dev_ctx.template Alloc<T>(out);
if (x.numel() == 0) return;
const auto index_dims = index.dims();
if (index_dims.size() == 2) {
PADDLE_ENFORCE_EQ(
index_dims[1],
1,
phi::errors::InvalidArgument(
"The last dim of index should be 1 when it is 2D, but we get %d",
index_dims[1]));
} else {
PADDLE_ENFORCE_EQ(
index_dims.size(),
1,
phi::errors::InvalidArgument(
"The index should be 1D, when it is not 2D, but we get %d",
index_dims.size()));
}
std::vector<int> xshape(x.dims().size());
for (int i = 0; i < x.dims().size(); ++i) {
xshape[i] = x.dims()[i];
}
using XPUType = typename XPUTypeTrait<T>::Type;
int r = XPU_SUCCESS;
if (index_type == DataType::INT32) {
r = xpu::gather<XPUType, int>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
index.dims()[0],
axis_v);
} else {
r = xpu::gather<XPUType, int64_t>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
index.data<int64_t>(),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
index.dims()[0],
axis_v);
}
PADDLE_ENFORCE_EQ(
r,
xpu::Error_t::SUCCESS,
phi::errors::External(
"XPU gather kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
}
} // namespace phi
PD_REGISTER_KERNEL(
gather, XPU, ALL_LAYOUT, phi::GatherKernel, float, phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gather_nd_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GatherNdKernel(const Context &ctx,
const DenseTensor &x,
const DenseTensor &index,
DenseTensor *out) {
ctx.template Alloc<T>(out);
const auto &index_type = index.dtype();
if (x.numel() == 0) return;
if (index.numel() == 0) {
phi::Copy(ctx, x, phi::XPUPlace(), true, out);
return;
}
bool index_type_match =
index_type == DataType::INT32 || index_type == DataType::INT64;
PADDLE_ENFORCE_EQ(
index_type_match,
true,
phi::errors::InvalidArgument("Index holds the wrong type, it holds [%s],"
"but desires to be [%s] or [%s]",
index_type,
DataType::INT32,
DataType::INT64));
auto x_shape = phi::vectorize<int>(x.dims());
auto index_shape = phi::vectorize<int>(index.dims());
if (index_shape.size() == 1) {
index_shape.insert(index_shape.begin(), 1);
}
xpu::VectorParam<int> x_vec = {
x_shape.data(), static_cast<int>(x_shape.size()), nullptr};
int ret = XPU_SUCCESS;
if (index_type == DataType::INT32) {
ret = xpu::gather_nd<T, int>(ctx.x_context(),
x.data<T>(),
index.data<int>(),
out->data<T>(),
x_vec,
index_shape);
} else {
ret = xpu::gather_nd<T, int64_t>(ctx.x_context(),
x.data<T>(),
index.data<int64_t>(),
out->data<T>(),
x_vec,
index_shape);
}
PADDLE_ENFORCE_EQ(
ret,
XPU_SUCCESS,
phi::errors::External("XPU gather_nd kernel return wrong value[%d %s]",
ret,
XPUAPIErrorMsg[ret]));
}
} // namespace phi
PD_REGISTER_KERNEL(
gather_nd, XPU, ALL_LAYOUT, phi::GatherNdKernel, float, int64_t, int) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/gaussian_random_kernel.h"
#include "paddle/fluid/framework/generator.h"
#include "paddle/fluid/memory/memcpy.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void GaussianRandomKernel(const Context& ctx,
const IntArray& shape,
float mean,
float std,
int seed,
DataType dtype,
DenseTensor* out) {
std::normal_distribution<T> dist(mean, std);
int64_t size = out->numel();
ctx.template Alloc<T>(out);
auto* data = out->data();
uint64_t seed_v = static_cast<uint64_t>(seed);
// TODO(pangyoki): implement GetXPURandomEngine to set different seeds on
// corresponding XPU device.
auto engine = paddle::framework::GetCPURandomEngine(seed_v);
std::unique_ptr<T[]> data_cpu(new T[size]);
for (int64_t i = 0; i < size; ++i) {
data_cpu[i] = dist(*engine);
}
paddle::memory::Copy(phi::XPUPlace(),
data,
phi::CPUPlace(),
reinterpret_cast<void*>(data_cpu.get()),
size * sizeof(T));
}
} // namespace phi
PD_REGISTER_KERNEL(
gaussian_random, XPU, ALL_LAYOUT, phi::GaussianRandomKernel, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册