未验证 提交 d330cf66 编写于 作者: J Jack Zhou 提交者: GitHub

Fix xpu enforce (#27978)

* test=kunlun;

Add elementwise XPU OP kernel for KUNLUN core, including (but still cannot process common broadcast):

    * elementwise_div op
    * elementwise_max op
    * elementwise_mul op (with grad op)
    * elementwise_sub op (with grad op)

* 0.05->0.01

* add xpu error message description;test=kunlun
上级 7cb4a8b8
......@@ -49,7 +49,8 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
int axis = ctx.Attr<int>("axis");
PADDLE_ENFORCE_GE(dx_dims.size(), dy_dims_untrimed.size(),
"Rank of first input must >= rank of second input.");
platform::errors::InvalidArgument(
"Rank of first input must >= rank of second input."));
if (dx != nullptr) {
dx->mutable_data<T>(ctx.GetPlace());
......@@ -69,8 +70,9 @@ class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel<T> {
n = dout->numel();
} else {
axis = (axis == -1 ? dx_dims.size() - dy_dims_untrimed.size() : axis);
PADDLE_ENFORCE(axis >= 0 && axis < dx_dims.size(),
"Axis should be in range [0, dx_dims)");
PADDLE_ENFORCE_EQ(axis >= 0 && axis < dx_dims.size(), true,
platform::errors::InvalidArgument(
"Axis should be in range [0, dx_dims)"));
auto dy_dims = trim_trailing_singular_dims(dy_dims_untrimed);
axis = (dy_dims.size() == 0) ? dx_dims.size() : axis;
get_mid_dims(dx_dims, dy_dims, axis, &pre, &n, &post,
......
......@@ -14,9 +14,25 @@ limitations under the License. */
#pragma once
#ifdef PADDLE_WITH_XPU
#include <string>
#include <unordered_map>
#include "paddle/fluid/framework/tensor.h"
#include "paddle/fluid/platform/place.h"
inline std::string get_xpu_error_message(int error_type) {
static std::unordered_map<int, std::string> xpu_error_map = {
{baidu::xpu::api::INVALID_PARAM, "Parameter is invalid."},
{baidu::xpu::api::RUNTIME_ERROR,
"Please check whether Baidu Kunlun Card "
"is properly installed."},
{baidu::xpu::api::NO_ENOUGH_WORKSPACE,
"There is not enough memory in Baidu"
" Kunlun Card."}};
if (xpu_error_map.find(error_type) == xpu_error_map.end()) {
return "Unknown error type!";
}
return xpu_error_map[error_type];
}
#define XPU_MALLOC(addr, num_bytes) \
PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast<void**>(addr), num_bytes), \
XPU_SUCCESS, \
......@@ -102,21 +118,27 @@ limitations under the License. */
int res = \
xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, \
n, post, xpu::ElementwiseOp::ASSIGN); \
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \
platform::errors::Fatal("XPU kernel error!")); \
PADDLE_ENFORCE_EQ( \
res, xpu::Error_t::SUCCESS, \
platform::errors::External("XPU kernel error occur! %s", \
get_xpu_error_message(res))); \
y_data = y_broadcast; \
} \
} \
int res = xpu::elementwise_##kernel_name##_grad( \
dev_ctx.x_context(), x_data, y_data, dout->data<T>() /*out*/, \
dout->data<T>(), dx_data, dy_data, len); \
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \
platform::errors::Fatal("XPU kernel error!")); \
PADDLE_ENFORCE_EQ( \
res, xpu::Error_t::SUCCESS, \
platform::errors::External("XPU kernel error occur! %s", \
get_xpu_error_message(res))); \
if ((dy != nullptr) && (len != n)) { \
int res = xpu::reduce_ew(dev_ctx.x_context(), dy_data, dy->data<T>(), \
pre, n, post, xpu::ElementwiseOp::ASSIGN); \
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS, \
platform::errors::Fatal("XPU kernel error!")); \
PADDLE_ENFORCE_EQ( \
res, xpu::Error_t::SUCCESS, \
platform::errors::External("XPU kernel error occur! %s", \
get_xpu_error_message(res))); \
dev_ctx.Wait(); \
xpu_free(dy_data); \
} \
......@@ -161,8 +183,8 @@ void XPUElementwise(const framework::ExecutionContext& ctx) {
platform::errors::PreconditionNotMet(
"This kernel only runs on XPU device."));
auto x_var = ctx.InputVar("X");
PADDLE_ENFORCE_NE(x_var, nullptr,
platform::errors::Fatal("Cannot get input Variable X"));
PADDLE_ENFORCE_NE(x_var, nullptr, platform::errors::InvalidArgument(
"Cannot get input Variable X"));
PADDLE_ENFORCE_EQ(
x_var->IsType<framework::LoDTensor>(), true,
platform::errors::InvalidArgument(
......@@ -206,36 +228,36 @@ void XPUElementwise(const framework::ExecutionContext& ctx) {
if (std::is_same<Functor, XPUAddFunctor<T>>::value) {
int res = xpu::matrix_vector_add(dev_ctx.x_context(), x_data, y_data,
z_data, pre, n);
PADDLE_ENFORCE_EQ(
res, xpu::Error_t::SUCCESS,
platform::errors::Fatal("XPU kernel error! res = %d", res));
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error occur! %s",
get_xpu_error_message(res)));
return;
}
if (std::is_same<Functor, XPUMulFunctor<T>>::value) {
int res = xpu::matrix_vector_mul(dev_ctx.x_context(), x_data, y_data,
z_data, pre, n);
PADDLE_ENFORCE_EQ(
res, xpu::Error_t::SUCCESS,
platform::errors::Fatal("XPU kernel error! res = %d", res));
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error occur! %s",
get_xpu_error_message(res)));
return;
}
}
if (pre != 1 || post != 1) {
PADDLE_ENFORCE(xpu_malloc(reinterpret_cast<void**>(&y_broadcast),
len * sizeof(T)) == XPU_SUCCESS);
XPU_MALLOC(&y_broadcast, len * sizeof(T));
int res = xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre,
n, post, xpu::ElementwiseOp::ASSIGN);
PADDLE_ENFORCE_EQ(
res, xpu::Error_t::SUCCESS,
platform::errors::Fatal("XPU kernel error! res = %d", res));
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
platform::errors::External("XPU kernel error occur! %s",
get_xpu_error_message(res)));
y_data = y_broadcast;
}
Functor functor;
int res = functor(dev_ctx.x_context(), x_data, y_data, z_data, len);
PADDLE_ENFORCE_EQ(res, xpu::Error_t::SUCCESS,
platform::errors::Fatal("XPU kernel error! res = %d", res));
platform::errors::External("XPU kernel error occur! %s",
get_xpu_error_message(res)));
if (pre != 1 || post != 1) {
dev_ctx.Wait();
......
......@@ -54,7 +54,7 @@ class TestXPUElementwiseOpBase(object):
inputs_to_check,
output_names,
no_grad_set=None,
max_relative_error=0.05):
max_relative_error=0.01):
if self.grad_implemented and not self.is_common_broadcast \
and not self.is_x_size_less_than_y:
if paddle.is_compiled_with_xpu():
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册