未验证 提交 442688a8 编写于 作者: T taixiurong 提交者: GitHub

add some ops support fp16 in kunlun2 (#36854)

* aaaa

* add some ops support fp16 in kunlun2
上级 113816d8
......@@ -53,14 +53,14 @@ class XPUActivationGradKernel
}
};
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, typename XPUT>
void xpu_activation_forward(
const framework::ExecutionContext &ctx,
std::function<int(xpu::Context *, const T *, T *, int)> func) {
std::function<int(xpu::Context *, const XPUT *, XPUT *, int)> func) {
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Output<Tensor>("Out");
const T *x_data = x->data<T>();
T *y_data = y->mutable_data<T>(ctx.GetPlace());
const XPUT *x_data = reinterpret_cast<const XPUT *>(x->data<T>());
XPUT *y_data = reinterpret_cast<XPUT *>(y->mutable_data<T>(ctx.GetPlace()));
auto xpu_context = ctx.device_context<DeviceContext>().x_context();
int r = func(xpu_context, x_data, y_data, x->numel());
......@@ -70,23 +70,24 @@ void xpu_activation_forward(
r, XPUAPIErrorMsg[r]));
}
template <typename DeviceContext, typename T>
void xpu_activation_backward(const framework::ExecutionContext &ctx,
std::function<int(xpu::Context *, const T *,
const T *, const T *, T *, int)>
func) {
template <typename DeviceContext, typename T, typename XPUT>
void xpu_activation_backward(
const framework::ExecutionContext &ctx,
std::function<int(xpu::Context *, const XPUT *, const XPUT *, const XPUT *,
XPUT *, int)>
func) {
/* TODO: relu tanh sigmoid are inplace */
const auto *x = ctx.Input<Tensor>("X");
auto *y = ctx.Input<Tensor>("Out");
auto *dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto *dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
const T *x_data = nullptr;
const T *y_data = nullptr;
const T *y_grad = nullptr;
if (x != nullptr) x_data = x->data<T>();
if (y != nullptr) y_data = y->data<T>();
if (dOut != nullptr) y_grad = dOut->data<T>();
T *x_grad = dX->mutable_data<T>(ctx.GetPlace());
const XPUT *x_data = nullptr;
const XPUT *y_data = nullptr;
const XPUT *y_grad = nullptr;
if (x != nullptr) x_data = reinterpret_cast<const XPUT *>(x->data<T>());
if (y != nullptr) y_data = reinterpret_cast<const XPUT *>(y->data<T>());
if (dOut != nullptr) y_grad = reinterpret_cast<const XPUT *>(dOut->data<T>());
XPUT *x_grad = reinterpret_cast<XPUT *>(dX->mutable_data<T>(ctx.GetPlace()));
auto xpu_context = ctx.device_context<DeviceContext>().x_context();
int r = func(xpu_context, x_data, y_data, y_grad, x_grad, dX->numel());
......@@ -98,65 +99,64 @@ void xpu_activation_backward(const framework::ExecutionContext &ctx,
template <typename T>
struct XPUReluFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
xpu::relu<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::relu<XPUType>);
}
};
template <typename T>
struct XPUSigmoidFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::sigmoid<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::sigmoid<XPUType>);
}
};
template <typename T>
struct XPUTanhFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
xpu::tanh<T>);
}
};
template <typename T>
struct XPUGeluFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
xpu::gelu<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::tanh<XPUType>);
}
};
template <typename T>
struct XPULogFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
xpu::log<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::log<XPUType>);
}
};
template <typename T>
struct XPUSquareFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::square<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::square<XPUType>);
}
};
template <typename T>
struct XPUSqrtFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
xpu::sqrt<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::sqrt<XPUType>);
}
};
template <typename T>
struct XPUAbsFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(ctx,
xpu::abs<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::abs<XPUType>);
}
};
......@@ -196,6 +196,7 @@ struct XPUPowFunctor : public BaseActivationFunctor<T> {
template <typename T>
struct XPUHardSwishFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
......@@ -208,61 +209,59 @@ struct XPUHardSwishFunctor : public BaseActivationFunctor<T> {
PADDLE_ENFORCE_EQ(
offset, 3.0f,
platform::errors::External("Not support offset [%f] in XPU", offset));
xpu_activation_forward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::hard_swish<T>);
xpu_activation_forward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::hard_swish<XPUType>);
}
};
template <typename T>
struct XPUReluGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::relu_grad<T>);
xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::relu_grad<XPUType>);
}
};
template <typename T>
struct XPUTanhGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::tanh_grad<T>);
xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::tanh_grad<XPUType>);
}
};
template <typename T>
struct XPUSigmoidGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::sigmoid_grad<T>);
}
};
template <typename T>
struct XPUGeluGradFunctor : public BaseActivationFunctor<T> {
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::gelu_grad<T>);
xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::sigmoid_grad<XPUType>);
}
};
template <typename T>
struct XPUSqrtGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::sqrt_grad<T>);
xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::sqrt_grad<XPUType>);
}
};
template <typename T>
struct XPUSquareGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::square_grad<T>);
xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::square_grad<XPUType>);
}
};
template <typename T>
struct XPUHardSwishGradFunctor : public BaseActivationFunctor<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
void operator()(const framework::ExecutionContext &ctx) const {
float threshold = ctx.Attr<float>("threshold");
float scale = ctx.Attr<float>("scale");
......@@ -275,8 +274,8 @@ struct XPUHardSwishGradFunctor : public BaseActivationFunctor<T> {
PADDLE_ENFORCE_EQ(
offset, 3.0f,
platform::errors::External("Not support offset [%f] in XPU", offset));
xpu_activation_backward<paddle::platform::XPUDeviceContext, T>(
ctx, xpu::hard_swish_grad<T>);
xpu_activation_backward<paddle::platform::XPUDeviceContext, T, XPUType>(
ctx, xpu::hard_swish_grad<XPUType>);
}
};
......@@ -342,16 +341,23 @@ namespace ops = paddle::operators;
ops::XPUActivationGradKernel<ops::grad_functor<float>>);
REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(tanh, XPUTanhFunctor, XPUTanhGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor,
XPUSigmoidGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(gelu, XPUGeluFunctor, XPUGeluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSquareGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(hard_swish, XPUHardSwishFunctor,
XPUHardSwishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor,
XPULeakyReluGradFunctor)
REGISTER_OP_XPU_KERNEL(
tanh, ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>,
ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
tanh_grad, ops::XPUActivationGradKernel<ops::XPUTanhGradFunctor<float>>,
ops::XPUActivationGradKernel<
ops::XPUTanhGradFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(log,
ops::XPUActivationKernel<ops::XPULogFunctor<float>>);
REGISTER_OP_XPU_KERNEL(pow,
......
......@@ -74,27 +74,15 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
platform::errors::External("XPU API(logical_not) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::isnan(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
is_nan.data<bool>(), x->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(isnan) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::logical_or(dev_ctx.x_context(), is_finite.data<bool>(),
is_nan.data<bool>(), is_finite.data<bool>(),
x->numel());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(logical_or) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::any(dev_ctx.x_context(), is_finite.data<bool>(),
found_inf_data, x->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(any) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
memory::Copy(platform::CPUPlace(), &cpu_found_inf_data,
BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
found_inf_data, sizeof(bool));
......@@ -103,12 +91,12 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
if (cpu_found_inf_data) {
inverse_scale = 0.0;
}
auto dev_env = XPUEnv::getenv("XPUSIM_DEVICE_MODEL");
paddle::platform::XPUVersion version = dev_ctx.xpu_version();
framework::Tensor float_x;
framework::Tensor float_out;
if (std::is_same<T, paddle::platform::float16>::value &&
(dev_env == nullptr || std::strcmp(dev_env, "KUNLUN1"))) {
framework::Tensor float_x;
framework::Tensor float_out;
(version == paddle::platform::XPUVersion::XPU1)) {
float_x.mutable_data<MPDType>(dev_ctx.GetPlace(),
x->numel() * sizeof(MPDType));
float_out.mutable_data<MPDType>(dev_ctx.GetPlace(),
......@@ -137,10 +125,6 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
"XPU API(cast_v2) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
} else {
int r = xpu::scale(dev_ctx.x_context(),
reinterpret_cast<const XPUTyp*>(x->data<T>()),
......@@ -152,6 +136,9 @@ class CheckFiniteAndUnscaleXPUKernel : public framework::OpKernel<T> {
r, XPUAPIErrorMsg[r]));
}
}
if (dev_ctx.x_context()->xpu_stream) {
dev_ctx.Wait();
}
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
found_inf_data, platform::CPUPlace(), &cpu_found_inf_data,
sizeof(bool));
......
......@@ -113,10 +113,9 @@ class UpdateLossScalingXPUKernel : public framework::OpKernel<T> {
} else {
cpu_pre_loss_scaling_data = (*pre_loss_scaling_data);
}
int cpu_good_out_data = 0;
int cpu_bad_out_data = 0;
MPDType cpu_updated_loss_scaling_data;
MPDType cpu_updated_loss_scaling_data = cpu_pre_loss_scaling_data;
if (cpu_found_inf_data) {
cpu_good_out_data = 0;
......@@ -140,8 +139,7 @@ class UpdateLossScalingXPUKernel : public framework::OpKernel<T> {
cpu_good_out_data = 0;
}
}
// copy to host
// copy to device
memory::Copy(BOOST_GET_CONST(platform::XPUPlace, dev_ctx.GetPlace()),
bad_out_data, platform::CPUPlace(), &cpu_bad_out_data,
sizeof(int));
......
......@@ -17,8 +17,11 @@ namespace ops = paddle::operators;
#ifdef PADDLE_WITH_XPU
REGISTER_OP_XPU_KERNEL(
fill_constant, ops::FillConstantKernel<float>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<double>,
ops::FillConstantKernel<bool>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<double>, ops::FillConstantKernel<uint8_t>,
ops::FillConstantKernel<int16_t>, ops::FillConstantKernel<int>,
ops::FillConstantKernel<int64_t>, ops::FillConstantKernel<bool>,
ops::FillConstantKernel<paddle::platform::float16>,
ops::FillConstantKernel<paddle::platform::bfloat16>,
ops::FillConstantKernel<paddle::platform::complex<float>>,
ops::FillConstantKernel<paddle::platform::complex<double>>);
#endif
......@@ -24,6 +24,8 @@ namespace operators {
template <typename T>
class GatherOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
......@@ -63,13 +65,16 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
auto &dev_ctx = ctx.template device_context<platform::XPUDeviceContext>();
int r = XPU_SUCCESS;
if (index->type() == framework::proto::VarType::INT32) {
r = xpu::gather<T, int>(dev_ctx.x_context(), x->data<T>(),
index->data<int>(), output->data<T>(), xshape,
index->dims()[0], 0);
r = xpu::gather<XPUType, int>(
dev_ctx.x_context(), reinterpret_cast<const XPUType *>(x->data<T>()),
index->data<int>(), reinterpret_cast<XPUType *>(output->data<T>()),
xshape, index->dims()[0], 0);
} else {
r = xpu::gather<T, int64_t>(dev_ctx.x_context(), x->data<T>(),
index->data<int64_t>(), output->data<T>(),
xshape, index->dims()[0], 0);
r = xpu::gather<XPUType, int64_t>(
dev_ctx.x_context(), reinterpret_cast<const XPUType *>(x->data<T>()),
index->data<int64_t>(),
reinterpret_cast<XPUType *>(output->data<T>()), xshape,
index->dims()[0], 0);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
......@@ -80,6 +85,8 @@ class GatherOpXPUKernel : public framework::OpKernel<T> {
template <typename T>
class GatherGradOpXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext &ctx) const override {
PADDLE_ENFORCE_EQ(
......@@ -123,13 +130,28 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
int r = XPU_SUCCESS;
if (index->type() == framework::proto::VarType::INT32) {
r = xpu::gather_grad<T, int>(dev_ctx.x_context(), dout->data<T>(),
index->data<int>(), dx->data<T>(), xshape,
index->dims()[0], 0, overwrite);
r = xpu::gather_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(dout->data<T>()),
index->data<int>(), reinterpret_cast<XPUType *>(dx->data<T>()),
xshape, index->dims()[0], 0, overwrite);
} else {
r = xpu::gather_grad<T, int64_t>(dev_ctx.x_context(), dout->data<T>(),
index->data<int64_t>(), dx->data<T>(),
xshape, index->dims()[0], 0, overwrite);
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
int *index_int_ptr_l3 =
RAII_GUARD.alloc_l3_or_gm<int32_t>(index->numel());
r = xpu::cast_v2<int64_t, int32_t>(dev_ctx.x_context(),
index->data<int64_t>(),
index_int_ptr_l3, index->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, platform::errors::External(
"XPU API(cast_v2) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::gather_grad<XPUType, int>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType *>(dout->data<T>()), index_int_ptr_l3,
reinterpret_cast<XPUType *>(dx->data<T>()), xshape, index->dims()[0],
0, overwrite);
}
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS,
platform::errors::External(
......@@ -142,6 +164,8 @@ class GatherGradOpXPUKernel : public framework::OpKernel<T> {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(gather, ops::GatherOpXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(gather_grad, ops::GatherGradOpXPUKernel<float>);
REGISTER_OP_XPU_KERNEL(gather, ops::GatherOpXPUKernel<float>,
ops::GatherOpXPUKernel<paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(gather_grad, ops::GatherGradOpXPUKernel<float>,
ops::GatherGradOpXPUKernel<paddle::platform::float16>);
#endif
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/gelu_op.h"
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class GeluXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* out = ctx.Output<Tensor>("Out");
auto place = ctx.GetPlace();
const XPUType* x_data = reinterpret_cast<const XPUType*>(x->data<T>());
XPUType* y_data = reinterpret_cast<XPUType*>(out->mutable_data<T>(place));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::gelu<XPUType>(dev_ctx.x_context(), x_data, y_data, x->numel());
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU gelu kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
};
template <typename DeviceContext, typename T>
class GeluGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<Tensor>("X");
auto* dout = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto place = ctx.GetPlace();
const XPUType* x_data = reinterpret_cast<const XPUType*>(x->data<T>());
const XPUType* dout_data =
reinterpret_cast<const XPUType*>(dout->data<T>());
XPUType* dx_data = reinterpret_cast<XPUType*>(dx->mutable_data<T>(place));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::gelu_grad<XPUType>(dev_ctx.x_context(), x_data, nullptr,
dout_data, dx_data, dout->numel());
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU gelu_grad kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
gelu, ops::GeluXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GeluXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
gelu_grad,
ops::GeluGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::GeluGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
......@@ -85,9 +85,10 @@ class SoftmaxOp : public framework::OperatorWithKernel {
#ifndef PADDLE_WITH_ASCEND_CL
if (input_data_type == framework::proto::VarType::FP16) {
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()), true,
platform::errors::InvalidArgument(
"float16 can only be used on GPU place"));
PADDLE_ENFORCE_EQ(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace()),
true, platform::errors::InvalidArgument(
"float16 can only be used on GPU/XPU place"));
}
#endif
......@@ -214,9 +215,10 @@ class SoftmaxOpGrad : public framework::OperatorWithKernel {
#endif
if (input_data_type == framework::proto::VarType::FP16) {
if (!(platform::is_gpu_place(ctx.GetPlace()) ||
platform::is_npu_place(ctx.GetPlace())))
platform::is_npu_place(ctx.GetPlace()) ||
platform::is_xpu_place(ctx.GetPlace())))
PADDLE_THROW(platform::errors::InvalidArgument(
"float16 can only be used on GPU/NPU place"));
"float16 can only be used on GPU/NPU/XPU place"));
}
return framework::OpKernelType(input_data_type, ctx.GetPlace(), layout_,
......
......@@ -22,6 +22,8 @@ using DDim = framework::DDim;
template <typename DeviceContext, typename T>
class SoftmaxXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* x = context.Input<Tensor>("X");
......@@ -43,29 +45,43 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = XPU_SUCCESS;
Tensor clip_x;
int len = x->numel();
T* clip_x_data =
clip_x.mutable_data<T>(context.GetPlace(), len * sizeof(T));
r = xpu::clip_v2(dev_ctx.x_context(), x->data<float>(), clip_x_data, len,
static_cast<float>(-1e20), static_cast<float>(1e20));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(clip) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::softmax<T>(dev_ctx.x_context(), clip_x_data, out->data<float>(),
x_dims, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_forward) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
paddle::platform::XPUVersion version = dev_ctx.xpu_version();
if (version == paddle::platform::XPUVersion::XPU1) {
xpu::ctx_guard RAII_GUARD(dev_ctx.x_context());
XPUType* clip_x_data_l3 = RAII_GUARD.alloc_l3_or_gm<XPUType>(x->numel());
r = xpu::clip_v2(dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(x->data<T>()),
clip_x_data_l3, x->numel(), static_cast<XPUType>(-1e20),
static_cast<XPUType>(1e20));
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU API(clip_v2) return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
r = xpu::softmax<XPUType>(dev_ctx.x_context(), clip_x_data_l3,
reinterpret_cast<XPUType*>(out->data<T>()),
x_dims, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_forward) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
} else {
r = xpu::softmax<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(x->data<T>()),
reinterpret_cast<XPUType*>(out->data<T>()), x_dims, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_forward) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
}
};
template <typename DeviceContext, typename T>
class SoftmaxGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* out = context.Input<Tensor>("Out");
......@@ -86,9 +102,10 @@ class SoftmaxGradXPUKernel : public framework::OpKernel<T> {
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::softmax_grad<T>(dev_ctx.x_context(), out->data<float>(),
dout->data<float>(), dx->data<float>(), x_dims,
axis);
int r = xpu::softmax_grad<XPUType>(
dev_ctx.x_context(), reinterpret_cast<const XPUType*>(out->data<T>()),
reinterpret_cast<const XPUType*>(dout->data<T>()),
reinterpret_cast<XPUType*>(dx->data<T>()), x_dims, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_backward) return wrong "
......@@ -103,9 +120,13 @@ class SoftmaxGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
softmax, ops::SoftmaxXPUKernel<paddle::platform::XPUDeviceContext, float>);
softmax, ops::SoftmaxXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SoftmaxXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
softmax_grad,
ops::SoftmaxGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::SoftmaxGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::SoftmaxGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif // PADDLE_WITH_XPU
......@@ -186,7 +186,36 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::FP32, XPUPlace())})},
{"scale", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace())})}
pOpKernelType(vartype::INT64, XPUPlace())})},
{"tanh", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"tanh_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gelu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gelu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gather", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"gather_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"fill_constant",
XPUKernelSet({pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT16, XPUPlace()),
pOpKernelType(vartype::INT8, XPUPlace()),
pOpKernelType(vartype::BOOL, XPUPlace()),
pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::BF16, XPUPlace()),
pOpKernelType(vartype::COMPLEX64, XPUPlace()),
pOpKernelType(vartype::COMPLEX128, XPUPlace())})},
{"softmax", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"softmax_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}
// AddMore
};
......
......@@ -19,6 +19,7 @@
#include <string>
#include <unordered_map>
#include "paddle/fluid/platform/bfloat16.h"
#include "paddle/fluid/platform/errors.h"
#include "paddle/fluid/platform/float16.h"
#include "xpu/runtime.h"
......@@ -68,4 +69,10 @@ class XPUTypeTrait<paddle::platform::float16> {
using Type = float16;
};
template <>
class XPUTypeTrait<paddle::platform::bfloat16> {
public:
using Type = bfloat16;
};
#endif
......@@ -89,6 +89,8 @@ class XPUOpTest(OpTest):
if self.dtype == np.float16:
if core.is_float16_supported(place) == False:
return
if self.dtype == np.float16:
atol = 0.1
return super().check_output_with_place(
place, atol, no_check_set, equal_nan, check_dygraph, inplace_atol)
......@@ -115,6 +117,7 @@ class XPUOpTest(OpTest):
return
if self.dtype == np.float16:
max_relative_error = 1.0
return super().check_grad_with_place(
place, inputs_to_check, output_names, no_grad_set,
numeric_grad_delta, in_place, max_relative_error,
......
......@@ -95,6 +95,26 @@ class TestXPUTanh(TestXPUActivation):
self.check_grad_with_place(place, ['X'], 'Out')
class TestXPUTanhFP16(TestXPUActivation):
def setUp(self):
self.op_type = "tanh"
self.init_dtype()
x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype)
out = np.tanh(x)
self.attrs = {'use_xpu': True}
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)}
self.outputs = {'Out': out}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUSqrt(TestXPUActivation):
......@@ -177,6 +197,27 @@ class TestXPUGelu(TestXPUActivation):
self.check_grad_with_place(place, ['X'], 'Out')
class TestXPUGelu(TestXPUActivation):
def setUp(self):
self.op_type = "gelu"
self.init_dtype()
approximate = False
x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype)
out = gelu(x, approximate)
self.inputs = {'X': x}
self.outputs = {'Out': out}
self.attrs = {"approximate": approximate, 'use_xpu': True}
def init_dtype(self):
self.dtype = np.float16
def test_check_grad(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def gelu(x, approximate):
if approximate:
y_ref = 0.5 * x * (1.0 + np.tanh(
......
......@@ -36,7 +36,6 @@ def gather_numpy(x, index, axis):
class TestXPUGatherOp(XPUOpTest):
def setUp(self):
self.dtype = "float32"
self.op_type = "gather"
self.use_xpu = True
self.use_mkldnn = False
......@@ -50,6 +49,16 @@ class TestXPUGatherOp(XPUOpTest):
}
self.outputs = {'Out': self.inputs["X"][self.inputs["Index"]]}
def config(self):
"""
For multi-dimension input
"""
self.dtype = np.float32
self.x_shape = (10, 20)
self.x_type = np.float32
self.index = [1, 3, 5]
self.index_type = np.int32
def test_check_output(self):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
......@@ -60,25 +69,17 @@ class TestXPUGatherOp(XPUOpTest):
place = paddle.XPUPlace(0)
self.check_grad_with_place(place, ['X'], 'Out')
def config(self):
"""
For multi-dimension input
"""
self.x_shape = (10, 20)
self.x_type = "float32"
self.index = [1, 3, 5]
self.index_type = "int32"
class TestCase1(TestXPUGatherOp):
def config(self):
"""
For one dimension input
"""
self.dtype = np.float32
self.x_shape = (100)
self.x_type = "float32"
self.x_type = np.float32
self.index = [1, 3, 5]
self.index_type = "int32"
self.index_type = np.int32
class TestCase2(TestXPUGatherOp):
......@@ -86,10 +87,11 @@ class TestCase2(TestXPUGatherOp):
"""
For int64_t index type
"""
self.dtype = np.float32
self.x_shape = (100)
self.x_type = "float32"
self.x_type = np.float32
self.index = [1, 3, 5]
self.index_type = "int32"
self.index_type = np.int64
class TestCase3(TestXPUGatherOp):
......@@ -97,46 +99,128 @@ class TestCase3(TestXPUGatherOp):
"""
For other input type
"""
self.dtype = np.float32
self.x_shape = (10, 20)
self.x_type = "float32"
self.x_type = np.float32
self.index = [1, 3, 5]
self.index_type = "int32"
self.index_type = np.int32
class TestCase4(TestXPUGatherOp):
def config(self):
self.dtype = np.float32
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': False}
self.x_type = "float32"
self.x_type = np.float32
self.index = [1, 1]
self.index_type = "int32"
self.index_type = np.int32
class TestCase5(TestXPUGatherOp):
def config(self):
self.dtype = np.float32
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': False}
self.x_type = "float32"
self.x_type = np.float32
self.index = [1, 1, 3]
self.index_type = "int32"
self.index_type = np.int32
class TestCase6(TestXPUGatherOp):
def config(self):
self.dtype = np.float32
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': True}
self.x_type = "float32"
self.x_type = np.float32
self.index = [1, 3]
self.index_type = "int32"
self.index_type = np.int32
class TestCase7(TestXPUGatherOp):
def config(self):
self.dtype = np.float32
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': True}
self.x_type = np.float32
self.index = [1, 3]
self.index_type = np.int64
## test fp16
class TestCaseFP161(TestXPUGatherOp):
def config(self):
"""
For one dimension input
"""
self.dtype = np.float16
self.x_shape = (100)
self.x_type = np.float16
self.index = [1, 3, 5]
self.index_type = np.int32
class TestCaseFP162(TestXPUGatherOp):
def config(self):
"""
For int64_t index type
"""
self.dtype = np.float16
self.x_shape = (100)
self.x_type = np.float16
self.index = [1, 3, 5]
self.index_type = np.int64
class TestCaseFP163(TestXPUGatherOp):
def config(self):
"""
For other input type
"""
self.dtype = np.float16
self.x_shape = (10, 20)
self.x_type = np.float16
self.index = [1, 3, 5]
self.index_type = np.int32
class TestCaseFP164(TestXPUGatherOp):
def config(self):
self.dtype = np.float16
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': False}
self.x_type = np.float16
self.index = [1, 1]
self.index_type = np.int32
class TestCaseFP165(TestXPUGatherOp):
def config(self):
self.dtype = np.float16
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': False}
self.x_type = np.float16
self.index = [1, 1, 3]
self.index_type = np.int32
class TestCaseFP166(TestXPUGatherOp):
def config(self):
self.dtype = np.float16
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': True}
self.x_type = np.float16
self.index = [1, 3]
self.index_type = np.int32
class TestCaseFP167(TestXPUGatherOp):
def config(self):
self.dtype = np.float16
self.x_shape = (10, 20)
self.attrs = {'use_xpu': True, 'overwrite': True}
self.x_type = "float32"
self.x_type = np.float16
self.index = [1, 3]
self.index_type = "int64"
self.index_type = np.int64
if __name__ == "__main__":
......
......@@ -17,8 +17,7 @@ import numpy as np
import sys
import unittest
sys.path.append("..")
from op_test import OpTest
from op_test_xpu import XPUOpTest
paddle.enable_static()
np.random.seed(10)
......@@ -41,15 +40,13 @@ def ref_softmax(x, axis=None, dtype=None):
return np.apply_along_axis(stable_softmax, axis, x_t)
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUSoftmaxOp(OpTest):
class TestXPUSoftmaxOp(XPUOpTest):
def setUp(self):
self.op_type = "softmax"
self.dtype = np.float32
self.shape = [2, 3, 4, 5]
self.axis = -1
self.set_attrs()
self.init_type()
x = np.random.uniform(-1, 1, self.shape).astype(self.dtype)
out = np.apply_along_axis(stable_softmax, self.axis, x)
......@@ -58,6 +55,9 @@ class TestXPUSoftmaxOp(OpTest):
self.outputs = {'Out': out}
self.attrs = {'axis': self.axis, 'use_xpu': True}
def init_type(self):
self.dtype = np.float16
def set_attrs(self):
pass
......@@ -68,26 +68,35 @@ class TestXPUSoftmaxOp(OpTest):
self.check_grad_with_place(paddle.XPUPlace(0), ['X'], 'Out')
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUSoftmaxAxis3(TestXPUSoftmaxOp):
def set_attrs(self):
self.axis = 3
# class TestXPUSoftmaxAxis3(TestXPUSoftmaxOp):
# def set_attrs(self):
# self.axis = 3
# class TestXPUSoftmax2D(TestXPUSoftmaxOp):
# def set_attrs(self):
# self.shape = [10, 12]
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUSoftmax2D(TestXPUSoftmaxOp):
def set_attrs(self):
self.shape = [10, 12]
# class TestXPUSoftmax3D(TestXPUSoftmaxOp):
# def set_attrs(self):
# self.shape = [4, 5, 6]
# class TestXPUSoftmaxAxis3FP16(TestXPUSoftmaxOp):
# def set_attrs(self):
# self.axis = 3
# def init_type(self):
# self.dtype = np.float16
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestXPUSoftmax3D(TestXPUSoftmaxOp):
def set_attrs(self):
self.shape = [4, 5, 6]
# class TestXPUSoftmax2DFP16(TestXPUSoftmaxOp):
# def set_attrs(self):
# self.shape = [10, 12]
# def init_type(self):
# self.dtype = np.float16
# class TestXPUSoftmax3DFP16(TestXPUSoftmaxOp):
# def set_attrs(self):
# self.shape = [4, 5, 6]
# def init_type(self):
# self.dtype = np.float16
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册