未验证 提交 c4d5a77f 编写于 作者: Z zhangyikun02 提交者: GitHub

concat and relu sopport FP16 in XPU, test=kunlun (#41631)

上级 468c1ad7
...@@ -490,7 +490,6 @@ REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor, ...@@ -490,7 +490,6 @@ REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor,
XPULeakyReluGradFunctor) XPULeakyReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, XPUReciprocalFunctor, REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, XPUReciprocalFunctor,
XPUReciprocalGradFunctor) XPUReciprocalGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor, REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor,
XPUSigmoidGradFunctor) XPUSigmoidGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
...@@ -500,6 +499,13 @@ REGISTER_ACTIVATION_XPU_KERNEL(softplus, XPUSoftPlusFunctor, ...@@ -500,6 +499,13 @@ REGISTER_ACTIVATION_XPU_KERNEL(softplus, XPUSoftPlusFunctor,
REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor) REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor)
REGISTER_OP_XPU_KERNEL(
relu, ops::XPUActivationKernel<ops::XPUReluFunctor<float>>,
ops::XPUActivationKernel<ops::XPUReluFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
relu_grad, ops::XPUActivationGradKernel<ops::XPUReluGradFunctor<float>>,
ops::XPUActivationGradKernel<
ops::XPUReluGradFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
tanh, ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>, tanh, ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>,
ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>); ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>);
......
...@@ -26,6 +26,8 @@ using Tensor = framework::Tensor; ...@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConcatXPUKernel : public framework::OpKernel<T> { class ConcatXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X"); auto ins = ctx.MultiInput<framework::LoDTensor>("X");
...@@ -79,10 +81,10 @@ class ConcatXPUKernel : public framework::OpKernel<T> { ...@@ -79,10 +81,10 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace(); auto place = ctx.GetPlace();
out->mutable_data<T>(place); out->mutable_data<T>(place);
std::vector<std::vector<int>> xdims_list; std::vector<std::vector<int>> xdims_list;
std::vector<const T*> ptrs; std::vector<const XPUType*> ptrs;
for (unsigned int i = 0; i < ins.size(); ++i) { for (unsigned int i = 0; i < ins.size(); ++i) {
if (ins[i] && ins[i]->numel() > 0) { if (ins[i] && ins[i]->numel() > 0) {
ptrs.push_back(ins[i]->data<T>()); ptrs.push_back(reinterpret_cast<const XPUType*>(ins[i]->data<T>()));
int size = ins[i]->dims().size(); int size = ins[i]->dims().size();
std::vector<int> tmp_dims(size); std::vector<int> tmp_dims(size);
for (int j = 0; j < size; ++j) { for (int j = 0; j < size; ++j) {
...@@ -96,8 +98,9 @@ class ConcatXPUKernel : public framework::OpKernel<T> { ...@@ -96,8 +98,9 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
"No tensor need concat")); "No tensor need concat"));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(), int r = xpu::concat<XPUType>(dev_ctx.x_context(), ptrs,
xdims_list, axis); reinterpret_cast<XPUType*>(out->data<T>()),
xdims_list, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU concat kernel return wrong value[%d %s]", r, "XPU concat kernel return wrong value[%d %s]", r,
...@@ -107,6 +110,8 @@ class ConcatXPUKernel : public framework::OpKernel<T> { ...@@ -107,6 +110,8 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ConcatGradXPUKernel : public framework::OpKernel<T> { class ConcatGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& ctx) const { void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad = auto* out_grad =
...@@ -134,12 +139,12 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -134,12 +139,12 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis = ComputeAxis(static_cast<int64_t>(axis), axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size())); static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName // get output tensor that the name is not kEmptyVarName
std::vector<T*> ptrs(outs.size()); std::vector<XPUType*> ptrs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) { for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName && if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) { outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace()); outs[j]->mutable_data<T>(ctx.GetPlace());
ptrs[j] = outs[j]->data<T>(); ptrs[j] = reinterpret_cast<XPUType*>(outs[j]->data<T>());
} else { } else {
ptrs[j] = nullptr; ptrs[j] = nullptr;
} }
...@@ -173,8 +178,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -173,8 +178,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
xdims_list[axis] = total_length; xdims_list[axis] = total_length;
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs, int r = xpu::split<XPUType>(
xdims_list, split_list, axis); dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()), ptrs, xdims_list,
split_list, axis);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
...@@ -189,9 +196,13 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> { ...@@ -189,9 +196,13 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>); concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL( REGISTER_OP_XPU_KERNEL(
concat_grad, concat_grad,
ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>); ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif #endif
...@@ -56,8 +56,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -56,8 +56,10 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})}, pOpKernelType(vartype::INT32, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()), {"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
...@@ -288,8 +290,10 @@ XPUOpMap& get_kl2_ops() { ...@@ -288,8 +290,10 @@ XPUOpMap& get_kl2_ops() {
{"reduce_sum_grad", {"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP16, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"reshape2_grad", {"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册