未验证 提交 1ef0de81 编写于 作者: Z zhangyikun02 提交者: GitHub

[XPU] silu op support to use fast_swish (#53980)

上级 3aa5d64e
......@@ -378,10 +378,16 @@ struct XPUSiluGradFunctor : public funcs::BaseActivationFunctor<T> {
const XPUType* y_grad = reinterpret_cast<const XPUType*>(dout->data<T>());
XPUType* x_grad = reinterpret_cast<XPUType*>(dx->data<T>());
if (std::getenv("XPU_PADDLE_ACT_LUT") != nullptr) {
int r = xpu::fast_swish_grad(
dev_ctx.x_context(), x_data, y_grad, x_grad, dx->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_swish_grad");
} else {
int r = xpu::swish_grad(
dev_ctx.x_context(), x_data, y_grad, x_grad, dx->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish_grad");
}
}
};
template <typename T>
......
......@@ -333,10 +333,16 @@ struct XPUSiluFunctor : public funcs::BaseActivationFunctor<T> {
XPUType* y_data = reinterpret_cast<XPUType*>(out->data<T>());
auto xpu_context = dev_ctx.x_context();
if (std::getenv("XPU_PADDLE_ACT_LUT") != nullptr) {
int r = xpu::fast_swish(
xpu_context, x_data, y_data, x.numel(), nullptr, nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "fast_swish");
} else {
int r =
xpu::swish(xpu_context, x_data, y_data, x.numel(), nullptr, nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "swish");
}
}
};
template <typename T>
......
......@@ -17,6 +17,8 @@ import unittest
sys.path.append('../../python/paddle/fluid/tests/unittests')
import os
import numpy as np
from eager_op_test import OpTest
from get_test_cover_info import (
......@@ -105,16 +107,38 @@ class XPUTestSiluOP(XPUOpTestWrapper):
self.outputs = {'Out': out}
self.attrs = {'use_xpu': True}
def test_check_output(self):
self.set_env()
self.check_output_with_place(self.place)
self.delete_env()
def test_check_grad(self):
self.set_env()
self.check_grad_with_place(self.place, ['X'], 'Out')
self.delete_env()
def init_shape(self):
self.shape = [11, 17]
def set_env(self):
pass
def delete_env(self):
pass
class TestSilu_ZeroDim(XPUTestSilu):
def init_shape(self):
self.shape = []
class TestSilu_LUT(XPUTestSilu):
def set_env(self):
# set "XPU_PADDLE_ACT_LUT" env to enable lut
os.environ['XPU_PADDLE_ACT_LUT'] = "1"
def delete_env(self):
if os.getenv('XPU_PADDLE_ACT_LUT'):
del os.environ['XPU_PADDLE_ACT_LUT']
class TestSiluAPI(unittest.TestCase):
# test paddle.nn.Silu, paddle.nn.functional.silu
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册