From 5648bd80d9dc07afea3b93395e61888cdeb40424 Mon Sep 17 00:00:00 2001 From: pangyoki Date: Mon, 12 Apr 2021 17:12:44 +0800 Subject: [PATCH] [NPU] Remove TensorFromVector and avoid sync copy in npu op kernel for better performance (#31994) * enable async copy and add wait before sync operation * remove unneccessary wait * add FillNpuTensorWithConstant * refine * fix fill_constant * change TensorFromVector to FillNpuTensorWithConstant * fix ignored api * delete extra unittest * fix little error * fix update_loss_scaling_op_npu and check_finite_and_unscale_op_npu * change TensorCopySync to TensorCopy * delete useless Wait and add StreamWait * fix npu_stream error * fix check_finite_and_unscale_op_npu TensorCopy * only save stream wait * fix NPUDeviceContext in all c++ unittest * delete wait Co-authored-by: zhiqiu --- paddle/fluid/operators/activation_op_npu.cc | 3 +- .../amp/check_finite_and_unscale_op_npu.cc | 11 +- .../amp/update_loss_scaling_op_npu.cc | 4 +- .../elementwise/elementwise_add_op_npu.cc | 14 +- .../elementwise/elementwise_sub_op_npu.cc | 4 +- paddle/fluid/operators/increment_op_npu.cc | 27 ++-- paddle/fluid/operators/layer_norm_op_npu.cc | 9 +- .../operators/lookup_table_v2_op_npu_test.cc | 142 ++++++++++++++++++ paddle/fluid/operators/mean_op_npu.cc | 39 ++--- .../fluid/operators/optimizers/adam_op_npu.cc | 46 +++--- .../fluid/operators/optimizers/sgd_op_npu.cc | 5 +- paddle/fluid/operators/range_op_npu.cc | 30 ++-- .../softmax_with_cross_entropy_op_npu.cc | 12 +- paddle/fluid/operators/top_k_op_npu.cc | 24 ++- .../truncated_gaussian_random_op_npu.cc | 12 +- paddle/fluid/platform/stream/npu_stream.cc | 1 - .../unittests/npu/test_increment_op_npu.py | 17 ++- 17 files changed, 255 insertions(+), 145 deletions(-) create mode 100644 paddle/fluid/operators/lookup_table_v2_op_npu_test.cc diff --git a/paddle/fluid/operators/activation_op_npu.cc b/paddle/fluid/operators/activation_op_npu.cc index 923b581af28..f368c658230 100644 --- a/paddle/fluid/operators/activation_op_npu.cc +++ b/paddle/fluid/operators/activation_op_npu.cc @@ -77,8 +77,7 @@ class PowGradNPUKernel : public framework::OpKernel { // 2.1 Get a factor tensor with shape [1]. Tensor factor_tensor(framework::proto::VarType::FP32); factor_tensor.mutable_data({1}, place); - TensorFromVector(std::vector{factor}, ctx.device_context(), - &factor_tensor); + FillNpuTensorWithConstant(&factor_tensor, factor); // 2.2 Get the factor which has the shape with x and the same value with // factor. diff --git a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc index 3db45805025..21968dcb05d 100644 --- a/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc +++ b/paddle/fluid/operators/amp/check_finite_and_unscale_op_npu.cc @@ -44,10 +44,7 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { // step1: inverse scale(RealDiv) Tensor const_tensor; const_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{static_cast(1.0)}, ctx.device_context(), - &const_tensor); - - ctx.template device_context().Wait(); + FillNpuTensorWithConstant(&const_tensor, static_cast(1.0)); // Inverse(1.0/scale) Tensor* tmp_inverse_out = const_cast(scale); @@ -105,7 +102,11 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel { bool* is_found_inf = found_inf_tensor.mutable_data(paddle::platform::CPUPlace()); *is_found_inf = true; - framework::TensorCopySync(found_inf_tensor, ctx.GetPlace(), found_inf); + + framework::TensorCopy( + found_inf_tensor, ctx.GetPlace(), + ctx.template device_context(), found_inf); + ctx.template device_context().Wait(); } } }; diff --git a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc index dd6dbfd5c0b..45b28bf61e5 100644 --- a/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc +++ b/paddle/fluid/operators/amp/update_loss_scaling_op_npu.cc @@ -41,7 +41,7 @@ void Update(const platform::NPUDeviceContext& ctx, // bad_out_data = bad_in_data + 1 Tensor factor_tensor(bad_out_tensor->type()); factor_tensor.mutable_data({1}, place); - TensorFromVector(std::vector{1}, ctx, &factor_tensor); + FillNpuTensorWithConstant(&factor_tensor, static_cast(1)); auto runner_p2 = NpuOpRunner("Add", {*bad_in_tensor, factor_tensor}, {*bad_out_tensor}, {}); runner_p2.Run(stream); @@ -84,7 +84,7 @@ void Update(const platform::NPUDeviceContext& ctx, // good_out_data = good_in_data + 1 Tensor factor_tensor(good_out_tensor->type()); factor_tensor.mutable_data({1}, place); - TensorFromVector(std::vector{1}, ctx, &factor_tensor); + FillNpuTensorWithConstant(&factor_tensor, static_cast(1)); auto runner_p2 = NpuOpRunner("Add", {*good_in_tensor, factor_tensor}, {*good_out_tensor}, {}); runner_p2.Run(stream); diff --git a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc index 5b8d08a8943..3768748931d 100644 --- a/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_add_op_npu.cc @@ -100,9 +100,9 @@ class ElementwiseAddGradNPUKernel : public framework::OpKernel { {{"axes", axes}, {"keep_dims", true}}); runner.Run(stream); } else { - ctx.template device_context() - .Wait(); - framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx); + framework::TensorCopy( + *tmp_dout, ctx.GetPlace(), + ctx.template device_context(), dx); } } @@ -127,8 +127,6 @@ class ElementwiseAddGradNPUKernel : public framework::OpKernel { {{"axes", axes}, {"keep_dims", false}}); runner.Run(stream); tmp_dout = &reduced_dout; - ctx.template device_context() - .Wait(); } // stage 2 @@ -144,9 +142,9 @@ class ElementwiseAddGradNPUKernel : public framework::OpKernel { {{"axes", axes}, {"keep_dims", true}}); runner.Run(stream); } else { - ctx.template device_context() - .Wait(); - framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dy); + framework::TensorCopy( + *tmp_dout, ctx.GetPlace(), + ctx.template device_context(), dy); } } } diff --git a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc index 809445c2862..a6e438f8016 100644 --- a/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc +++ b/paddle/fluid/operators/elementwise/elementwise_sub_op_npu.cc @@ -102,7 +102,9 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel { {{"axes", axes}, {"keep_dims", true}}); runner.Run(stream); } else { - framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx); + framework::TensorCopy( + *tmp_dout, ctx.GetPlace(), + ctx.template device_context(), dx); } } if (dy) { diff --git a/paddle/fluid/operators/increment_op_npu.cc b/paddle/fluid/operators/increment_op_npu.cc index 90f9787cc38..22c70e99a44 100644 --- a/paddle/fluid/operators/increment_op_npu.cc +++ b/paddle/fluid/operators/increment_op_npu.cc @@ -12,10 +12,9 @@ // See the License for the specific language governing permissions and // limitations under the License. - #include "paddle/fluid/operators/increment_op.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace framework { @@ -30,7 +29,6 @@ class OpBase; namespace paddle { namespace operators { - template class IncrementalNPUKernel : public framework::OpKernel { public: @@ -41,21 +39,15 @@ class IncrementalNPUKernel : public framework::OpKernel { out_tensor->mutable_data(context.GetPlace()); Tensor step_tensor(x_tensor->type()); - std::vector step_vec; - step_vec.push_back(static_cast(step)); - framework::TensorFromVector( - step_vec, - context.device_context(), - &step_tensor); + step_tensor.mutable_data({1}, context.GetPlace()); + FillNpuTensorWithConstant(&step_tensor, static_cast(step)); - auto runner = NpuOpRunner("Add", - {*x_tensor, step_tensor}, - {*out_tensor}, - {}); + auto runner = + NpuOpRunner("Add", {*x_tensor, step_tensor}, {*out_tensor}, {}); auto stream = - context.template device_context() - .stream(); + context.template device_context() + .stream(); runner.Run(stream); } }; @@ -63,7 +55,6 @@ class IncrementalNPUKernel : public framework::OpKernel { } // namespace operators } // namespace paddle - namespace plat = paddle::platform; namespace ops = paddle::operators; @@ -73,5 +64,5 @@ REGISTER_OP_NPU_KERNEL( ops::IncrementalNPUKernel, ops::IncrementalNPUKernel, ops::IncrementalNPUKernel, - ops::IncrementalNPUKernel) - + ops::IncrementalNPUKernel) diff --git a/paddle/fluid/operators/layer_norm_op_npu.cc b/paddle/fluid/operators/layer_norm_op_npu.cc index 95549319cd2..c0c228ef22a 100644 --- a/paddle/fluid/operators/layer_norm_op_npu.cc +++ b/paddle/fluid/operators/layer_norm_op_npu.cc @@ -80,8 +80,7 @@ class LayerNormNPUKernel : public framework::OpKernel { default_scale.mutable_data(framework::make_ddim(axes), place); Tensor value(x->type()); value.mutable_data({1}, place); - TensorFromVector(std::vector{static_cast(1.0)}, - ctx.device_context(), &value); + FillNpuTensorWithConstant(&value, static_cast(1.0)); auto runner = NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}}); runner.Run(stream); @@ -95,8 +94,7 @@ class LayerNormNPUKernel : public framework::OpKernel { default_bias.mutable_data(framework::make_ddim(axes), place); Tensor value(x->type()); value.mutable_data({1}, place); - TensorFromVector(std::vector{static_cast(0)}, ctx.device_context(), - &value); + FillNpuTensorWithConstant(&value, static_cast(0)); auto runner = NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}}); runner.Run(stream); @@ -251,8 +249,7 @@ class LayerNormGradNPUKernel : public framework::OpKernel { default_scale.mutable_data(framework::make_ddim(axes), place); Tensor value(x->type()); value.mutable_data({1}, place); - TensorFromVector(std::vector{static_cast(1.0)}, - ctx.device_context(), &value); + FillNpuTensorWithConstant(&value, static_cast(1.0)); auto runner = NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}}); runner.Run(stream); diff --git a/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc new file mode 100644 index 00000000000..028d70b4224 --- /dev/null +++ b/paddle/fluid/operators/lookup_table_v2_op_npu_test.cc @@ -0,0 +1,142 @@ +/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. */ + +#ifndef _WIN32 +#include +#endif + +#include +#include +#include +#include +#include // NOLINT +#include + +#include "gtest/gtest.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/operator.h" +#include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/operators/dropout_op.h" +#include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/string/printf.h" + +namespace f = paddle::framework; +namespace p = paddle::platform; +namespace m = paddle::operators::math; + +USE_OP(lookup_table_v2); +USE_OP_DEVICE_KERNEL(lookup_table_v2, NPU); + +template +void Compare(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto ids = scope->Var("Ids"); + auto out = scope->Var("Out"); + auto w = scope->Var("W"); + + auto ids_t = ids->GetMutable(); + auto out_t = out->GetMutable(); + auto w_t = w->GetMutable(); + int bsz = 10; + int dim = 32; + int seqlen = 8; + int vocab_size = 100; + TensorFromVector(std::vector(bsz * seqlen, 3), ctx, ids_t); + std::vector val(vocab_size * dim, 10.); + TensorFromVector(val, ctx, w_t); + ids_t->Resize({bsz, seqlen}); + w_t->Resize({vocab_size, dim}); + out_t->Resize({bsz, seqlen, dim}); + ctx.Wait(); + + auto place = ctx.GetPlace(); + out_t->mutable_data(place); + f::AttributeMap attrs = {{}}; + auto op = f::OpRegistry::CreateOp("lookup_table_v2", + {{"W", {"W"}}, {"Ids", {"Ids"}}}, + {{"Out", {"Out"}}}, attrs); + op->Run(*scope, place); + std::vector out_v; + TensorToVector(*out_t, ctx, &out_v); + ctx.Wait(); + EXPECT_EQ(out_t->numel(), bsz * seqlen * dim); + T res = std::accumulate(out_v.begin(), out_v.end(), 0.); + float eps = 1.e-6; + EXPECT_LT(fabs(res - bsz * seqlen * dim * 10.), eps); +} + +template +void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) { + // init + auto w = scope->Var("W"); + auto ids = scope->Var("Ids"); + auto out = scope->Var("DOut"); + auto dw = scope->Var("DW"); + + auto w_t = w->GetMutable(); + auto ids_t = ids->GetMutable(); + auto out_t = out->GetMutable(); + auto dw_t = dw->GetMutable(); + + int bsz = 2; + int dim = 2; + int seqlen = 2; + int vocab_size = 4; + + std::vector val_int(bsz * seqlen, 3); + std::vector val(vocab_size * dim, 0.); + std::vector val_out(bsz * seqlen * dim, 1.); + + TensorFromVector(val_int, ctx, ids_t); + TensorFromVector(val, ctx, w_t); + TensorFromVector(val, ctx, dw_t); + TensorFromVector(val_out, ctx, out_t); + + w_t->Resize({vocab_size, dim}); + ids_t->Resize({bsz, seqlen}); + out_t->Resize({bsz, seqlen, dim}); + dw_t->Resize({vocab_size, dim}); + + ctx.Wait(); + + auto place = ctx.GetPlace(); + out_t->mutable_data(place); + w_t->mutable_data(place); + dw_t->mutable_data(place); + f::AttributeMap attrs = {{}}; + auto op = f::OpRegistry::CreateOp( + "lookup_table_v2_grad", + {{"Ids", {"Ids"}}, {"W", {"W"}}, {"Out@GRAD", {"DOut"}}}, + {{"W@GRAD", {"DW"}}}, attrs); + op->Run(*scope, place); + ctx.Wait(); + std::vector w_v; + TensorToVector(*dw_t, ctx, &w_v); + ctx.Wait(); + EXPECT_EQ(dw_t->numel(), vocab_size * dim); + T res = std::accumulate(w_v.begin(), w_v.end(), 0.); + float eps = 1.e-6; + EXPECT_LT(fabs(res - bsz * seqlen * dim), eps); +} + +TEST(lookup_table_v2, NPU_fp32) { + f::Scope scope; + auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(0)); + Compare(&scope, *ctx); +} + +TEST(lookup_table_v2_grad, NPU_fp32) { + f::Scope scope; + auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(0)); + CompareGrad(&scope, *ctx); +} diff --git a/paddle/fluid/operators/mean_op_npu.cc b/paddle/fluid/operators/mean_op_npu.cc index a577da80de4..d6e982039fa 100644 --- a/paddle/fluid/operators/mean_op_npu.cc +++ b/paddle/fluid/operators/mean_op_npu.cc @@ -10,9 +10,8 @@ See the License for the specific language governing permissions and limitations under the License. */ #include "paddle/fluid/operators/mean_op.h" -#include "paddle/fluid/platform/float16.h" #include "paddle/fluid/operators/npu_op_runner.h" - +#include "paddle/fluid/platform/float16.h" namespace paddle { namespace operators { @@ -26,34 +25,27 @@ class MeanNPUKernel : public framework::OpKernel { std::vector axes; - framework::NPUAttributeMap attr_input = { - {"keep_dims", false}, - {"axes", axes}}; + framework::NPUAttributeMap attr_input = {{"keep_dims", false}, + {"axes", axes}}; out->mutable_data(ctx.GetPlace()); - auto runner = NpuOpRunner("ReduceMeanD", - {*x}, - {*out}, - attr_input); + auto runner = NpuOpRunner("ReduceMeanD", {*x}, {*out}, attr_input); auto stream = - ctx.template device_context< - paddle::platform::NPUDeviceContext>() - .stream(); + ctx.template device_context() + .stream(); runner.Run(stream); } }; - template class MeanGradNPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext& context) const override { auto stream = - context.template device_context< - paddle::platform::NPUDeviceContext>() - .stream(); + context.template device_context() + .stream(); auto grad = context.Input(framework::GradVarName("Out")); @@ -76,11 +68,8 @@ class MeanGradNPUKernel : public framework::OpKernel { Tensor mean_tensor(grad->type()); mean_tensor.Resize({1}); mean_tensor.mutable_data(context.GetPlace()); - std::vector mean_vec; - mean_vec.push_back(1.0/static_cast(IG->numel())); - framework::TensorFromVector(mean_vec, - context.device_context(), - &mean_tensor); + FillNpuTensorWithConstant( + &mean_tensor, static_cast(1.0 / static_cast(IG->numel()))); // means mul ones Tensor mean_ma(grad->type()); @@ -95,23 +84,19 @@ class MeanGradNPUKernel : public framework::OpKernel { } }; - } // namespace operators } // namespace paddle namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_NPU_KERNEL( - mean, - ops::MeanNPUKernel, + mean, ops::MeanNPUKernel, ops::MeanNPUKernel, ops::MeanNPUKernel, ops::MeanNPUKernel) - REGISTER_OP_NPU_KERNEL( - mean_grad, - ops::MeanGradNPUKernel, + mean_grad, ops::MeanGradNPUKernel, ops::MeanGradNPUKernel, ops::MeanGradNPUKernel, ops::MeanGradNPUKernel) diff --git a/paddle/fluid/operators/optimizers/adam_op_npu.cc b/paddle/fluid/operators/optimizers/adam_op_npu.cc index e2d262ff97d..6592022711e 100644 --- a/paddle/fluid/operators/optimizers/adam_op_npu.cc +++ b/paddle/fluid/operators/optimizers/adam_op_npu.cc @@ -61,23 +61,17 @@ class AdamNPUKernel : public framework::OpKernel { param_out->mutable_data(ctx.GetPlace()); mom1_out->mutable_data(ctx.GetPlace()); mom2_out->mutable_data(ctx.GetPlace()); + beta1_pow_out->mutable_data(ctx.GetPlace()); + beta2_pow_out->mutable_data(ctx.GetPlace()); // NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place. if (beta1_pow->place() == platform::CPUPlace()) { - float beta1 = *beta1_pow->data(); - beta1_pow_out->mutable_data(ctx.GetPlace()); - TensorFromVector(std::vector{beta1}, ctx.device_context(), - beta1_pow_out); - } else { - beta1_pow_out->mutable_data(ctx.GetPlace()); + T beta1 = *beta1_pow->data(); + FillNpuTensorWithConstant(beta1_pow_out, beta1); } if (beta2_pow->place() == platform::CPUPlace()) { - float beta2 = *beta2_pow->data(); - beta2_pow_out->mutable_data(ctx.GetPlace()); - TensorFromVector(std::vector{beta2}, ctx.device_context(), - beta2_pow_out); - } else { - beta2_pow_out->mutable_data(ctx.GetPlace()); + T beta2 = *beta2_pow->data(); + FillNpuTensorWithConstant(beta2_pow_out, beta2); } T beta1 = static_cast(ctx.Attr("beta1")); @@ -116,18 +110,15 @@ class AdamNPUKernel : public framework::OpKernel { // reshape Tensor beta1_tensor(framework::proto::VarType::FP32); - beta1_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{beta1}, ctx.device_context(), - &beta1_tensor); + beta1_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&beta1_tensor, beta1); Tensor beta2_tensor(framework::proto::VarType::FP32); - beta2_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{beta2}, ctx.device_context(), - &beta2_tensor); + beta2_tensor.mutable_data({1}, ctx.GetPlace()); + FillNpuTensorWithConstant(&beta2_tensor, beta2); Tensor epsilon_tensor(framework::proto::VarType::FP32); epsilon_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{epsilon}, ctx.device_context(), - &epsilon_tensor); + FillNpuTensorWithConstant(&epsilon_tensor, epsilon); auto stream = ctx.template device_context() .stream(); @@ -146,16 +137,19 @@ class AdamNPUKernel : public framework::OpKernel { // NOTE(zhiqiu): ApplyAdamD updates params inplace, so // if param and param_out is not same, we need to do copy. if (param_out->data() != param->data()) { - ctx.template device_context().Wait(); - framework::TensorCopySync(*param, ctx.GetPlace(), param_out); + framework::TensorCopy( + *param, ctx.GetPlace(), + ctx.template device_context(), param_out); } if (mom1_out->data() != mom1->data()) { - ctx.template device_context().Wait(); - framework::TensorCopySync(*mom1, ctx.GetPlace(), mom1_out); + framework::TensorCopy( + *mom1, ctx.GetPlace(), + ctx.template device_context(), mom1_out); } if (mom2_out->data() != mom2->data()) { - ctx.template device_context().Wait(); - framework::TensorCopySync(*mom2, ctx.GetPlace(), mom2_out); + framework::TensorCopy( + *mom2, ctx.GetPlace(), + ctx.template device_context(), mom2_out); } auto runner_m1 = NpuOpRunner("Mul", {*beta1_pow, beta1_tensor}, {*beta1_pow_out}, {}); diff --git a/paddle/fluid/operators/optimizers/sgd_op_npu.cc b/paddle/fluid/operators/optimizers/sgd_op_npu.cc index b7aaff5d457..a8d19148ef5 100644 --- a/paddle/fluid/operators/optimizers/sgd_op_npu.cc +++ b/paddle/fluid/operators/optimizers/sgd_op_npu.cc @@ -44,8 +44,9 @@ class SGDNPUKernel : public framework::OpKernel { // NOTE(zhiqiu): ApplyGradientDescent updates params inplace, so // if param and param_out is not same, we need to do copy. if (param_out->data() != param_var->data()) { - ctx.template device_context().Wait(); - framework::TensorCopySync(*param_var, ctx.GetPlace(), param_out); + framework::TensorCopy( + *param_var, ctx.GetPlace(), + ctx.template device_context(), param_out); } } }; diff --git a/paddle/fluid/operators/range_op_npu.cc b/paddle/fluid/operators/range_op_npu.cc index acdc092ade3..a9a2effd2eb 100644 --- a/paddle/fluid/operators/range_op_npu.cc +++ b/paddle/fluid/operators/range_op_npu.cc @@ -16,20 +16,19 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/range_op.h" -#include "paddle/fluid/operators/npu_op_runner.h" -#include "paddle/fluid/operators/utils.h" #include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/framework/operator.h" #include "paddle/fluid/framework/program_desc.h" +#include "paddle/fluid/framework/tensor_util.h" #include "paddle/fluid/operators/dropout_op.h" #include "paddle/fluid/operators/math/math_function.h" +#include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/operators/range_op.h" +#include "paddle/fluid/operators/utils.h" namespace paddle { namespace operators { - template class RangeNPUKernel : public framework::OpKernel { public: @@ -40,11 +39,23 @@ class RangeNPUKernel : public framework::OpKernel { auto* out = context.Output("Out"); framework::Tensor n; - framework::TensorCopySync(*start_t, platform::CPUPlace(), &n); + framework::TensorCopy( + *start_t, platform::CPUPlace(), + context.template device_context(), &n); + context.template device_context() + .Wait(); T start = n.data()[0]; - framework::TensorCopySync(*end_t, platform::CPUPlace(), &n); + framework::TensorCopy( + *end_t, platform::CPUPlace(), + context.template device_context(), &n); + context.template device_context() + .Wait(); T end = n.data()[0]; - framework::TensorCopySync(*step_t, platform::CPUPlace(), &n); + framework::TensorCopy( + *step_t, platform::CPUPlace(), + context.template device_context(), &n); + context.template device_context() + .Wait(); T step = n.data()[0]; int64_t size = 0; @@ -70,8 +81,7 @@ class RangeNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; REGISTER_OP_NPU_KERNEL( - range, - ops::RangeNPUKernel, + range, ops::RangeNPUKernel, ops::RangeNPUKernel, ops::RangeNPUKernel) diff --git a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc index c777a02f96b..a34946315f5 100644 --- a/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc +++ b/paddle/fluid/operators/softmax_with_cross_entropy_op_npu.cc @@ -67,12 +67,10 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel { // on and off Tensor on_tensor(framework::proto::VarType::INT32); on_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{static_cast(1)}, - ctx.device_context(), &on_tensor); + FillNpuTensorWithConstant(&on_tensor, static_cast(1)); Tensor off_tensor(framework::proto::VarType::INT32); off_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{static_cast(0)}, - ctx.device_context(), &off_tensor); + FillNpuTensorWithConstant(&off_tensor, static_cast(0)); // one_hot Tensor tmp_onehot(on_tensor.type()); @@ -142,12 +140,10 @@ class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel { // on and off Tensor on_tensor(framework::proto::VarType::INT32); on_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{static_cast(1)}, - ctx.device_context(), &on_tensor); + FillNpuTensorWithConstant(&on_tensor, static_cast(1)); Tensor off_tensor(framework::proto::VarType::INT32); off_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{static_cast(0)}, - ctx.device_context(), &off_tensor); + FillNpuTensorWithConstant(&off_tensor, static_cast(0)); // one_hot Tensor tmp_onehot(on_tensor.type()); diff --git a/paddle/fluid/operators/top_k_op_npu.cc b/paddle/fluid/operators/top_k_op_npu.cc index a4690133410..684bd476b6e 100644 --- a/paddle/fluid/operators/top_k_op_npu.cc +++ b/paddle/fluid/operators/top_k_op_npu.cc @@ -12,14 +12,14 @@ limitations under the License. */ #include #include -#include "paddle/fluid/operators/top_k_op.h" #include "paddle/fluid/operators/npu_op_runner.h" +#include "paddle/fluid/operators/top_k_op.h" namespace paddle { namespace operators { -void gen_assist_seq(framework::Tensor* assit_tensor, - int64_t dim, const framework::ExecutionContext& ctx) { +void gen_assist_seq(framework::Tensor* assit_tensor, int64_t dim, + const framework::ExecutionContext& ctx) { const int64_t dimx2 = dim; std::vector assit; assit.resize(2 * dimx2); @@ -28,15 +28,14 @@ void gen_assist_seq(framework::Tensor* assit_tensor, assit[i] = static_cast(i); // for i in range [dim, dimx2] - int64_t idx = static_cast( - static_cast(i)); + int64_t idx = + static_cast(static_cast(i)); int64_t gap = i - idx; assit[i + dim] = static_cast(gap); } framework::TensorFromVector(assit, ctx.device_context(), assit_tensor); } - template class TopkNPUKernel : public framework::OpKernel { public: @@ -64,10 +63,8 @@ class TopkNPUKernel : public framework::OpKernel { {"largest", true}}; // run ascend - auto runner = NpuOpRunner("TopKD", - {*input, assist_seq_tensor}, - {*output, *indices}, - attr_input); + auto runner = NpuOpRunner("TopKD", {*input, assist_seq_tensor}, + {*output, *indices}, attr_input); auto stream = ctx.template device_context() @@ -83,7 +80,6 @@ class TopkNPUKernel : public framework::OpKernel { namespace ops = paddle::operators; // Ascend Op TopKD only support input float 16 dtype -REGISTER_OP_NPU_KERNEL( - top_k, - ops::TopkNPUKernel); +REGISTER_OP_NPU_KERNEL(top_k, + ops::TopkNPUKernel); diff --git a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc index 4253187fdde..7f3190d9112 100644 --- a/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc +++ b/paddle/fluid/operators/truncated_gaussian_random_op_npu.cc @@ -35,28 +35,24 @@ class TruncatedGaussianRandomNPUKernel : public framework::OpKernel { float mean = ctx.Attr("mean"); Tensor mean_tensor(framework::proto::VarType::FP32); mean_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{mean}, ctx.device_context(), - &mean_tensor); + FillNpuTensorWithConstant(&mean_tensor, mean); float std = ctx.Attr("std"); Tensor std_tensor(framework::proto::VarType::FP32); std_tensor.mutable_data({1}, ctx.GetPlace()); - TensorFromVector(std::vector{std}, ctx.device_context(), - &std_tensor); + FillNpuTensorWithConstant(&std_tensor, std); int32_t seed_var = ctx.Attr("seed"); Tensor min_tensor(framework::proto::VarType::FP32); min_tensor.mutable_data({1}, ctx.GetPlace()); float min_value = mean - std * 2.0; - TensorFromVector(std::vector{min_value}, ctx.device_context(), - &min_tensor); + FillNpuTensorWithConstant(&min_tensor, min_value); Tensor max_tensor(framework::proto::VarType::FP32); max_tensor.mutable_data({1}, ctx.GetPlace()); float max_value = mean + std * 2.0; - TensorFromVector(std::vector{max_value}, ctx.device_context(), - &max_tensor); + FillNpuTensorWithConstant(&max_tensor, max_value); auto* out = ctx.Output("Out"); out->mutable_data(ctx.GetPlace()); diff --git a/paddle/fluid/platform/stream/npu_stream.cc b/paddle/fluid/platform/stream/npu_stream.cc index 1a07a1ed837..2664ac7194b 100644 --- a/paddle/fluid/platform/stream/npu_stream.cc +++ b/paddle/fluid/platform/stream/npu_stream.cc @@ -46,7 +46,6 @@ void NPUStream::Wait() const { PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_)); } - } // namespace stream } // namespace platform } // namespace paddle diff --git a/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py b/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py index 09019e36c82..3e2e8f944b8 100644 --- a/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py +++ b/python/paddle/fluid/tests/unittests/npu/test_increment_op_npu.py @@ -26,7 +26,7 @@ from paddle.fluid import core paddle.enable_static() SEED = 2021 -NPUPlace = 5 +NPUPlace = 0 @unittest.skipIf(not paddle.is_compiled_with_npu(), @@ -38,7 +38,10 @@ class TestIncrement(OpTest): self.op_type = "increment" self.init_dtype() - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), } + self.inputs = { + 'X': + OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), + } self.attrs = {"Step": 1} self.outputs = {'Out': np.array([2])} @@ -63,7 +66,10 @@ class TestIncrementFP16(OpTest): self.op_type = "increment" self.init_dtype() - self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), } + self.inputs = { + 'X': + OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), + } self.pre_input_id = id(self.inputs['X']) self.attrs = {"Step": 1} @@ -100,10 +106,7 @@ class TestIncrementInplace(unittest.TestCase): exe = paddle.static.Executor(place) exe.run(startup_prog) - b_value = exe.run( - main_prog, - feed={"a": a_np,}, - fetch_list=[b]) + b_value = exe.run(main_prog, feed={"a": a_np, }, fetch_list=[b]) print('input a id is : {}'.format(id(a))) print('input b id is : {}'.format(id(b))) -- GitLab