未验证 提交 5648bd80 编写于 作者: P pangyoki 提交者: GitHub

[NPU] Remove TensorFromVector and avoid sync copy in npu op kernel for better performance (#31994)

* enable async copy and  add wait before sync operation

* remove unneccessary wait

* add FillNpuTensorWithConstant

* refine

* fix fill_constant

* change TensorFromVector to FillNpuTensorWithConstant

* fix ignored api

* delete extra unittest

* fix little error

* fix update_loss_scaling_op_npu and check_finite_and_unscale_op_npu

* change TensorCopySync to TensorCopy

* delete useless Wait and add StreamWait

* fix npu_stream error

* fix check_finite_and_unscale_op_npu TensorCopy

* only save stream wait

* fix NPUDeviceContext in all c++ unittest

* delete wait
Co-authored-by: Nzhiqiu <chenqiuliang@baidu.com>
上级 5ad94e7b
......@@ -77,8 +77,7 @@ class PowGradNPUKernel : public framework::OpKernel<T> {
// 2.1 Get a factor tensor with shape [1].
Tensor factor_tensor(framework::proto::VarType::FP32);
factor_tensor.mutable_data<float>({1}, place);
TensorFromVector(std::vector<float>{factor}, ctx.device_context(),
&factor_tensor);
FillNpuTensorWithConstant<float>(&factor_tensor, factor);
// 2.2 Get the factor which has the shape with x and the same value with
// factor.
......
......@@ -44,10 +44,7 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
// step1: inverse scale(RealDiv)
Tensor const_tensor;
const_tensor.mutable_data<T>({1}, ctx.GetPlace());
TensorFromVector(std::vector<T>{static_cast<T>(1.0)}, ctx.device_context(),
&const_tensor);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
FillNpuTensorWithConstant<T>(&const_tensor, static_cast<T>(1.0));
// Inverse(1.0/scale)
Tensor* tmp_inverse_out = const_cast<Tensor*>(scale);
......@@ -105,7 +102,11 @@ class CheckFiniteAndUnscaleNPUKernel : public framework::OpKernel<T> {
bool* is_found_inf =
found_inf_tensor.mutable_data<bool>(paddle::platform::CPUPlace());
*is_found_inf = true;
framework::TensorCopySync(found_inf_tensor, ctx.GetPlace(), found_inf);
framework::TensorCopy(
found_inf_tensor, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), found_inf);
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
}
}
};
......
......@@ -41,7 +41,7 @@ void Update(const platform::NPUDeviceContext& ctx,
// bad_out_data = bad_in_data + 1
Tensor factor_tensor(bad_out_tensor->type());
factor_tensor.mutable_data<int>({1}, place);
TensorFromVector(std::vector<int>{1}, ctx, &factor_tensor);
FillNpuTensorWithConstant<int>(&factor_tensor, static_cast<int>(1));
auto runner_p2 = NpuOpRunner("Add", {*bad_in_tensor, factor_tensor},
{*bad_out_tensor}, {});
runner_p2.Run(stream);
......@@ -84,7 +84,7 @@ void Update(const platform::NPUDeviceContext& ctx,
// good_out_data = good_in_data + 1
Tensor factor_tensor(good_out_tensor->type());
factor_tensor.mutable_data<int>({1}, place);
TensorFromVector(std::vector<int>{1}, ctx, &factor_tensor);
FillNpuTensorWithConstant<int>(&factor_tensor, static_cast<int>(1));
auto runner_p2 = NpuOpRunner("Add", {*good_in_tensor, factor_tensor},
{*good_out_tensor}, {});
runner_p2.Run(stream);
......
......@@ -100,9 +100,9 @@ class ElementwiseAddGradNPUKernel : public framework::OpKernel<T> {
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx);
framework::TensorCopy(
*tmp_dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
}
......@@ -127,8 +127,6 @@ class ElementwiseAddGradNPUKernel : public framework::OpKernel<T> {
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
tmp_dout = &reduced_dout;
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
}
// stage 2
......@@ -144,9 +142,9 @@ class ElementwiseAddGradNPUKernel : public framework::OpKernel<T> {
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dy);
framework::TensorCopy(
*tmp_dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dy);
}
}
}
......
......@@ -102,7 +102,9 @@ class ElementwiseSubGradNPUKernel : public framework::OpKernel<T> {
{{"axes", axes}, {"keep_dims", true}});
runner.Run(stream);
} else {
framework::TensorCopySync(*tmp_dout, ctx.GetPlace(), dx);
framework::TensorCopy(
*tmp_dout, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), dx);
}
}
if (dy) {
......
......@@ -12,10 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/increment_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace framework {
......@@ -30,7 +29,6 @@ class OpBase;
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class IncrementalNPUKernel : public framework::OpKernel<T> {
public:
......@@ -41,21 +39,15 @@ class IncrementalNPUKernel : public framework::OpKernel<T> {
out_tensor->mutable_data<T>(context.GetPlace());
Tensor step_tensor(x_tensor->type());
std::vector<T> step_vec;
step_vec.push_back(static_cast<T>(step));
framework::TensorFromVector(
step_vec,
context.device_context(),
&step_tensor);
step_tensor.mutable_data<T>({1}, context.GetPlace());
FillNpuTensorWithConstant<T>(&step_tensor, static_cast<T>(step));
auto runner = NpuOpRunner("Add",
{*x_tensor, step_tensor},
{*out_tensor},
{});
auto runner =
NpuOpRunner("Add", {*x_tensor, step_tensor}, {*out_tensor}, {});
auto stream =
context.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
context.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
......@@ -63,7 +55,6 @@ class IncrementalNPUKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
namespace plat = paddle::platform;
namespace ops = paddle::operators;
......@@ -73,5 +64,5 @@ REGISTER_OP_NPU_KERNEL(
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, int64_t>,
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext, plat::float16>)
ops::IncrementalNPUKernel<paddle::platform::NPUDeviceContext,
plat::float16>)
......@@ -80,8 +80,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
default_scale.mutable_data<T>(framework::make_ddim(axes), place);
Tensor value(x->type());
value.mutable_data<T>({1}, place);
TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
ctx.device_context(), &value);
FillNpuTensorWithConstant<T>(&value, static_cast<T>(1.0));
auto runner =
NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
runner.Run(stream);
......@@ -95,8 +94,7 @@ class LayerNormNPUKernel : public framework::OpKernel<T> {
default_bias.mutable_data<T>(framework::make_ddim(axes), place);
Tensor value(x->type());
value.mutable_data<T>({1}, place);
TensorFromVector(std::vector<T>{static_cast<T>(0)}, ctx.device_context(),
&value);
FillNpuTensorWithConstant<T>(&value, static_cast<T>(0));
auto runner =
NpuOpRunner("FillD", {value}, {default_bias}, {{"dims", axes}});
runner.Run(stream);
......@@ -251,8 +249,7 @@ class LayerNormGradNPUKernel : public framework::OpKernel<T> {
default_scale.mutable_data<T>(framework::make_ddim(axes), place);
Tensor value(x->type());
value.mutable_data<T>({1}, place);
TensorFromVector(std::vector<T>{static_cast<T>(1.0)},
ctx.device_context(), &value);
FillNpuTensorWithConstant<T>(&value, static_cast<T>(1.0));
auto runner =
NpuOpRunner("FillD", {value}, {default_scale}, {{"dims", axes}});
runner.Run(stream);
......
/* Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifndef _WIN32
#include <unistd.h>
#endif
#include <cmath>
#include <iostream>
#include <numeric>
#include <string>
#include <thread> // NOLINT
#include <vector>
#include "gtest/gtest.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/string/printf.h"
namespace f = paddle::framework;
namespace p = paddle::platform;
namespace m = paddle::operators::math;
USE_OP(lookup_table_v2);
USE_OP_DEVICE_KERNEL(lookup_table_v2, NPU);
template <typename T>
void Compare(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto ids = scope->Var("Ids");
auto out = scope->Var("Out");
auto w = scope->Var("W");
auto ids_t = ids->GetMutable<f::LoDTensor>();
auto out_t = out->GetMutable<f::LoDTensor>();
auto w_t = w->GetMutable<f::LoDTensor>();
int bsz = 10;
int dim = 32;
int seqlen = 8;
int vocab_size = 100;
TensorFromVector(std::vector<int64_t>(bsz * seqlen, 3), ctx, ids_t);
std::vector<T> val(vocab_size * dim, 10.);
TensorFromVector(val, ctx, w_t);
ids_t->Resize({bsz, seqlen});
w_t->Resize({vocab_size, dim});
out_t->Resize({bsz, seqlen, dim});
ctx.Wait();
auto place = ctx.GetPlace();
out_t->mutable_data<T>(place);
f::AttributeMap attrs = {{}};
auto op = f::OpRegistry::CreateOp("lookup_table_v2",
{{"W", {"W"}}, {"Ids", {"Ids"}}},
{{"Out", {"Out"}}}, attrs);
op->Run(*scope, place);
std::vector<T> out_v;
TensorToVector(*out_t, ctx, &out_v);
ctx.Wait();
EXPECT_EQ(out_t->numel(), bsz * seqlen * dim);
T res = std::accumulate(out_v.begin(), out_v.end(), 0.);
float eps = 1.e-6;
EXPECT_LT(fabs(res - bsz * seqlen * dim * 10.), eps);
}
template <typename T>
void CompareGrad(f::Scope* scope, const p::DeviceContext& ctx) {
// init
auto w = scope->Var("W");
auto ids = scope->Var("Ids");
auto out = scope->Var("DOut");
auto dw = scope->Var("DW");
auto w_t = w->GetMutable<f::LoDTensor>();
auto ids_t = ids->GetMutable<f::LoDTensor>();
auto out_t = out->GetMutable<f::LoDTensor>();
auto dw_t = dw->GetMutable<f::LoDTensor>();
int bsz = 2;
int dim = 2;
int seqlen = 2;
int vocab_size = 4;
std::vector<int64_t> val_int(bsz * seqlen, 3);
std::vector<T> val(vocab_size * dim, 0.);
std::vector<T> val_out(bsz * seqlen * dim, 1.);
TensorFromVector(val_int, ctx, ids_t);
TensorFromVector(val, ctx, w_t);
TensorFromVector(val, ctx, dw_t);
TensorFromVector(val_out, ctx, out_t);
w_t->Resize({vocab_size, dim});
ids_t->Resize({bsz, seqlen});
out_t->Resize({bsz, seqlen, dim});
dw_t->Resize({vocab_size, dim});
ctx.Wait();
auto place = ctx.GetPlace();
out_t->mutable_data<T>(place);
w_t->mutable_data<T>(place);
dw_t->mutable_data<T>(place);
f::AttributeMap attrs = {{}};
auto op = f::OpRegistry::CreateOp(
"lookup_table_v2_grad",
{{"Ids", {"Ids"}}, {"W", {"W"}}, {"Out@GRAD", {"DOut"}}},
{{"W@GRAD", {"DW"}}}, attrs);
op->Run(*scope, place);
ctx.Wait();
std::vector<T> w_v;
TensorToVector(*dw_t, ctx, &w_v);
ctx.Wait();
EXPECT_EQ(dw_t->numel(), vocab_size * dim);
T res = std::accumulate(w_v.begin(), w_v.end(), 0.);
float eps = 1.e-6;
EXPECT_LT(fabs(res - bsz * seqlen * dim), eps);
}
TEST(lookup_table_v2, NPU_fp32) {
f::Scope scope;
auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(0));
Compare<float>(&scope, *ctx);
}
TEST(lookup_table_v2_grad, NPU_fp32) {
f::Scope scope;
auto* ctx = p::DeviceContextPool::Instance().Get(p::NPUPlace(0));
CompareGrad<float>(&scope, *ctx);
}
......@@ -10,9 +10,8 @@ See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/mean_op.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/platform/float16.h"
namespace paddle {
namespace operators {
......@@ -26,34 +25,27 @@ class MeanNPUKernel : public framework::OpKernel<T> {
std::vector<int> axes;
framework::NPUAttributeMap attr_input = {
{"keep_dims", false},
{"axes", axes}};
framework::NPUAttributeMap attr_input = {{"keep_dims", false},
{"axes", axes}};
out->mutable_data<T>(ctx.GetPlace());
auto runner = NpuOpRunner("ReduceMeanD",
{*x},
{*out},
attr_input);
auto runner = NpuOpRunner("ReduceMeanD", {*x}, {*out}, attr_input);
auto stream =
ctx.template device_context<
paddle::platform::NPUDeviceContext>()
.stream();
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
}
};
template <typename DeviceContext, typename T>
class MeanGradNPUKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto stream =
context.template device_context<
paddle::platform::NPUDeviceContext>()
.stream();
context.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
auto grad = context.Input<Tensor>(framework::GradVarName("Out"));
......@@ -76,11 +68,8 @@ class MeanGradNPUKernel : public framework::OpKernel<T> {
Tensor mean_tensor(grad->type());
mean_tensor.Resize({1});
mean_tensor.mutable_data<T>(context.GetPlace());
std::vector<float> mean_vec;
mean_vec.push_back(1.0/static_cast<float>(IG->numel()));
framework::TensorFromVector(mean_vec,
context.device_context(),
&mean_tensor);
FillNpuTensorWithConstant<T>(
&mean_tensor, static_cast<T>(1.0 / static_cast<float>(IG->numel())));
// means mul ones
Tensor mean_ma(grad->type());
......@@ -95,23 +84,19 @@ class MeanGradNPUKernel : public framework::OpKernel<T> {
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(
mean,
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, int>,
mean, ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::MeanNPUKernel<paddle::platform::NPUDeviceContext, plat::float16>)
REGISTER_OP_NPU_KERNEL(
mean_grad,
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
mean_grad, ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, double>,
ops::MeanGradNPUKernel<paddle::platform::NPUDeviceContext, plat::float16>)
......@@ -61,23 +61,17 @@ class AdamNPUKernel : public framework::OpKernel<T> {
param_out->mutable_data<T>(ctx.GetPlace());
mom1_out->mutable_data<T>(ctx.GetPlace());
mom2_out->mutable_data<T>(ctx.GetPlace());
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
// NOTE(zhiqiu): beta1_pow and beta2_pow may on CPU and not transform place.
if (beta1_pow->place() == platform::CPUPlace()) {
float beta1 = *beta1_pow->data<float>();
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
TensorFromVector(std::vector<float>{beta1}, ctx.device_context(),
beta1_pow_out);
} else {
beta1_pow_out->mutable_data<T>(ctx.GetPlace());
T beta1 = *beta1_pow->data<T>();
FillNpuTensorWithConstant<T>(beta1_pow_out, beta1);
}
if (beta2_pow->place() == platform::CPUPlace()) {
float beta2 = *beta2_pow->data<float>();
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
TensorFromVector(std::vector<float>{beta2}, ctx.device_context(),
beta2_pow_out);
} else {
beta2_pow_out->mutable_data<T>(ctx.GetPlace());
T beta2 = *beta2_pow->data<T>();
FillNpuTensorWithConstant<T>(beta2_pow_out, beta2);
}
T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
......@@ -116,18 +110,15 @@ class AdamNPUKernel : public framework::OpKernel<T> {
// reshape
Tensor beta1_tensor(framework::proto::VarType::FP32);
beta1_tensor.mutable_data<float>({1}, ctx.GetPlace());
TensorFromVector(std::vector<T>{beta1}, ctx.device_context(),
&beta1_tensor);
beta1_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta1_tensor, beta1);
Tensor beta2_tensor(framework::proto::VarType::FP32);
beta2_tensor.mutable_data<float>({1}, ctx.GetPlace());
TensorFromVector(std::vector<T>{beta2}, ctx.device_context(),
&beta2_tensor);
beta2_tensor.mutable_data<T>({1}, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&beta2_tensor, beta2);
Tensor epsilon_tensor(framework::proto::VarType::FP32);
epsilon_tensor.mutable_data<T>({1}, ctx.GetPlace());
TensorFromVector(std::vector<T>{epsilon}, ctx.device_context(),
&epsilon_tensor);
FillNpuTensorWithConstant<T>(&epsilon_tensor, epsilon);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
......@@ -146,16 +137,19 @@ class AdamNPUKernel : public framework::OpKernel<T> {
// NOTE(zhiqiu): ApplyAdamD updates params inplace, so
// if param and param_out is not same, we need to do copy.
if (param_out->data<T>() != param->data<T>()) {
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
framework::TensorCopySync(*param, ctx.GetPlace(), param_out);
framework::TensorCopy(
*param, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), param_out);
}
if (mom1_out->data<T>() != mom1->data<T>()) {
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
framework::TensorCopySync(*mom1, ctx.GetPlace(), mom1_out);
framework::TensorCopy(
*mom1, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom1_out);
}
if (mom2_out->data<T>() != mom2->data<T>()) {
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
framework::TensorCopySync(*mom2, ctx.GetPlace(), mom2_out);
framework::TensorCopy(
*mom2, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), mom2_out);
}
auto runner_m1 =
NpuOpRunner("Mul", {*beta1_pow, beta1_tensor}, {*beta1_pow_out}, {});
......
......@@ -44,8 +44,9 @@ class SGDNPUKernel : public framework::OpKernel<T> {
// NOTE(zhiqiu): ApplyGradientDescent updates params inplace, so
// if param and param_out is not same, we need to do copy.
if (param_out->data<T>() != param_var->data<T>()) {
ctx.template device_context<paddle::platform::NPUDeviceContext>().Wait();
framework::TensorCopySync(*param_var, ctx.GetPlace(), param_out);
framework::TensorCopy(
*param_var, ctx.GetPlace(),
ctx.template device_context<platform::DeviceContext>(), param_out);
}
}
};
......
......@@ -16,20 +16,19 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/dropout_op.h"
#include "paddle/fluid/operators/math/math_function.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/range_op.h"
#include "paddle/fluid/operators/utils.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class RangeNPUKernel : public framework::OpKernel<T> {
public:
......@@ -40,11 +39,23 @@ class RangeNPUKernel : public framework::OpKernel<T> {
auto* out = context.Output<framework::Tensor>("Out");
framework::Tensor n;
framework::TensorCopySync(*start_t, platform::CPUPlace(), &n);
framework::TensorCopy(
*start_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
T start = n.data<T>()[0];
framework::TensorCopySync(*end_t, platform::CPUPlace(), &n);
framework::TensorCopy(
*end_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
T end = n.data<T>()[0];
framework::TensorCopySync(*step_t, platform::CPUPlace(), &n);
framework::TensorCopy(
*step_t, platform::CPUPlace(),
context.template device_context<platform::DeviceContext>(), &n);
context.template device_context<paddle::platform::NPUDeviceContext>()
.Wait();
T step = n.data<T>()[0];
int64_t size = 0;
......@@ -70,8 +81,7 @@ class RangeNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL(
range,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, int>,
range, ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, int>,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::RangeNPUKernel<paddle::platform::NPUDeviceContext, double>)
......
......@@ -67,12 +67,10 @@ class SoftmaxWithCrossEntropyNPUKernel : public framework::OpKernel<T> {
// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(1)},
ctx.device_context(), &on_tensor);
FillNpuTensorWithConstant<int>(&on_tensor, static_cast<int>(1));
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(0)},
ctx.device_context(), &off_tensor);
FillNpuTensorWithConstant<int>(&off_tensor, static_cast<int>(0));
// one_hot
Tensor tmp_onehot(on_tensor.type());
......@@ -142,12 +140,10 @@ class SoftmaxWithCrossEntropyGradNPUKernel : public framework::OpKernel<T> {
// on and off
Tensor on_tensor(framework::proto::VarType::INT32);
on_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(1)},
ctx.device_context(), &on_tensor);
FillNpuTensorWithConstant<int>(&on_tensor, static_cast<int>(1));
Tensor off_tensor(framework::proto::VarType::INT32);
off_tensor.mutable_data<int>({1}, ctx.GetPlace());
TensorFromVector(std::vector<int>{static_cast<int>(0)},
ctx.device_context(), &off_tensor);
FillNpuTensorWithConstant<int>(&off_tensor, static_cast<int>(0));
// one_hot
Tensor tmp_onehot(on_tensor.type());
......
......@@ -12,14 +12,14 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/top_k_op.h"
#include "paddle/fluid/operators/npu_op_runner.h"
#include "paddle/fluid/operators/top_k_op.h"
namespace paddle {
namespace operators {
void gen_assist_seq(framework::Tensor* assit_tensor,
int64_t dim, const framework::ExecutionContext& ctx) {
void gen_assist_seq(framework::Tensor* assit_tensor, int64_t dim,
const framework::ExecutionContext& ctx) {
const int64_t dimx2 = dim;
std::vector<paddle::platform::float16> assit;
assit.resize(2 * dimx2);
......@@ -28,15 +28,14 @@ void gen_assist_seq(framework::Tensor* assit_tensor,
assit[i] = static_cast<paddle::platform::float16>(i);
// for i in range [dim, dimx2]
int64_t idx = static_cast<int64_t>(
static_cast<paddle::platform::float16>(i));
int64_t idx =
static_cast<int64_t>(static_cast<paddle::platform::float16>(i));
int64_t gap = i - idx;
assit[i + dim] = static_cast<paddle::platform::float16>(gap);
}
framework::TensorFromVector(assit, ctx.device_context(), assit_tensor);
}
template <typename DeviceContext, typename T>
class TopkNPUKernel : public framework::OpKernel<T> {
public:
......@@ -64,10 +63,8 @@ class TopkNPUKernel : public framework::OpKernel<T> {
{"largest", true}};
// run ascend
auto runner = NpuOpRunner("TopKD",
{*input, assist_seq_tensor},
{*output, *indices},
attr_input);
auto runner = NpuOpRunner("TopKD", {*input, assist_seq_tensor},
{*output, *indices}, attr_input);
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
......@@ -83,7 +80,6 @@ class TopkNPUKernel : public framework::OpKernel<T> {
namespace ops = paddle::operators;
// Ascend Op TopKD only support input float 16 dtype
REGISTER_OP_NPU_KERNEL(
top_k,
ops::TopkNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(top_k,
ops::TopkNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
......@@ -35,28 +35,24 @@ class TruncatedGaussianRandomNPUKernel : public framework::OpKernel<T> {
float mean = ctx.Attr<float>("mean");
Tensor mean_tensor(framework::proto::VarType::FP32);
mean_tensor.mutable_data<float>({1}, ctx.GetPlace());
TensorFromVector(std::vector<float>{mean}, ctx.device_context(),
&mean_tensor);
FillNpuTensorWithConstant<float>(&mean_tensor, mean);
float std = ctx.Attr<float>("std");
Tensor std_tensor(framework::proto::VarType::FP32);
std_tensor.mutable_data<float>({1}, ctx.GetPlace());
TensorFromVector(std::vector<float>{std}, ctx.device_context(),
&std_tensor);
FillNpuTensorWithConstant<float>(&std_tensor, std);
int32_t seed_var = ctx.Attr<int32_t>("seed");
Tensor min_tensor(framework::proto::VarType::FP32);
min_tensor.mutable_data<float>({1}, ctx.GetPlace());
float min_value = mean - std * 2.0;
TensorFromVector(std::vector<float>{min_value}, ctx.device_context(),
&min_tensor);
FillNpuTensorWithConstant<float>(&min_tensor, min_value);
Tensor max_tensor(framework::proto::VarType::FP32);
max_tensor.mutable_data<float>({1}, ctx.GetPlace());
float max_value = mean + std * 2.0;
TensorFromVector(std::vector<float>{max_value}, ctx.device_context(),
&max_tensor);
FillNpuTensorWithConstant<float>(&max_tensor, max_value);
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
......
......@@ -46,7 +46,6 @@ void NPUStream::Wait() const {
PADDLE_ENFORCE_NPU_SUCCESS(aclrtSynchronizeStream(stream_));
}
} // namespace stream
} // namespace platform
} // namespace paddle
......@@ -26,7 +26,7 @@ from paddle.fluid import core
paddle.enable_static()
SEED = 2021
NPUPlace = 5
NPUPlace = 0
@unittest.skipIf(not paddle.is_compiled_with_npu(),
......@@ -38,7 +38,10 @@ class TestIncrement(OpTest):
self.op_type = "increment"
self.init_dtype()
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), }
self.inputs = {
'X':
OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)),
}
self.attrs = {"Step": 1}
self.outputs = {'Out': np.array([2])}
......@@ -63,7 +66,10 @@ class TestIncrementFP16(OpTest):
self.op_type = "increment"
self.init_dtype()
self.inputs = {'X': OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)), }
self.inputs = {
'X':
OpTest.np_dtype_to_fluid_dtype(np.array([1]).astype(self.dtype)),
}
self.pre_input_id = id(self.inputs['X'])
self.attrs = {"Step": 1}
......@@ -100,10 +106,7 @@ class TestIncrementInplace(unittest.TestCase):
exe = paddle.static.Executor(place)
exe.run(startup_prog)
b_value = exe.run(
main_prog,
feed={"a": a_np,},
fetch_list=[b])
b_value = exe.run(main_prog, feed={"a": a_np, }, fetch_list=[b])
print('input a id is : {}'.format(id(a)))
print('input b id is : {}'.format(id(b)))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册