未验证 提交 1ef0de81 编写于 作者: Z zhangyikun02 提交者: GitHub

[XPU] silu op support to use fast_swish (#53980)

上级 3aa5d64e
...@@ -378,10 +378,16 @@ struct XPUSiluGradFunctor : public funcs::BaseActivationFunctor<T> { ...@@ -378,10 +378,16 @@ struct XPUSiluGradFunctor : public funcs::BaseActivationFunctor<T> {
const XPUType* y_grad = reinterpret_cast<const XPUType*>(dout->data<T>()); const XPUType* y_grad = reinterpret_cast<const XPUType*>(dout->data<T>());
XPUType* x_grad = reinterpret_cast<XPUType*>(dx->data<T>()); XPUType* x_grad = reinterpret_cast<XPUType*>(dx->data<T>());
if (std::getenv("XPU_PADDLE_ACT_LUT") != nullptr) {
int r = xpu::fast_swish_grad(
dev_ctx.x_context(), x_data, y_grad, x_grad, dx->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_swish_grad");
} else {
int r = xpu::swish_grad( int r = xpu::swish_grad(
dev_ctx.x_context(), x_data, y_grad, x_grad, dx->numel()); dev_ctx.x_context(), x_data, y_grad, x_grad, dx->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad");
} }
}
}; };
template <typename T> template <typename T>
......
...@@ -333,10 +333,16 @@ struct XPUSiluFunctor : public funcs::BaseActivationFunctor<T> { ...@@ -333,10 +333,16 @@ struct XPUSiluFunctor : public funcs::BaseActivationFunctor<T> {
XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>()); XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>());
auto xpu_context = dev_ctx.x_context(); auto xpu_context = dev_ctx.x_context();
if (std::getenv("XPU_PADDLE_ACT_LUT") != nullptr) {
int r = xpu::fast_swish(
xpu_context, x_data, y_data, x.numel(), nullptr, nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_swish");
} else {
int r = int r =
xpu::swish(xpu_context, x_data, y_data, x.numel(), nullptr, nullptr); xpu::swish(xpu_context, x_data, y_data, x.numel(), nullptr, nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish");
} }
}
}; };
template <typename T> template <typename T>
......
...@@ -17,6 +17,8 @@ import unittest ...@@ -17,6 +17,8 @@ import unittest
sys.path.append('../../python/paddle/fluid/tests/unittests') sys.path.append('../../python/paddle/fluid/tests/unittests')
import os
import numpy as np import numpy as np
from eager_op_test import OpTest from eager_op_test import OpTest
from get_test_cover_info import ( from get_test_cover_info import (
...@@ -105,16 +107,38 @@ class XPUTestSiluOP(XPUOpTestWrapper): ...@@ -105,16 +107,38 @@ class XPUTestSiluOP(XPUOpTestWrapper):
self.outputs = {'Out': out} self.outputs = {'Out': out}
self.attrs = {'use_xpu': True} self.attrs = {'use_xpu': True}
def test_check_output(self):
self.set_env()
self.check_output_with_place(self.place)
self.delete_env()
def test_check_grad(self): def test_check_grad(self):
self.set_env()
self.check_grad_with_place(self.place, ['X'], 'Out') self.check_grad_with_place(self.place, ['X'], 'Out')
self.delete_env()
def init_shape(self): def init_shape(self):
self.shape = [11, 17] self.shape = [11, 17]
def set_env(self):
pass
def delete_env(self):
pass
class TestSilu_ZeroDim(XPUTestSilu): class TestSilu_ZeroDim(XPUTestSilu):
def init_shape(self): def init_shape(self):
self.shape = [] self.shape = []
class TestSilu_LUT(XPUTestSilu):
def set_env(self):
# set "XPU_PADDLE_ACT_LUT" env to enable lut
os.environ['XPU_PADDLE_ACT_LUT'] = "1"
def delete_env(self):
if os.getenv('XPU_PADDLE_ACT_LUT'):
del os.environ['XPU_PADDLE_ACT_LUT']
class TestSiluAPI(unittest.TestCase): class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu # test paddle.nn.Silu, paddle.nn.functional.silu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册