未验证 提交 c4d5a77f 编写于 作者: Z zhangyikun02 提交者: GitHub

concat and relu sopport FP16 in XPU, test=kunlun (#41631)

上级 468c1ad7
......@@ -490,7 +490,6 @@ REGISTER_ACTIVATION_XPU_KERNEL(leaky_relu, XPULeakyReluFunctor,
XPULeakyReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(reciprocal, XPUReciprocalFunctor,
XPUReciprocalGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor,
XPUSigmoidGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor)
......@@ -500,6 +499,13 @@ REGISTER_ACTIVATION_XPU_KERNEL(softplus, XPUSoftPlusFunctor,
REGISTER_ACTIVATION_XPU_KERNEL(swish, XPUSwishFunctor, XPUSwishGradFunctor)
REGISTER_ACTIVATION_XPU_KERNEL(pow, XPUPowFunctor, XPUPowGradFunctor)
REGISTER_OP_XPU_KERNEL(
relu, ops::XPUActivationKernel<ops::XPUReluFunctor<float>>,
ops::XPUActivationKernel<ops::XPUReluFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
relu_grad, ops::XPUActivationGradKernel<ops::XPUReluGradFunctor<float>>,
ops::XPUActivationGradKernel<
ops::XPUReluGradFunctor<paddle::platform::float16>>);
REGISTER_OP_XPU_KERNEL(
tanh, ops::XPUActivationKernel<ops::XPUTanhFunctor<float>>,
ops::XPUActivationKernel<ops::XPUTanhFunctor<paddle::platform::float16>>);
......
......@@ -26,6 +26,8 @@ using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class ConcatXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::LoDTensor>("X");
......@@ -79,10 +81,10 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
auto place = ctx.GetPlace();
out->mutable_data<T>(place);
std::vector<std::vector<int>> xdims_list;
std::vector<const T*> ptrs;
std::vector<const XPUType*> ptrs;
for (unsigned int i = 0; i < ins.size(); ++i) {
if (ins[i] && ins[i]->numel() > 0) {
ptrs.push_back(ins[i]->data<T>());
ptrs.push_back(reinterpret_cast<const XPUType*>(ins[i]->data<T>()));
int size = ins[i]->dims().size();
std::vector<int> tmp_dims(size);
for (int j = 0; j < size; ++j) {
......@@ -96,8 +98,9 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
"No tensor need concat"));
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::concat<T>(dev_ctx.x_context(), ptrs, out->data<T>(),
xdims_list, axis);
int r = xpu::concat<XPUType>(dev_ctx.x_context(), ptrs,
reinterpret_cast<XPUType*>(out->data<T>()),
xdims_list, axis);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU concat kernel return wrong value[%d %s]", r,
......@@ -107,6 +110,8 @@ class ConcatXPUKernel : public framework::OpKernel<T> {
template <typename DeviceContext, typename T>
class ConcatGradXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& ctx) const {
auto* out_grad =
......@@ -134,12 +139,12 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
axis = ComputeAxis(static_cast<int64_t>(axis),
static_cast<int64_t>(ins[0]->dims().size()));
// get output tensor that the name is not kEmptyVarName
std::vector<T*> ptrs(outs.size());
std::vector<XPUType*> ptrs(outs.size());
for (size_t j = 0; j < outs.size(); ++j) {
if (out_var_names[j] != framework::kEmptyVarName &&
outs[j]->numel() != 0UL) {
outs[j]->mutable_data<T>(ctx.GetPlace());
ptrs[j] = outs[j]->data<T>();
ptrs[j] = reinterpret_cast<XPUType*>(outs[j]->data<T>());
} else {
ptrs[j] = nullptr;
}
......@@ -173,8 +178,10 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
xdims_list[axis] = total_length;
auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::split<T>(dev_ctx.x_context(), out_grad->data<T>(), ptrs,
xdims_list, split_list, axis);
int r = xpu::split<XPUType>(
dev_ctx.x_context(),
reinterpret_cast<const XPUType*>(out_grad->data<T>()), ptrs, xdims_list,
split_list, axis);
PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External(
......@@ -189,9 +196,13 @@ class ConcatGradXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>);
concat, ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ConcatXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_XPU_KERNEL(
concat_grad,
ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>);
ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext, float>,
ops::ConcatGradXPUKernel<paddle::platform::XPUDeviceContext,
paddle::platform::float16>);
#endif
......@@ -56,8 +56,10 @@ XPUOpMap& get_kl2_ops() {
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})},
{"clip", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"concat_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"concat", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"conv2d", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
......@@ -288,8 +290,10 @@ XPUOpMap& get_kl2_ops() {
{"reduce_sum_grad",
XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"reduce_sum", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"relu_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"relu", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace())})},
{"reshape2_grad",
XPUKernelSet({pOpKernelType(vartype::FP64, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册