未验证 提交 8e1c3ddf 编写于 作者: Q QingshuChen 提交者: GitHub

add aarch64 and sunway kunlun lib (#30027)

* add aarch64 and sunway kunlun lib

* minor

* optimize elementwise_add for kunlun

* update kunlun dependence

* minor

* minor
上级 05b27695
...@@ -4,7 +4,15 @@ endif() ...@@ -4,7 +4,15 @@ endif()
INCLUDE(ExternalProject) INCLUDE(ExternalProject)
SET(XPU_PROJECT "extern_xpu") SET(XPU_PROJECT "extern_xpu")
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2020_12_15.tar.gz" CACHE STRING "" FORCE)
if (WITH_AARCH64)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/aarch64/xpu_2020_1229.tar.gz" CACHE STRING "" FORCE)
elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz" CACHE STRING "" FORCE)
else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_0105.tar.gz" CACHE STRING "" FORCE)
endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}")
SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu")
......
...@@ -26,17 +26,140 @@ namespace operators { ...@@ -26,17 +26,140 @@ namespace operators {
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddXPUKernel : public framework::OpKernel<T> { class ElementwiseAddXPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
XPUElementwise<T>(ctx, xpu::add<T>); // XPUElementwise<T>(ctx, xpu::add<T>);
// ToDo(QingshuChen): update this optimization to elementwise_xpu.h
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
"Cannot get input Variable X"));
PADDLE_ENFORCE_EQ(
x_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
"XPU only support LoDTensor, Input(X) is not LoDTensor"));
auto x = x_var->Get<framework::LoDTensor>();
auto* y = ctx.Input<framework::LoDTensor>("Y");
auto* z = ctx.Output<framework::LoDTensor>("Out");
z->mutable_data<T>(ctx.GetPlace());
auto x_dims = x.dims();
auto y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
int axis = ctx.Attr<int>("axis");
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
const T* x_data = x.data<T>();
const T* y_data = y->data<T>();
T* z_data = z->data<T>();
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::SUCCESS;
ret = xpu::broadcast_add<T>(dev_ctx.x_context(), x_data, y_data, z_data,
x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
ElemwiseGradKernel<T>::Compute(ctx); ElemwiseGradKernel<T>::Compute(ctx);
XPUElementwiseGrad<T>(ctx, xpu::add_grad<T>, false); // XPUElementwiseGrad<T>(ctx, xpu::add_grad<T>, false);
auto* x = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y");
auto* dz = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
int axis = ctx.Attr<int>("axis");
const framework::DDim& x_dims = x->dims();
const framework::DDim& y_dims = y->dims();
int max_dim = std::max(x_dims.size(), y_dims.size());
axis = (axis == -1 ? std::abs(x_dims.size() - y_dims.size()) : axis);
PADDLE_ENFORCE_GE(
axis, 0,
platform::errors::InvalidArgument(
"Axis should be great than or equal to 0, but received axis is %d.",
axis));
PADDLE_ENFORCE_LT(
axis, max_dim,
platform::errors::InvalidArgument(
"Axis should be less than %d, but received axis is %d.", max_dim,
axis));
std::vector<int> x_dims_vec(max_dim, 1);
std::vector<int> y_dims_vec(max_dim, 1);
if (x_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
x_dims_vec[i] = x_dims[i];
}
} else {
for (int i = 0; i < x_dims.size(); i++) {
x_dims_vec[i + axis] = x_dims[i];
}
}
if (y_dims.size() == max_dim) {
for (int i = 0; i < max_dim; i++) {
y_dims_vec[i] = y_dims[i];
}
} else {
for (int i = 0; i < y_dims.size(); i++) {
y_dims_vec[i + axis] = y_dims[i];
}
}
T* dx_data = nullptr;
T* dy_data = nullptr;
if (dx) {
dx_data = dx->mutable_data<T>(ctx.GetPlace());
}
if (dy) {
dy_data = dy->mutable_data<T>(ctx.GetPlace());
}
auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>();
int ret = xpu::broadcast_add_grad<T>(dev_ctx.x_context(), dx_data, dx_data,
dx_data, dz->data<T>(), dy_data,
dx_data, x_dims_vec, y_dims_vec);
PADDLE_ENFORCE_EQ(
ret, xpu::SUCCESS,
platform::errors::External(
"XPU kernel Elementwise occur error in XPUElementwise error code ",
ret, XPUAPIErrorMsg[ret]));
} }
}; };
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册