未验证 提交 d330cf66 编写于 作者: J Jack Zhou 提交者: GitHub

Fix xpu enforce (#27978)

* test=kunlun;

Add elementwise XPU OP kernel for KUNLUN core, including (but still cannot process common broadcast):

    * elementwise_div op
    * elementwise_max op
    * elementwise_mul op (with grad op)
    * elementwise_sub op (with grad op)

* 0.05->0.01

* add xpu error message description;test=kunlun
上级 7cb4a8b8
...@@ -49,7 +49,8 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -49,7 +49,8 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
int axis = ctx.Attr<int>("axis"); int axis = ctx.Attr<int>("axis");
PADDLE_ENFORCE_GE(dx_dims.size(), dy_dims_untrimed.size(), PADDLE_ENFORCE_GE(dx_dims.size(), dy_dims_untrimed.size(),
"Rank of first input must >= rank of second input."); platform::errors::InvalidArgument(
"Rank of first input must >= rank of second input."));
if (dx != nullptr) { if (dx != nullptr) {
dx->mutable_data<T>(ctx.GetPlace()); dx->mutable_data<T>(ctx.GetPlace());
...@@ -69,8 +70,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> { ...@@ -69,8 +70,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
n = dout->numel(); n = dout->numel();
} else { } else {
axis = (axis == -1 ? dx_dims.size() - dy_dims_untrimed.size() : axis); axis = (axis == -1 ? dx_dims.size() - dy_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < dx_dims.size(), PADDLE_ENFORCE_EQ(axis >= 0 && axis < dx_dims.size(), true,
"Axis should be in range [0, dx_dims)"); platform::errors::InvalidArgument(
"Axis should be in range [0, dx_dims)"));
auto dy_dims = trim_trailing_singular_dims(dy_dims_untrimed); auto dy_dims = trim_trailing_singular_dims(dy_dims_untrimed);
axis = (dy_dims.size() == 0) ? dx_dims.size() : axis; axis = (dy_dims.size() == 0) ? dx_dims.size() : axis;
get_mid_dims(dx_dims, dy_dims, axis, &pre, &n, &post, get_mid_dims(dx_dims, dy_dims, axis, &pre, &n, &post,
......
...@@ -14,9 +14,25 @@ limitations under the License. */ ...@@ -14,9 +14,25 @@ limitations under the License. */
#pragma once #pragma once
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
#include <string> #include <string>
#include <unordered_map>
#include "paddle/fluid/framework/tensor.h" #include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h" #include "paddle/fluid/platform/place.h"
inline std::string get_xpu_error_message(int error_type) {
static std::unordered_map<int, std::string> xpu_error_map = {
{baidu::xpu::api::INVALID_PARAM, "Parameter is invalid."},
{baidu::xpu::api::RUNTIME_ERROR,
"Please check whether Baidu Kunlun Card "
"is properly installed."},
{baidu::xpu::api::NO_ENOUGH_WORKSPACE,
"There is not enough memory in Baidu"
" Kunlun Card."}};
if (xpu_error_map.find(error_type) == xpu_error_map.end()) {
return "Unknown error type!";
}
return xpu_error_map[error_type];
}
#define XPU_MALLOC(addr, num_bytes) \ #define XPU_MALLOC(addr, num_bytes) \
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(addr), num_bytes), \ PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(addr), num_bytes), \
XPU_SUCCESS, \ XPU_SUCCESS, \
...@@ -102,21 +118,27 @@ limitations under the License. */ ...@@ -102,21 +118,27 @@ limitations under the License. */
int res = \ int res = \
xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, \ xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, \
n, post, xpu::ElementwiseOp::ASSIGN); \ n, post, xpu::ElementwiseOp::ASSIGN); \
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \ PADDLE_ENFORCE_EQ( \
platform::errors::Fatal("XPU kernel error!")); \ res, xpu::Error_t::SUCCESS, \
platform::errors::External("XPU kernel error occur! %s", \
get_xpu_error_message(res))); \
y_data = y_broadcast; \ y_data = y_broadcast; \
} \ } \
} \ } \
int res = xpu::elementwise_##kernel_name##_grad( \ int res = xpu::elementwise_##kernel_name##_grad( \
dev_ctx.x_context(), x_data, y_data, dout->data<T>() /*out*/, \ dev_ctx.x_context(), x_data, y_data, dout->data<T>() /*out*/, \
dout->data<T>(), dx_data, dy_data, len); \ dout->data<T>(), dx_data, dy_data, len); \
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \ PADDLE_ENFORCE_EQ( \
platform::errors::Fatal("XPU kernel error!")); \ res, xpu::Error_t::SUCCESS, \
platform::errors::External("XPU kernel error occur! %s", \
get_xpu_error_message(res))); \
if ((dy != nullptr) && (len != n)) { \ if ((dy != nullptr) && (len != n)) { \
int res = xpu::reduce_ew(dev_ctx.x_context(), dy_data, dy->data<T>(), \ int res = xpu::reduce_ew(dev_ctx.x_context(), dy_data, dy->data<T>(), \
pre, n, post, xpu::ElementwiseOp::ASSIGN); \ pre, n, post, xpu::ElementwiseOp::ASSIGN); \
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \ PADDLE_ENFORCE_EQ( \
platform::errors::Fatal("XPU kernel error!")); \ res, xpu::Error_t::SUCCESS, \
platform::errors::External("XPU kernel error occur! %s", \
get_xpu_error_message(res))); \
dev_ctx.Wait(); \ dev_ctx.Wait(); \
xpu_free(dy_data); \ xpu_free(dy_data); \
} \ } \
...@@ -161,8 +183,8 @@ void XPUElementwise(const framework::ExecutionContext& ctx) { ...@@ -161,8 +183,8 @@ void XPUElementwise(const framework::ExecutionContext& ctx) {
platform::errors::PreconditionNotMet( platform::errors::PreconditionNotMet(
"This kernel only runs on XPU device.")); "This kernel only runs on XPU device."));
auto x_var = ctx.InputVar("X"); auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr, PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
platform::errors::Fatal("Cannot get input Variable X")); "Cannot get input Variable X"));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
x_var->IsType<framework::LoDTensor>(), true, x_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
...@@ -206,36 +228,36 @@ void XPUElementwise(const framework::ExecutionContext& ctx) { ...@@ -206,36 +228,36 @@ void XPUElementwise(const framework::ExecutionContext& ctx) {
if (std::is_same<Functor, XPUAddFunctor<T>>::value) { if (std::is_same<Functor, XPUAddFunctor<T>>::value) {
int res = xpu::matrix_vector_add(dev_ctx.x_context(), x_data, y_data, int res = xpu::matrix_vector_add(dev_ctx.x_context(), x_data, y_data,
z_data, pre, n); z_data, pre, n);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
res, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error occur! %s",
platform::errors::Fatal("XPU kernel error! res = %d", res)); get_xpu_error_message(res)));
return; return;
} }
if (std::is_same<Functor, XPUMulFunctor<T>>::value) { if (std::is_same<Functor, XPUMulFunctor<T>>::value) {
int res = xpu::matrix_vector_mul(dev_ctx.x_context(), x_data, y_data, int res = xpu::matrix_vector_mul(dev_ctx.x_context(), x_data, y_data,
z_data, pre, n); z_data, pre, n);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
res, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error occur! %s",
platform::errors::Fatal("XPU kernel error! res = %d", res)); get_xpu_error_message(res)));
return; return;
} }
} }
if (pre != 1 || post != 1) { if (pre != 1 || post != 1) {
PADDLE_ENFORCE(xpu_malloc(reinterpret_cast<void**>(&y_broadcast), XPU_MALLOC(&y_broadcast, len * sizeof(T));
len * sizeof(T)) == XPU_SUCCESS);
int res = xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, int res = xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre,
n, post, xpu::ElementwiseOp::ASSIGN); n, post, xpu::ElementwiseOp::ASSIGN);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
res, xpu::Error_t::SUCCESS, platform::errors::External("XPU kernel error occur! %s",
platform::errors::Fatal("XPU kernel error! res = %d", res)); get_xpu_error_message(res)));
y_data = y_broadcast; y_data = y_broadcast;
} }
Functor functor; Functor functor;
int res = functor(dev_ctx.x_context(), x_data, y_data, z_data, len); int res = functor(dev_ctx.x_context(), x_data, y_data, z_data, len);
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
platform::errors::Fatal("XPU kernel error! res = %d", res)); platform::errors::External("XPU kernel error occur! %s",
get_xpu_error_message(res)));
if (pre != 1 || post != 1) { if (pre != 1 || post != 1) {
dev_ctx.Wait(); dev_ctx.Wait();
......
...@@ -54,7 +54,7 @@ class TestXPUElementwiseOpBase(object): ...@@ -54,7 +54,7 @@ class TestXPUElementwiseOpBase(object):
inputs_to_check, inputs_to_check,
output_names, output_names,
no_grad_set=None, no_grad_set=None,
max_relative_error=0.05): max_relative_error=0.01):
if self.grad_implemented and not self.is_common_broadcast \ if self.grad_implemented and not self.is_common_broadcast \
and not self.is_x_size_less_than_y: and not self.is_x_size_less_than_y:
if paddle.is_compiled_with_xpu(): if paddle.is_compiled_with_xpu():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册