未验证 提交 ff818c77 编写于 作者: T TTerror 提交者: GitHub

add fp16 for masked_select on kunlun, *test=kunlun (#41215)

上级 482e5b6c
...@@ -19,13 +19,15 @@ namespace operators { ...@@ -19,13 +19,15 @@ namespace operators {
template <typename T> template <typename T>
class MaskedSelectXPUKernel : public framework::OpKernel<T> { class MaskedSelectXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<framework::Tensor>("X"); auto input = context.Input<framework::Tensor>("X");
auto mask = context.Input<framework::Tensor>("Mask"); auto mask = context.Input<framework::Tensor>("Mask");
auto out = context.Output<framework::Tensor>("Y"); auto out = context.Output<framework::Tensor>("Y");
auto* mask_data = mask->data<bool>(); auto* mask_data = mask->data<bool>();
auto* input_data = input->data<T>(); auto* input_data = reinterpret_cast<const XPUType*>(input->data<T>());
auto input_dim = input->dims(); auto input_dim = input->dims();
auto mask_dim = mask->dims(); auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -51,7 +53,8 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> { ...@@ -51,7 +53,8 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
framework::DDim out_dim{out_size_cpu}; framework::DDim out_dim{out_size_cpu};
out->Resize(out_dim); out->Resize(out_dim);
auto out_data = out->mutable_data<T>(context.GetPlace()); auto out_data =
reinterpret_cast<XPUType*>(out->mutable_data<T>(context.GetPlace()));
auto input_shape = phi::vectorize<int>(input_dim); auto input_shape = phi::vectorize<int>(input_dim);
auto mask_shape = phi::vectorize<int>(mask_dim); auto mask_shape = phi::vectorize<int>(mask_dim);
...@@ -69,6 +72,7 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> { ...@@ -69,6 +72,7 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(masked_select, ops::MaskedSelectXPUKernel<float>, REGISTER_OP_XPU_KERNEL(masked_select, ops::MaskedSelectXPUKernel<float>,
ops::MaskedSelectXPUKernel<paddle::platform::float16>,
ops::MaskedSelectXPUKernel<int>, ops::MaskedSelectXPUKernel<int>,
ops::MaskedSelectXPUKernel<int64_t>); ops::MaskedSelectXPUKernel<int64_t>);
#endif #endif
...@@ -243,6 +243,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -243,6 +243,7 @@ XPUOpMap& get_kl2_ops() {
{"masked_select", {"masked_select",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()), pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})}, pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})}, {"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2_grad", {"matmul_v2_grad",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册