From 6b727e08b1f38b3f4acc1708c163ed6ae5df8d58 Mon Sep 17 00:00:00 2001 From: QingshuChen Date: Sun, 27 Sep 2020 13:56:14 +0800 Subject: [PATCH] support elementwise add, activation, matmul on Baidu Kunlun (#27143) * support elementwise add, activation, matmul on Baidu Kunlun * test=kunlun * minor * test=kunlun * reconstuct the xpu directory * test=kunlun * minor * test=kunlun * minor * test=kunlun * minor * test=kunlun * minor * test=kunlun * minor * test=kunlun --- cmake/external/xpu.cmake | 2 +- cmake/operators.cmake | 8 +- .../allocation/naive_best_fit_allocator.cc | 9 +- paddle/fluid/operators/activation_op_xpu.cc | 179 +++++++++ .../elementwise/elementwise_add_op_xpu.cc | 162 ++++++++ .../operators/elementwise/elementwise_xpu.h | 113 ++++++ paddle/fluid/operators/matmul_op_xpu.cc | 343 +++++++++++++++++ .../{xpu/mul_xpu_op.cc => mul_op_xpu.cc} | 2 +- paddle/fluid/platform/init_test.cc | 1 + paddle/fluid/platform/xpu_header.h | 27 ++ python/paddle/__init__.py | 2 + python/paddle/device.py | 35 +- .../paddle/fluid/tests/unittests/op_test.py | 16 + .../fluid/tests/unittests/test_matmul_op.py | 1 + .../fluid/tests/unittests/test_mul_op.py | 54 +-- .../tests/unittests/xpu/test_activation_op.py | 215 +++++++++++ .../unittests/xpu/test_elementwise_add_op.py | 346 +++++++++++++++++ .../tests/unittests/xpu/test_matmul_op.py | 355 ++++++++++++++++++ .../fluid/tests/unittests/xpu/test_mul_op.py | 161 ++++++++ tools/wlist.json | 4 +- 20 files changed, 1970 insertions(+), 65 deletions(-) create mode 100644 paddle/fluid/operators/activation_op_xpu.cc create mode 100644 paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc create mode 100644 paddle/fluid/operators/elementwise/elementwise_xpu.h create mode 100644 paddle/fluid/operators/matmul_op_xpu.cc rename paddle/fluid/operators/{xpu/mul_xpu_op.cc => mul_op_xpu.cc} (100%) create mode 100755 python/paddle/fluid/tests/unittests/xpu/test_activation_op.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_elementwise_add_op.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_matmul_op.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_mul_op.py diff --git a/cmake/external/xpu.cmake b/cmake/external/xpu.cmake index 8a927d8e282..07fe7d245ef 100644 --- a/cmake/external/xpu.cmake +++ b/cmake/external/xpu.cmake @@ -4,7 +4,7 @@ endif() INCLUDE(ExternalProject) SET(XPU_PROJECT "extern_xpu") -SET(XPU_URL "https://kunlun1.su.bcebos.com/xpu.tar.gz" CACHE STRING "" FORCE) +SET(XPU_URL "https://baidu-kunlun-public.su.bcebos.com/paddle_depence/xpu.tar.gz" CACHE STRING "" FORCE) SET(XPU_SOURCE_DIR "${THIRD_PARTY_PATH}/xpu") SET(XPU_DOWNLOAD_DIR "${XPU_SOURCE_DIR}/src/${XPU_PROJECT}") SET(XPU_INSTALL_DIR "${THIRD_PARTY_PATH}/install/xpu") diff --git a/cmake/operators.cmake b/cmake/operators.cmake index 21080fbe8fd..7aa2766763c 100644 --- a/cmake/operators.cmake +++ b/cmake/operators.cmake @@ -62,9 +62,9 @@ function(op_library TARGET) endif() endif() if(WITH_XPU) - string(REPLACE "_op" "_xpu_op" XPU_FILE "${TARGET}") - if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/xpu/${XPU_FILE}.cc) - list(APPEND xpu_cc_srcs xpu/${XPU_FILE}.cc) + string(REPLACE "_op" "_op_xpu" XPU_FILE "${TARGET}") + if (EXISTS ${CMAKE_CURRENT_SOURCE_DIR}/${XPU_FILE}.cc) + list(APPEND xpu_cc_srcs ${XPU_FILE}.cc) endif() endif() else() @@ -83,7 +83,7 @@ function(op_library TARGET) list(APPEND mkldnn_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cu.cc$") list(APPEND cu_cc_srcs ${src}) - elseif(WITH_XPU AND ${src} MATCHES ".*_xpu_op.cc$") + elseif(WITH_XPU AND ${src} MATCHES ".*_op_xpu.cc$") list(APPEND xpu_cc_srcs ${src}) elseif(${src} MATCHES ".*\\.cc$") list(APPEND cc_srcs ${src}) diff --git a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc index 92e3933a072..c661c9f9c37 100644 --- a/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc +++ b/paddle/fluid/memory/allocation/naive_best_fit_allocator.cc @@ -127,11 +127,10 @@ void *Alloc(const platform::XPUPlace &place, size_t size) { "Baidu Kunlun Card is properly installed.", ret)); ret = xpu_malloc(reinterpret_cast(&p), size); - PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, - platform::errors::External( - "XPU API return wrong value[%d], please check whether " - "Baidu Kunlun Card is properly installed.", - ret)); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], no enough memory", ret)); if (FLAGS_init_allocated_mem) { PADDLE_THROW(platform::errors::Unimplemented( "xpu memory FLAGS_init_allocated_mem is not implemented.")); diff --git a/paddle/fluid/operators/activation_op_xpu.cc b/paddle/fluid/operators/activation_op_xpu.cc new file mode 100644 index 00000000000..49b7a08a7b5 --- /dev/null +++ b/paddle/fluid/operators/activation_op_xpu.cc @@ -0,0 +1,179 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/activation_op.h" +#include +#include "paddle/fluid/platform/xpu_header.h" + +namespace paddle { +namespace operators { + +using paddle::framework::Tensor; + +template +class XPUActivationKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto &attr : attrs) { + *attr.second = context.Attr(attr.first); + } + functor(context); + } +}; + +template +class XPUActivationGradKernel + : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + Functor functor; + + auto attrs = functor.GetAttrs(); + for (auto &attr : attrs) { + *attr.second = context.Attr(attr.first); + } + functor(context); + } +}; + +template +void xpu_activation_forward(const framework::ExecutionContext &ctx, + xpu::Activation_t type) { + const auto *x = ctx.Input("X"); + auto *y = ctx.Output("Out"); + const T *x_data = x->data(); + T *y_data = y->mutable_data(ctx.GetPlace()); + int r = 0; + if (xpu::Activation_t::ACT_POW == type.type) { + type.pow_factor = ctx.Attr("factor"); + } + auto xpu_context = ctx.device_context().x_context(); + r = xpu::activation_forward(xpu_context, type, x->numel(), + reinterpret_cast(x_data), + reinterpret_cast(y_data)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); +} + +template +void xpu_activation_backward(const framework::ExecutionContext &ctx, + xpu::Activation_t type) { + /* TODO: relu tanh sigmoid are inplace */ + const auto *x = ctx.Input("X"); + auto *y = ctx.Input("Out"); + auto *dOut = ctx.Input(framework::GradVarName("Out")); + auto *dX = ctx.Output(framework::GradVarName("X")); + const T *x_data = nullptr; + const T *y_data = nullptr; + const T *y_grad = nullptr; + if (x != nullptr) x_data = x->data(); + if (y != nullptr) y_data = y->data(); + if (dOut != nullptr) y_grad = dOut->data(); + T *x_grad = dX->mutable_data(ctx.GetPlace()); + auto xpu_context = ctx.device_context().x_context(); + int r = xpu::activation_backward(xpu_context, type, dX->numel(), + reinterpret_cast(x_data), + reinterpret_cast(y_data), + reinterpret_cast(y_grad), + reinterpret_cast(x_grad)); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); +} + +template +struct XPUActivationFunc : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_forward(ctx, + algorithm); + } +}; + +template +struct XPUActivationGradFunc : public BaseActivationFunctor { + void operator()(const framework::ExecutionContext &ctx) const { + xpu_activation_backward(ctx, + algorithm); + } +}; + +template +using XPUReluFunctor = XPUActivationFunc; +template +using XPUSigmoidFunctor = XPUActivationFunc; +template +using XPUTanhFunctor = XPUActivationFunc; +template +using XPUGeluFunctor = XPUActivationFunc; +template +using XPULogFunctor = XPUActivationFunc; +template +using XPUSquareFunctor = XPUActivationFunc; +template +using XPUSuareGradFunctor = XPUActivationGradFunc; +template +using XPUReluGradFunctor = XPUActivationGradFunc; +template +using XPUSigmoidGradFunctor = + XPUActivationGradFunc; +template +using XPUTanhGradFunctor = XPUActivationGradFunc; +template +using XPUGeluGradFunctor = XPUActivationGradFunc; +template +using XPUSqrtFunctor = XPUActivationFunc; +template +using XPUSqrtGradFunctor = XPUActivationGradFunc; +template +using XPUACTPowFunctor = XPUActivationFunc; +template +using XPUABSFunctor = XPUActivationFunc; +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +#define REGISTER_ACTIVATION_XPU_KERNEL(act_type, functor, grad_functor) \ + REGISTER_OP_XPU_KERNEL(act_type, \ + ops::XPUActivationKernel>); \ + REGISTER_OP_XPU_KERNEL( \ + act_type##_grad, \ + ops::XPUActivationGradKernel>); + +REGISTER_ACTIVATION_XPU_KERNEL(relu, XPUReluFunctor, XPUReluGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(tanh, XPUTanhFunctor, XPUTanhGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(sigmoid, XPUSigmoidFunctor, + XPUSigmoidGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(gelu, XPUGeluFunctor, XPUGeluGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(sqrt, XPUSqrtFunctor, XPUSqrtGradFunctor) +REGISTER_ACTIVATION_XPU_KERNEL(square, XPUSquareFunctor, XPUSuareGradFunctor) +REGISTER_OP_XPU_KERNEL(log, + ops::XPUActivationKernel>); +REGISTER_OP_XPU_KERNEL(pow, + ops::XPUActivationKernel>); +REGISTER_OP_XPU_KERNEL(abs, + ops::XPUActivationKernel>); + +#endif // PADDLE_WITH_XPU diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc new file mode 100644 index 00000000000..9ff7a71d7f0 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_xpu.cc @@ -0,0 +1,162 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/elementwise/elementwise_add_op.h" +#include +#include +#include "paddle/fluid/operators/elementwise/elementwise_op.h" + +#include "paddle/fluid/operators/elementwise/elementwise_xpu.h" + +namespace paddle { +namespace operators { + +template +class ElementwiseAddXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + XPUElementwise>(ctx); + } +}; + +template +class ElementwiseAddGradXPUKernel : public ElemwiseGradKernel { + public: + void Compute(const framework::ExecutionContext &ctx) const override { + ElemwiseGradKernel::Compute(ctx); + using Tensor = framework::Tensor; + + auto *dout = ctx.Input(framework::GradVarName("Out")); + auto *dx = ctx.Output(framework::GradVarName("X")); + auto *dy = ctx.Output(framework::GradVarName("Y")); + + auto dx_dims = dout->dims(); + auto dy_dims_untrimed = dout->dims(); + T *dx_data = NULL; + T *dy_data = NULL; + + int axis = ctx.Attr("axis"); + PADDLE_ENFORCE_GE(dx_dims.size(), dy_dims_untrimed.size(), + "Rank of first input must >= rank of second input."); + + if (dx != nullptr) { + dx->mutable_data(ctx.GetPlace()); + dx_dims = dx->dims(); + dx_data = dx->data(); + } + + if (dy != nullptr) { + dy->mutable_data(ctx.GetPlace()); + dy_dims_untrimed = dy->dims(); + dy_data = dy->data(); + } + + int pre, n, post, is_common_broadcast; + if (dx_dims == dy_dims_untrimed) { + pre = post = 1; + n = dout->numel(); + } else { + axis = (axis == -1 ? dx_dims.size() - dy_dims_untrimed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < dx_dims.size(), + "Axis should be in range [0, dx_dims)"); + auto dy_dims = trim_trailing_singular_dims(dy_dims_untrimed); + axis = (dy_dims.size() == 0) ? dx_dims.size() : axis; + get_mid_dims(dx_dims, dy_dims, axis, &pre, &n, &post, + &is_common_broadcast); + } + int len = pre * n * post; + + auto &dev_ctx = + ctx.template device_context(); + if (post == 1) { + int r = xpu::matrix_vector_add_grad( + dev_ctx.x_context(), dout->data(), dout->data(), + dout->data(), dout->data(), dx_data, dy_data, pre, n); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + return; + } + + if (dx == nullptr) { + PADDLE_ENFORCE_EQ( + xpu_malloc(reinterpret_cast(&dx_data), len * sizeof(float)), + XPU_SUCCESS, platform::errors::External("XPU has no enough memory")); + } + + if (dy == nullptr) { + PADDLE_ENFORCE_EQ( + xpu_malloc(reinterpret_cast(&dy_data), len * sizeof(float)), + XPU_SUCCESS, platform::errors::External("XPU has no enough memory")); + } else { + if (len != n) { + PADDLE_ENFORCE_EQ(xpu_malloc(reinterpret_cast(&dy_data), + len * sizeof(float)), + XPU_SUCCESS, platform::errors::External( + "XPU has no enough memory")); + } + } + + int r = xpu::elementwise_add_grad( + dev_ctx.x_context(), dout->data() /*x*/, dout->data() /*y*/, + dout->data() /*out*/, dout->data(), dx_data, dy_data, len); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + + if ((dy != nullptr) && (len != n)) { + r = xpu::reduce_ew(dev_ctx.x_context(), dy_data, dy->data(), pre, n, + post, xpu::ElementwiseOp::ASSIGN); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + dev_ctx.Wait(); + xpu_free(dy_data); + } + + if ((dx == nullptr || dy == nullptr) && !(dy != nullptr && len != n)) { + dev_ctx.Wait(); + } + + if (dx == nullptr) { + xpu_free(dx_data); + } + if (dy == nullptr) { + xpu_free(dy_data); + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + elementwise_add, + ops::ElementwiseAddXPUKernel); +REGISTER_OP_XPU_KERNEL(elementwise_add_grad, + ops::ElementwiseAddGradXPUKernel< + paddle::platform::XPUDeviceContext, float>); +#endif diff --git a/paddle/fluid/operators/elementwise/elementwise_xpu.h b/paddle/fluid/operators/elementwise/elementwise_xpu.h new file mode 100644 index 00000000000..53c4332e919 --- /dev/null +++ b/paddle/fluid/operators/elementwise/elementwise_xpu.h @@ -0,0 +1,113 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ +#pragma once +#ifdef PADDLE_WITH_XPU +#include "paddle/fluid/framework/tensor.h" +#include "paddle/fluid/platform/place.h" + +namespace paddle { +namespace operators { + +template +struct XPUAddFunctor { + int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) { + return xpu::elementwise_add(ctx, x, y, z, len); + } +}; + +template +struct XPUMulFunctor { + int operator()(xpu::Context* ctx, const T* x, const T* y, T* z, int len) { + return xpu::elementwise_mul(ctx, x, y, z, len); + } +}; + +template +void XPUElementwise(const framework::ExecutionContext& ctx) { + PADDLE_ENFORCE(platform::is_xpu_place(ctx.GetPlace()), + "This kernel only runs on XPU device."); + auto x_var = ctx.InputVar("X"); + PADDLE_ENFORCE_NE(x_var, nullptr, + platform::errors::Fatal("Cannot get input Variable X")); + PADDLE_ENFORCE(x_var->IsType(), + "XPU only support LoDTensor"); + + auto x = x_var->Get(); + auto* y = ctx.Input("Y"); + auto* z = ctx.Output("Out"); + z->mutable_data(ctx.GetPlace()); + + int axis = ctx.Attr("axis"); + auto x_dims = x.dims(); + auto y_dims_untrimed = y->dims(); + PADDLE_ENFORCE_GE(x_dims.size(), y_dims_untrimed.size(), + "Rank of first input must >= rank of second input."); + axis = (axis == -1 ? x_dims.size() - y_dims_untrimed.size() : axis); + PADDLE_ENFORCE(axis >= 0 && axis < x_dims.size(), + "Axis should be in range [0, x_dims)"); + auto y_dims = trim_trailing_singular_dims(y_dims_untrimed); + axis = (y_dims.size() == 0) ? x_dims.size() : axis; + int pre, n, post, is_common_broadcast; + get_mid_dims(x_dims, y_dims, axis, &pre, &n, &post, &is_common_broadcast); + int len = pre * n * post; + + const T* x_data = x.data(); + const T* y_data = y->data(); + T* z_data = z->data(); + T* y_broadcast = nullptr; + + auto& dev_ctx = + ctx.template device_context(); + + if (post == 1) { + if (std::is_same>::value) { + int res = xpu::matrix_vector_add(dev_ctx.x_context(), x_data, y_data, + z_data, pre, n); + PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", + res); + return; + } + if (std::is_same>::value) { + int res = xpu::matrix_vector_mul(dev_ctx.x_context(), x_data, y_data, + z_data, pre, n); + PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", + res); + return; + } + } + + if (pre != 1 || post != 1) { + PADDLE_ENFORCE(xpu_malloc(reinterpret_cast(&y_broadcast), + len * sizeof(T)) == XPU_SUCCESS); + int res = xpu::broadcast_ew(dev_ctx.x_context(), y_data, y_broadcast, pre, + n, post, xpu::ElementwiseOp::ASSIGN); + PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", + res); + y_data = y_broadcast; + } + + Functor functor; + int res = functor(dev_ctx.x_context(), x_data, y_data, z_data, len); + PADDLE_ENFORCE(res == xpu::Error_t::SUCCESS, "XPU kernel error! res = %d", + res); + + if (pre != 1 || post != 1) { + dev_ctx.Wait(); + xpu_free(y_broadcast); + } +} + +} // namespace operators +} // namespace paddle +#endif diff --git a/paddle/fluid/operators/matmul_op_xpu.cc b/paddle/fluid/operators/matmul_op_xpu.cc new file mode 100644 index 00000000000..ff038d7ef12 --- /dev/null +++ b/paddle/fluid/operators/matmul_op_xpu.cc @@ -0,0 +1,343 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifdef PADDLE_WITH_XPU + +#include +#include +#include +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/blas.h" + +namespace paddle { +namespace operators { + +static framework::DDim RowMatrixFromVector(const framework::DDim &x_dim) { + if (x_dim.size() > 1) { + return x_dim; + } + return framework::make_ddim({1, x_dim[0]}); +} + +static framework::Tensor FoldInitDims(const framework::Tensor &input) { + auto output = input; + auto in_dims = input.dims(); + if (in_dims.size() == 3) { + output.Resize({in_dims[0] * in_dims[1], in_dims[2]}); + } + return output; +} +/** + * Get column matrix shape from a vector shape. If the ran of y_dim > 1, the + * original y_dim is returned. + */ +static framework::DDim ColumnMatrixFromVector(const framework::DDim &y_dim) { + if (y_dim.size() > 1) { + return y_dim; + } + return framework::make_ddim({y_dim[0], 1}); +} + +static void ReshapeTensorIntoMatrixSequence( + framework::Tensor *x, const math::MatDescriptor &descriptor) { + int64_t h, w; + h = descriptor.height_; + w = descriptor.width_; + if (descriptor.trans_) { + std::swap(w, h); + } + if (descriptor.batch_size_) { + x->Resize({descriptor.batch_size_, h, w}); + } else { + x->Resize({h, w}); + } +} +/** + * Reshape the x,y,out tensor to 3-D or 2-D tensor by matrix descriptor + * Out = matmul(x, y) + * + * This method will first calculate X,Y matrix sequence, and then calculate + * the out shape. + * + * Assume X = [BatchSize, H1, W1], Y = [BatchSize, H2, W2] + * The out = [BatchSize, H1, W2] + * + * If there is no batch size in `X` and `Y`, the out will be [H1, W2] + * If any of `X` and `Y` has batch size BatchSize, the out will have the + * BatchSize. + */ +static void ReshapeXYOutIntoMatrixSequence(framework::Tensor *x, + framework::Tensor *y, + framework::Tensor *out, bool trans_x, + bool trans_y) { + auto x_dim = RowMatrixFromVector(x->dims()); + auto y_dim = ColumnMatrixFromVector(y->dims()); + auto mat_dim_x = math::CreateMatrixDescriptor(x_dim, 0, trans_x); + auto mat_dim_y = math::CreateMatrixDescriptor(y_dim, 0, trans_y); + if (mat_dim_x.batch_size_ == 0 && mat_dim_y.batch_size_ == 0) { + out->Resize({mat_dim_x.height_, mat_dim_y.width_}); + } else { + out->Resize({std::max(mat_dim_x.batch_size_, mat_dim_y.batch_size_), + mat_dim_x.height_, mat_dim_y.width_}); + } + + ReshapeTensorIntoMatrixSequence(x, mat_dim_x); + ReshapeTensorIntoMatrixSequence(y, mat_dim_y); +} + +template +class MatMulXPUKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext &context) const override { + auto *x = context.Input("X"); + auto *y = context.Input("Y"); + auto *out = context.Output("Out"); + out->mutable_data(context.GetPlace()); + + auto mat_dim_a = math::CreateMatrixDescriptor( + RowMatrixFromVector(x->dims()), 0, context.Attr("transpose_X")); + auto mat_dim_b = + math::CreateMatrixDescriptor(ColumnMatrixFromVector(y->dims()), 0, + context.Attr("transpose_Y")); + PADDLE_ENFORCE_EQ( + mat_dim_a.width_, mat_dim_b.height_, + platform::errors::InvalidArgument("Shape mistake in matmul_op")); + PADDLE_ENFORCE_EQ( + mat_dim_a.batch_size_, mat_dim_b.batch_size_, + platform::errors::InvalidArgument("Shape mistake in matmul_op")); + T alpha = static_cast(context.Attr("alpha")); + + auto &dev_ctx = context.template device_context(); + float *data_c = out->data(); + if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) { + int r = + xpu::fc_int16(dev_ctx.x_context(), mat_dim_a.trans_, mat_dim_b.trans_, + mat_dim_a.height_, mat_dim_b.width_, mat_dim_a.width_, + alpha, x->data(), y->data(), 0.0f, data_c); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } else { + // batch matmul + int r = xpu::batched_gemm_int16(dev_ctx.x_context(), mat_dim_a.trans_, + mat_dim_b.trans_, mat_dim_a.batch_size_, + mat_dim_a.height_, mat_dim_b.width_, + mat_dim_a.width_, alpha, x->data(), + y->data(), data_c, nullptr, nullptr); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } + } +}; + +// Reshape a rank-3 tensor from P x M x N to M x (P * N). +// (Warning: This requires transposing data and writes into new memory.) +// Identity op if the tensor is not of rank 3. +template +static framework::Tensor XPUFoldHeadAndLastDims( + const DeviceContext &context, const framework::Tensor &input) { + auto in_dims = input.dims(); + if (in_dims.size() != 3) { + return input; + } + + framework::Tensor output; + output.Resize({in_dims[1], in_dims[0], in_dims[2]}); + output.mutable_data(context.GetPlace()); + std::vector in_shape_host = {static_cast(in_dims[0]), + static_cast(in_dims[1]), + static_cast(in_dims[2])}; + std::vector axis_host = {1, 0, 2}; + + int r = xpu::transpose(context.x_context(), input.data(), output.data(), + in_shape_host.data(), axis_host.data(), /*ndims=*/3); + PADDLE_ENFORCE_EQ(r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + output.Resize({in_dims[1], in_dims[0] * in_dims[2]}); + + return output; +} + +// Using dimensional constraints on matrix multiplication, it is +// straight-forward to check the following table for when X and Y +// are both matrices. +// +// transpose_X | False | True | False | True +// transpose_Y | False | False | True | True +// -----------+----------+----------+----------+----------- +// dX = | dOut Y^T | Y dOut^T | dOut Y | Y^T dOut^T +// dY = | X^T dOut | X dOut | dOut^T X | dOut^T X^T +// +// When X is a vector of size K, we treat it instead as a matrix of shape +// (1, K). Similarly, when Y is a vector of size K, we treat it instead as +// a matrix of shape (K, 1). +// +// When X and Y are both 3-dimensional tensors, then the first dimension +// the batch dimension can be ignored and the exact same formulas apply +// as for two matrices. +// +// Finally, when, e.g., X is a 3-dimensional tensor but Y is a matrix, we end +// up with formulas like +// +// dY_{ij} = \sum_{p, m} X_{pmi} dOut_{pmj} +// +// To handle this sort of scenario, we reshape X : P x M x K, dOut: P x M x N +// to X: (P * M) x K, dOut: (P * M) x N. +template +class MatMulGradXPUKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + const framework::Tensor &b, bool trans_b, + framework::Tensor *out) const { + out->mutable_data(context.GetPlace()); + auto mat_dim_a = math::CreateMatrixDescriptor(a.dims(), 0, trans_a); + auto mat_dim_b = math::CreateMatrixDescriptor(b.dims(), 0, trans_b); + PADDLE_ENFORCE_EQ( + mat_dim_a.width_, mat_dim_b.height_, + platform::errors::InvalidArgument("Shape mistake in matmul_grad_op")); + PADDLE_ENFORCE_EQ( + mat_dim_a.batch_size_, mat_dim_b.batch_size_, + platform::errors::InvalidArgument("Shape mistake in matmul_grad_op")); + T alpha = static_cast(context.Attr("alpha")); + + auto &dev_ctx = context.template device_context(); + float *data_c = out->data(); + if (mat_dim_a.batch_size_ == 0 || mat_dim_a.batch_size_ == 1) { + int r = + xpu::fc_int16(dev_ctx.x_context(), mat_dim_a.trans_, mat_dim_b.trans_, + mat_dim_a.height_, mat_dim_b.width_, mat_dim_a.width_, + alpha, a.data(), b.data(), 0.0f, data_c); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } else { + // batch matmul + int r = xpu::batched_gemm_int16(dev_ctx.x_context(), mat_dim_a.trans_, + mat_dim_b.trans_, mat_dim_a.batch_size_, + mat_dim_a.height_, mat_dim_b.width_, + mat_dim_a.width_, alpha, a.data(), + b.data(), data_c, nullptr, nullptr); + PADDLE_ENFORCE_EQ( + r, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d], please check whether " + "Baidu Kunlun Card is properly installed.", + r)); + } + } + + void CalcInputGrad(const framework::ExecutionContext &context, + const framework::Tensor &a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor &b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor *out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out); + } else { + auto &dev_ctx = context.template device_context(); + MatMul( + context, is_fold_init_dims_a + ? FoldInitDims(a) + : XPUFoldHeadAndLastDims(dev_ctx, a), + trans_a, is_fold_init_dims_b + ? FoldInitDims(b) + : XPUFoldHeadAndLastDims(dev_ctx, b), + trans_b, out); + } + } + + void Compute(const framework::ExecutionContext &context) const override { + auto x = *context.Input("X"); + auto y = *context.Input("Y"); + auto dout = + *context.Input(framework::GradVarName("Out")); + auto *dx = context.Output(framework::GradVarName("X")); + auto *dy = context.Output(framework::GradVarName("Y")); + bool transpose_x = context.Attr("transpose_X"); + bool transpose_y = context.Attr("transpose_Y"); + + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + + if (transpose_x && transpose_y) { + CalcInputGrad(context, y, true, true, dout, true, false, dx); + CalcInputGrad(context, dout, true, true, x, true, false, dy); + } else if (transpose_x) { + CalcInputGrad(context, y, false, false, dout, true, false, dx); + CalcInputGrad(context, x, false, false, dout, false, true, dy); + } else if (transpose_y) { + CalcInputGrad(context, dout, false, false, y, false, true, dx); + CalcInputGrad(context, dout, true, true, x, false, true, dy); + } else { + CalcInputGrad(context, dout, false, false, y, true, false, dx); + CalcInputGrad(context, x, true, true, dout, false, true, dy); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL( + matmul, ops::MatMulXPUKernel); +REGISTER_OP_XPU_KERNEL( + matmul_grad, + ops::MatMulGradXPUKernel); +#endif diff --git a/paddle/fluid/operators/xpu/mul_xpu_op.cc b/paddle/fluid/operators/mul_op_xpu.cc similarity index 100% rename from paddle/fluid/operators/xpu/mul_xpu_op.cc rename to paddle/fluid/operators/mul_op_xpu.cc index 79aae71c304..0c8469101ab 100644 --- a/paddle/fluid/operators/xpu/mul_xpu_op.cc +++ b/paddle/fluid/operators/mul_op_xpu.cc @@ -14,11 +14,11 @@ limitations under the License. */ #ifdef PADDLE_WITH_XPU +#include "paddle/fluid/operators/mul_op.h" #include #include #include #include -#include "paddle/fluid/operators/mul_op.h" namespace paddle { namespace operators { diff --git a/paddle/fluid/platform/init_test.cc b/paddle/fluid/platform/init_test.cc index f14fbdd74f9..f1832206a1a 100644 --- a/paddle/fluid/platform/init_test.cc +++ b/paddle/fluid/platform/init_test.cc @@ -15,6 +15,7 @@ limitations under the License. */ #include "paddle/fluid/platform/device_context.h" #include "paddle/fluid/platform/init.h" +#include "paddle/fluid/platform/xpu_info.h" TEST(InitDevices, CPU) { using paddle::framework::InitDevices; diff --git a/paddle/fluid/platform/xpu_header.h b/paddle/fluid/platform/xpu_header.h index d8c5f85f9cf..95e4979951d 100644 --- a/paddle/fluid/platform/xpu_header.h +++ b/paddle/fluid/platform/xpu_header.h @@ -15,9 +15,36 @@ #pragma once #ifdef PADDLE_WITH_XPU +#include +#include + +#include "paddle/fluid/platform/errors.h" #include "xpu/api.h" #include "xpu/runtime.h" #include "xpu/runtime_ex.h" namespace xpu = baidu::xpu::api; + +class XPUActHelper { + public: + // Convert string to activation type in xpu + static xpu::Activation_t ConvertToXpuActType( + const std::string& act_type_str) { + static std::unordered_map str2act = { + {"linear", xpu::Activation_t::LINEAR}, + {"relu", xpu::Activation_t::RELU}, + {"sigmoid", xpu::Activation_t::SIGMOID}, + {"tanh", xpu::Activation_t::TANH}, + {"gelu", xpu::Activation_t::GELU}, + {"leaky_relu", xpu::Activation_t::LEAKY_RELU}, + {"sqrt", xpu::Activation_t::SQRT}, + {"square", xpu::Activation_t::SQUARE}}; + + auto res = str2act.find(act_type_str); + PADDLE_ENFORCE_NE(res, str2act.end(), + paddle::platform::errors::InvalidArgument( + "Invalid activation type(%s) in XPU", act_type_str)); + return res->second; + } +}; #endif diff --git a/python/paddle/__init__.py b/python/paddle/__init__.py index 40275a2ce71..e707de8e068 100755 --- a/python/paddle/__init__.py +++ b/python/paddle/__init__.py @@ -257,6 +257,8 @@ from .tensor.stat import numel #DEFINE_ALIAS from .device import get_cudnn_version from .device import set_device from .device import get_device +from .device import is_compiled_with_xpu +from .device import XPUPlace # from .tensor.tensor import Tensor #DEFINE_ALIAS # from .tensor.tensor import LoDTensor #DEFINE_ALIAS # from .tensor.tensor import LoDTensorArray #DEFINE_ALIAS diff --git a/python/paddle/device.py b/python/paddle/device.py index de24fd87513..46d0ff7bedc 100644 --- a/python/paddle/device.py +++ b/python/paddle/device.py @@ -22,7 +22,9 @@ from paddle.fluid.dygraph.parallel import ParallelEnv __all__ = [ 'get_cudnn_version', 'set_device', - 'get_device' + 'get_device', + 'XPUPlace', + 'is_compiled_with_xpu' # 'cpu_places', # 'CPUPlace', # 'cuda_pinned_places', @@ -35,6 +37,37 @@ __all__ = [ _cudnn_version = None +def is_compiled_with_xpu(): + """ + Whether paddle was built with WITH_XPU=ON to support Baidu Kunlun + + Returns (bool): whether paddle was built with WITH_XPU=ON + + Examples: + .. code-block:: python + + import paddle + support_xpu = paddle.device.is_compiled_with_xpu() + """ + return core.is_compiled_with_xpu() + + +def XPUPlace(dev_id): + """ + Return a Baidu Kunlun Place + + Parameters: + dev_id(int): Baidu Kunlun device id + + Examples: + .. code-block:: python + + import paddle + place = paddle.device.XPUPlace(0) + """ + return core.XPUPlace(dev_id) + + def get_cudnn_version(): """ This funciton return the version of cudnn. the retuen value is int which represents the diff --git a/python/paddle/fluid/tests/unittests/op_test.py b/python/paddle/fluid/tests/unittests/op_test.py index d02fdafe995..96efc36ed0a 100644 --- a/python/paddle/fluid/tests/unittests/op_test.py +++ b/python/paddle/fluid/tests/unittests/op_test.py @@ -217,6 +217,9 @@ class OpTest(unittest.TestCase): return False return True + def is_xpu_op_test(): + return hasattr(cls, "use_xpu") and cls.use_xpu == True + def is_mkldnn_op_test(): return hasattr(cls, "use_mkldnn") and cls.use_mkldnn == True @@ -239,6 +242,7 @@ class OpTest(unittest.TestCase): if cls.dtype in [np.float32, np.float64] \ and cls.op_type not in op_accuracy_white_list.NO_FP64_CHECK_GRAD_OP_LIST \ and not hasattr(cls, 'exist_fp64_check_grad') \ + and not is_xpu_op_test() \ and not is_mkldnn_op_test(): raise AssertionError( "This test of %s op needs check_grad with fp64 precision." % @@ -336,6 +340,11 @@ class OpTest(unittest.TestCase): self.attrs["use_mkldnn"] == True): self.__class__.use_mkldnn = True + if (hasattr(self, "use_xpu") and self.use_xpu == True) or \ + (hasattr(self, "attrs") and "use_xpu" in self.attrs and \ + self.attrs["use_xpu"] == True): + self.__class__.use_xpu = True + op_proto = OpProtoHolder.instance().get_op_proto(self.op_type) "infer datatype from inputs and outputs for this test case" self.infer_dtype_from_inputs_outputs(self.inputs, self.outputs) @@ -932,6 +941,8 @@ class OpTest(unittest.TestCase): need_run_ops = self._get_need_run_ops(op_desc) res = {} + if hasattr(self, 'attrs') and bool(self.attrs.get('use_xpu', False)): + return for op_desc, father_op_desc in reversed(need_run_ops): # The first one is the forward op has_infer_inplace = fluid.core.has_infer_inplace(op_desc.type()) @@ -1203,6 +1214,11 @@ class OpTest(unittest.TestCase): self.attrs["use_mkldnn"] == True): self.__class__.use_mkldnn = True + if (hasattr(self, "use_xpu") and self.use_xpu == True) or \ + (hasattr(self, "attrs") and "use_xpu" in self.attrs and \ + self.attrs["use_xpu"] == True): + self.__class__.use_xpu = True + places = self._get_places() for place in places: res = self.check_output_with_place(place, atol, no_check_set, diff --git a/python/paddle/fluid/tests/unittests/test_matmul_op.py b/python/paddle/fluid/tests/unittests/test_matmul_op.py index 3eb822bfed8..2d5f098a7fe 100644 --- a/python/paddle/fluid/tests/unittests/test_matmul_op.py +++ b/python/paddle/fluid/tests/unittests/test_matmul_op.py @@ -14,6 +14,7 @@ from __future__ import print_function +import paddle.fluid.core as core import unittest import numpy as np from op_test import OpTest diff --git a/python/paddle/fluid/tests/unittests/test_mul_op.py b/python/paddle/fluid/tests/unittests/test_mul_op.py index 5f223de1954..927383c1223 100644 --- a/python/paddle/fluid/tests/unittests/test_mul_op.py +++ b/python/paddle/fluid/tests/unittests/test_mul_op.py @@ -18,6 +18,8 @@ import unittest import numpy as np import paddle import paddle.fluid.core as core +import sys +sys.path.append("..") from op_test import OpTest import paddle.fluid as fluid from paddle.fluid import Program, program_guard @@ -175,57 +177,5 @@ class TestFP16MulOp2(TestMulOp2): no_grad_set=set('Y')) -@unittest.skipIf(not core.is_compiled_with_xpu(), - "core is not compiled with XPU") -class TestXPUMulOp1(TestMulOp): - def init_dtype_type(self): - self.dtype = np.float32 - - def test_check_output(self): - place = core.XPUPlace(0) - self.check_output_with_place(place, atol=1e-1) - - def test_check_grad_normal(self): - place = core.XPUPlace(0) - self.check_grad_with_place( - place, ['X', 'Y'], 'Out', max_relative_error=0.5) - - def test_check_grad_ingore_x(self): - place = core.XPUPlace(0) - self.check_grad_with_place( - place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X")) - - def test_check_grad_ingore_y(self): - place = core.XPUPlace(0) - self.check_grad_with_place( - place, ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) - - -@unittest.skipIf(not core.is_compiled_with_xpu(), - "core is not compiled with XPU") -class TestXPUMulOp2(TestMulOp2): - def init_dtype_type(self): - self.dtype = np.float32 - - def test_check_output(self): - place = core.XPUPlace(0) - self.check_output_with_place(place, atol=2e-1) - - def test_check_grad_normal(self): - place = core.XPUPlace(0) - self.check_grad_with_place( - place, ['X', 'Y'], 'Out', max_relative_error=0.9) - - def test_check_grad_ingore_x(self): - place = core.XPUPlace(0) - self.check_grad_with_place( - place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X")) - - def test_check_grad_ingore_y(self): - place = core.XPUPlace(0) - self.check_grad_with_place( - place, ['X'], 'Out', max_relative_error=0.9, no_grad_set=set('Y')) - - if __name__ == "__main__": unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_activation_op.py b/python/paddle/fluid/tests/unittests/xpu/test_activation_op.py new file mode 100755 index 00000000000..788c110a592 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_activation_op.py @@ -0,0 +1,215 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle.fluid.core as core +from op_test import OpTest +from scipy.special import expit, erf +import paddle +import paddle.fluid as fluid +import paddle.nn as nn +import paddle.nn.functional as F +from paddle.fluid import compiler, Program, program_guard + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUActivation(OpTest): + def setUp(self): + self.op_type = "exp" + self.init_dtype() + self.init_kernel_type() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.exp(x) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def init_dtype(self): + self.dtype = np.float32 + + def test_check_output(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def init_kernel_type(self): + pass + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUSigmoid(TestXPUActivation): + def setUp(self): + self.op_type = "sigmoid" + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = 1 / (1 + np.exp(-x)) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + def test_check_grad(self): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=0.01) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUTanh(TestXPUActivation): + def setUp(self): + self.op_type = "tanh" + self.init_dtype() + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.tanh(x) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUSqrt(TestXPUActivation): + def setUp(self): + self.op_type = "sqrt" + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.sqrt(x) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUAbs(TestXPUActivation): + def setUp(self): + self.op_type = "abs" + self.init_dtype() + + x = np.random.uniform(-1, 1, [4, 25]).astype(self.dtype) + # Because we set delta = 0.005 in calculating numeric gradient, + # if x is too small, such as 0.002, x_neg will be -0.003 + # x_pos will be 0.007, so the numeric gradient is inaccurate. + # we should avoid this + x[np.abs(x) < 0.005] = 0.02 + out = np.abs(x) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPURelu(TestXPUActivation): + def setUp(self): + self.op_type = "relu" + self.init_dtype() + + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + # The same reason with TestAbs + x[np.abs(x) < 0.005] = 0.02 + out = np.maximum(x, 0) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': x} + self.outputs = {'Out': out} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUGelu(TestXPUActivation): + def setUp(self): + self.op_type = "gelu" + self.init_dtype() + approximate = False + x = np.random.uniform(-1, 1, [11, 17]).astype(self.dtype) + out = gelu(x, approximate) + + self.inputs = {'X': x} + self.outputs = {'Out': out} + self.attrs = {"approximate": approximate, 'use_xpu': True} + + +def gelu(x, approximate): + if approximate: + y_ref = 0.5 * x * (1.0 + np.tanh( + np.sqrt(2 / np.pi) * (x + 0.044715 * np.power(x, 3)))) + else: + y_ref = 0.5 * x * (1 + erf(x / np.sqrt(2))) + return y_ref.astype(x.dtype) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPULog(TestXPUActivation): + def setUp(self): + self.op_type = "log" + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.log(x) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUSquare(TestXPUActivation): + def setUp(self): + self.op_type = "square" + self.init_dtype() + + x = np.random.uniform(0.1, 1, [11, 17]).astype(self.dtype) + out = np.square(x) + + self.attrs = {'use_xpu': True} + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.outputs = {'Out': out} + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUPow(TestXPUActivation): + def setUp(self): + self.op_type = "pow" + self.init_dtype() + + x = np.random.uniform(1, 2, [11, 17]).astype(self.dtype) + out = np.power(x, 3) + + self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(x)} + self.attrs = {'factor': 3.0, 'use_xpu': True} + self.outputs = {'Out': out} + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_elementwise_add_op.py b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_add_op.py new file mode 100644 index 00000000000..9c6e7d21c1a --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_elementwise_add_op.py @@ -0,0 +1,346 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function +import sys +sys.path.append("..") +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +from op_test import OpTest, skip_check_grad_ci +import paddle.fluid as fluid +from paddle.fluid import compiler, Program, program_guard + + +class TestElementwiseAddOp(OpTest): + def init_kernel_type(self): + self.use_mkldnn = False + + def setUp(self): + self.op_type = "elementwise_add" + self.init_dtype() + self.init_input_output() + self.init_kernel_type() + self.init_axis() + + self.inputs = { + 'X': OpTest.np_dtype_to_fluid_dtype(self.x), + 'Y': OpTest.np_dtype_to_fluid_dtype(self.y) + } + self.attrs = {'axis': self.axis, 'use_mkldnn': self.use_mkldnn} + self.outputs = {'Out': self.out} + + def test_check_output(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + self.check_output(check_dygraph=(self.use_mkldnn == False)) + + def test_check_grad_normal(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if self.dtype == np.float16: + return + self.check_grad( + ['X', 'Y'], 'Out', check_dygraph=(self.use_mkldnn == False)) + + def test_check_grad_ingore_x(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if self.dtype == np.float16: + return + self.check_grad( + ['Y'], + 'Out', + no_grad_set=set("X"), + check_dygraph=(self.use_mkldnn == False)) + + def test_check_grad_ingore_y(self): + # TODO(wangzhongpu): support mkldnn op in dygraph mode + if self.dtype == np.float16: + return + self.check_grad( + ['X'], + 'Out', + no_grad_set=set('Y'), + check_dygraph=(self.use_mkldnn == False)) + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float64 + + def init_axis(self): + self.axis = -1 + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUElementwiseAddOp(OpTest): + def setUp(self): + self.op_type = "elementwise_add" + self.init_dtype() + self.init_input_output() + self.init_axis() + + self.inputs = {'X': self.x, 'Y': self.y} + self.attrs = {'axis': self.axis, 'use_mkldnn': False, 'use_xpu': True} + self.outputs = {'Out': self.out} + + def test_check_output(self): + if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_output_with_place(place) + + def test_check_grad_normal(self): + if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['Y'], 'Out') + + def test_check_grad_ingore_y(self): + if self.dtype == np.float32 and paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + self.check_grad_with_place(place, ['X'], 'Out') + + def init_input_output(self): + self.x = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.y = np.random.uniform(0.1, 1, [13, 17]).astype(self.dtype) + self.out = np.add(self.x, self.y) + + def init_dtype(self): + self.dtype = np.float32 + + def init_axis(self): + self.axis = -1 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseAddOp_scalar(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1,1) to test broadcast.") +class TestElementwiseAddOp_scalar2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 4).astype(self.dtype) + self.y = np.random.rand(1, 1).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_Vector(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.random((100, )).astype(self.dtype) + self.y = np.random.random((100, )).astype(self.dtype) + self.out = np.add(self.x, self.y) + + +class TestElementwiseAddOp_broadcast_0(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(100, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseAddOp_broadcast_1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 100, 3).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 100, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseAddOp_broadcast_2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 100).astype(self.dtype) + self.y = np.random.rand(100).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1, 100) + + +class TestElementwiseAddOp_broadcast_3(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12, 3).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseAddOp_broadcast_4(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3, 4).astype(self.dtype) + self.y = np.random.rand(100, 1).astype(self.dtype) + self.out = self.x + self.y.reshape(100, 1, 1, 1) + + def init_axis(self): + self.axis = 0 + + +class TestElementwiseAddOp_broadcast_5(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 3, 12).astype(self.dtype) + self.y = np.random.rand(10, 1, 12).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_broadcast_6(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 12, 3, 5).astype(self.dtype) + self.y = np.random.rand(2, 12, 1, 5).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_broadcast_7(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(1, 1, 20, 5).astype(self.dtype) + self.y = np.random.rand(20, 5, 1, 1).astype(self.dtype) + self.out = self.x + self.y + + +class TestElementwiseAddOp_rowwise_add_0(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 10, 12).astype(self.dtype) + self.y = np.random.rand(10, 12).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 10, 12) + + def init_axis(self): + self.axis = 1 + + +@skip_check_grad_ci( + reason="[skip shape check] Use y_shape(1) to test broadcast.") +class TestElementwiseAddOp_rowwise_add_1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 1).astype(self.dtype) + self.y = np.random.rand(1).astype(self.dtype) + self.out = self.x + self.y.reshape(1, 1) + + def init_axis(self): + self.axis = 1 + + +class TestElementwiseAddOp_channelwise_add(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(100, 2, 3).astype(self.dtype) + self.y = np.random.rand(100, 1, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseAddOp_commonuse_add1(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(2, 3, 100).astype(self.dtype) + self.y = np.random.rand(1, 1, 100).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseAddOp_commonuse_add2(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 3, 1, 4).astype(self.dtype) + self.y = np.random.rand(10, 1, 12, 1).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = -1 + + +class TestElementwiseAddOp_xsize_lessthan_ysize_add(TestElementwiseAddOp): + def init_input_output(self): + self.x = np.random.rand(10, 12).astype(self.dtype) + self.y = np.random.rand(2, 3, 10, 12).astype(self.dtype) + self.out = self.x + self.y + + def init_axis(self): + self.axis = 2 + + +class TestElementwiseAddOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # the input of elementwise_add must be Variable. + x1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + y1 = fluid.create_lod_tensor( + np.array([-1, 3, 5, 5]), [[1, 1, 1, 1]], fluid.CPUPlace()) + self.assertRaises(TypeError, fluid.layers.elementwise_add, x1, y1) + + # the input dtype of elementwise_add must be float16 or float32 or float64 or int32 or int64 + # float16 only can be set on GPU place + x2 = fluid.layers.data(name='x2', shape=[3, 4, 5, 6], dtype="uint8") + y2 = fluid.layers.data(name='y2', shape=[3, 4, 5, 6], dtype="uint8") + self.assertRaises(TypeError, fluid.layers.elementwise_add, x2, y2) + + +class TestAddOp(unittest.TestCase): + def test_name(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name="x", shape=[2, 3], dtype="float32") + y = fluid.data(name='y', shape=[2, 3], dtype='float32') + + y_1 = paddle.add(x, y, name='add_res') + self.assertEqual(('add_res' in y_1.name), True) + + def test_declarative(self): + with fluid.program_guard(fluid.Program()): + + def gen_data(): + return { + "x": np.array([2, 3, 4]).astype('float32'), + "y": np.array([1, 5, 2]).astype('float32') + } + + x = fluid.data(name="x", shape=[3], dtype='float32') + y = fluid.data(name="y", shape=[3], dtype='float32') + z = paddle.add(x, y) + + place = fluid.CPUPlace() + exe = fluid.Executor(place) + z_value = exe.run(feed=gen_data(), fetch_list=[z.name]) + z_expected = np.array([3., 8., 6.]) + self.assertEqual((z_value == z_expected).all(), True) + + def test_dygraph(self): + with fluid.dygraph.guard(): + np_x = np.array([2, 3, 4]).astype('float64') + np_y = np.array([1, 5, 2]).astype('float64') + x = fluid.dygraph.to_variable(np_x) + y = fluid.dygraph.to_variable(np_y) + z = paddle.add(x, y) + np_z = z.numpy() + z_expected = np.array([3., 8., 6.]) + self.assertEqual((np_z == z_expected).all(), True) + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_op.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_op.py new file mode 100644 index 00000000000..ac32d224910 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_op.py @@ -0,0 +1,355 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import sys +sys.path.append("..") +import paddle.fluid.core as core +import unittest +import numpy as np +from op_test import OpTest +import paddle +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +def generate_compatible_shapes(dim_X, dim_Y, transpose_X, transpose_Y): + BATCH_SIZE = 2 + M = 3 + N = 4 + K = 5 + if (dim_X == 1 and transpose_X) or (dim_Y == 1 and transpose_Y): + K = 1 + if dim_X == 1: + if transpose_X: + shape_X = [M] + else: + shape_X = [K] + if dim_Y == 1: + if transpose_Y: + shape_Y = [N] + else: + shape_Y = [K] + if dim_X >= 2: + if transpose_X: + shape_X = [K, M] + else: + shape_X = [M, K] + if dim_X == 3: + shape_X = [BATCH_SIZE] + shape_X + if dim_Y >= 2: + if transpose_Y: + shape_Y = [N, K] + else: + shape_Y = [K, N] + if dim_Y == 3: + shape_Y = [BATCH_SIZE] + shape_Y + return shape_X, shape_Y + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size, 1)) + elif X.ndim == 2: + X = X.T + else: + dim = [i for i in range(len(X.shape))] + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((1, Y.size)) + else: + dim = [i for i in range(len(Y.shape))] + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + + Out = np.matmul(X, Y) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float32") + return Out + + +class Generator(object): + def setUp(self): + self.op_type = "matmul" + X = np.random.random(self.shape_X).astype("float32") + Y = np.random.random(self.shape_Y).astype("float32") + Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y) + self.inputs = {'X': X, 'Y': Y} + self.attrs = { + 'transpose_X': self.transpose_X, + 'transpose_Y': self.transpose_Y + } + self.outputs = {'Out': Out} + + def test_check_output(self): + self.check_output() + if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( + self.inputs['Y'].shape) and self.inputs['X'].shape[ + 0] == self.inputs['Y'].shape[0]: + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=1e-3) + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out', max_relative_error=1e-3) + if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( + self.inputs['Y'].shape) and self.inputs['X'].shape[ + 0] == self.inputs['Y'].shape[0]: + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', max_relative_error=5e-2) + + def test_check_grad_ignore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=1e-3, no_grad_set=set("X")) + if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( + self.inputs['Y'].shape) and self.inputs['X'].shape[ + 0] == self.inputs['Y'].shape[0]: + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Y'], + 'Out', + max_relative_error=5e-2, + no_grad_set=set("X")) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=1e-3, no_grad_set=set('Y')) + if paddle.is_compiled_with_xpu() and len(self.inputs['X'].shape) == len( + self.inputs['Y'].shape) and self.inputs['X'].shape[ + 0] == self.inputs['Y'].shape[0]: + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], + 'Out', + max_relative_error=5e-2, + no_grad_set=set('Y')) + + +class TestMatmulOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The inputs type of matmul_op must be Variable. + input1 = 12 + self.assertRaises(TypeError, fluid.layers.matmul, input1, input1) + # The inputs dtype of matmul_op must be float32, float64. + input2 = fluid.layers.data( + name='input2', shape=[10, 10], dtype="int32") + self.assertRaises(TypeError, fluid.layers.matmul, input2, input2) + input3 = fluid.layers.data( + name='input3', shape=[2, 2], dtype="float16") + fluid.layers.matmul(input3, input3) + + +# Negative dimension generation +def generate_negative_dims(in_shape): + from itertools import combinations + size = len(in_shape) + indexs = list() + shapes = list() + for i in range(size): + indexs.extend(list(combinations([j for j in range(size)], i + 1))) + for idx in indexs: + shapes.append( + [in_shape[i] if i not in idx else -1 for i in range(size)]) + return shapes + + +# Build program with inputs sizes that contain negative numbers +def test_negative_dims_program(obj): + for shape_x in generate_negative_dims(obj.shape_X): + for shape_y in generate_negative_dims(obj.shape_Y): + X = np.random.random(obj.shape_X).astype("float32") + Y = np.random.random(obj.shape_Y).astype("float32") + Ref = reference_matmul(X, Y, obj.transpose_X, obj.transpose_Y) + with program_guard(Program(), Program()): + x = fluid.data(name='x', shape=shape_x, dtype='float32') + y = fluid.data(name='y', shape=shape_y, dtype='float32') + output = fluid.layers.matmul(x, y, obj.transpose_X, + obj.transpose_Y) + obj.assertEqual(len(Ref.shape), len(output.shape)) + for idx in range(len(Ref.shape)): + if output.shape[idx] != -1: + obj.assertEqual(Ref.shape[idx], output.shape[idx]) + exe = fluid.Executor(fluid.CPUPlace()) + res, = exe.run(fluid.default_main_program(), + feed={'x': X, + 'y': Y}, + fetch_list=[output]) + np.allclose(res, Ref, atol=1e-5) + + +# Generate program api cases for all negative possibilities +def api_test(dim_x, dim_y, trans_x, trans_y): + test_name = ('TestMatMulAPI_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim_x, dim_y, trans_x, trans_y)) + shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, + trans_y) + globals()[test_name] = type(test_name, (unittest.TestCase, ), { + 'shape_X': shape_x, + 'shape_Y': shape_y, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, + 'test_propram': test_negative_dims_program, + }) + + +# Generate operators cases for all possibilities +def inject_test(dim_x, dim_y, trans_x, trans_y): + test_name = ('TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim_x, dim_y, trans_x, trans_y)) + shape_x, shape_y = generate_compatible_shapes(dim_x, dim_y, trans_x, + trans_y) + globals()[test_name] = type(test_name, (Generator, OpTest), { + 'shape_X': shape_x, + 'shape_Y': shape_y, + 'transpose_X': trans_x, + 'transpose_Y': trans_y, + }) + + +for dim_X in (1, 2, 3): + for dim_Y in (1, 2, 3): + for transose_x in (False, True): + for transose_y in (False, True): + inject_test(dim_X, dim_Y, transose_x, transose_y) + api_test(dim_X, dim_Y, transose_x, transose_y) + + +# Test case n-dim +def generate_compatible_shapes(dim, transpose_X, transpose_Y): + M = 2 + N = 4 + K = 3 + shape_X = [2 for _ in range(dim - 2)] + shape_Y = [2 for _ in range(dim - 2)] + + if transpose_X: + shape_X += [K, M] + else: + shape_X += [M, K] + + if transpose_Y: + shape_Y += [N, K] + else: + shape_Y += [K, N] + + return shape_X, shape_Y + + +# # Test case n-dim +for dim in [4]: + for transpose_X in [False, True]: + for transpose_Y in [False, True]: + test_name = ( + 'TestMatMulOp_dimX_{}_dim_Y_{}_transX_{}_transY_{}'.format( + dim, dim, transpose_X, transpose_Y)) + shape_X, shape_Y = generate_compatible_shapes(dim, transpose_X, + transpose_Y) + globals()[test_name] = type(test_name, (Generator, OpTest), { + 'shape_X': shape_X, + 'shape_Y': shape_Y, + 'transpose_X': transpose_X, + 'transpose_Y': transpose_Y, + }) + + +class API_TestMm(unittest.TestCase): + def test_out(self): + with fluid.program_guard(fluid.Program()): + x = fluid.data(name="x", shape=[2], dtype="float64") + y = fluid.data(name='y', shape=[2], dtype='float64') + res = fluid.data(name="output", shape=[1], dtype="float64") + result = paddle.mm(x, y) + exe = fluid.Executor(fluid.CPUPlace()) + data1 = np.random.rand(2) + data2 = np.random.rand(2) + np_res = exe.run(feed={'x': data1, 'y': data2}, fetch_list=[result]) + expected_result = np.matmul( + data1.reshape(1, 2), data2.reshape(2, 1)) + + self.assertTrue( + np.allclose( + np_res, expected_result, atol=1e-5), + "two value is\ + {}\n{}, check diff!".format(np_res, expected_result)) + + def test_dygraph_without_out(self): + device = fluid.CPUPlace() + with fluid.dygraph.guard(device): + input_array1 = np.random.rand(3, 4).astype("float64") + input_array2 = np.random.rand(4, 3).astype("float64") + data1 = fluid.dygraph.to_variable(input_array1) + data2 = fluid.dygraph.to_variable(input_array2) + out = paddle.mm(data1, data2) + expected_result = np.matmul(input_array1, input_array2) + self.assertTrue(np.allclose(expected_result, out.numpy())) + + +class Test_API_Matmul(unittest.TestCase): + def test_dygraph_without_out(self): + device = fluid.CPUPlace() + with fluid.dygraph.guard(device): + input_array1 = np.random.rand(3, 4).astype("float64") + input_array2 = np.random.rand(4, 3).astype("float64") + data1 = fluid.dygraph.to_variable(input_array1) + data2 = fluid.dygraph.to_variable(input_array2) + out = paddle.matmul(data1, data2) + expected_result = np.matmul(input_array1, input_array2) + self.assertTrue(np.allclose(expected_result, out.numpy())) + + +class API_TestMmError(unittest.TestCase): + def test_errors(self): + def test_error1(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.data(name="data1", shape=[10, 2], dtype="float32") + data2 = fluid.data(name="data2", shape=[3, 10], dtype="float32") + paddle.mm(data1, data2) + + self.assertRaises(ValueError, test_error1) + + def test_error2(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.data( + name="data1", shape=[-1, 10, 2], dtype="float32") + data2 = fluid.data( + name="data2", shape=[-1, 2, 10], dtype="float32") + paddle.mm(data1, data2) + + test_error2() + + def test_error3(): + with fluid.program_guard(fluid.Program(), fluid.Program()): + data1 = fluid.data( + name="data1", shape=[10, 10, 2], dtype="float32") + data2 = fluid.data( + name="data2", shape=[3, 2, 10], dtype="float32") + paddle.mm(data1, data2) + + self.assertRaises(ValueError, test_error3) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_mul_op.py b/python/paddle/fluid/tests/unittests/xpu/test_mul_op.py new file mode 100644 index 00000000000..94ab5b71e4f --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_mul_op.py @@ -0,0 +1,161 @@ +# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import paddle +import paddle.fluid.core as core +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid as fluid +from paddle.fluid import Program, program_guard + + +class TestMulOp(OpTest): + def setUp(self): + self.op_type = "mul" + self.dtype = np.float64 + self.init_dtype_type() + self.inputs = { + 'X': np.random.random((20, 5)).astype(self.dtype), + 'Y': np.random.random((5, 21)).astype(self.dtype) + } + self.outputs = {'Out': np.dot(self.inputs['X'], self.inputs['Y'])} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) + + +class TestMulOpError(unittest.TestCase): + def test_errors(self): + with program_guard(Program(), Program()): + # The input type of mul_op must be Variable. + x1 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + x2 = fluid.create_lod_tensor( + np.array([[-1]]), [[1]], fluid.CPUPlace()) + self.assertRaises(TypeError, fluid.layers.mul, x1, x2) + # The input dtype of mul_op must be float32 or float64. + x3 = fluid.layers.data(name='x3', shape=[4], dtype="int32") + x4 = fluid.layers.data(name='x4', shape=[4], dtype="int32") + self.assertRaises(TypeError, fluid.layers.mul, x3, x4) + + +class TestMulOp2(OpTest): + def setUp(self): + self.op_type = "mul" + self.dtype = np.float64 + self.init_dtype_type() + self.inputs = { + 'X': np.random.random((3, 4, 2, 9)).astype(self.dtype), + 'Y': np.random.random((3, 6, 1, 2, 3)).astype(self.dtype) + } + self.attrs = { + 'x_num_col_dims': 2, + 'y_num_col_dims': 2, + } + result = np.dot(self.inputs['X'].reshape(3 * 4, 2 * 9), + self.inputs['Y'].reshape(3 * 6, 1 * 2 * 3)) + result = result.reshape(3, 4, 1, 2, 3) + self.outputs = {'Out': result} + + def init_dtype_type(self): + pass + + def test_check_output(self): + self.check_output() + + def test_check_grad_normal(self): + self.check_grad(['X', 'Y'], 'Out') + + def test_check_grad_ingore_x(self): + self.check_grad( + ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set('X')) + + def test_check_grad_ignore_y(self): + self.check_grad( + ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUMulOp1(TestMulOp): + def init_dtype_type(self): + self.dtype = np.float32 + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=1e-1) + + def test_check_grad_normal(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', max_relative_error=0.5) + + def test_check_grad_ingore_x(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=0.5, no_grad_set=set('Y')) + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestXPUMulOp2(TestMulOp2): + def init_dtype_type(self): + self.dtype = np.float32 + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=2e-1) + + def test_check_grad_normal(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', max_relative_error=0.9) + + def test_check_grad_ingore_x(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Y'], 'Out', max_relative_error=0.5, no_grad_set=set("X")) + + def test_check_grad_ingore_y(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=0.9, no_grad_set=set('Y')) + + +if __name__ == "__main__": + unittest.main() diff --git a/tools/wlist.json b/tools/wlist.json index 9b36ac6adc7..3ca14cd1dd6 100644 --- a/tools/wlist.json +++ b/tools/wlist.json @@ -407,7 +407,9 @@ "TransformerDecoder.prepare_incremental_cache", "LinearChainCRF.forward", "CRFDecoding.forward", - "SequenceTagging.forward" + "SequenceTagging.forward", + "XPUPlace", + "is_compiled_with_xpu" ], "gpu_not_white":[ "deformable_conv", -- GitLab