未验证 提交 b982af4a 编写于 作者: H houj04 提交者: GitHub

[XPU] add fused_softmax_mask and fused_softmax_mask_grad. (#55914)

上级 4315bc4c
......@@ -724,8 +724,10 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64})},
{"softmax",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fused_softmax_mask", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"fused_softmax_mask_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"softmax_with_cross_entropy_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"softmax_with_cross_entropy",
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/kernels/softmax_grad_kernel.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedSoftmaxMaskGradKernel(const Context& dev_ctx,
const DenseTensor& out,
const DenseTensor& out_grad,
DenseTensor* x_grad) {
dev_ctx.template Alloc<T>(x_grad);
SoftmaxGradKernel<T, Context>(
dev_ctx, out, out_grad, 3, x_grad); // axis for softmax
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_softmax_mask_grad,
XPU,
ALL_LAYOUT,
phi::fusion::FusedSoftmaxMaskGradKernel,
float) {}
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/kernel_registry.h"
namespace phi {
namespace fusion {
template <typename T, typename Context>
void FusedSoftmaxMaskKernel(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& mask,
DenseTensor* out) {
auto* x_data = x.data<T>();
auto* mask_data = mask.data<T>();
auto* y_data = dev_ctx.template Alloc<T>(out);
auto x_dim = x.dims();
auto mask_dim = mask.dims();
auto query_seq_len = x_dim[2];
auto key_seq_len = x_dim[3];
PADDLE_ENFORCE_GT(query_seq_len,
1,
phi::errors::InvalidArgument(
"Input x's second last dim must be large than 1 but "
"received the second last dimension of x is %d",
query_seq_len));
PADDLE_ENFORCE_EQ(key_seq_len >= 32 && key_seq_len < 8192,
true,
phi::errors::InvalidArgument(
"Input x's last dim must be between [32, 8192) "
"received the last dimension of x is %d",
key_seq_len));
PADDLE_ENFORCE_EQ(mask_dim[1],
1,
phi::errors::InvalidArgument(
"Input mask's second dim must be 1 "
"received the second dimension of mask is %d",
mask_dim[1]));
// dim of x and mask must be equal
for (size_t idx = 0; idx < 4; ++idx) {
if (idx == 1) continue;
PADDLE_ENFORCE_EQ(
x_dim[idx],
mask_dim[idx],
phi::errors::InvalidArgument(
"Input x's %dth dim should be equal with input mask's %dth dim "
"but "
"received the %dth dimension of x and mask are not equal "
"the %dth dim of x is %d, while the %dth dim of mask is %d.",
idx,
idx,
idx,
idx,
x_dim[idx],
idx,
mask_dim[idx]));
}
std::vector<int64_t> x_shape = phi::vectorize<int64_t>(x.dims());
std::vector<int64_t> mask_shape = phi::vectorize<int64_t>(mask.dims());
// int softmax_with_mask(Context* ctx, const T* x, const T* mask, T* y, const
// std::vector<int64_t>& x_shape, const std::vector<int64_t>& mask_shape);
int r = xpu::softmax_with_mask(
dev_ctx.x_context(), x_data, mask_data, y_data, x_shape, mask_shape);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "softmax_with_mask");
}
} // namespace fusion
} // namespace phi
PD_REGISTER_KERNEL(fused_softmax_mask,
XPU,
ALL_LAYOUT,
phi::fusion::FusedSoftmaxMaskKernel,
float) {}
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test_xpu import XPUOpTest
import paddle
paddle.enable_static()
def _get_softmax(x, mask):
masked_x = (x + mask).astype("float32")
max_value = np.max(masked_x, axis=-1, keepdims=True)
before_exp = masked_x - max_value
exp = np.exp(before_exp)
exp_sum = np.sum(exp, axis=-1, keepdims=True)
rst = exp / exp_sum
return rst
class XPUTestFusedSoftmaxMaskOp(XPUOpTestWrapper):
"""Test sigmoid_cross_entropy_with_logit_op with binary label"""
def __init__(self):
self.op_name = "fused_softmax_mask"
self.use_dynamic_create_class = False
class TestFusedSoftmaxMaskOp(XPUOpTest):
def setUp(self):
self.set_xpu()
self.op_type = "fused_softmax_mask"
self.init_dtype()
self.set_output()
def set_output(self):
x = np.random.random((1, 4, 4096, 4096)).astype("float32")
mask_input = np.random.random((1, 1, 4096, 4096)).astype("float32")
self.inputs = {'X': x, 'Mask': mask_input}
rst = _get_softmax(x, mask_input)
self.outputs = {'Out': rst}
def test_check_output(self):
self.check_output_with_place(self.place)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X'], 'Out')
def set_xpu(self):
self.__class__.use_xpu = True
self.place = paddle.XPUPlace(0)
def init_dtype(self):
self.dtype = self.in_type
support_types = get_xpu_op_support_types('fused_softmax_mask')
for stype in support_types:
create_test_class(globals(), XPUTestFusedSoftmaxMaskOp, stype)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册