未验证 提交 ff818c77 编写于 作者: T TTerror 提交者: GitHub

add fp16 for masked_select on kunlun, *test=kunlun (#41215)

上级 482e5b6c
......@@ -19,13 +19,15 @@ namespace operators {
template <typename T>
class MaskedSelectXPUKernel : public framework::OpKernel<T> {
using XPUType = typename XPUTypeTrait<T>::Type;
public:
void Compute(const framework::ExecutionContext& context) const override {
auto input = context.Input<framework::Tensor>("X");
auto mask = context.Input<framework::Tensor>("Mask");
auto out = context.Output<framework::Tensor>("Y");
auto* mask_data = mask->data<bool>();
auto* input_data = input->data<T>();
auto* input_data = reinterpret_cast<const XPUType*>(input->data<T>());
auto input_dim = input->dims();
auto mask_dim = mask->dims();
PADDLE_ENFORCE_EQ(
......@@ -51,7 +53,8 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
framework::DDim out_dim{out_size_cpu};
out->Resize(out_dim);
auto out_data = out->mutable_data<T>(context.GetPlace());
auto out_data =
reinterpret_cast<XPUType*>(out->mutable_data<T>(context.GetPlace()));
auto input_shape = phi::vectorize<int>(input_dim);
auto mask_shape = phi::vectorize<int>(mask_dim);
......@@ -69,6 +72,7 @@ class MaskedSelectXPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(masked_select, ops::MaskedSelectXPUKernel<float>,
ops::MaskedSelectXPUKernel<paddle::platform::float16>,
ops::MaskedSelectXPUKernel<int>,
ops::MaskedSelectXPUKernel<int64_t>);
#endif
......@@ -243,6 +243,7 @@ XPUOpMap& get_kl2_ops() {
{"masked_select",
XPUKernelSet({pOpKernelType(vartype::INT32, XPUPlace()),
pOpKernelType(vartype::INT64, XPUPlace()),
pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_grad", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"matmul_v2_grad",
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册