未验证 提交 6c5f9aa8 编写于 作者: Y ykkk2333 提交者: GitHub

migrate xpu...

migrate xpu activation/activation_grad/transpose/transpose_grad/tril_triu/tril_triu_grad kernel to PHI, test=kunlun (#45554)
上级 530f6b79
此差异已折叠。
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
#include "paddle/phi/kernels/instance_norm_grad_kernel.h"
#include "paddle/phi/kernels/instance_norm_kernel.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class InstanceNormXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* bias = ctx.Input<Tensor>("Bias");
auto* y = ctx.Output<Tensor>("Y");
auto* mean = ctx.Output<Tensor>("SavedMean");
auto* variance = ctx.Output<Tensor>("SavedVariance");
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// call phi kernel
phi::InstanceNormKernel<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
*scale,
*bias,
epsilon,
y,
mean,
variance);
}
};
template <typename DeviceContext, typename T>
class InstanceNormGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
const auto epsilon = ctx.Attr<float>("epsilon");
const auto* x = ctx.Input<Tensor>("X");
const auto* mean = ctx.Input<Tensor>("SavedMean");
const auto* variance = ctx.Input<Tensor>("SavedVariance");
const auto* scale = ctx.Input<Tensor>("Scale");
const auto* dy = ctx.Input<Tensor>(framework::GradVarName("Y"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dscale = ctx.Output<Tensor>(framework::GradVarName("Scale"));
auto* dbias = ctx.Output<Tensor>(framework::GradVarName("Bias"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
// call phi kernel
phi::InstanceNormGradKernel<T>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*x,
*dy,
*scale,
*mean,
*variance,
epsilon,
dx,
dbias,
dscale);
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
instance_norm,
ops::InstanceNormXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
instance_norm_grad,
ops::InstanceNormGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif // PADDLE_WITH_XPU}
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/operators/transpose_op.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
namespace operators {
using framework::Tensor;
template <typename DeviceContext, typename T>
class TransposeXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
auto x = context.Input<framework::Tensor>("X");
auto out = context.Output<framework::Tensor>("Out");
// axis is permute
auto axis = context.Attr<std::vector<int>>("axis");
int ndims = axis.size();
const auto x_dims = x->dims();
const T* x_data = x->data<T>();
T* y_data = out->mutable_data<T>(context.GetPlace());
if (out->numel() == 0) {
return;
}
std::vector<int> x_shape_host(ndims, 0);
for (int i = 0; i < ndims; ++i) {
x_shape_host[i] = x_dims[i];
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::transpose<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x_data),
reinterpret_cast<XPUType*>(y_data),
x_shape_host,
axis);
PADDLE_ENFORCE_EQ(
r,
xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r));
}
};
template <typename DeviceContext, typename T>
class TransposeGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
if (!x_grad) return;
x_grad->mutable_data<T>(context.GetPlace());
std::vector<int> axis = context.Attr<std::vector<int>>("axis");
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
std::vector<int> out_shape_host(ndims, 0);
for (int i = 0; i < ndims; ++i) {
out_shape_host[i] = out_grad->dims()[i];
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::transpose<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
out_shape_host,
reversed_axis);
PADDLE_ENFORCE_EQ(
r,
xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error! error code=%d", r));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
transpose,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
transpose_grad,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
transpose2,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
transpose2_grad,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::TransposeGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif // PADDLE_WITH_XPU
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. Licensed under
the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class TrilTriuXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* x = context.Input<framework::Tensor>("X");
const auto* x_data = x->data<T>();
auto* out = context.Output<framework::Tensor>("Out");
auto* out_data = out->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
auto xshape = phi::vectorize<int>(x->dims());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = 0;
if (lower) {
r = xpu::tril(dev_ctx.x_context(), x_data, out_data, xshape, diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op");
} else {
r = xpu::triu(dev_ctx.x_context(), x_data, out_data, xshape, diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op");
}
}
};
template <typename DeviceContext, typename T>
class TrilTriuGradXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const auto* d_out =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
const auto* dout_data = d_out->data<T>();
auto* d_x = context.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dx_data = d_x->mutable_data<T>(context.GetPlace());
const int diagonal = context.Attr<int>("diagonal");
const bool lower = context.Attr<bool>("lower");
auto dy_shape = phi::vectorize<int>(d_out->dims());
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = 0;
if (lower) {
r = xpu::tril(
dev_ctx.x_context(), dout_data, dx_data, dy_shape, diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op");
} else {
r = xpu::triu(
dev_ctx.x_context(), dout_data, dx_data, dy_shape, diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op");
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
tril_triu,
ops::TrilTriuXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::TrilTriuXPUKernel<paddle::platform::XPUDeviceContext, float>);
REGISTER_OP_XPU_KERNEL(
tril_triu_grad,
ops::TrilTriuGradXPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::TrilTriuGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/abs_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AbsGradKernel(const Context& ctx,
const DenseTensor& x,
const DenseTensor& dout,
DenseTensor* dx) {
ctx.template Alloc<T>(dx);
int r = xpu::abs_grad(ctx.x_context(),
x.data<T>(),
dout.data<T>(),
dout.data<T>(),
dx->data<T>(),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(abs_grad, XPU, ALL_LAYOUT, phi::AbsGradKernel, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/abs_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void AbsKernel(const Context& ctx, const DenseTensor& x, DenseTensor* out) {
ctx.template Alloc<T>(out);
int r = xpu::abs(ctx.x_context(), x.data<T>(), out->data<T>(), x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "abs");
}
} // namespace phi
PD_REGISTER_KERNEL(abs, XPU, ALL_LAYOUT, phi::AbsKernel, float) {}
此差异已折叠。
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/phi/kernels/activation_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/funcs/activation_functor.h"
#include "paddle/fluid/memory/memory.h"
namespace phi {
template <typename T, typename Context, typename Functor>
void ActivationXPUImpl(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
const Functor& functor) {
PADDLE_ENFORCE_NOT_NULL(out,
errors::NotFound("Output Out should not be nullptr"));
dev_ctx.template Alloc<T>(out);
functor(dev_ctx, x, out);
}
#define DEFINE_XPU_ACTIVATION_KERNEL(name, functor_class) \
template <typename T, typename Context> \
void name##Kernel( \
const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) { \
functor_class<T> functor; \
ActivationXPUImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \
}
#define DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(name, functor_class, attr) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr, \
DenseTensor* out) { \
functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr; \
ActivationXPUImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \
}
#define DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS( \
name, functor_class, attr1, attr2) \
template <typename T, typename Context> \
void name##Kernel(const Context& dev_ctx, \
const DenseTensor& x, \
float attr1, \
float attr2, \
DenseTensor* out) { \
functor_class<T> functor; \
auto attrs = functor.GetAttrs(); \
*(attrs[0].second) = attr1; \
*(attrs[1].second) = attr2; \
ActivationXPUImpl<T, Context, functor_class<T>>(dev_ctx, x, out, functor); \
}
template <typename Context, typename T, typename XPUType>
int xpu_activation_func(
const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
std::function<int(xpu::Context*, const XPUType*, XPUType*, int)> func) {
int r = func(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
return r;
}
template <typename Context, typename T, typename XPUType>
int xpu_activation_1attr_func(
const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
float attr,
std::function<int(xpu::Context*, const XPUType*, XPUType*, int, float)>
func) {
int r = func(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel(),
attr);
return r;
}
template <typename Context, typename T, typename XPUType>
int xpu_activation_2attr_func(
const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out,
float attr1,
float attr2,
std::function<
int(xpu::Context*, const XPUType*, XPUType*, int, float, float)> func) {
int r = func(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel(),
attr1,
attr2);
return r;
}
template <typename T>
struct XPUExpFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::exp<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "exp");
}
};
template <typename T>
struct XPULogFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::log<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "log");
}
};
template <typename T>
struct XPULeakyReluFunctor : public funcs::BaseActivationFunctor<T> {
float alpha;
typename funcs::BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"alpha", &alpha}};
}
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
using XPUType = typename XPUTypeTrait<T>::Type;
int r = xpu_activation_1attr_func<Context, T, XPUType>(
dev_ctx, x, out, alpha, xpu::leaky_relu<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "leaky_relu");
}
};
template <typename T, typename Context>
void PowKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& factor,
DenseTensor* out) {
dev_ctx.template Alloc<T>(out);
float pow_factor = factor.to<float>();
const T* x_data = x.data<T>();
T* y_data = out->data<T>();
auto xpu_context = dev_ctx.x_context();
// allocate temp memory for factor on xpu
xpu::ctx_guard RAII_GUARD(xpu_context);
T* factor_data = RAII_GUARD.alloc_l3_or_gm<T>(1);
PADDLE_ENFORCE_NOT_NULL(
factor_data, errors::External("XPU alloc_l3_or_gm returns nullptr"));
paddle::memory::Copy(dev_ctx.GetPlace(),
static_cast<void*>(factor_data),
phi::CPUPlace(),
static_cast<void*>(&pow_factor),
sizeof(T));
// broadcast_pow(Context* ctx, const T* x, const T* y, T* z, const
// std::vector<int>& xshape, const std::vector<int>& yshape);
auto x_dims = vectorize<int>(x.dims());
int r =
xpu::broadcast_pow(xpu_context, x_data, factor_data, y_data, x_dims, {1});
PADDLE_ENFORCE_XDNN_SUCCESS(r, "broadcast_pow");
}
template <typename T>
struct XPUHardSwishFunctor : public funcs::BaseActivationFunctor<T> {
float threshold;
float scale;
float offset;
typename funcs::BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}, {"scale", &scale}, {"offset", &offset}};
}
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
using XPUType = typename XPUTypeTrait<T>::Type;
PADDLE_ENFORCE_EQ(
threshold,
6.0f,
errors::External("Not support threshold [%f] in XPU", threshold));
PADDLE_ENFORCE_EQ(
scale, 6.0f, errors::External("Not support scale [%f] in XPU", scale));
PADDLE_ENFORCE_EQ(
offset,
3.0f,
errors::External("Not support offset [%f] in XPU", offset));
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::hard_swish<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "hard_swish");
}
};
template <typename T>
struct XPUReciprocalFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::reciprocal<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "reciprocal");
}
};
template <typename T>
struct XPUReluFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
const XPUType* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>());
auto xpu_context = dev_ctx.x_context();
int r = xpu::relu(xpu_context, x_data, y_data, x.numel(), nullptr, nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu");
}
};
template <typename T>
struct XPURelu6Functor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
float threshold;
typename funcs::BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::relu6<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "relu6");
}
};
template <typename T>
struct XPUSigmoidFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::sigmoid<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sigmoid");
}
};
template <typename T>
struct XPUSquareFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::square<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "square");
}
};
template <typename T>
struct XPUSqrtFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::sqrt<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sqrt");
}
};
template <typename T>
struct XPUMishFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
float threshold;
typename funcs::BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"threshold", &threshold}};
}
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_1attr_func<Context, T, XPUType>(
dev_ctx, x, out, threshold, xpu::mish<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mish");
}
};
template <typename T, typename Context>
void SwishKernel(const Context& dev_ctx,
const DenseTensor& x,
float beta,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(out);
int r = xpu::swish(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x.numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish");
}
template <typename T>
struct XPUSoftplusFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
float beta;
float threshold;
typename funcs::BaseActivationFunctor<T>::AttrPair GetAttrs() {
return {{"beta", &beta}, {"threshold", &threshold}};
}
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_2attr_func<Context, T, XPUType>(
dev_ctx, x, out, beta, threshold, xpu::softplus<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softplus");
}
};
template <typename T>
struct XPUTanhFunctor : public funcs::BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
template <typename Context>
void operator()(const Context& dev_ctx,
const DenseTensor& x,
DenseTensor* out) const {
int r = xpu_activation_func<Context, T, XPUType>(
dev_ctx, x, out, xpu::tanh<XPUType>);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "tanh");
}
};
DEFINE_XPU_ACTIVATION_KERNEL(Exp, XPUExpFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Log, XPULogFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Reciprocal, XPUReciprocalFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Relu, XPUReluFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Sigmoid, XPUSigmoidFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Square, XPUSquareFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Sqrt, XPUSqrtFunctor)
DEFINE_XPU_ACTIVATION_KERNEL(Tanh, XPUTanhFunctor)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Mish, XPUMishFunctor, threshold)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(LeakyRelu,
XPULeakyReluFunctor,
alpha)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_ONE_ATTRS(Relu6, XPURelu6Functor, threshold)
DEFINE_XPU_ACTIVATION_KERNEL_WITH_TWO_ATTRS(Softplus,
XPUSoftplusFunctor,
beta,
threshold)
template <typename T, typename Context>
void HardSwishKernel(const Context& dev_ctx,
const DenseTensor& x,
float threshold,
float scale,
float offset,
DenseTensor* out) {
XPUHardSwishFunctor<T> functor;
auto attrs = functor.GetAttrs();
*(attrs[0].second) = threshold;
*(attrs[1].second) = scale;
*(attrs[2].second) = offset;
ActivationXPUImpl<T, Context, XPUHardSwishFunctor<T>>(
dev_ctx, x, out, functor);
}
} // namespace phi
PD_REGISTER_KERNEL(
relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
PD_REGISTER_KERNEL(
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(log, LogKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_swish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6, Relu6Kernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish, SwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(square, SquareKernel)
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/transpose_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TransposeGradKernel(const Context& dev_ctx,
const DenseTensor& out_grad,
const std::vector<int>& axis,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
dev_ctx.template Alloc<T>(x_grad);
std::vector<int> reversed_axis(axis);
for (size_t i = 0; i < axis.size(); i++) {
reversed_axis[axis[i]] = i;
}
int ndims = axis.size();
std::vector<int> out_shape_host(ndims, 0);
for (int i = 0; i < ndims; ++i) {
out_shape_host[i] = out_grad.dims()[i];
}
int r = xpu::transpose<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
out_shape_host,
reversed_axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose_grad");
}
} // namespace phi
PD_REGISTER_KERNEL(transpose_grad,
XPU,
ALL_LAYOUT,
phi::TransposeGradKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/transpose_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TransposeKernel(const Context& dev_ctx,
const DenseTensor& x,
const std::vector<int>& axis,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
if (out->numel() == 0) {
return;
}
dev_ctx.template Alloc<T>(out);
int ndims = axis.size();
std::vector<int> x_shape_host(ndims, 0);
for (int i = 0; i < ndims; ++i) {
x_shape_host[i] = x.dims()[i];
}
int r = xpu::transpose<XPUType>(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
x_shape_host,
axis);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "transpose");
}
} // namespace phi
PD_REGISTER_KERNEL(transpose,
XPU,
ALL_LAYOUT,
phi::TransposeKernel,
float,
phi::dtype::float16) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tril_triu_grad_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TrilTriuGradKernel(const Context& ctx,
const DenseTensor& out_grad,
int diagonal,
bool lower,
DenseTensor* x_grad) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(x_grad);
auto dy_shape = vectorize<int>(out_grad.dims());
int r = 0;
if (lower) {
r = xpu::tril(ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
dy_shape,
diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op");
} else {
r = xpu::triu(ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad.data<T>()),
reinterpret_cast<XPUType*>(x_grad->data<T>()),
dy_shape,
diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
tril_triu_grad, XPU, ALL_LAYOUT, phi::TrilTriuGradKernel, int, float) {}
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/tril_triu_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
template <typename T, typename Context>
void TrilTriuKernel(const Context& ctx,
const DenseTensor& x,
int diagonal,
bool lower,
DenseTensor* out) {
using XPUType = typename XPUTypeTrait<T>::Type;
ctx.template Alloc<T>(out);
auto xshape = vectorize<int>(x.dims());
int r = 0;
if (lower) {
r = xpu::tril(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "tril_op");
} else {
r = xpu::triu(ctx.x_context(),
reinterpret_cast<const XPUType*>(x.data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()),
xshape,
diagonal);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "triu_op");
}
}
} // namespace phi
PD_REGISTER_KERNEL(
tril_triu, XPU, ALL_LAYOUT, phi::TrilTriuKernel, int, float) {}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册