From 19b02d95e099ae066a1f58161501ed2d5140988a Mon Sep 17 00:00:00 2001 From: Aganlengzi Date: Mon, 25 Oct 2021 19:46:15 +0800 Subject: [PATCH] [NPU] modifications for model ernie-1.0 (#36642) * [NPU] modifications for model ernie-1.0 * rollback 503003 and change cast to dtype --- paddle/fluid/operators/cumsum_op_npu.cc | 45 +- .../elementwise/elementwise_sub_op_npu.cc | 6 + .../fluid/operators/lookup_table_v2_op_npu.cc | 55 +- paddle/fluid/operators/matmul_op_npu.cc | 528 ++++++++++++++---- .../tests/unittests/npu/test_cumsum_op_npu.py | 40 ++ .../npu/test_elementwise_sub_op_npu.py | 5 + .../npu/test_lookup_table_v2_op_npu.py | 40 +- .../tests/unittests/npu/test_matmul_op_npu.py | 329 +++++++++++ 8 files changed, 908 insertions(+), 140 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/npu/test_matmul_op_npu.py diff --git a/paddle/fluid/operators/cumsum_op_npu.cc b/paddle/fluid/operators/cumsum_op_npu.cc index e8cf1a46db3..486e85b0f0d 100644 --- a/paddle/fluid/operators/cumsum_op_npu.cc +++ b/paddle/fluid/operators/cumsum_op_npu.cc @@ -21,6 +21,38 @@ namespace operators { using Tensor = framework::Tensor; +static void CumsumImp(const Tensor& input, Tensor* output, + const framework::NPUAttributeMap& attr_input, + const framework::ExecutionContext& ctx) { + auto stream = + ctx.template device_context() + .stream(); + if (input.type() == framework::proto::VarType::INT64) { + Tensor tmp_input; + tmp_input.mutable_data(input.dims(), ctx.GetPlace()); + auto dst_acl_dtype = ConvertToNpuDtype(tmp_input.type()); + const auto& cast_runner_1 = + NpuOpRunner("Cast", {input}, {tmp_input}, + {{"dst_type", static_cast(dst_acl_dtype)}}); + cast_runner_1.Run(stream); + + Tensor tmp_output; + tmp_output.mutable_data(output->dims(), ctx.GetPlace()); + const auto& runner = + NpuOpRunner("CumsumD", {tmp_input}, {tmp_output}, attr_input); + runner.Run(stream); + + dst_acl_dtype = ConvertToNpuDtype(output->type()); + const auto& cast_runner_2 = + NpuOpRunner("Cast", {tmp_output}, {*output}, + {{"dst_type", static_cast(dst_acl_dtype)}}); + cast_runner_2.Run(stream); + } else { + const auto& runner = NpuOpRunner("CumsumD", {input}, {*output}, attr_input); + runner.Run(stream); + } +} + template class CumSumNPUKernel : public framework::OpKernel { public: @@ -36,10 +68,6 @@ class CumSumNPUKernel : public framework::OpKernel { framework::NPUAttributeMap attr_input = { {"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}; - auto stream = - ctx.template device_context() - .stream(); - bool flatten = ctx.Attr("flatten"); if (flatten) { PADDLE_ENFORCE_EQ( @@ -53,11 +81,9 @@ class CumSumNPUKernel : public framework::OpKernel { new_x.Resize(framework::make_ddim({x->numel()})); - const auto& runner = NpuOpRunner("CumsumD", {new_x}, {*out}, attr_input); - runner.Run(stream); + CumsumImp(new_x, out, attr_input, ctx); } else { - const auto& runner = NpuOpRunner("CumsumD", {*x}, {*out}, attr_input); - runner.Run(stream); + CumsumImp(*x, out, attr_input, ctx); } } }; @@ -69,5 +95,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( cumsum, ops::CumSumNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::CumSumNPUKernel, +#endif ops::CumSumNPUKernel, ops::CumSumNPUKernel); diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc index 48b98dafc7b..4cc4228b164 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc @@ -167,10 +167,16 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ElementwiseSubNPUKernel, +#endif ops::ElementwiseSubNPUKernel, ops::ElementwiseSubNPUKernel); REGISTER_OP_NPU_KERNEL(elementwise_sub_grad, ops::ElementwiseSubGradNPUKernel, +#ifdef PADDLE_WITH_ASCEND_INT64 + ops::ElementwiseSubGradNPUKernel, +#endif ops::ElementwiseSubGradNPUKernel, ops::ElementwiseSubGradNPUKernel); diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu.cc b/paddle/fluid/operators/lookup_table_v2_op_npu.cc index b75ae8a6588..3cb91c71233 100644 --- a/paddle/fluid/operators/lookup_table_v2_op_npu.cc +++ b/paddle/fluid/operators/lookup_table_v2_op_npu.cc @@ -21,6 +21,9 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; +constexpr int64_t kNoPadding = -1; + template class LookupTableV2NPUKernel : public framework::OpKernel { public: @@ -35,16 +38,52 @@ class LookupTableV2NPUKernel : public framework::OpKernel { platform::errors::InvalidArgument("npu only accept LoDTensor")); output_t->mutable_data(ctx.GetPlace()); - NpuOpRunner runner; - runner.SetType("GatherV2") - .AddInput(*table_t) - .AddInput(*ids_t) - .AddInput(std::vector{0}) + int64_t padding_idx = ctx.Attr("padding_idx"); + if (padding_idx == kNoPadding) { + NpuOpRunner runner; + runner.SetType("GatherV2") + .AddInput(*table_t) + .AddInput(*ids_t) + .AddInput(std::vector{0}) +#if (CANN_VERSION_CODE >= 503003) + .AddAttrs({{"batch_dims", 0}}) +#endif + .AddOutput(*output_t); + runner.Run(); + } else { + Tensor tmp_table_t(table_t->type()); + tmp_table_t.mutable_data(table_t->dims(), ctx.GetPlace()); + + Tensor index; + index.mutable_data({1, 1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&index, + static_cast(padding_idx)); + + auto updata_dim = framework::make_ddim({1, table_t->dims()[1]}); + Tensor update; + update.mutable_data(updata_dim, ctx.GetPlace()); + FillNpuTensorWithConstant(&update, static_cast(0)); + update.Resize(updata_dim); + + NpuOpRunner update_runner; + update_runner.SetType("TensorScatterUpdate") + .AddInput(*table_t) + .AddInput(index) + .AddInput(update) + .AddOutput(tmp_table_t); + update_runner.Run(); + + NpuOpRunner runner; + runner.SetType("GatherV2") + .AddInput(tmp_table_t) + .AddInput(*ids_t) + .AddInput(std::vector{0}) #if (CANN_VERSION_CODE >= 503003) - .AddAttrs({{"batch_dims", 0}}) + .AddAttrs({{"batch_dims", 0}}) #endif - .AddOutput(*output_t); - runner.Run(); + .AddOutput(*output_t); + runner.Run(); + } } }; diff --git a/paddle/fluid/operators/matmul_op_npu.cc b/paddle/fluid/operators/matmul_op_npu.cc index d5606177a55..df811abc1de 100644 --- a/paddle/fluid/operators/matmul_op_npu.cc +++ b/paddle/fluid/operators/matmul_op_npu.cc @@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include -#include #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/operators/npu_op_runner.h" @@ -21,40 +19,253 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; +using NPUDeviceContext = platform::NPUDeviceContext; + +template +static void Mul(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& X, const Tensor& Y, + Tensor* Out, const float alpha) { + Out->mutable_data(ctx.GetPlace()); + + if (fabs(alpha - 1.0) < std::numeric_limits::epsilon()) { + const auto& runner_dx = NpuOpRunner("Mul", {X, Y}, {*Out}, {}); + runner_dx.Run(stream); + } else { + Tensor Out_temp(Out->type()); + Out_temp.mutable_data(Out->dims(), ctx.GetPlace()); + const auto& runner_dx = NpuOpRunner("Mul", {X, Y}, {Out_temp}, {}); + runner_dx.Run(stream); + + const auto& runner = + NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}}); + runner.Run(stream); + } +} + +template +static void Dot(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& X, const Tensor& Y, + Tensor* Out, const float alpha) { + Out->mutable_data(ctx.GetPlace()); + + if (fabs(alpha - 1.0) < std::numeric_limits::epsilon()) { + const auto& runner = NpuOpRunner("Dot", {X, Y}, {*Out}); + runner.Run(stream); + } else { + Tensor Out_temp(Out->type()); + Out_temp.mutable_data(Out->dims(), ctx.GetPlace()); + const auto& out_temp_runner = NpuOpRunner("Dot", {X, Y}, {Out_temp}); + out_temp_runner.Run(stream); + + const auto& runner = + NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}}); + runner.Run(stream); + } +} + +template +static void MatMul2D(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& X, + const Tensor& Y, Tensor* Out, const bool trans_x, + const bool trans_y, const float alpha) { + Out->mutable_data(ctx.GetPlace()); + + if (fabs(alpha - 1.0) < std::numeric_limits::epsilon()) { + const auto& runner = + NpuOpRunner("MatMul", {X, Y}, {*Out}, + {{"transpose_x1", trans_x}, {"transpose_x2", trans_y}}); + runner.Run(stream); + } else { + Tensor Out_temp(Out->type()); + Out_temp.mutable_data(Out->dims(), ctx.GetPlace()); + const auto& out_temp_runner = + NpuOpRunner("MatMul", {X, Y}, {Out_temp}, + {{"transpose_x1", trans_x}, {"transpose_x2", trans_y}}); + out_temp_runner.Run(stream); + + const auto& runner = + NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}}); + runner.Run(stream); + } +} + +template +static void MatMulND(const framework::ExecutionContext& ctx, + const aclrtStream& stream, const Tensor& X, + const Tensor& Y, Tensor* Out, const bool trans_x, + const bool trans_y, const float alpha) { + Out->mutable_data(ctx.GetPlace()); + + if (fabs(alpha - 1.0) < std::numeric_limits::epsilon()) { + const auto& runner = + NpuOpRunner("BatchMatMul", {X, Y}, {*Out}, + {{"adj_x1", trans_x}, {"adj_x2", trans_y}}); + runner.Run(stream); + } else { + Tensor Out_temp(Out->type()); + Out_temp.mutable_data(Out->dims(), ctx.GetPlace()); + const auto& out_temp_runner = + NpuOpRunner("BatchMatMul", {X, Y}, {Out_temp}, + {{"adj_x1", trans_x}, {"adj_x2", trans_y}}); + out_temp_runner.Run(stream); + + const auto& runner = + NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}}); + runner.Run(stream); + } +} + +template +static void ReduceDims(const framework::ExecutionContext& ctx, + const aclrtStream& stream, + const std::vector& dims, + const std::vector& brd_dims, const Tensor& in, + Tensor* out) { + std::vector axes; + int64_t size = brd_dims.size(); + int64_t diff = brd_dims.size() - dims.size(); + for (int64_t i = 0; i < size; ++i) { + if (i < diff) { + axes.push_back(i); + continue; + } + if (brd_dims[i] > dims[i - diff]) { + axes.push_back(i); + } + } + out->mutable_data(ctx.GetPlace()); + const auto& runner = NpuOpRunner("ReduceSumD", {in}, {*out}, + {{"axes", axes}, {"keep_dims", false}}); + runner.Run(stream); +} + template class MatMulNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* out = ctx.Output("Out"); + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* Out = ctx.Output("Out"); bool transpose_x = ctx.Attr("transpose_X"); bool transpose_y = ctx.Attr("transpose_Y"); + float alpha = static_cast(ctx.Attr("alpha")); + + std::vector x_dims = framework::vectorize(X->dims()); + std::vector y_dims = framework::vectorize(Y->dims()); + std::vector out_dims = framework::vectorize(Out->dims()); + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int out_ndim = out_dims.size(); - if (x->dims().size() == 2) { - out->mutable_data(ctx.GetPlace()); + auto stream = ctx.template device_context().stream(); - const auto& runner = NpuOpRunner( - "MatMul", {*x, *y}, {*out}, - {{"transpose_x1", transpose_x}, {"transpose_x2", transpose_y}}); + // Case 1: [K] x [K] = [1] + if (x_ndim == 1 && y_ndim == 1) { + PADDLE_ENFORCE_EQ( + X->numel(), Y->numel(), + platform::errors::InvalidArgument( + "X's numbers must be equal to Y's numbers," + "when X/Y's dims =1. But received X has [%d] elements," + "received Y has [%d] elements", + X->numel(), Y->numel())); + Out->Resize({1}); + Dot(ctx, stream, *X, *Y, Out, alpha); + return; + } - auto stream = - ctx.template device_context() - .stream(); - runner.Run(stream); + // Resize dim 1 to 2 + Tensor x_temp, y_temp; + x_temp.ShareDataWith(*X); + y_temp.ShareDataWith(*Y); + if (x_ndim == 1) { + x_dims.insert(x_dims.begin(), 1); + out_dims.insert(out_dims.end() - 1, 1); + x_temp.Resize(framework::make_ddim(x_dims)); + x_ndim = 2; + out_ndim += 1; + } + if (y_ndim == 1) { + y_dims.push_back(1); + out_dims.push_back(1); + y_temp.Resize(framework::make_ddim(y_dims)); + y_ndim = 2; + out_ndim += 1; + } + + const int K = transpose_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + if (transpose_y) { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K, + platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1])); + } else { + PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K, + platform::errors::InvalidArgument( + "Input(Y) has error dim." + "Y'dims[%d] must be equal to %d" + "But received Y'dims[%d] is %d", + y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2])); + } + + // Case 2: [M, K] x [K, N] = [M, N] + if (x_ndim == 2 && y_ndim == 2) { + MatMul2D(ctx, stream, x_temp, y_temp, Out, transpose_x, transpose_y, + alpha); + return; + } + + // Case 3: [B, M, K] x [K, N] = [B, M, N], when transpose_x = false + // Equal: [B * M, K] x [K, N] = [B * M, N] => [B, M, N] + if (transpose_x == false && y_ndim == 2) { + std::vector vec_dim = {x_temp.numel() / K, K}; + x_temp.Resize(framework::make_ddim(vec_dim)); + MatMul2D(ctx, stream, x_temp, y_temp, Out, transpose_x, transpose_y, + alpha); + return; + } - } else if (x->dims().size() > 2) { - out->mutable_data(ctx.GetPlace()); + // Case 4: [B, M, K] x [B, K, N] = [B, M, N] + std::vector x_broadcast_dims(out_ndim, 1); + std::vector y_broadcast_dims(out_ndim, 1); + std::copy(out_dims.begin(), out_dims.end() - 2, x_broadcast_dims.begin()); + std::copy(out_dims.begin(), out_dims.end() - 2, y_broadcast_dims.begin()); + std::copy(x_dims.end() - 2, x_dims.end(), x_broadcast_dims.end() - 2); + std::copy(y_dims.end() - 2, y_dims.end(), y_broadcast_dims.end() - 2); - const auto& runner = - NpuOpRunner("BatchMatMul", {*x, *y}, {*out}, - {{"adj_x1", transpose_x}, {"adj_x2", transpose_y}}); + Tensor x_temp_brd(X->type()); + if (x_dims == x_broadcast_dims) { + x_temp_brd.ShareDataWith(*X); + x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims)); + } else { + x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims)); + x_temp_brd.mutable_data(ctx.GetPlace()); + NpuOpRunner runner_brd; + runner_brd.SetType("BroadcastTo") + .AddInput(x_temp) + .AddInput(std::move(x_broadcast_dims)) + .AddOutput(x_temp_brd) + .Run(stream); + } - auto stream = - ctx.template device_context() - .stream(); - runner.Run(stream); + Tensor y_temp_brd(Y->type()); + if (y_dims == y_broadcast_dims) { + y_temp_brd.ShareDataWith(*Y); + y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims)); + } else { + y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims)); + y_temp_brd.mutable_data(ctx.GetPlace()); + NpuOpRunner runner_brd; + runner_brd.SetType("BroadcastTo") + .AddInput(y_temp) + .AddInput(std::move(y_broadcast_dims)) + .AddOutput(y_temp_brd) + .Run(stream); } + MatMulND(ctx, stream, x_temp_brd, y_temp_brd, Out, transpose_x, + transpose_y, alpha); } }; @@ -62,109 +273,200 @@ template class MatMulGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& ctx) const override { - auto* x = ctx.Input("X"); - auto* y = ctx.Input("Y"); - auto* dout = ctx.Input(framework::GradVarName("Out")); - auto* dx = ctx.Output(framework::GradVarName("X")); - auto* dy = ctx.Output(framework::GradVarName("Y")); + auto* X = ctx.Input("X"); + auto* Y = ctx.Input("Y"); + auto* dOut = ctx.Input(framework::GradVarName("Out")); + auto* dX = ctx.Output(framework::GradVarName("X")); + auto* dY = ctx.Output(framework::GradVarName("Y")); + bool transpose_x = ctx.Attr("transpose_X"); bool transpose_y = ctx.Attr("transpose_Y"); - auto stream = - ctx.template device_context() - .stream(); - - if (x->dims().size() == 2) { - if (transpose_y) { - if (dx) { - dx->mutable_data(ctx.GetPlace()); - const auto& runner_dx = - NpuOpRunner("MatMul", {*dout, *y}, {*dx}, - {{"transpose_x1", false}, {"transpose_x2", false}}); - - runner_dx.Run(stream); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - const auto& runner_dy = - NpuOpRunner("MatMul", {*dout, *x}, {*dy}, - {{"transpose_x1", true}, {"transpose_x2", false}}); + float alpha = static_cast(ctx.Attr("alpha")); - runner_dy.Run(stream); - } + std::vector x_dims = framework::vectorize(X->dims()); + std::vector y_dims = framework::vectorize(Y->dims()); + std::vector out_dims = framework::vectorize(dOut->dims()); + int x_ndim = x_dims.size(); + int y_ndim = y_dims.size(); + int out_ndim = out_dims.size(); - } else { - if (dx) { - dx->mutable_data(ctx.GetPlace()); - const auto& runner_dx = - NpuOpRunner("MatMul", {*dout, *y}, {*dx}, - {{"transpose_x1", false}, {"transpose_x2", true}}); + auto stream = ctx.template device_context().stream(); - runner_dx.Run(stream); - } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - const auto& runner_dy = - NpuOpRunner("MatMul", {*x, *dout}, {*dy}, - {{"transpose_x1", true}, {"transpose_x2", false}}); + // Case 1: [K] x [K] = [1] + if (x_ndim == 1 && y_ndim == 1) { + Tensor dout_temp(dOut->type()); + dout_temp.Resize(X->dims()); + dout_temp.mutable_data(ctx.GetPlace()); + NpuOpRunner runner; + runner.SetType("BroadcastTo") + .AddInput(*dOut) + .AddInput(std::move(x_dims)) + .AddOutput(dout_temp) + .Run(stream); + + if (dX) { + Mul(ctx, stream, dout_temp, *Y, dX, alpha); + } + if (dY) { + Mul(ctx, stream, dout_temp, *X, dY, alpha); + } + return; + } + + // Resize dim 1 to 2 + Tensor x_temp, y_temp, dout_temp; + x_temp.ShareDataWith(*X); + y_temp.ShareDataWith(*Y); + dout_temp.ShareDataWith(*dOut); + if (x_ndim == 1) { + x_dims.insert(x_dims.begin(), 1); + out_dims.insert(out_dims.end() - 1, 1); + x_temp.Resize(framework::make_ddim(x_dims)); + dout_temp.Resize(framework::make_ddim(out_dims)); + x_ndim = 2; + out_ndim += 1; + } + if (y_ndim == 1) { + y_dims.push_back(1); + out_dims.push_back(1); + y_temp.Resize(framework::make_ddim(y_dims)); + dout_temp.Resize(framework::make_ddim(out_dims)); + y_ndim = 2; + out_ndim += 1; + } - runner_dy.Run(stream); + // Case 2: [M, K] x [K, N] = [M, N] + if (out_ndim == 2) { + if (dX) { + dX->Resize(framework::make_ddim(x_dims)); + if (transpose_x) { + MatMul2D(ctx, stream, y_temp, dout_temp, dX, transpose_y, true, + alpha); + } else { + MatMul2D(ctx, stream, dout_temp, y_temp, dX, false, !transpose_y, + alpha); } + dX->Resize(X->dims()); } - } else if (x->dims().size() > 2) { - if (transpose_y) { - if (dx) { - dx->mutable_data(ctx.GetPlace()); - const auto& runner_dx = - NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, - {{"adj_x1", false}, {"adj_x2", false}}); - - runner_dx.Run(stream); + if (dY) { + dY->Resize(framework::make_ddim(y_dims)); + if (transpose_y) { + MatMul2D(ctx, stream, dout_temp, x_temp, dY, true, transpose_x, + alpha); + } else { + MatMul2D(ctx, stream, x_temp, dout_temp, dY, !transpose_x, false, + alpha); } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - const auto& runner_dy = - NpuOpRunner("BatchMatMul", {*dout, *x}, {*dy}, - {{"adj_x1", true}, {"adj_x2", false}}); + dY->Resize(Y->dims()); + } + return; + } + + const int K = transpose_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1]; + const int N = transpose_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1]; - runner_dy.Run(stream); + // Case 3: [B, M, K] x [K, N] = [B, M, N], when transpose_x = false + // Equal: [B * M, K] x [K, N] = [B * M, N] => [B, M, N] + if (transpose_x == false && y_ndim == 2) { + std::vector x_vec_dim = {x_temp.numel() / K, K}; + dout_temp.Resize( + framework::make_ddim(std::vector{dout_temp.numel() / N, N})); + if (dX) { + dX->Resize(framework::make_ddim(x_vec_dim)); + MatMul2D(ctx, stream, dout_temp, y_temp, dX, false, !transpose_y, + alpha); + dX->Resize(X->dims()); + } + if (dY) { + x_temp.Resize(framework::make_ddim(x_vec_dim)); + if (transpose_y) { + MatMul2D(ctx, stream, dout_temp, x_temp, dY, true, false, alpha); + } else { + MatMul2D(ctx, stream, x_temp, dout_temp, dY, true, false, alpha); } - } else { - if (dx) { - dx->mutable_data(ctx.GetPlace()); - const auto& runner_dx = - NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, - {{"adj_x1", false}, {"adj_x2", true}}); + } + return; + } - runner_dx.Run(stream); + // Case 4: [B, M, K] x [B, K, N] = [B, M, N] + std::vector x_broadcast_dims(out_ndim, 1); + std::vector y_broadcast_dims(out_ndim, 1); + std::copy(out_dims.begin(), out_dims.end() - 2, x_broadcast_dims.begin()); + std::copy(out_dims.begin(), out_dims.end() - 2, y_broadcast_dims.begin()); + std::copy(x_dims.end() - 2, x_dims.end(), x_broadcast_dims.end() - 2); + std::copy(y_dims.end() - 2, y_dims.end(), y_broadcast_dims.end() - 2); + + Tensor x_temp_brd(X->type()); + if (x_dims == x_broadcast_dims) { + x_temp_brd.ShareDataWith(*X); + x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims)); + } else { + x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims)); + x_temp_brd.mutable_data(ctx.GetPlace()); + NpuOpRunner runner_brd; + runner_brd.SetType("BroadcastTo") + .AddInput(x_temp) + .AddInput(std::move(x_broadcast_dims)) + .AddOutput(x_temp_brd) + .Run(stream); + } + + Tensor y_temp_brd(Y->type()); + if (y_dims == y_broadcast_dims) { + y_temp_brd.ShareDataWith(*Y); + y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims)); + } else { + y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims)); + y_temp_brd.mutable_data(ctx.GetPlace()); + NpuOpRunner runner_brd; + runner_brd.SetType("BroadcastTo") + .AddInput(y_temp) + .AddInput(std::move(y_broadcast_dims)) + .AddOutput(y_temp_brd) + .Run(stream); + } + + if (dX) { + if (x_dims == x_broadcast_dims) { + if (transpose_x) { + MatMulND(ctx, stream, y_temp_brd, dout_temp, dX, transpose_y, true, + alpha); + } else { + MatMulND(ctx, stream, dout_temp, y_temp_brd, dX, false, + !transpose_y, alpha); + } + } else { + Tensor dx_temp(X->type()); + dx_temp.Resize(framework::make_ddim(x_broadcast_dims)); + if (transpose_x) { + MatMulND(ctx, stream, y_temp_brd, dout_temp, &dx_temp, transpose_y, + true, alpha); + } else { + MatMulND(ctx, stream, dout_temp, y_temp_brd, &dx_temp, false, + !transpose_y, alpha); } - if (dy) { - dy->mutable_data(ctx.GetPlace()); - if ((x->dims().size() == 3) && (dout->dims().size() == 3) && - (dy->dims().size() == 2)) { - framework::Tensor dout_tmp; - dout_tmp.ShareDataWith(*dout); - std::vector vec_dim = - framework::vectorize(dout_tmp.dims()); - std::vector vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; - dout_tmp.Resize(framework::make_ddim(vec_dim_v)); - - framework::Tensor x_tmp; - x_tmp.ShareDataWith(*x); - std::vector vec_dim_x = - framework::vectorize(x_tmp.dims()); - std::vector vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], - vec_dim_x[2]}; - x_tmp.Resize(framework::make_ddim(vec_dim_x_v)); - const auto& runner_dy = - NpuOpRunner("MatMul", {x_tmp, dout_tmp}, {*dy}, - {{"transpose_x1", true}, {"transpose_x2", false}}); - runner_dy.Run(stream); - } else { - const auto& runner_dy = - NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy}, - {{"adj_x1", true}, {"adj_x2", false}}); - runner_dy.Run(stream); - } + ReduceDims(ctx, stream, x_dims, x_broadcast_dims, dx_temp, dX); + } + } + if (dY) { + if (y_dims == y_broadcast_dims) { + if (transpose_y) { + MatMulND(ctx, stream, dout_temp, x_temp_brd, dY, true, transpose_x, + alpha); + } else { + MatMulND(ctx, stream, x_temp_brd, dout_temp, dY, !transpose_x, + false, alpha); + } + } else { + Tensor dy_temp(Y->type()); + dy_temp.Resize(framework::make_ddim(y_broadcast_dims)); + if (transpose_y) { + MatMulND(ctx, stream, dout_temp, x_temp_brd, &dy_temp, true, + transpose_x, alpha); + } else { + MatMulND(ctx, stream, x_temp_brd, dout_temp, &dy_temp, + !transpose_x, false, alpha); } + ReduceDims(ctx, stream, y_dims, y_broadcast_dims, dy_temp, dY); } } } diff --git a/python/paddle/fluid/tests/unittests/npu/test_cumsum_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_cumsum_op_npu.py index 5a3f98524bb..9289da6641e 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_cumsum_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_cumsum_op_npu.py @@ -249,5 +249,45 @@ class TestNPUCumSumWithFlatten2(TestNPUCumSumOp1): self.outputs = {'Out': self.inputs['X'].cumsum()} +#----------------Cumsum Int64---------------- +class TestNPUCumSumOpInt64(TestNPUCumSumOp1): + def init_testcase(self): + self.attrs = {'axis': -1, 'reverse': True} + self.inputs = { + 'X': np.random.randint( + 1, 10000, size=(5, 6, 10)).astype(self.dtype) + } + self.outputs = { + 'Out': np.flip( + np.flip( + self.inputs['X'], axis=2).cumsum(axis=2), axis=2) + } + + +def create_test_int64(parent): + class TestCumSumInt64(parent): + def init_dtype(self): + self.dtype = np.int64 + + cls_name = "{0}_{1}".format(parent.__name__, "Int64") + TestCumSumInt64.__name__ = cls_name + globals()[cls_name] = TestCumSumInt64 + + +create_test_int64(TestNPUCumSumOp1) +create_test_int64(TestNPUCumSumOp2) +create_test_int64(TestNPUCumSumOp3) +create_test_int64(TestNPUCumSumOp4) +create_test_int64(TestNPUCumSumOp5) +create_test_int64(TestNPUCumSumOp7) +create_test_int64(TestNPUCumSumExclusive1) +create_test_int64(TestNPUCumSumExclusive2) +create_test_int64(TestNPUCumSumExclusive3) +create_test_int64(TestNPUCumSumExclusive4) +create_test_int64(TestNPUCumSumExclusive5) +create_test_int64(TestNPUCumSumReverseExclusive) +create_test_int64(TestNPUCumSumWithFlatten1) +create_test_int64(TestNPUCumSumWithFlatten2) + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py index 7c8710fd42b..fac2bc66ff4 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_elementwise_sub_op_npu.py @@ -95,6 +95,11 @@ class TestElementwiseSubOpInt32(TestElementwiseSubOp): self.dtype = np.int32 +class TestElementwiseSubOpInt64(TestElementwiseSubOp): + def init_dtype(self): + self.dtype = np.int64 + + class TestSubtractAPI(unittest.TestCase): def test_name(self): with paddle.static.program_guard(paddle.static.Program()): diff --git a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py index 56f04a6e993..1031be4c1a7 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_lookup_table_v2_op_npu.py @@ -33,14 +33,15 @@ class TestLookupTableV2(OpTest): self.place = paddle.NPUPlace(0) self.init_dtype() - self.init_dim() + self.init_dims() + self.init_padding_idx() np.random.seed(SEED) - bsz = 6 - seqlen = 8 - vocab = 10 - w = np.ones([vocab, self.dim]).astype(self.dtype) - x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32) - out = np.ones([bsz, seqlen, self.dim]).astype(self.dtype) + w = np.random.random([self.vocab, self.dim]).astype(self.dtype) + x = np.random.randint( + 0, self.vocab, size=(self.bsz, self.seqlen)).astype(np.int32) + out = w[x] + if self.padding_idx != -1: + out[np.squeeze(x == self.padding_idx)] = np.zeros(self.dim) self.inputs = { 'W': OpTest.np_dtype_to_fluid_dtype(w), @@ -50,7 +51,7 @@ class TestLookupTableV2(OpTest): 'is_sparse': False, 'is_distributed': False, 'remote_prefetch': False, - 'padding_idx': -1 + 'padding_idx': self.padding_idx } self.outputs = {'Out': out} @@ -60,10 +61,16 @@ class TestLookupTableV2(OpTest): def init_dtype(self): self.dtype = np.float32 - def init_dim(self): + def init_dims(self): + self.bsz = 6 + self.seqlen = 8 + self.vocab = 10 # embedding_dim is not multiple of 32 self.dim = 20 + def init_padding_idx(self): + self.padding_idx = -1 + def test_check_output(self): self.check_output_with_place(self.place) @@ -85,7 +92,10 @@ class TestLookupTableV2FP16(TestLookupTableV2): class TestLookupTableV2Dim32(TestLookupTableV2): - def init_dim(self): + def init_dims(self): + self.bsz = 6 + self.seqlen = 8 + self.vocab = 10 # embedding_dim is multiple of 32 self.dim = 64 @@ -96,7 +106,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2): def init_dtype(self): self.dtype = np.float16 - def init_dim(self): + def init_dims(self): + self.bsz = 6 + self.seqlen = 8 + self.vocab = 10 self.dim = 64 def set_npu(self): @@ -104,5 +117,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2): self.__class__.no_need_check_grad = True +class TestLookupTableV2WithPadding(TestLookupTableV2): + def init_padding_idx(self): + self.padding_idx = np.random.randint(0, self.vocab) + + if __name__ == '__main__': unittest.main() diff --git a/python/paddle/fluid/tests/unittests/npu/test_matmul_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_matmul_op_npu.py new file mode 100644 index 00000000000..a8dc0c137c3 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/npu/test_matmul_op_npu.py @@ -0,0 +1,329 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import numpy as np +import unittest +import sys +sys.path.append("..") +from op_test import OpTest +import paddle +import paddle.fluid as fluid + +paddle.enable_static() +SEED = 2021 + + +def reference_matmul(X, Y, transpose_X=False, transpose_Y=False, scale=1.0): + """Reference forward implementation using np.matmul.""" + # np.matmul does not support the transpose flags, so we manually + # transpose X and Y appropriately. + if transpose_X: + if X.ndim == 1: + X = X.reshape((X.size, )) + elif X.ndim == 2: + X = X.T + else: + dim = [i for i in range(len(X.shape))] + dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1] + X = np.transpose(X, tuple(dim)) + if transpose_Y: + if Y.ndim == 1: + Y = Y.reshape((Y.size, )) + else: + dim = [i for i in range(len(Y.shape))] + dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1] + Y = np.transpose(Y, tuple(dim)) + + Out = np.matmul(X, Y) + if not Out.shape: + # We do not support 0-dimensional Tensors (scalars). So where + # np.matmul outputs a scalar, we must convert to a Tensor of + # shape (1, ) instead. + # Everywhere else, we are compatible with np.matmul. + Out = np.array([Out], dtype="float64") + if abs(scale - 1.0) > 1e-09: + Out = Out * scale + return Out + + +class TestMatMulOp(OpTest): + """ + basic case + """ + + def setUp(self): + self.set_npu() + self.op_type = "matmul" + self.init_dtype() + self.init_alpha() + self.config() + + X = np.random.random(self.x_shape).astype(self.dtype) + Y = np.random.random(self.y_shape).astype(self.dtype) + # -0.1 ~ 0.1 + X = -0.1 + 0.2 * X + Y = -0.1 + 0.2 * Y + + Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y, + self.alpha) + Out = Out.astype(self.dtype) + self.inputs = {'X': X, 'Y': Y} + self.attrs = { + 'transpose_X': self.transpose_X, + 'transpose_Y': self.transpose_Y, + 'alpha': self.alpha + } + self.outputs = {'Out': Out} + + def set_npu(self): + self.__class__.use_npu = True + self.place = paddle.NPUPlace(0) + + def config(self): + self.x_shape = (100, ) + self.y_shape = (100, ) + self.transpose_X = False + self.transpose_Y = False + + def init_alpha(self): + self.alpha = 1.0 + + def init_dtype(self): + self.dtype = "float32" + + def test_check_output(self): + self.check_output_with_place(self.place, atol=1e-7) + + def test_check_grad_normal(self): + self.check_grad_with_place(self.place, ['X', 'Y'], 'Out') + + +class TestMatMulOp1(TestMatMulOp): + """ + case x_ndim == 1, y_ndim != 1 + """ + + def config(self): + self.x_shape = (100, ) + self.y_shape = (1, 3, 2, 100) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp2(TestMatMulOp): + """ + case x_ndim != 1, y_ndim == 1 + """ + + def config(self): + self.x_shape = (1, 2, 100, 1) + self.y_shape = (100, ) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp3(TestMatMulOp): + """ + case [M, K] x [K, N] = [M, N] + """ + + def config(self): + self.x_shape = (2, 100) + self.y_shape = (100, 2) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp4(TestMatMulOp): + """ + case [M, K] x [K, N] = [M, N] + """ + + def config(self): + self.x_shape = (2, 100) + self.y_shape = (2, 100) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp5(TestMatMulOp): + """ + case [M, K] x [K, N] = [M, N] + """ + + def config(self): + self.x_shape = (100, 2) + self.y_shape = (100, 2) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp6(TestMatMulOp): + """ + case [B, M, K] x [K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 2, 25) + self.y_shape = (25, 4) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp7(TestMatMulOp): + """ + case [B, M, K] x [K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (1, 2, 25) + self.y_shape = (4, 25) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp8(TestMatMulOp): + """ + case [B, M, K] x [K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (1, 25, 4) + self.y_shape = (25, 4) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp9(TestMatMulOp): + """ + case [B, M, K] x [B, K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 5, 10) + self.y_shape = (2, 10, 5) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp10(TestMatMulOp): + """ + case [B, M, K] x [B, K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 10, 5) + self.y_shape = (2, 10, 5) + self.transpose_X = True + self.transpose_Y = False + + +class TestMatMulOp11(TestMatMulOp): + """ + case [B, M, K] x [B, K, N] = [B, M, N] + """ + + def config(self): + self.x_shape = (2, 5, 10) + self.y_shape = (2, 5, 10) + self.transpose_X = False + self.transpose_Y = True + + +class TestMatMulOp12(TestMatMulOp): + """ + case to check the gradient for special case + """ + + def config(self): + self.x_shape = (100) + self.y_shape = (1, 2, 2, 100, 2) + self.transpose_X = False + self.transpose_Y = False + + +class TestMatMulOp13(TestMatMulOp): + """ + case to check the gradient for special case + """ + + def config(self): + self.x_shape = (2, 1, 100) + self.y_shape = (100) + self.transpose_X = False + self.transpose_Y = False + + +#--------------------test matmul alpha-------------------- +def create_test_alpha_class(parent): + class TestMatMulOpAlphaCase(parent): + def init_alpha(self): + self.alpha = 0.125 + + cls_name = "{0}_{1}".format(parent.__name__, "Alpha") + TestMatMulOpAlphaCase.__name__ = cls_name + globals()[cls_name] = TestMatMulOpAlphaCase + + +create_test_alpha_class(TestMatMulOp) +create_test_alpha_class(TestMatMulOp1) +create_test_alpha_class(TestMatMulOp2) +create_test_alpha_class(TestMatMulOp3) +create_test_alpha_class(TestMatMulOp4) +create_test_alpha_class(TestMatMulOp5) +create_test_alpha_class(TestMatMulOp6) +create_test_alpha_class(TestMatMulOp9) +create_test_alpha_class(TestMatMulOp10) +create_test_alpha_class(TestMatMulOp11) +create_test_alpha_class(TestMatMulOp12) +create_test_alpha_class(TestMatMulOp13) + + +#--------------------test matmul fp16-------------------- +def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5): + class TestMatMulOpFp16Case(parent): + def init_kernel_type(self): + self.dtype = np.float16 + + def test_check_output(self): + self.check_output_with_place(self.place, atol=atol) + + def test_check_grad(self): + self.check_grad_with_place( + self.place, ['X', 'Y'], + 'Out', + max_relative_error=max_relative_error) + + cls_name = "{0}_{1}".format(parent.__name__, "Fp16") + TestMatMulOpFp16Case.__name__ = cls_name + globals()[cls_name] = TestMatMulOpFp16Case + + +create_test_fp16_class(TestMatMulOp) +create_test_fp16_class(TestMatMulOp1) +create_test_fp16_class(TestMatMulOp2) +create_test_fp16_class(TestMatMulOp3) +create_test_fp16_class(TestMatMulOp4) +create_test_fp16_class(TestMatMulOp5) +create_test_fp16_class(TestMatMulOp6) +create_test_fp16_class(TestMatMulOp9) +create_test_fp16_class(TestMatMulOp10) +create_test_fp16_class(TestMatMulOp11) +create_test_fp16_class(TestMatMulOp12) +create_test_fp16_class(TestMatMulOp13) + +if __name__ == "__main__": + unittest.main() -- GitLab