未验证 提交 3b9b4c34 编写于 作者: Y ykkk2333 提交者: GitHub

migrate shaple sgd, split,sign xpu kernels to phi, test=kunlun (#45607)

上级 445fce62
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <string>
#include "paddle/fluid/operators/optimizers/sgd_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SGDOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
const auto *learning_rate = ctx.Input<framework::Tensor>("LearningRate");
const auto *param_var = ctx.InputVar("Param");
const auto *grad_var = ctx.InputVar("Grad");
if (param_var->IsType<framework::LoDTensor>() &&
grad_var->IsType<framework::LoDTensor>()) {
const auto *param = ctx.Input<framework::Tensor>("Param");
auto *param_out = ctx.Output<framework::Tensor>("ParamOut");
// Actually, all tensors are LoDTensor except SelectedRows.
const auto *grad = ctx.Input<framework::Tensor>("Grad");
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(param->numel(),
sz,
platform::errors::InvalidArgument(
"The input tensor Param's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Param's "
"numel = [%s], ParamOut's numel = [%s]",
param->numel(),
sz));
PADDLE_ENFORCE_EQ(grad->numel(),
sz,
platform::errors::InvalidArgument(
"The input tensor Grad's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Grad's "
"numel = [%s], ParamOut's numel = [%s]",
grad->numel(),
sz));
const T *lr_t = learning_rate->data<T>();
auto &dev_ctx = ctx.template device_context<DeviceContext>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const float *lr = nullptr;
if (std::is_same<T, paddle::platform::float16>::value) {
float *lr_float =
RAII_GUARD.alloc_l3_or_gm<float>(learning_rate->numel());
int r = xpu::cast_v2<XPUType, float>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(lr_t),
lr_float,
learning_rate->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
lr = lr_float;
} else {
lr = reinterpret_cast<const float *>(lr_t);
}
const T *param_data = param->data<T>();
const T *grad_data = grad->data<T>();
T *out_data = param_out->mutable_data<T>(ctx.GetPlace());
int r = xpu::sgd(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(grad_data),
reinterpret_cast<const XPUType *>(param_data),
lr,
reinterpret_cast<XPUType *>(out_data),
sz);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sgd");
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
sgd,
ops::SGDOpXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SGDOpXPUKernel<paddle::platform::XPUDeviceContext, plat::float16>);
#endif
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
* http://www.apache.org/licenses/LICENSE-2.0
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <algorithm>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
using LoDTensor = framework::LoDTensor;
using SelectedRows = phi::SelectedRows;
template <typename T>
class ShapeXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* in_var = ctx.InputVar("Input");
framework::DDim in_dims;
if (in_var->IsType<phi::SelectedRows>()) {
in_dims = in_var->Get<phi::SelectedRows>().value().dims();
} else {
in_dims = in_var->Get<LoDTensor>().dims();
}
auto* out_t = ctx.Output<Tensor>("Out");
out_t->Resize({in_dims.size()});
auto out_data = out_t->mutable_data<int32_t>(platform::CPUPlace());
for (int i = 0; i < in_dims.size(); ++i) {
out_data[i] = in_dims[i];
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(shape,
ops::ShapeXPUKernel<bool>,
ops::ShapeXPUKernel<int>,
ops::ShapeXPUKernel<int64_t>,
ops::ShapeXPUKernel<float>,
ops::ShapeXPUKernel<double>);
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <string>
#include <vector>
#include "paddle/fluid/operators/split_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class SplitXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* input = ctx.Input<framework::Tensor>("X");
auto output = ctx.MultiOutput<framework::Tensor>("Out");
int num = ctx.Attr<int>("num");
std::vector<int> sections = ctx.Attr<std::vector<int>>("sections");
int axis = ctx.Attr<int>("axis");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
auto in_dims = input->dims();
auto input_shape = phi::vectorize<int>(in_dims);
std::vector<int> split_lists;
std::vector<T*> out_ptrs;
auto outs_number = output.size();
std::vector<framework::DDim> outs_dims =
UpdateOutsDims(true, true, in_dims, num, sections, axis, outs_number);
for (size_t i = 0; i < output.size(); ++i) {
output[i]->Resize(outs_dims[i]);
out_ptrs.push_back(output[i]->mutable_data<T>(ctx.GetPlace()));
split_lists.push_back(output[i]->dims()[axis]);
}
int r = xpu::split<T>(dev_ctx.x_context(),
input->data<T>(),
out_ptrs,
input_shape,
split_lists,
axis);
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External("XPU split kernel return wrong value[%d %s]",
r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
split,
ops::SplitXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SplitXPUKernel<paddle::platform::XPUDeviceContext, int>);
#endif
......@@ -67,3 +67,17 @@ PD_REGISTER_KERNEL(shape,
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
#if defined(PADDLE_WITH_XPU)
PD_REGISTER_KERNEL(shape,
XPU,
ALL_LAYOUT,
phi::ShapeKernel,
bool,
int,
int64_t,
float,
double) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/sgd_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void SGDDenseKernel(const Context &dev_ctx,
const DenseTensor &param,
const DenseTensor &learning_rate,
const DenseTensor &grad,
const paddle::optional<DenseTensor> &master_param,
bool multi_precision,
DenseTensor *param_out,
DenseTensor *master_param_out) {
using XPUType = typename XPUTypeTrait<T>::Type;
auto sz = param_out->numel();
PADDLE_ENFORCE_EQ(
param.numel(),
sz,
errors::InvalidArgument("The input tensor Param's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Param's "
"numel = [%s], ParamOut's numel = [%s]",
param.numel(),
sz));
PADDLE_ENFORCE_EQ(
grad.numel(),
sz,
errors::InvalidArgument("The input tensor Grad's numel of SgdOp "
"should be equal with ParamOut's numel. "
"But received Grad's "
"numel = [%s], ParamOut's numel = [%s]",
grad.numel(),
sz));
const T *lr_t = learning_rate.data<T>();
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
const float *lr = nullptr;
if (std::is_same<T, dtype::float16>::value) {
float *lr_float = RAII_GUARD.alloc_l3_or_gm<float>(learning_rate.numel());
int r =
xpu::cast_v2<XPUType, float>(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(lr_t),
lr_float,
learning_rate.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "clip_v2");
lr = lr_float;
} else {
lr = reinterpret_cast<const float *>(lr_t);
}
const T *param_data = param.data<T>();
const T *grad_data = grad.data<T>();
dev_ctx.template Alloc<T>(param_out);
T *out_data = param_out->data<T>();
int r = xpu::sgd(dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(grad_data),
reinterpret_cast<const XPUType *>(param_data),
lr,
reinterpret_cast<XPUType *>(out_data),
sz);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sgd");
}
} // namespace phi
PD_REGISTER_KERNEL(
sgd, XPU, ALL_LAYOUT, phi::SGDDenseKernel, phi::dtype::float16, float) {}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
......@@ -12,32 +12,22 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class SignXPUKernel : public framework::OpKernel<T> {
public:
virtual void Compute(const framework::ExecutionContext& context) const {
auto* out = context.Output<framework::Tensor>("Out");
auto* in = context.Input<framework::Tensor>("X");
out->mutable_data<T>(in->place());
auto xpu_context = context.device_context<DeviceContext>().x_context();
// int sign(Context* ctx, const T* x , T* y, int len);
int r = xpu::sign(xpu_context, in->data<T>(), out->data<T>(), in->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sign");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
sign, ops::SignXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
#include "paddle/phi/kernels/sign_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void SignKernel(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
auto xpu_context = dev_ctx.x_context();
int r = xpu::sign(xpu_context, x.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sign");
}
} // namespace phi
PD_REGISTER_KERNEL(sign, XPU, ALL_LAYOUT, phi::SignKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/split_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void SplitKernel(const Context& dev_ctx,
const DenseTensor& x,
const IntArray& sections,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
int axis = axis_scalar.to<int>();
auto in_dims = x.dims();
auto input_shape = vectorize<int>(in_dims);
std::vector<T*> out_ptrs;
std::vector<int> split_lists;
for (size_t j = 0; j < outs.size(); ++j) {
dev_ctx.template Alloc<T>(outs[j]);
out_ptrs.push_back(outs[j]->data<T>());
split_lists.push_back(outs[j]->dims()[axis]);
}
int r = xpu::split<T>(dev_ctx.x_context(),
x.data<T>(),
out_ptrs,
input_shape,
split_lists,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "split");
}
template <typename T, typename Context>
void SplitWithNumKernel(const Context& dev_ctx,
const DenseTensor& x,
int num,
const Scalar& axis_scalar,
std::vector<DenseTensor*> outs) {
int axis_value = axis_scalar.to<int>();
auto input_axis_dim = x.dims().at(axis_value);
std::vector<int64_t> sections_vec;
for (int i = 0; i < num; ++i) {
sections_vec.push_back(input_axis_dim / num);
}
IntArray sections(sections_vec);
SplitKernel<T, Context>(dev_ctx, x, sections, axis_scalar, outs);
}
} // namespace phi
PD_REGISTER_KERNEL(split, XPU, ALL_LAYOUT, phi::SplitKernel, float, int) {}
PD_REGISTER_KERNEL(
split_with_num, XPU, ALL_LAYOUT, phi::SplitWithNumKernel, float, int) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册