diff --git a/paddle/fluid/operators/matmul_v2_op_xpu.cc b/paddle/fluid/operators/matmul_v2_op_xpu.cc new file mode 100644 index 0000000000000000000000000000000000000000..d6f3cc226e65543e8ddc923467d7dc26a5ca4432 --- /dev/null +++ b/paddle/fluid/operators/matmul_v2_op_xpu.cc @@ -0,0 +1,367 @@ +// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#ifdef PADDLE_WITH_XPU + +#include "paddle/fluid/operators/matmul_v2_op.h" +#include +#include + +namespace paddle { +namespace operators { + +template +void MatMulXPUFunction(const Tensor* X, const Tensor* Y, + const std::vector& x_dims, + const std::vector& y_dims, Tensor* Out, + bool trans_x, bool trans_y, + const paddle::framework::ExecutionContext& ctx) { + const int x_ndim = x_dims.size(); + const int y_ndim = y_dims.size(); + + auto& dev_ctx = + ctx.template device_context(); + + // currently only support x_ndim == y_dim and non-broadcast case + PADDLE_ENFORCE_EQ(x_ndim, y_ndim, platform::errors::InvalidArgument( + "Shape mistake in matmul_v2_op")); + for (int i = 0; i < x_ndim - 2; i++) { + PADDLE_ENFORCE_EQ( + x_dims.data()[i], y_dims.data()[i], + platform::errors::InvalidArgument("Shape mistake in matmul_v2_op")); + } + + int ret = 0; + if (x_ndim == 1 && y_ndim == 1) { + PADDLE_ENFORCE_EQ(X->numel(), Y->numel(), + platform::errors::InvalidArgument( + "X's numbers is not equal to Y's numbers," + "when X/Y's dims =1")); + VLOG(3) << "MatMul's case 1"; + Out->Resize({1}); + Out->mutable_data(ctx.GetPlace()); + ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, 1, 1, + X->numel(), 1.0f, X->data(), + Y->data(), 0.0f, Out->data()); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d] in matmul_v2, please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + return; + } + + if (x_ndim == 1) { + const int N = X->numel(); + if (trans_y) { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 1], N, + platform::errors::InvalidArgument("Input(Y) has error dim.")); + } else { + PADDLE_ENFORCE_EQ( + y_dims[y_ndim - 2], N, + platform::errors::InvalidArgument("Input(Y) has error dim.")); + } + std::vector out_dims(y_ndim - 1); + if (trans_y) { + std::copy_n(y_dims.cbegin(), y_ndim - 1, out_dims.begin()); + } else { + std::copy_n(y_dims.cbegin(), y_ndim - 2, out_dims.begin()); + out_dims.back() = y_dims.back(); + } + Out->Resize(framework::make_ddim(out_dims)); + Out->mutable_data(ctx.GetPlace()); + if (trans_y) { + const int M = Y->numel() / N; + VLOG(3) << "MatMul's case 2"; + ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, true, 1, M, N, + 1.0f, X->data(), Y->data(), 0.0f, + Out->data()); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External("XPU API return wrong value[%d] in " + "matmul_v2, please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + } else { + const int M = y_dims[y_ndim - 1]; + const int batch_size = Y->numel() / (M * N); + for (int i = 0; i < batch_size; i++) { + ret = baidu::xpu::api::fc_int16( + dev_ctx.x_context(), false, false, 1, M, N, 1.0f, X->data(), + Y->data() + i * M * N, 0.0f, Out->data() + i * M); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d] in matmul_v2, " + "please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + } + } + return; + } + + if (y_ndim == 1) { + const int N = Y->numel(); + if (trans_x) { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 2], N, + platform::errors::InvalidArgument("Input(X) has error dim.")); + } else { + PADDLE_ENFORCE_EQ( + x_dims[x_ndim - 1], N, + platform::errors::InvalidArgument("Input(X) has error dim.")); + } + std::vector out_dims(x_ndim - 1); + if (trans_x) { + std::copy_n(x_dims.cbegin(), x_ndim - 2, out_dims.begin()); + out_dims.back() = x_dims.back(); + } else { + std::copy_n(x_dims.cbegin(), x_ndim - 1, out_dims.begin()); + } + Out->Resize(framework::make_ddim(out_dims)); + Out->mutable_data(ctx.GetPlace()); + + if (trans_x) { + const int M = x_dims[x_ndim - 1]; + const int batch_size = X->numel() / (M * N); + for (int i = 0; i < batch_size; i++) { + ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), true, false, M, 1, + N, 1.0f, X->data() + i * M * N, + Y->data(), 0.0f, + Out->data() + i * M); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d] in matmul_v2, " + "please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + } + } else { + const int M = X->numel() / N; + VLOG(3) << "MatMul's case 7"; + ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, M, 1, + N, 1.0f, X->data(), Y->data(), 0.0f, + Out->data()); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External("XPU API return wrong value[%d] in " + "matmul_v2, please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + } + return; + } + + const int M = trans_x ? x_dims[x_ndim - 1] : x_dims[x_ndim - 2]; + const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (trans_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, platform::errors::InvalidArgument( + "Input(X) has error dim.")); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, platform::errors::InvalidArgument( + "Input(X) has error dim.")); + } + const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; + const int ndim = (std::max)(x_ndim, y_ndim); + std::vector out_broadcast_dims(ndim); + int batch_size = 1; + for (int i = 0; i < ndim - 2; i++) { + PADDLE_ENFORCE_EQ( + x_dims.data()[i], y_dims.data()[i], + platform::errors::InvalidArgument("Shape mistake in matmul_v2_op")); + out_broadcast_dims[i] = x_dims.data()[i]; + batch_size *= x_dims.data()[i]; + } + + out_broadcast_dims[ndim - 2] = M; + out_broadcast_dims[ndim - 1] = N; + + Out->Resize(framework::make_ddim(out_broadcast_dims)); + Out->mutable_data(ctx.GetPlace()); + ret = baidu::xpu::api::batched_gemm_int16( + dev_ctx.x_context(), trans_x, trans_y, batch_size, M, N, K, 1.0f, + X->data(), Y->data(), Out->data(), nullptr, nullptr); + PADDLE_ENFORCE_EQ( + ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d] in matmul_v2, please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); +} + +template +class MatMulV2XPUKernel : public framework::OpKernel { + public: + void Compute(const paddle::framework::ExecutionContext& ctx) const override { + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* Out = ctx.Output("Out"); + bool trans_x = ctx.Attr("trans_x"); + bool trans_y = ctx.Attr("trans_y"); + MatMulXPUFunction(X, Y, vectorize(X->dims()), vectorize(Y->dims()), Out, + trans_x, trans_y, ctx); + } +}; + +template +class MatMulV2XPUGradKernel : public framework::OpKernel { + public: + void MatMul(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + const framework::Tensor& b, bool trans_b, + framework::Tensor* out) const { + out->mutable_data(context.GetPlace()); + MatMulXPUFunction(&a, &b, vectorize(a.dims()), vectorize(b.dims()), out, + trans_a, trans_b, context); + } + + void CalcInputGrad(const framework::ExecutionContext& context, + const framework::Tensor& a, bool trans_a, + bool is_fold_init_dims_a, const framework::Tensor& b, + bool trans_b, bool is_fold_init_dims_b, + framework::Tensor* out) const { + if (out == nullptr) return; + bool need_combine = (a.dims().size() == 3 || b.dims().size() == 3) && + out->dims().size() == 2; + if (!need_combine) { + MatMul(context, a, trans_a, b, trans_b, out); + } else { + // currently not support this case + } + } + + void Compute(const framework::ExecutionContext& ctx) const override { + bool transpose_x = ctx.Attr("trans_x"); + bool transpose_y = ctx.Attr("trans_y"); + + auto x = *ctx.Input("X"); + auto y = *ctx.Input("Y"); + auto dout = *ctx.Input(framework::GradVarName("Out")); + + // get dims + std::vector x_dims = vectorize(x.dims()); + std::vector y_dims = vectorize(y.dims()); + std::vector dout_dims = vectorize(dout.dims()); + + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + + auto* dx = ctx.Output(framework::GradVarName("X")); + auto* dy = ctx.Output(framework::GradVarName("Y")); + + auto& dev_ctx = + ctx.template device_context(); + // Case1 : x's or y's dim = 1 + int ret = 0; + if (x_ndim == 1 && y_ndim == 1) { + if (dx) { + dx->mutable_data(ctx.GetPlace()); + ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, + dx->numel(), 1, 1, 1.0f, y.data(), + dout.data(), 0.0f, dx->data()); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d] in " + "matmul_v2_grad, please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + } + if (dy) { + dy->mutable_data(ctx.GetPlace()); + ret = baidu::xpu::api::fc_int16(dev_ctx.x_context(), false, false, + dy->numel(), 1, 1, 1.0f, x.data(), + dout.data(), 0.0f, dy->data()); + PADDLE_ENFORCE_EQ(ret, XPU_SUCCESS, + platform::errors::External( + "XPU API return wrong value[%d] in " + "matmul_v2_grad, please check whether " + "Baidu Kunlun Card is properly installed.", + ret)); + } + return; + } + + bool is_broadcast = true; + if (x_ndim <= 2 || y_ndim <= 2) { + is_broadcast = false; + } else if (x_ndim != y_ndim) { + is_broadcast = true; + } else { + is_broadcast = !std::equal(x_dims.cbegin(), x_dims.cbegin() + x_ndim - 2, + y_dims.cbegin()); + } + + // currently only support non-broadcast case + PADDLE_ENFORCE_EQ( + is_broadcast, false, + platform::errors::InvalidArgument("Shape mistake in matmul_v2_op")); + + // Case2: no broadcast or no batch size, it aims to speed and it is same as + // matmul in old version. + if (!is_broadcast) { + ReshapeXYOutIntoMatrixSequence(&x, &y, &dout, transpose_x, transpose_y); + framework::DDim dx_dims; + if (dx) { + dx_dims = dx->dims(); + if (dx_dims != x.dims()) { + dx->Resize(x.dims()); + } + } + + framework::DDim dy_dims; + if (dy) { + dy_dims = dy->dims(); + if (dy_dims != y.dims()) { + dy->Resize(y.dims()); + } + } + if (transpose_x && transpose_y) { + CalcInputGrad(ctx, y, true, true, dout, true, false, dx); + CalcInputGrad(ctx, dout, true, true, x, true, false, dy); + } else if (transpose_x) { + CalcInputGrad(ctx, y, false, false, dout, true, false, dx); + CalcInputGrad(ctx, x, false, false, dout, false, true, dy); + } else if (transpose_y) { + CalcInputGrad(ctx, dout, false, false, y, false, true, dx); + CalcInputGrad(ctx, dout, true, true, x, false, true, dy); + } else { + CalcInputGrad(ctx, dout, false, false, y, true, false, dx); + CalcInputGrad(ctx, x, true, true, dout, false, true, dy); + } + + if (dx) { + if (dx_dims != x.dims()) { + dx->Resize(dx_dims); + } + } + if (dy) { + if (dy_dims != y.dims()) { + dy->Resize(dy_dims); + } + } + } + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; + +REGISTER_OP_XPU_KERNEL(matmul_v2, ops::MatMulV2XPUKernel); +REGISTER_OP_XPU_KERNEL(matmul_v2_grad, ops::MatMulV2XPUGradKernel); + +#endif diff --git a/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py new file mode 100644 index 0000000000000000000000000000000000000000..1cc9950f9a15bcacee87cae548670647d6afde84 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_matmul_v2_op_xpu.py @@ -0,0 +1,277 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import numpy as np +import sys +sys.path.append("..") +from op_test import OpTest +import paddle.fluid.core as core + +import paddle +import paddle.fluid as fluid +import paddle.fluid.framework as framework + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size, )) + elif X.ndim == 2: + X = X.T + else: + dim = [i for i in range(len(X.shape))] + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((Y.size, )) + else: + dim = [i for i in range(len(Y.shape))] + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + + Out = np.matmul(X, Y) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float64") + return Out + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMulV2Op(OpTest): + """ + case 1 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False + + def init_kernel_type(self): + self.dtype = "float32" + + def setUp(self): + self.init_kernel_type() + self.config() + self.op_type = "matmul_v2" + self.use_xpu = True + x = np.random.random(self.x_shape).astype(self.dtype) + y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + x = -0.1 + 0.2 * x + y = -0.1 + 0.2 * y + result = reference_matmul(x, y, self.trans_x, self.trans_y) + result = result.astype(self.dtype) + self.inputs = { + 'X': x, + 'Y': y, + } + self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} + self.outputs = {'Out': result} + + def test_check_output(self): + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=0.01) + + def test_check_grad(self): + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', max_relative_error=0.1) + + +''' +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp2(TestMatMulV2Op): + """ + case 2 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 3, 2, 100) + self.trans_x = False + self.trans_y = True + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp3(TestMatMulV2Op): + """ + case 3 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp4(TestMatMulV2Op): + """ + case 4 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 2, 100, 2) + self.trans_x = False + self.trans_y = False + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp5(TestMatMulV2Op): + """ + case 5 + """ + + def config(self): + self.x_shape = (1, 1, 100, 1) + self.y_shape = (100, ) + self.trans_x = True + self.trans_y = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp6(TestMatMulV2Op): + """ + case 6 + """ + + def config(self): + self.x_shape = (1, 2, 100, 1) + self.y_shape = (100, ) + self.trans_x = True + self.trans_y = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp7(TestMatMulV2Op): + """ + case 7 + """ + + def config(self): + self.x_shape = (1, 2, 1, 100) + self.y_shape = (100, ) + self.trans_x = False + self.trans_y = False +''' + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp8(TestMatMulV2Op): + """ + case 8 + """ + + def config(self): + self.x_shape = (1, 1, 2, 100) + self.y_shape = (1, 1, 100, 2) + self.trans_x = False + self.trans_y = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp13(TestMatMulV2Op): + """ + case 13 + """ + + def config(self): + self.x_shape = (2, 2, 2, 50) + self.y_shape = (2, 2, 2, 50) + self.trans_x = True + self.trans_y = False + + +''' +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp16(TestMatMulV2Op): + """ + case 16 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (100) + self.y_shape = (1, 2, 2, 100, 2) + self.trans_x = False + self.trans_y = False + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMuklOp17(TestMatMulV2Op): + """ + case 17 : to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 1, 100) + self.y_shape = (100) + self.trans_x = False + self.trans_y = False +''' + + +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") +class TestMatMulV2API(unittest.TestCase): + def setUp(self): + self.places = [fluid.CPUPlace()] + self.places.append(fluid.XPUPlace(0)) + + def check_static_result(self, place): + with fluid.program_guard(fluid.Program(), fluid.Program()): + input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32") + input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32") + + result = paddle.matmul(input_x, input_y) + + x_np = np.random.random([4, 3]).astype("float32") + y_np = np.random.random([3, 4]).astype("float32") + + exe = fluid.Executor(place) + fetches = exe.run(fluid.default_main_program(), + feed={"input_x": x_np, + "input_y": y_np}, + fetch_list=[result]) + + def test_static(self): + for place in self.places: + self.check_static_result(place=place) + + +if __name__ == "__main__": + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_mul_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_mul_op_xpu.py index 636453e916b10a981b0f4387f76f035acf20dabe..7cf005fefa613a48403a124b8d8782f87de0bab9 100644 --- a/python/paddle/fluid/tests/unittests/xpu/test_mul_op_xpu.py +++ b/python/paddle/fluid/tests/unittests/xpu/test_mul_op_xpu.py @@ -28,6 +28,8 @@ import time paddle.enable_static() +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") class TestMulOpError(unittest.TestCase): def test_errors(self): with program_guard(Program(), Program()): @@ -43,6 +45,8 @@ class TestMulOpError(unittest.TestCase): self.assertRaises(TypeError, fluid.layers.mul, x3, x4) +@unittest.skipIf(not paddle.is_compiled_with_xpu(), + "core is not compiled with XPU") class TestXPUMulOp1(OpTest): def setUp(self): self.op_type = "mul" @@ -67,18 +71,23 @@ class TestXPUMulOp1(OpTest): pass def test_check_output(self): - self.check_output() + place = paddle.XPUPlace(0) + self.check_output_with_place(place, atol=0.01) def test_check_grad_normal(self): - self.check_grad(['X', 'Y'], 'Out', max_relative_error=0.1) + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X', 'Y'], 'Out', max_relative_error=0.1) def test_check_grad_ingore_x(self): - self.check_grad( - ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set('X')) + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['Y'], 'Out', max_relative_error=0.1, no_grad_set=set("X")) def test_check_grad_ignore_y(self): - self.check_grad( - ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) + place = paddle.XPUPlace(0) + self.check_grad_with_place( + place, ['X'], 'Out', max_relative_error=0.1, no_grad_set=set('Y')) @unittest.skipIf(not paddle.is_compiled_with_xpu(),