未验证 提交 ed2ad5d9 编写于 作者: H houj04 提交者: GitHub

[XPU] add c_embedding_op_xpu. (#45617)

上级 4ed6f3bc
...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so") ...@@ -10,7 +10,7 @@ set(XPU_RT_LIB_NAME "libxpurt.so")
if(NOT DEFINED XPU_BASE_URL) if(NOT DEFINED XPU_BASE_URL)
set(XPU_BASE_URL_WITHOUT_DATE set(XPU_BASE_URL_WITHOUT_DATE
"https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev") "https://baidu-kunlun-product.cdn.bcebos.com/KL-SDK/klsdk-dev")
set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220820") set(XPU_BASE_URL "${XPU_BASE_URL_WITHOUT_DATE}/20220831")
else() else()
set(XPU_BASE_URL "${XPU_BASE_URL}") set(XPU_BASE_URL "${XPU_BASE_URL}")
endif() endif()
...@@ -19,7 +19,7 @@ endif() ...@@ -19,7 +19,7 @@ endif()
if(NOT DEFINED XPU_XDNN_BASE_URL) if(NOT DEFINED XPU_XDNN_BASE_URL)
set(XPU_XDNN_BASE_URL_WITHOUT_DATE set(XPU_XDNN_BASE_URL_WITHOUT_DATE
"https://klx-sdk-release-public.su.bcebos.com/xdnn/dev") "https://klx-sdk-release-public.su.bcebos.com/xdnn/dev")
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220820") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL_WITHOUT_DATE}/20220831")
else() else()
set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}") set(XPU_XDNN_BASE_URL "${XPU_XDNN_BASE_URL}")
endif() endif()
......
/* Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/collective/c_embedding_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
template <typename DeviceContext, typename T>
class CEmbeddingOpXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& ctx) const override {
auto* table_t = ctx.Input<LoDTensor>("W");
auto* ids_t = ctx.Input<LoDTensor>("Ids");
auto* output_t = ctx.Output<LoDTensor>("Out");
const int64_t start_index = ctx.Attr<int64_t>("start_index");
const T* table_data = table_t->data<T>();
T* output_data = output_t->mutable_data<T>(ctx.GetPlace());
const int64_t height = table_t->dims()[0];
const int64_t width = table_t->dims()[1];
// int embedding(Context* ctx, const T* x, const TID* indices, T* y, int xm,
// int n, int ym, int padding_idx, TID start_index = 0);
// xm: table height: number of entries of table.
// n: embedding dim: number of float value within single entry.
// ym: number of elements of input ids.
auto& dev_ctx = ctx.template device_context<DeviceContext>();
const auto& index_type = framework::TransToProtoVarType(ids_t->dtype());
if (index_type == framework::proto::VarType::INT32) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
ids_t->data<int32_t>(),
output_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int32_t>(start_index));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else if (index_type == framework::proto::VarType::INT64) {
int r = xpu::embedding(dev_ctx.x_context(),
table_data,
ids_t->data<int64_t>(),
output_data,
height,
width,
ids_t->numel(),
-1,
static_cast<int64_t>(start_index));
PADDLE_ENFORCE_XDNN_SUCCESS(r, "embedding");
} else {
PADDLE_THROW(platform::errors::Unavailable(
"XPU c_embedding ids only support int32 or int64."));
}
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_XPU_KERNEL(
c_embedding,
ops::CEmbeddingOpXPUKernel<paddle::platform::XPUDeviceContext, float>);
...@@ -84,6 +84,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -84,6 +84,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace()),
pOpKernelType(vartype::INT32, XPUPlace())})}, pOpKernelType(vartype::INT32, XPUPlace())})},
{"c_embedding", XPUKernelSet({pOpKernelType(vartype::FP32, XPUPlace())})},
{"c_identity", {"c_identity",
XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()), XPUKernelSet({pOpKernelType(vartype::FP16, XPUPlace()),
pOpKernelType(vartype::FP32, XPUPlace()), pOpKernelType(vartype::FP32, XPUPlace()),
......
...@@ -42,6 +42,8 @@ class TestCEmbeddingCPU(OpTest): ...@@ -42,6 +42,8 @@ class TestCEmbeddingCPU(OpTest):
self.initcase() self.initcase()
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self.__class__.use_npu = True self.__class__.use_npu = True
elif core.is_compiled_with_xpu():
self.__class__.use_xpu = True
elif core.is_compiled_with_cuda(): elif core.is_compiled_with_cuda():
self.__class__.exist_fp64_check_grad = True self.__class__.exist_fp64_check_grad = True
...@@ -59,6 +61,8 @@ class TestCEmbeddingCPU(OpTest): ...@@ -59,6 +61,8 @@ class TestCEmbeddingCPU(OpTest):
self.attrs = {'start_index': self.start_index} self.attrs = {'start_index': self.start_index}
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self.__class__.use_npu = True self.__class__.use_npu = True
elif core.is_compiled_with_xpu():
self.__class__.use_xpu = True
def test_check_cpu(self): def test_check_cpu(self):
self.check_output_with_place(core.CPUPlace()) self.check_output_with_place(core.CPUPlace())
...@@ -82,12 +86,16 @@ class TestCEmbeddingOpBase(TestCEmbeddingCPU): ...@@ -82,12 +86,16 @@ class TestCEmbeddingOpBase(TestCEmbeddingCPU):
self.check_output_with_place(core.CUDAPlace(0)) self.check_output_with_place(core.CUDAPlace(0))
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
self.check_output_with_place(core.NPUPlace(0)) self.check_output_with_place(core.NPUPlace(0))
elif core.is_compiled_with_xpu():
self.check_output_with_place(core.XPUPlace(0))
def test_check_grad(self): def test_check_grad(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out') self.check_grad_with_place(core.CUDAPlace(0), ['W'], 'Out')
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
self.check_grad_with_place(core.NPUPlace(0), ['W'], 'Out') self.check_grad_with_place(core.NPUPlace(0), ['W'], 'Out')
elif core.is_compiled_with_xpu():
self.check_grad_with_place(core.XPUPlace(0), ['W'], 'Out')
def init_dtype(self): def init_dtype(self):
if core.is_compiled_with_cuda(): if core.is_compiled_with_cuda():
...@@ -96,6 +104,9 @@ class TestCEmbeddingOpBase(TestCEmbeddingCPU): ...@@ -96,6 +104,9 @@ class TestCEmbeddingOpBase(TestCEmbeddingCPU):
elif core.is_compiled_with_npu(): elif core.is_compiled_with_npu():
self.dtype = "float32" self.dtype = "float32"
self.ids_dtype = "int32" self.ids_dtype = "int32"
elif core.is_compiled_with_xpu():
self.dtype = "float32"
self.ids_dtype = "int64"
class TestCEmbeddingOpFP32(TestCEmbeddingOpBase): class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):
...@@ -123,6 +134,8 @@ class TestCEmbeddingOpFP32(TestCEmbeddingOpBase): ...@@ -123,6 +134,8 @@ class TestCEmbeddingOpFP32(TestCEmbeddingOpBase):
if core.is_compiled_with_npu(): if core.is_compiled_with_npu():
self.__class__.use_npu = True self.__class__.use_npu = True
elif core.is_compiled_with_xpu():
self.__class__.use_xpu = True
elif core.is_compiled_with_cuda(): elif core.is_compiled_with_cuda():
self.__class__.exist_fp64_check_grad = True self.__class__.exist_fp64_check_grad = True
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
import paddle
from paddle.fluid.tests.unittests.c_embedding_op_base import TestCEmbeddingCPU, TestCEmbeddingOpBase, TestCEmbeddingOpFP32
paddle.enable_static()
TestCEmbeddingCPU()
TestCEmbeddingOpBase()
TestCEmbeddingOpFP32()
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册