未验证 提交 6a3c8725 编写于 作者: T taixiurong 提交者: GitHub

support transformer v2.0 (#30381)

上级 e85be1b1
...@@ -10,7 +10,7 @@ if (WITH_AARCH64) ...@@ -10,7 +10,7 @@ if (WITH_AARCH64)
elseif(WITH_SUNWAY) elseif(WITH_SUNWAY)
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/sunway/xpu_2020_1227.tar.gz" CACHE STRING "" FORCE)
else() else()
SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_0105.tar.gz" CACHE STRING "" FORCE) SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu_2021_01_13.tar.gz" CACHE STRING "" FORCE)
endif() endif()
SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu")
......
...@@ -45,15 +45,13 @@ class LayerNormXPUKernel : public framework::OpKernel<T> { ...@@ -45,15 +45,13 @@ class LayerNormXPUKernel : public framework::OpKernel<T> {
auto* mean_data = mean->mutable_data<T>(ctx.GetPlace()); auto* mean_data = mean->mutable_data<T>(ctx.GetPlace());
auto* variance_data = variance->mutable_data<T>(ctx.GetPlace()); auto* variance_data = variance->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm(dev_ctx.x_context(), left, right, x_data, y_data, int r = xpu::layer_norm(dev_ctx.x_context(), x_data, y_data, left, right,
scale_data, bias_data, epsilon, mean_data, epsilon, scale_data, bias_data, mean_data,
variance_data, false); variance_data);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
r, XPU_SUCCESS, platform::errors::External(
platform::errors::External("XPU API(layer_norm) return wrong " "XPU layer_norm kernel return wrong value[%d %s]", r,
"value[%d], please check whether Baidu " XPUAPIErrorMsg[r]));
"Kunlun Card is properly installed.",
r));
} }
}; };
...@@ -87,15 +85,14 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> { ...@@ -87,15 +85,14 @@ class LayerNormGradXPUKernel : public framework::OpKernel<T> {
auto* dx_data = auto* dx_data =
(dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace())); (dx == nullptr ? nullptr : dx->mutable_data<T>(ctx.GetPlace()));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::layer_norm_backward( int r = xpu::layer_norm_grad(dev_ctx.x_context(), x_data, dy_data, dx_data,
dev_ctx.x_context(), left, right, x_data, scale_data, variance_data, left, right, epsilon, scale_data, mean_data,
mean_data, dy_data, dx_data, dscale_data, dbias_data, epsilon); variance_data, dscale_data, dbias_data);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU API(layer_norm_backward) return wrong " platform::errors::External(
"value[%d], please check whether Baidu " "XPU layer_norm_grad kernel return wrong value[%d %s]", r,
"Kunlun Card is properly installed.", XPUAPIErrorMsg[r]));
r));
} }
}; };
......
...@@ -24,6 +24,8 @@ limitations under the License. */ ...@@ -24,6 +24,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using framework::Tensor;
static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) { static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) {
if (x_dim.size() > 1) { if (x_dim.size() > 1) {
return x_dim; return x_dim;
...@@ -97,26 +99,23 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, ...@@ -97,26 +99,23 @@ static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x,
ReshapeTensorIntoMatrixSequence(y, mat_dim_y); ReshapeTensorIntoMatrixSequence(y, mat_dim_y);
} }
template <typename DeviceContext, typename T> template <typename T, typename FCT>
class MatMulXPUKernel : public framework::OpKernel<T> { static void MatMulXPUFunction(const Tensor *x, const Tensor *y, Tensor *out,
public: bool trans_x, bool trans_y,
void Compute(const framework::ExecutionContext &context) const override { const paddle::framework::ExecutionContext &ctx) {
auto *x = context.Input<framework::Tensor>("X"); const auto &x_dims = x->dims();
auto *y = context.Input<framework::Tensor>("Y"); const auto &y_dims = y->dims();
auto *out = context.Output<framework::Tensor>("Out"); auto &dev_ctx =
out->mutable_data<T>(context.GetPlace()); ctx.template device_context<paddle::platform::XPUDeviceContext>();
auto mat_dim_a = math::CreateMatrixDescriptor( auto mat_dim_a =
RowMatrixFromVector(x->dims()), 0, context.Attr<bool>("transpose_X")); math::CreateMatrixDescriptor(RowMatrixFromVector(x_dims), 0, trans_x);
auto mat_dim_b = auto mat_dim_b =
math::CreateMatrixDescriptor(ColumnMatrixFromVector(y->dims()), 0, math::CreateMatrixDescriptor(ColumnMatrixFromVector(y_dims), 0, trans_y);
context.Attr<bool>("transpose_Y"));
const auto &x_dims = x->dims();
const auto &y_dims = y->dims();
if (x_dims.size() == 3 && y_dims.size() <= 2) { if (x_dims.size() == 3 && y_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time // if transpose_X is true, the transpose cost much time
if (!context.Attr<bool>("transpose_X")) { if (!trans_x) {
mat_dim_a.height_ *= mat_dim_a.batch_size_; mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0; mat_dim_a.batch_size_ = 0;
} else { } else {
...@@ -124,7 +123,6 @@ class MatMulXPUKernel : public framework::OpKernel<T> { ...@@ -124,7 +123,6 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_; mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
} }
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
mat_dim_a.width_, mat_dim_b.height_, mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument("Shape mistake in matmul_op, the " platform::errors::InvalidArgument("Shape mistake in matmul_op, the "
...@@ -139,9 +137,9 @@ class MatMulXPUKernel : public framework::OpKernel<T> { ...@@ -139,9 +137,9 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
"tensor batch_size:%d, second " "tensor batch_size:%d, second "
"tensor batch_size:%d", "tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_)); mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto &dev_ctx = context.template device_context<DeviceContext>(); T alpha = static_cast<T>(ctx.Attr<float>("alpha"));
float *data_c = out->data<T>(); float *data_c = out->data<T>();
int m = mat_dim_a.height_; int m = mat_dim_a.height_;
int n = mat_dim_b.width_; int n = mat_dim_b.width_;
...@@ -150,11 +148,12 @@ class MatMulXPUKernel : public framework::OpKernel<T> { ...@@ -150,11 +148,12 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
int ldy = mat_dim_b.trans_ ? k : n; int ldy = mat_dim_b.trans_ ? k : n;
int ldout = n; int ldout = n;
int batch_size = mat_dim_a.batch_size_; int batch_size = mat_dim_a.batch_size_;
if (batch_size == 0 || batch_size == 1) {
int r = xpu::fc_fusion<float, float, float, int16_t>( if (batch_size == 0) {
int r = xpu::fc_fusion<float, float, float, FCT>(
dev_ctx.x_context(), x->data<T>(), y->data<T>(), data_c, m, n, k, dev_ctx.x_context(), x->data<T>(), y->data<T>(), data_c, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, ldy,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r, "XPU fc_fusion kernel return wrong value[%d %s]", r,
...@@ -168,16 +167,33 @@ class MatMulXPUKernel : public framework::OpKernel<T> { ...@@ -168,16 +167,33 @@ class MatMulXPUKernel : public framework::OpKernel<T> {
const float *x_data = x->data<T>() + x_stride * i; const float *x_data = x->data<T>() + x_stride * i;
const float *y_data = y->data<T>() + y_stride * i; const float *y_data = y->data<T>() + y_stride * i;
float *out_data = data_c + out_stride * i; float *out_data = data_c + out_stride * i;
int r = xpu::fc_fusion<float, float, float, int16_t>( int r = xpu::fc_fusion<float, float, float, FCT>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k, dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx, mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR); ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External( platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", "XPU fc_fusion kernel return wrong value[%d %s]", r,
r, XPUAPIErrorMsg[r])); XPUAPIErrorMsg[r]));
} }
} }
}
template <typename DeviceContext, typename T>
class MatMulXPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &context) const override {
auto *x = context.Input<framework::Tensor>("X");
auto *y = context.Input<framework::Tensor>("Y");
auto *out = context.Output<framework::Tensor>("Out");
out->mutable_data<T>(context.GetPlace());
bool trans_x = context.Attr<bool>("transpose_X");
bool trans_y = context.Attr<bool>("transpose_Y");
if (std::getenv("XPU_PADDLE_MAT_MUL_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, context);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, context);
}
} }
}; };
...@@ -244,75 +260,10 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> { ...@@ -244,75 +260,10 @@ class MatMulGradXPUKernel : public framework::OpKernel<T> {
const framework::Tensor &b, bool trans_b, const framework::Tensor &b, bool trans_b,
framework::Tensor *out) const { framework::Tensor *out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_FCINT32") != nullptr) {
auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, context);
const auto &a_dims = a.dims();
const auto &b_dims = b.dims();
if (a_dims.size() == 3 && b_dims.size() <= 2) {
// if transpose_X is true, the transpose cost much time
if (!context.Attr<bool>("transpose_X")) {
mat_dim_a.height_ *= mat_dim_a.batch_size_;
mat_dim_a.batch_size_ = 0;
} else { } else {
mat_dim_b.batch_size_ = mat_dim_a.batch_size_; MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, context);
mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
}
}
PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the "
"first tensor width must be same as second tensor "
"height, but received "
"width:%d, height:%d",
mat_dim_a.width_, mat_dim_b.height_));
PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_grad_op, the two input"
"tensor batch_size must be same, but received first "
"tensor batch_size:%d, second "
"tensor batch_size:%d",
mat_dim_a.batch_size_, mat_dim_b.batch_size_));
T alpha = static_cast<T>(context.Attr<float>("alpha"));
auto &dev_ctx = context.template device_context<DeviceContext>();
float *data_c = out->data<T>();
int m = mat_dim_a.height_;
int n = mat_dim_b.width_;
int k = mat_dim_a.width_;
int ldx = mat_dim_a.trans_ ? m : k;
int ldy = mat_dim_b.trans_ ? k : n;
int ldout = n;
int batch_size = mat_dim_a.batch_size_;
if (batch_size == 0 || batch_size == 1) {
int r = xpu::fc_fusion<float, float, float, int16_t>(
dev_ctx.x_context(), a.data<T>(), b.data<T>(), data_c, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} else {
// batch matmul
int x_stride = mat_dim_a.stride_;
int y_stride = mat_dim_b.stride_;
int out_stride = m * n;
for (int i = 0; i < batch_size; ++i) {
const float *x_data = a.data<T>() + x_stride * i;
const float *y_data = b.data<T>() + y_stride * i;
float *out_data = data_c + out_stride * i;
int r = xpu::fc_fusion<float, float, float, int16_t>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr, ldx,
ldy, ldout, alpha, 0, nullptr, xpu::Activation_t::LINEAR);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
}
} }
} }
......
...@@ -21,211 +21,141 @@ ...@@ -21,211 +21,141 @@
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename FCT>
void MatMulXPUFunction(const Tensor* X, const Tensor* Y, static void MatMulXPUFunction(const Tensor* x, const Tensor* y, Tensor* out,
const std::vector<std::int64_t>& x_dims,
const std::vector<std::int64_t>& y_dims, Tensor* Out,
bool trans_x, bool trans_y, bool trans_x, bool trans_y,
const paddle::framework::ExecutionContext& ctx) { const paddle::framework::ExecutionContext& ctx) {
const int x_ndim = x_dims.size(); const auto& x_dims = x->dims();
const int y_ndim = y_dims.size(); const auto& y_dims = y->dims();
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); ctx.template device_context<paddle::platform::XPUDeviceContext>();
// currently only support x_ndim == y_dim and non-broadcast case auto mat_dim_a =
PADDLE_ENFORCE_EQ(x_ndim, y_ndim, platform::errors::InvalidArgument( math::CreateMatrixDescriptor(RowMatrixFromVector(x_dims), 0, trans_x);
"Shape mistake in matmul_v2_op")); auto mat_dim_b =
for (int i = 0; i < x_ndim - 2; i++) { math::CreateMatrixDescriptor(ColumnMatrixFromVector(y_dims), 0, trans_y);
PADDLE_ENFORCE_EQ(
x_dims.data()[i], y_dims.data()[i],
platform::errors::InvalidArgument("Shape mistake in matmul_v2_op"));
}
int ret = 0;
if (x_ndim == 1 && y_ndim == 1) {
PADDLE_ENFORCE_EQ(X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers is not equal to Y's numbers,"
"when X/Y's dims =1"));
VLOG(3) << "MatMul's case 1";
Out->Resize({1});
Out->mutable_data<T>(ctx.GetPlace());
ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, 1, 1,
X->numel(), 1.0f, X->data<T>(),
Y->data<T>(), 0.0f, Out->data<T>());
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d] in matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
return;
}
if (x_ndim == 1) {
const int N = X->numel();
if (trans_y) {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 1], N,
platform::errors::InvalidArgument("Input(Y) has error dim."));
} else {
PADDLE_ENFORCE_EQ(
y_dims[y_ndim - 2], N,
platform::errors::InvalidArgument("Input(Y) has error dim."));
}
std::vector<std::int64_t> out_dims(y_ndim - 1);
if (trans_y) {
std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin());
} else {
std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin());
out_dims.back() = y_dims.back();
}
Out->Resize(framework::make_ddim(out_dims));
Out->mutable_data<T>(ctx.GetPlace());
if (trans_y) {
const int M = Y->numel() / N;
VLOG(3) << "MatMul's case 2";
ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, true, 1, M, N,
1.0f, X->data<T>(), Y->data<T>(), 0.0f,
Out->data<T>());
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d] in "
"matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
} else {
const int M = y_dims[y_ndim - 1];
const int batch_size = Y->numel() / (M * N);
for (int i = 0; i < batch_size; i++) {
ret = baidu::xpu::api::fc_int16(
dev_ctx.x_context(), false, false, 1, M, N, 1.0f, X->data<T>(),
Y->data<T>() + i * M * N, 0.0f, Out->data<T>() + i * M);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d] in matmul_v2, "
"please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
}
}
return;
}
if (y_ndim == 1) { if (x_dims.size() == 3 && y_dims.size() <= 2) {
const int N = Y->numel(); // if transpose_X is true, the transpose cost much time
if (trans_x) { if (!trans_x) {
PADDLE_ENFORCE_EQ( mat_dim_a.height_ *= mat_dim_a.batch_size_;
x_dims[x_ndim - 2], N, mat_dim_a.batch_size_ = 0;
platform::errors::InvalidArgument("Input(X) has error dim."));
} else { } else {
PADDLE_ENFORCE_EQ( mat_dim_b.batch_size_ = mat_dim_a.batch_size_;
x_dims[x_ndim - 1], N, mat_dim_b.height_ = mat_dim_b.height_ / mat_dim_b.batch_size_;
platform::errors::InvalidArgument("Input(X) has error dim."));
} }
std::vector<std::int64_t> out_dims(x_ndim - 1);
if (trans_x) {
std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin());
out_dims.back() = x_dims.back();
} else {
std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin());
} }
Out->Resize(framework::make_ddim(out_dims));
Out->mutable_data<T>(ctx.GetPlace());
if (trans_x) { if (mat_dim_a.width_ == mat_dim_b.height_) {
const int M = x_dims[x_ndim - 1]; if (mat_dim_a.batch_size_ == 0 && mat_dim_b.batch_size_ == 1) {
const int batch_size = X->numel() / (M * N); mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
for (int i = 0; i < batch_size; i++) {
ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), true, false, M, 1,
N, 1.0f, X->data<T>() + i * M * N,
Y->data<T>(), 0.0f,
Out->data<T>() + i * M);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d] in matmul_v2, "
"please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
} }
} else { if (mat_dim_a.batch_size_ == 1 && mat_dim_b.batch_size_ == 0) {
const int M = X->numel() / N; mat_dim_a.batch_size_ = mat_dim_b.batch_size_ = 0;
VLOG(3) << "MatMul's case 7";
ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, M, 1,
N, 1.0f, X->data<T>(), Y->data<T>(), 0.0f,
Out->data<T>());
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External("XPU API return wrong value[%d] in "
"matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
} }
return;
} }
const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; PADDLE_ENFORCE_EQ(mat_dim_a.width_, mat_dim_b.height_,
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; platform::errors::InvalidArgument(
if (trans_y) { "Shape mistake in matmul_v2_op xdims = %s ydims = %s",
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, platform::errors::InvalidArgument( x_dims.to_str(), y_dims.to_str()));
"Input(X) has error dim.")); PADDLE_ENFORCE_EQ(mat_dim_a.batch_size_, mat_dim_b.batch_size_,
platform::errors::InvalidArgument(
"Shape mistake in matmul_v2_op xdims = %s ydims = %s",
x_dims.to_str(), y_dims.to_str()));
float* data_c = out->data<T>();
int m = mat_dim_a.height_;
int n = mat_dim_b.width_;
int k = mat_dim_a.width_;
int batch_size = mat_dim_a.batch_size_;
if (batch_size == 0) {
int r = xpu::fc<float, float, float, FCT>(
dev_ctx.x_context(), x->data<T>(), y->data<T>(), data_c, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} else { } else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, platform::errors::InvalidArgument( // batch matmul
"Input(X) has error dim.")); int x_stride = mat_dim_a.stride_;
int y_stride = mat_dim_b.stride_;
int out_stride = m * n;
for (int i = 0; i < batch_size; ++i) {
const float* x_data = x->data<T>() + x_stride * i;
const float* y_data = y->data<T>() + y_stride * i;
float* out_data = data_c + out_stride * i;
int r = xpu::fc<float, float, float, FCT>(
dev_ctx.x_context(), x_data, y_data, out_data, m, n, k,
mat_dim_a.trans_, mat_dim_b.trans_, nullptr, nullptr, nullptr);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU fc_fusion kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
} }
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
const int ndim = (std::max)(x_ndim, y_ndim);
std::vector<std::int64_t> out_broadcast_dims(ndim);
int batch_size = 1;
for (int i = 0; i < ndim - 2; i++) {
PADDLE_ENFORCE_EQ(
x_dims.data()[i], y_dims.data()[i],
platform::errors::InvalidArgument("Shape mistake in matmul_v2_op"));
out_broadcast_dims[i] = x_dims.data()[i];
batch_size *= x_dims.data()[i];
} }
out_broadcast_dims[ndim - 2] = M;
out_broadcast_dims[ndim - 1] = N;
Out->Resize(framework::make_ddim(out_broadcast_dims));
Out->mutable_data<T>(ctx.GetPlace());
ret = baidu::xpu::api::batched_gemm_int16(
dev_ctx.x_context(), trans_x, trans_y, batch_size, M, N, K, 1.0f,
X->data<T>(), Y->data<T>(), Out->data<T>(), nullptr, nullptr);
PADDLE_ENFORCE_EQ(
ret, XPU_SUCCESS,
platform::errors::External(
"XPU API return wrong value[%d] in matmul_v2, please check whether "
"Baidu Kunlun Card is properly installed.",
ret));
} }
template <typename T> template <typename T>
class MatMulV2XPUKernel : public framework::OpKernel<T> { class MatMulV2XPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const paddle::framework::ExecutionContext& ctx) const override { void Compute(const paddle::framework::ExecutionContext& ctx) const override {
auto* X = ctx.Input<Tensor>("X"); auto* x = ctx.Input<Tensor>("X");
auto* Y = ctx.Input<Tensor>("Y"); auto* y = ctx.Input<Tensor>("Y");
auto* Out = ctx.Output<Tensor>("Out"); auto* out = ctx.Output<Tensor>("Out");
bool trans_x = ctx.Attr<bool>("trans_x"); bool trans_x = ctx.Attr<bool>("trans_x");
bool trans_y = ctx.Attr<bool>("trans_y"); bool trans_y = ctx.Attr<bool>("trans_y");
MatMulXPUFunction<T>(X, Y, vectorize(X->dims()), vectorize(Y->dims()), Out, out->mutable_data<T>(ctx.GetPlace());
trans_x, trans_y, ctx); if (std::getenv("XPU_PADDLE_MAT_MUL_V2_FCINT32") != nullptr) {
MatMulXPUFunction<T, int32_t>(x, y, out, trans_x, trans_y, ctx);
} else {
MatMulXPUFunction<T, int16_t>(x, y, out, trans_x, trans_y, ctx);
}
} }
}; };
template <typename DeviceContext, typename T>
static framework::Tensor XPUFoldHeadAndLastDims(
const DeviceContext& context, const framework::Tensor& input) {
auto in_dims = input.dims();
if (in_dims.size() != 3) {
return input;
}
framework::Tensor output;
output.Resize({in_dims[1], in_dims[0], in_dims[2]});
output.mutable_data<T>(context.GetPlace());
std::vector<int> in_shape_host = {static_cast<int>(in_dims[0]),
static_cast<int>(in_dims[1]),
static_cast<int>(in_dims[2])};
std::vector<int> axis_host = {1, 0, 2};
int r = xpu::transpose(context.x_context(), input.data<T>(), output.data<T>(),
in_shape_host.data(), axis_host.data(), /*ndims=*/3);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External(
"XPU transpose kernel return wrong value[%d %s]", r,
XPUAPIErrorMsg[r]));
output.Resize({in_dims[1], in_dims[0] * in_dims[2]});
return output;
}
template <typename T> template <typename T>
class MatMulV2XPUGradKernel : public framework::OpKernel<T> { class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
public: public:
void MatMul(const framework::ExecutionContext& context, void MatMul(const framework::ExecutionContext& ctx,
const framework::Tensor& a, bool trans_a, const framework::Tensor& a, bool trans_a,
const framework::Tensor& b, bool trans_b, const framework::Tensor& b, bool trans_b,
framework::Tensor* out) const { framework::Tensor* out) const {
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(ctx.GetPlace());
MatMulXPUFunction<T>(&a, &b, vectorize(a.dims()), vectorize(b.dims()), out, if (std::getenv("XPU_PADDLE_MAT_MUL_GRAD_V2_FCINT32") != nullptr) {
trans_a, trans_b, context); MatMulXPUFunction<T, int32_t>(&a, &b, out, trans_a, trans_b, ctx);
} else {
MatMulXPUFunction<T, int16_t>(&a, &b, out, trans_a, trans_b, ctx);
}
} }
void CalcInputGrad(const framework::ExecutionContext& context, void CalcInputGrad(const framework::ExecutionContext& context,
...@@ -239,79 +169,33 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> { ...@@ -239,79 +169,33 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
if (!need_combine) { if (!need_combine) {
MatMul(context, a, trans_a, b, trans_b, out); MatMul(context, a, trans_a, b, trans_b, out);
} else { } else {
// currently not support this case
}
}
void Compute(const framework::ExecutionContext& ctx) const override {
bool transpose_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y");
auto x = *ctx.Input<framework::Tensor>("X");
auto y = *ctx.Input<framework::Tensor>("Y");
auto dout = *ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
// get dims
std::vector<std::int64_t> x_dims = vectorize(x.dims());
std::vector<std::int64_t> y_dims = vectorize(y.dims());
std::vector<std::int64_t> dout_dims = vectorize(dout.dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
auto* dx = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<Tensor>(framework::GradVarName("Y"));
auto& dev_ctx = auto& dev_ctx =
ctx.template device_context<paddle::platform::XPUDeviceContext>(); context.template device_context<paddle::platform::XPUDeviceContext>();
// Case1 : x's or y's dim = 1 MatMul(
int ret = 0; context,
if (x_ndim == 1 && y_ndim == 1) { is_fold_init_dims_a
if (dx) { ? FoldInitDims(a)
dx->mutable_data<T>(ctx.GetPlace()); : XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, dev_ctx, a),
dx->numel(), 1, 1, 1.0f, y.data<T>(), trans_a,
dout.data<T>(), 0.0f, dx->data<T>()); is_fold_init_dims_b
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, ? FoldInitDims(b)
platform::errors::External( : XPUFoldHeadAndLastDims<paddle::platform::XPUDeviceContext, T>(
"XPU API return wrong value[%d] in " dev_ctx, b),
"matmul_v2_grad, please check whether " trans_b, out);
"Baidu Kunlun Card is properly installed.", }
ret)); }
}
if (dy) { void Compute(const framework::ExecutionContext& context) const override {
dy->mutable_data<T>(ctx.GetPlace()); bool transpose_x = context.Attr<bool>("trans_x");
ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, bool transpose_y = context.Attr<bool>("trans_y");
dy->numel(), 1, 1, 1.0f, x.data<T>(),
dout.data<T>(), 0.0f, dy->data<T>()); auto x = *context.Input<framework::Tensor>("X");
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, auto y = *context.Input<framework::Tensor>("Y");
platform::errors::External( auto dout =
"XPU API return wrong value[%d] in " *context.Input<framework::Tensor>(framework::GradVarName("Out"));
"matmul_v2_grad, please check whether " auto* dx = context.Output<framework::Tensor>(framework::GradVarName("X"));
"Baidu Kunlun Card is properly installed.", auto* dy = context.Output<framework::Tensor>(framework::GradVarName("Y"));
ret));
}
return;
}
bool is_broadcast = true;
if (x_ndim <= 2 || y_ndim <= 2) {
is_broadcast = false;
} else if (x_ndim != y_ndim) {
is_broadcast = true;
} else {
is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2,
y_dims.cbegin());
}
// currently only support non-broadcast case
PADDLE_ENFORCE_EQ(
is_broadcast, false,
platform::errors::InvalidArgument("Shape mistake in matmul_v2_op"));
// Case2: no broadcast or no batch size, it aims to speed and it is same as
// matmul in old version.
if (!is_broadcast) {
ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y);
framework::DDim dx_dims; framework::DDim dx_dims;
if (dx) { if (dx) {
...@@ -328,18 +212,19 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> { ...@@ -328,18 +212,19 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
dy->Resize(y.dims()); dy->Resize(y.dims());
} }
} }
if (transpose_x && transpose_y) { if (transpose_x && transpose_y) {
CalcInputGrad(ctx, y, true, true, dout, true, false, dx); CalcInputGrad(context, y, true, true, dout, true, false, dx);
CalcInputGrad(ctx, dout, true, true, x, true, false, dy); CalcInputGrad(context, dout, true, true, x, true, false, dy);
} else if (transpose_x) { } else if (transpose_x) {
CalcInputGrad(ctx, y, false, false, dout, true, false, dx); CalcInputGrad(context, y, false, false, dout, true, false, dx);
CalcInputGrad(ctx, x, false, false, dout, false, true, dy); CalcInputGrad(context, x, false, false, dout, false, true, dy);
} else if (transpose_y) { } else if (transpose_y) {
CalcInputGrad(ctx, dout, false, false, y, false, true, dx); CalcInputGrad(context, dout, false, false, y, false, true, dx);
CalcInputGrad(ctx, dout, true, true, x, false, true, dy); CalcInputGrad(context, dout, true, true, x, false, true, dy);
} else { } else {
CalcInputGrad(ctx, dout, false, false, y, true, false, dx); CalcInputGrad(context, dout, false, false, y, true, false, dx);
CalcInputGrad(ctx, x, true, true, dout, false, true, dy); CalcInputGrad(context, x, true, true, dout, false, true, dy);
} }
if (dx) { if (dx) {
...@@ -347,13 +232,13 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> { ...@@ -347,13 +232,13 @@ class MatMulV2XPUGradKernel : public framework::OpKernel<T> {
dx->Resize(dx_dims); dx->Resize(dx_dims);
} }
} }
if (dy) { if (dy) {
if (dy_dims != y.dims()) { if (dy_dims != y.dims()) {
dy->Resize(dy_dims); dy->Resize(dy_dims);
} }
} }
} }
}
}; };
} // namespace operators } // namespace operators
......
...@@ -35,7 +35,7 @@ class OneHotXPUKernel : public framework::OpKernel<T> { ...@@ -35,7 +35,7 @@ class OneHotXPUKernel : public framework::OpKernel<T> {
if (context.HasInput("depth_tensor")) { if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor"); auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>(); auto* depth_data = depth_tensor->data<int32_t>();
if (depth_tensor->place() == platform::XPUPlace()) { if (platform::is_xpu_place(depth_tensor->place())) {
xpu_memcpy(static_cast<void*>(&depth), xpu_memcpy(static_cast<void*>(&depth),
static_cast<const void*>(depth_data), sizeof(int32_t), static_cast<const void*>(depth_data), sizeof(int32_t),
XPU_DEVICE_TO_HOST); XPU_DEVICE_TO_HOST);
......
// Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifdef PADDLE_WITH_XPU
#include <string>
#include <vector>
#include "paddle/fluid/framework/framework.pb.h"
#include "paddle/fluid/operators/one_hot_op.h"
namespace paddle {
namespace operators {
using LoDTensor = framework::LoDTensor;
using Tensor = framework::Tensor;
template <typename DeviceContext, typename T>
class OneHotV2XPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<LoDTensor>("X");
auto* out = context.Output<LoDTensor>("Out");
int depth = context.Attr<int>("depth");
if (context.HasInput("depth_tensor")) {
auto* depth_tensor = context.Input<Tensor>("depth_tensor");
auto* depth_data = depth_tensor->data<int32_t>();
if (platform::is_xpu_place(depth_tensor->place())) {
xpu_memcpy(static_cast<void*>(&depth),
static_cast<const void*>(depth_data), sizeof(int32_t),
XPU_DEVICE_TO_HOST);
} else {
depth = depth_data[0];
}
auto out_dims = out->dims();
out_dims[out_dims.size() - 1] = depth;
out->Resize(out_dims);
}
auto& dev_ctx = context.template device_context<DeviceContext>();
int len = in->numel();
int ret = xpu::one_hot<T>(dev_ctx.x_context(), in->data<T>(),
out->mutable_data<float>(context.GetPlace()), len,
depth, 1.0, 0.0);
PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS,
platform::errors::External(
"XPU one_hot kernel return wrong value[%d %s]", ret,
XPUAPIErrorMsg[ret]));
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
one_hot_v2, ops::OneHotV2XPUKernel<paddle::platform::XPUDeviceContext, int>,
ops::OneHotV2XPUKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif
...@@ -46,10 +46,13 @@ class ScaleXPUKernel : public framework::OpKernel<T> { ...@@ -46,10 +46,13 @@ class ScaleXPUKernel : public framework::OpKernel<T> {
in->dims().to_str().c_str(), in->dims().to_str().c_str(),
out->dims().to_str().c_str())); out->dims().to_str().c_str()));
auto& dev_ctx = ctx.template device_context<DeviceContext>(); auto& dev_ctx = ctx.template device_context<DeviceContext>();
int r = xpu::scale(dev_ctx.x_context(), in->numel(), scale, bias, int r =
bias_after_scale, in->data<float>(), out->data<float>()); xpu::scale(dev_ctx.x_context(), in->data<float>(), out->data<float>(),
PADDLE_ENFORCE_EQ(r, xpu::Error_t::SUCCESS, in->numel(), bias_after_scale, scale, bias);
platform::errors::Fatal("XPU scale kernel error!")); PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS,
platform::errors::External("XPU scale kernel return wrong value[%d %s]",
r, XPUAPIErrorMsg[r]));
} }
}; };
......
...@@ -41,8 +41,21 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> { ...@@ -41,8 +41,21 @@ class SoftmaxXPUKernel : public framework::OpKernel<T> {
} }
auto& dev_ctx = context.template device_context<DeviceContext>(); auto& dev_ctx = context.template device_context<DeviceContext>();
int r = xpu::softmax<T>(dev_ctx.x_context(), x->data<float>(),
out->data<float>(), x_dims, axis); int r = XPU_SUCCESS;
Tensor clip_x;
int len = x->numel();
T* clip_x_data =
clip_x.mutable_data<T>(platform::XPUPlace(), len * sizeof(T));
r = xpu::clip(dev_ctx.x_context(), x->data<float>(), clip_x_data, len,
-1e30, 1e30);
PADDLE_ENFORCE_EQ(r, XPU_SUCCESS,
platform::errors::External("XPU API(clip) return wrong "
"value[%d %s]",
r, XPUAPIErrorMsg[r]));
r = xpu::softmax<T>(dev_ctx.x_context(), clip_x_data, out->data<float>(),
x_dims, axis);
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
r, XPU_SUCCESS, r, XPU_SUCCESS,
platform::errors::External("XPU API(softmax2d_forward) return wrong " platform::errors::External("XPU API(softmax2d_forward) return wrong "
......
...@@ -13,12 +13,11 @@ ...@@ -13,12 +13,11 @@
# limitations under the License. # limitations under the License.
from __future__ import print_function from __future__ import print_function
import unittest
import numpy as np
import sys import sys
sys.path.append("..") sys.path.append("..")
from op_test import OpTest import unittest
import numpy as np
from op_test_xpu import XPUOpTest
import paddle.fluid.core as core import paddle.fluid.core as core
import paddle import paddle
...@@ -57,9 +56,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): ...@@ -57,9 +56,7 @@ def reference_matmul(X, Y, transpose_X=False, transpose_Y=False):
return Out return Out
@unittest.skipIf(not paddle.is_compiled_with_xpu(), class TestMatMulV2Op(XPUOpTest):
"core is not compiled with XPU")
class TestMatMulV2Op(OpTest):
""" """
case 1 case 1
""" """
...@@ -74,10 +71,10 @@ class TestMatMulV2Op(OpTest): ...@@ -74,10 +71,10 @@ class TestMatMulV2Op(OpTest):
self.dtype = "float32" self.dtype = "float32"
def setUp(self): def setUp(self):
self.use_xpu = True
self.init_kernel_type() self.init_kernel_type()
self.config() self.config()
self.op_type = "matmul_v2" self.op_type = "matmul_v2"
self.use_xpu = True
x = np.random.random(self.x_shape).astype(self.dtype) x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype) y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1 # -0.1 ~ 0.1
...@@ -94,31 +91,25 @@ class TestMatMulV2Op(OpTest): ...@@ -94,31 +91,25 @@ class TestMatMulV2Op(OpTest):
def test_check_output(self): def test_check_output(self):
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_output_with_place(place, atol=0.01) self.check_output_with_place(place)
def test_check_grad(self): def test_check_grad(self):
place = paddle.XPUPlace(0) place = paddle.XPUPlace(0)
self.check_grad_with_place( self.check_grad_with_place(place, ['X', 'Y'], 'Out')
place, ['X', 'Y'], 'Out', max_relative_error=0.1)
''' # class TestMatMuklOp2(TestMatMulV2Op):
@unittest.skipIf(not paddle.is_compiled_with_xpu(), # """
"core is not compiled with XPU") # case 2
class TestMatMuklOp2(TestMatMulV2Op): # """
"""
case 2
"""
def config(self): # def config(self):
self.x_shape = (100, ) # self.x_shape = (100, )
self.y_shape = (1, 3, 2, 100) # self.y_shape = (1, 3, 2, 100)
self.trans_x = False # self.trans_x = False
self.trans_y = True # self.trans_y = True
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestMatMuklOp3(TestMatMulV2Op): class TestMatMuklOp3(TestMatMulV2Op):
""" """
case 3 case 3
...@@ -131,21 +122,18 @@ class TestMatMuklOp3(TestMatMulV2Op): ...@@ -131,21 +122,18 @@ class TestMatMuklOp3(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(), # class TestMatMuklOp4(TestMatMulV2Op):
"core is not compiled with XPU") # """
class TestMatMuklOp4(TestMatMulV2Op): # case 4
""" # """
case 4
""" # def config(self):
# self.x_shape = (100, )
# self.y_shape = (1, 2, 100, 2)
# self.trans_x = False
# self.trans_y = False
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 2, 100, 2)
self.trans_x = False
self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestMatMuklOp5(TestMatMulV2Op): class TestMatMuklOp5(TestMatMulV2Op):
""" """
case 5 case 5
...@@ -158,37 +146,29 @@ class TestMatMuklOp5(TestMatMulV2Op): ...@@ -158,37 +146,29 @@ class TestMatMuklOp5(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(), # class TestMatMuklOp6(TestMatMulV2Op):
"core is not compiled with XPU") # """
class TestMatMuklOp6(TestMatMulV2Op): # case 6
""" # """
case 6
"""
def config(self):
self.x_shape = (1, 2, 100, 1)
self.y_shape = (100, )
self.trans_x = True
self.trans_y = False
# def config(self):
# self.x_shape = (1, 2, 102, 1)
# self.y_shape = (102, )
# self.trans_x = True
# self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(), # class TestMatMuklOp7(TestMatMulV2Op):
"core is not compiled with XPU") # """
class TestMatMuklOp7(TestMatMulV2Op): # case 7
""" # """
case 7
"""
def config(self): # def config(self):
self.x_shape = (1, 2, 1, 100) # self.x_shape = (1, 2, 1, 100)
self.y_shape = (100, ) # self.y_shape = (100, )
self.trans_x = False # self.trans_x = False
self.trans_y = False # self.trans_y = False
'''
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestMatMuklOp8(TestMatMulV2Op): class TestMatMuklOp8(TestMatMulV2Op):
""" """
case 8 case 8
...@@ -201,37 +181,97 @@ class TestMatMuklOp8(TestMatMulV2Op): ...@@ -201,37 +181,97 @@ class TestMatMuklOp8(TestMatMulV2Op):
self.trans_y = False self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(), # class TestMatMuklOp9(TestMatMulV2Op):
"core is not compiled with XPU") # """
# case 9
# """
# def config(self):
# self.x_shape = (1, 1, 1, 100)
# self.y_shape = (2, 1, 2, 100)
# self.trans_x = False
# self.trans_y = True
# class TestMatMuklOp10(TestMatMulV2Op):
# """
# case 10
# """
# def config(self):
# self.x_shape = (1, 1, 25, 4)
# self.y_shape = (1, 2, 4, 25)
# self.trans_x = False
# self.trans_y = False
# class TestMatMuklOp11(TestMatMulV2Op):
# """
# case 11
# """
# def config(self):
# self.x_shape = (2, 1, 2, 100)
# self.y_shape = (1, 1, 100, 2)
# self.trans_x = False
# self.trans_y = False
# class TestMatMuklOp12(TestMatMulV2Op):
# """
# case 12
# """
# def config(self):
# self.x_shape = (2, 1, 4, 25)
# self.y_shape = (1, 1, 4, 25)
# self.trans_x = True
# self.trans_y = False
class TestMatMuklOp13(TestMatMulV2Op): class TestMatMuklOp13(TestMatMulV2Op):
""" """
case 13 case 13
""" """
def config(self): def config(self):
self.x_shape = (2, 2, 2, 50) self.x_shape = (2, 2, 10, 10)
self.y_shape = (2, 2, 2, 50) self.y_shape = (2, 2, 10, 10)
self.trans_x = True self.trans_x = True
self.trans_y = False self.trans_y = False
''' # class TestMatMuklOp14(TestMatMulV2Op):
@unittest.skipIf(not paddle.is_compiled_with_xpu(), # """
"core is not compiled with XPU") # case 14_1
class TestMatMuklOp16(TestMatMulV2Op): # """
"""
case 16 : to check the gradient for special case
"""
def config(self): # def config(self):
self.x_shape = (100) # self.x_shape = (3, 1, 6, 6)
self.y_shape = (1, 2, 2, 100, 2) # self.y_shape = (1, 2, 6, 9)
self.trans_x = False # self.trans_x = True
self.trans_y = False # self.trans_y = False
# class TestMatMuklOp15(TestMatMulV2Op):
# """
# case 14_2
# """
# def config(self):
# self.x_shape = (3, 1, 6, 6)
# self.y_shape = (1, 2, 6, 9)
# self.trans_x = False
# self.trans_y = False
# class TestMatMuklOp16(TestMatMulV2Op):
# """
# case 16 : to check the gradient for special case
# """
# def config(self):
# self.x_shape = (100)
# self.y_shape = (1, 2, 2, 100, 2)
# self.trans_x = False
# self.trans_y = False
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestMatMuklOp17(TestMatMulV2Op): class TestMatMuklOp17(TestMatMulV2Op):
""" """
case 17 : to check the gradient for special case case 17 : to check the gradient for special case
...@@ -242,36 +282,30 @@ class TestMatMuklOp17(TestMatMulV2Op): ...@@ -242,36 +282,30 @@ class TestMatMuklOp17(TestMatMulV2Op):
self.y_shape = (100) self.y_shape = (100)
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
'''
@unittest.skipIf(not paddle.is_compiled_with_xpu(),
"core is not compiled with XPU")
class TestMatMulV2API(unittest.TestCase):
def setUp(self):
self.places = [fluid.CPUPlace()]
self.places.append(fluid.XPUPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32")
input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32")
result = paddle.matmul(input_x, input_y)
x_np = np.random.random([4, 3]).astype("float32") # class TestMatMuklOpBroadcast1(TestMatMulV2Op):
y_np = np.random.random([3, 4]).astype("float32") # """
# case 14_3
# """
exe = fluid.Executor(place) # def config(self):
fetches = exe.run(fluid.default_main_program(), # self.x_shape = (3, 1, 10, 10)
feed={"input_x": x_np, # self.y_shape = (1, 2, 10, 10)
"input_y": y_np}, # self.trans_x = True
fetch_list=[result]) # self.trans_y = True
def test_static(self): # class TestMatMuklOpBroadcast2(TestMatMulV2Op):
for place in self.places: # """
self.check_static_result(place=place) # case 14_4
# """
# def config(self):
# self.x_shape = (3, 1, 10, 10)
# self.y_shape = (1, 2, 10, 10)
# self.trans_x = False
# self.trans_y = True
if __name__ == "__main__": if __name__ == "__main__":
paddle.enable_static()
unittest.main() unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import numpy as np
import paddle
import paddle.fluid.core as core
import sys
sys.path.append("..")
from op_test_xpu import XPUOpTest
import paddle.fluid as fluid
from paddle.fluid import Program, program_guard
import time
paddle.enable_static()
class TestOneHotOp(XPUOpTest):
def setUp(self):
self.use_xpu = True
self.op_type = 'one_hot_v2'
depth = 10
depth_np = np.array(10).astype('int32')
# dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32)}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), 1,
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, 0, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'dtype': int(core.VarDesc.VarType.FP32), 'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_default_dtype(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
depth_np = np.array(10).astype('int32')
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, x[i]] = 1.0
self.inputs = {'X': (x, x_lod), 'depth_tensor': depth_np}
self.attrs = {}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_default_dtype_attr(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
dimension = 12
x_lod = [[4, 1, 3, 3]]
x = [np.random.randint(0, depth - 1) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0]), 1])
out = np.zeros(shape=(np.product(x.shape[:-1]), 1,
depth)).astype('float32')
for i in range(np.product(x.shape)):
out[i, 0, x[i]] = 1.0
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOp_out_of_range(XPUOpTest):
def setUp(self):
self.op_type = 'one_hot_v2'
depth = 10
x_lod = [[4, 1, 3, 3]]
x = [np.random.choice([-1, depth]) for i in range(sum(x_lod[0]))]
x = np.array(x).astype('int32').reshape([sum(x_lod[0])])
out = np.zeros(shape=(np.product(x.shape), depth)).astype('float32')
self.inputs = {'X': (x, x_lod)}
self.attrs = {'depth': depth, 'allow_out_of_range': True}
self.outputs = {'Out': (out, x_lod)}
def test_check_output(self):
place = paddle.XPUPlace(0)
self.check_output_with_place(place, check_dygraph=False)
class TestOneHotOpApi(unittest.TestCase):
def test_api(self):
depth = 10
self._run(depth)
def test_api_with_depthTensor(self):
depth = fluid.layers.assign(input=np.array([10], dtype=np.int32))
self._run(depth)
def test_api_with_dygraph(self):
depth = 10
label = np.array([np.random.randint(0, depth - 1)
for i in range(6)]).reshape([6, 1])
with fluid.dygraph.guard():
one_hot_label = fluid.one_hot(
input=fluid.dygraph.to_variable(label), depth=depth)
def _run(self, depth):
label = fluid.layers.data(name="label", shape=[1], dtype="int64")
one_hot_label = fluid.one_hot(input=label, depth=depth)
place = fluid.XPUPlace(0)
label_data = np.array([np.random.randint(0, 10 - 1)
for i in range(6)]).reshape([6, 1])
exe = fluid.Executor(place)
exe.run(fluid.default_startup_program())
ret = exe.run(feed={'label': label_data, },
fetch_list=[one_hot_label],
return_numpy=False)
class BadInputTestOnehotV2(unittest.TestCase):
def test_error(self):
with fluid.program_guard(fluid.Program()):
def test_bad_x():
label = fluid.layers.data(
name="label",
shape=[4],
append_batch_size=False,
dtype="float32")
one_hot_label = fluid.one_hot(input=label, depth=4)
self.assertRaises(TypeError, test_bad_x)
if __name__ == '__main__':
paddle.enable_static()
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册