未验证 提交 19b02d95 编写于 作者: A Aganlengzi 提交者: GitHub

[NPU] modifications for model ernie-1.0 (#36642)

* [NPU] modifications for model ernie-1.0

* rollback 503003 and change cast to dtype
上级 2dd0a46a
...@@ -21,6 +21,38 @@ namespace operators { ...@@ -21,6 +21,38 @@ namespace operators {
using Tensor = framework::Tensor; using Tensor = framework::Tensor;
static void CumsumImp(const Tensor& input, Tensor* output,
const framework::NPUAttributeMap& attr_input,
const framework::ExecutionContext& ctx) {
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
if (input.type() == framework::proto::VarType::INT64) {
Tensor tmp_input;
tmp_input.mutable_data<float>(input.dims(), ctx.GetPlace());
auto dst_acl_dtype = ConvertToNpuDtype(tmp_input.type());
const auto& cast_runner_1 =
NpuOpRunner("Cast", {input}, {tmp_input},
{{"dst_type", static_cast<int>(dst_acl_dtype)}});
cast_runner_1.Run(stream);
Tensor tmp_output;
tmp_output.mutable_data<float>(output->dims(), ctx.GetPlace());
const auto& runner =
NpuOpRunner("CumsumD", {tmp_input}, {tmp_output}, attr_input);
runner.Run(stream);
dst_acl_dtype = ConvertToNpuDtype(output->type());
const auto& cast_runner_2 =
NpuOpRunner("Cast", {tmp_output}, {*output},
{{"dst_type", static_cast<int>(dst_acl_dtype)}});
cast_runner_2.Run(stream);
} else {
const auto& runner = NpuOpRunner("CumsumD", {input}, {*output}, attr_input);
runner.Run(stream);
}
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class CumSumNPUKernel : public framework::OpKernel<T> { class CumSumNPUKernel : public framework::OpKernel<T> {
public: public:
...@@ -36,10 +68,6 @@ class CumSumNPUKernel : public framework::OpKernel<T> { ...@@ -36,10 +68,6 @@ class CumSumNPUKernel : public framework::OpKernel<T> {
framework::NPUAttributeMap attr_input = { framework::NPUAttributeMap attr_input = {
{"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}}; {"axis", axis}, {"exclusive", exclusive}, {"reverse", reverse}};
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
bool flatten = ctx.Attr<bool>("flatten"); bool flatten = ctx.Attr<bool>("flatten");
if (flatten) { if (flatten) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -53,11 +81,9 @@ class CumSumNPUKernel : public framework::OpKernel<T> { ...@@ -53,11 +81,9 @@ class CumSumNPUKernel : public framework::OpKernel<T> {
new_x.Resize(framework::make_ddim({x->numel()})); new_x.Resize(framework::make_ddim({x->numel()}));
const auto& runner = NpuOpRunner("CumsumD", {new_x}, {*out}, attr_input); CumsumImp(new_x, out, attr_input, ctx);
runner.Run(stream);
} else { } else {
const auto& runner = NpuOpRunner("CumsumD", {*x}, {*out}, attr_input); CumsumImp(*x, out, attr_input, ctx);
runner.Run(stream);
} }
} }
}; };
...@@ -69,5 +95,8 @@ namespace ops = paddle::operators; ...@@ -69,5 +95,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(
cumsum, ops::CumSumNPUKernel<plat::NPUDeviceContext, int>, cumsum, ops::CumSumNPUKernel<plat::NPUDeviceContext, int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::CumSumNPUKernel<plat::NPUDeviceContext, int64_t>,
#endif
ops::CumSumNPUKernel<plat::NPUDeviceContext, float>, ops::CumSumNPUKernel<plat::NPUDeviceContext, float>,
ops::CumSumNPUKernel<plat::NPUDeviceContext, plat::float16>); ops::CumSumNPUKernel<plat::NPUDeviceContext, plat::float16>);
...@@ -167,10 +167,16 @@ namespace ops = paddle::operators; ...@@ -167,10 +167,16 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel<int>, REGISTER_OP_NPU_KERNEL(elementwise_sub, ops::ElementwiseSubNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ElementwiseSubNPUKernel<int64_t>,
#endif
ops::ElementwiseSubNPUKernel<float>, ops::ElementwiseSubNPUKernel<float>,
ops::ElementwiseSubNPUKernel<plat::float16>); ops::ElementwiseSubNPUKernel<plat::float16>);
REGISTER_OP_NPU_KERNEL(elementwise_sub_grad, REGISTER_OP_NPU_KERNEL(elementwise_sub_grad,
ops::ElementwiseSubGradNPUKernel<int>, ops::ElementwiseSubGradNPUKernel<int>,
#ifdef PADDLE_WITH_ASCEND_INT64
ops::ElementwiseSubGradNPUKernel<int64_t>,
#endif
ops::ElementwiseSubGradNPUKernel<float>, ops::ElementwiseSubGradNPUKernel<float>,
ops::ElementwiseSubGradNPUKernel<plat::float16>); ops::ElementwiseSubGradNPUKernel<plat::float16>);
...@@ -21,6 +21,9 @@ limitations under the License. */ ...@@ -21,6 +21,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
constexpr int64_t kNoPadding = -1;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class LookupTableV2NPUKernel : public framework::OpKernel<T> { class LookupTableV2NPUKernel : public framework::OpKernel<T> {
public: public:
...@@ -35,16 +38,52 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> { ...@@ -35,16 +38,52 @@ class LookupTableV2NPUKernel : public framework::OpKernel<T> {
platform::errors::InvalidArgument("npu only accept LoDTensor")); platform::errors::InvalidArgument("npu only accept LoDTensor"));
output_t->mutable_data<T>(ctx.GetPlace()); output_t->mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner; int64_t padding_idx = ctx.Attr<int64_t>("padding_idx");
runner.SetType("GatherV2") if (padding_idx == kNoPadding) {
.AddInput(*table_t) NpuOpRunner runner;
.AddInput(*ids_t) runner.SetType("GatherV2")
.AddInput(std::vector<int32_t>{0}) .AddInput(*table_t)
.AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0})
#if (CANN_VERSION_CODE >= 503003)
.AddAttrs({{"batch_dims", 0}})
#endif
.AddOutput(*output_t);
runner.Run();
} else {
Tensor tmp_table_t(table_t->type());
tmp_table_t.mutable_data<T>(table_t->dims(), ctx.GetPlace());
Tensor index;
index.mutable_data<int32_t>({1, 1}, ctx.GetPlace());
FillNpuTensorWithConstant<int32_t>(&index,
static_cast<int32_t>(padding_idx));
auto updata_dim = framework::make_ddim({1, table_t->dims()[1]});
Tensor update;
update.mutable_data<T>(updata_dim, ctx.GetPlace());
FillNpuTensorWithConstant<T>(&update, static_cast<T>(0));
update.Resize(updata_dim);
NpuOpRunner update_runner;
update_runner.SetType("TensorScatterUpdate")
.AddInput(*table_t)
.AddInput(index)
.AddInput(update)
.AddOutput(tmp_table_t);
update_runner.Run();
NpuOpRunner runner;
runner.SetType("GatherV2")
.AddInput(tmp_table_t)
.AddInput(*ids_t)
.AddInput(std::vector<int32_t>{0})
#if (CANN_VERSION_CODE >= 503003) #if (CANN_VERSION_CODE >= 503003)
.AddAttrs({{"batch_dims", 0}}) .AddAttrs({{"batch_dims", 0}})
#endif #endif
.AddOutput(*output_t); .AddOutput(*output_t);
runner.Run(); runner.Run();
}
} }
}; };
......
...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,8 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/fluid/operators/npu_op_runner.h" #include "paddle/fluid/operators/npu_op_runner.h"
...@@ -21,40 +19,253 @@ limitations under the License. */ ...@@ -21,40 +19,253 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename T>
static void Mul(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& X, const Tensor& Y,
Tensor* Out, const float alpha) {
Out->mutable_data<T>(ctx.GetPlace());
if (fabs(alpha - 1.0) < std::numeric_limits<float>::epsilon()) {
const auto& runner_dx = NpuOpRunner("Mul", {X, Y}, {*Out}, {});
runner_dx.Run(stream);
} else {
Tensor Out_temp(Out->type());
Out_temp.mutable_data<T>(Out->dims(), ctx.GetPlace());
const auto& runner_dx = NpuOpRunner("Mul", {X, Y}, {Out_temp}, {});
runner_dx.Run(stream);
const auto& runner =
NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}});
runner.Run(stream);
}
}
template <typename T>
static void Dot(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& X, const Tensor& Y,
Tensor* Out, const float alpha) {
Out->mutable_data<T>(ctx.GetPlace());
if (fabs(alpha - 1.0) < std::numeric_limits<float>::epsilon()) {
const auto& runner = NpuOpRunner("Dot", {X, Y}, {*Out});
runner.Run(stream);
} else {
Tensor Out_temp(Out->type());
Out_temp.mutable_data<T>(Out->dims(), ctx.GetPlace());
const auto& out_temp_runner = NpuOpRunner("Dot", {X, Y}, {Out_temp});
out_temp_runner.Run(stream);
const auto& runner =
NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}});
runner.Run(stream);
}
}
template <typename T>
static void MatMul2D(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& X,
const Tensor& Y, Tensor* Out, const bool trans_x,
const bool trans_y, const float alpha) {
Out->mutable_data<T>(ctx.GetPlace());
if (fabs(alpha - 1.0) < std::numeric_limits<float>::epsilon()) {
const auto& runner =
NpuOpRunner("MatMul", {X, Y}, {*Out},
{{"transpose_x1", trans_x}, {"transpose_x2", trans_y}});
runner.Run(stream);
} else {
Tensor Out_temp(Out->type());
Out_temp.mutable_data<T>(Out->dims(), ctx.GetPlace());
const auto& out_temp_runner =
NpuOpRunner("MatMul", {X, Y}, {Out_temp},
{{"transpose_x1", trans_x}, {"transpose_x2", trans_y}});
out_temp_runner.Run(stream);
const auto& runner =
NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}});
runner.Run(stream);
}
}
template <typename T>
static void MatMulND(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& X,
const Tensor& Y, Tensor* Out, const bool trans_x,
const bool trans_y, const float alpha) {
Out->mutable_data<T>(ctx.GetPlace());
if (fabs(alpha - 1.0) < std::numeric_limits<float>::epsilon()) {
const auto& runner =
NpuOpRunner("BatchMatMul", {X, Y}, {*Out},
{{"adj_x1", trans_x}, {"adj_x2", trans_y}});
runner.Run(stream);
} else {
Tensor Out_temp(Out->type());
Out_temp.mutable_data<T>(Out->dims(), ctx.GetPlace());
const auto& out_temp_runner =
NpuOpRunner("BatchMatMul", {X, Y}, {Out_temp},
{{"adj_x1", trans_x}, {"adj_x2", trans_y}});
out_temp_runner.Run(stream);
const auto& runner =
NpuOpRunner("Muls", {Out_temp}, {*Out}, {{"value", alpha}});
runner.Run(stream);
}
}
template <typename T>
static void ReduceDims(const framework::ExecutionContext& ctx,
const aclrtStream& stream,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& brd_dims, const Tensor& in,
Tensor* out) {
std::vector<int64_t> axes;
int64_t size = brd_dims.size();
int64_t diff = brd_dims.size() - dims.size();
for (int64_t i = 0; i < size; ++i) {
if (i < diff) {
axes.push_back(i);
continue;
}
if (brd_dims[i] > dims[i - diff]) {
axes.push_back(i);
}
}
out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("ReduceSumD", {in}, {*out},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
}
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T>
class MatMulNPUKernel : public framework::OpKernel<T> { class MatMulNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X"); auto* X = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto* Y = ctx.Input<framework::Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* Out = ctx.Output<framework::Tensor>("Out");
bool transpose_x = ctx.Attr<bool>("transpose_X"); bool transpose_x = ctx.Attr<bool>("transpose_X");
bool transpose_y = ctx.Attr<bool>("transpose_Y"); bool transpose_y = ctx.Attr<bool>("transpose_Y");
float alpha = static_cast<T>(ctx.Attr<float>("alpha"));
std::vector<int64_t> x_dims = framework::vectorize(X->dims());
std::vector<int64_t> y_dims = framework::vectorize(Y->dims());
std::vector<int64_t> out_dims = framework::vectorize(Out->dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int out_ndim = out_dims.size();
if (x->dims().size() == 2) { auto stream = ctx.template device_context<NPUDeviceContext>().stream();
out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner( // Case 1: [K] x [K] = [1]
"MatMul", {*x, *y}, {*out}, if (x_ndim == 1 && y_ndim == 1) {
{{"transpose_x1", transpose_x}, {"transpose_x2", transpose_y}}); PADDLE_ENFORCE_EQ(
X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements",
X->numel(), Y->numel()));
Out->Resize({1});
Dot<T>(ctx, stream, *X, *Y, Out, alpha);
return;
}
auto stream = // Resize dim 1 to 2
ctx.template device_context<paddle::platform::NPUDeviceContext>() Tensor x_temp, y_temp;
.stream(); x_temp.ShareDataWith(*X);
runner.Run(stream); y_temp.ShareDataWith(*Y);
if (x_ndim == 1) {
x_dims.insert(x_dims.begin(), 1);
out_dims.insert(out_dims.end() - 1, 1);
x_temp.Resize(framework::make_ddim(x_dims));
x_ndim = 2;
out_ndim += 1;
}
if (y_ndim == 1) {
y_dims.push_back(1);
out_dims.push_back(1);
y_temp.Resize(framework::make_ddim(y_dims));
y_ndim = 2;
out_ndim += 1;
}
const int K = transpose_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (transpose_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2]));
}
// Case 2: [M, K] x [K, N] = [M, N]
if (x_ndim == 2 && y_ndim == 2) {
MatMul2D<T>(ctx, stream, x_temp, y_temp, Out, transpose_x, transpose_y,
alpha);
return;
}
// Case 3: [B, M, K] x [K, N] = [B, M, N], when transpose_x = false
// Equal: [B * M, K] x [K, N] = [B * M, N] => [B, M, N]
if (transpose_x == false && y_ndim == 2) {
std::vector<int64_t> vec_dim = {x_temp.numel() / K, K};
x_temp.Resize(framework::make_ddim(vec_dim));
MatMul2D<T>(ctx, stream, x_temp, y_temp, Out, transpose_x, transpose_y,
alpha);
return;
}
} else if (x->dims().size() > 2) { // Case 4: [B, M, K] x [B, K, N] = [B, M, N]
out->mutable_data<T>(ctx.GetPlace()); std::vector<int64_t> x_broadcast_dims(out_ndim, 1);
std::vector<int64_t> y_broadcast_dims(out_ndim, 1);
std::copy(out_dims.begin(), out_dims.end() - 2, x_broadcast_dims.begin());
std::copy(out_dims.begin(), out_dims.end() - 2, y_broadcast_dims.begin());
std::copy(x_dims.end() - 2, x_dims.end(), x_broadcast_dims.end() - 2);
std::copy(y_dims.end() - 2, y_dims.end(), y_broadcast_dims.end() - 2);
const auto& runner = Tensor x_temp_brd(X->type());
NpuOpRunner("BatchMatMul", {*x, *y}, {*out}, if (x_dims == x_broadcast_dims) {
{{"adj_x1", transpose_x}, {"adj_x2", transpose_y}}); x_temp_brd.ShareDataWith(*X);
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
} else {
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
x_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(x_temp)
.AddInput(std::move(x_broadcast_dims))
.AddOutput(x_temp_brd)
.Run(stream);
}
auto stream = Tensor y_temp_brd(Y->type());
ctx.template device_context<paddle::platform::NPUDeviceContext>() if (y_dims == y_broadcast_dims) {
.stream(); y_temp_brd.ShareDataWith(*Y);
runner.Run(stream); y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
} else {
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
y_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(y_temp)
.AddInput(std::move(y_broadcast_dims))
.AddOutput(y_temp_brd)
.Run(stream);
} }
MatMulND<T>(ctx, stream, x_temp_brd, y_temp_brd, Out, transpose_x,
transpose_y, alpha);
} }
}; };
...@@ -62,109 +273,200 @@ template <typename DeviceContext, typename T> ...@@ -62,109 +273,200 @@ template <typename DeviceContext, typename T>
class MatMulGradNPUKernel : public framework::OpKernel<T> { class MatMulGradNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X"); auto* X = ctx.Input<framework::Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto* Y = ctx.Input<framework::Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* dOut = ctx.Input<framework::Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dX = ctx.Output<framework::Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dY = ctx.Output<framework::Tensor>(framework::GradVarName("Y"));
bool transpose_x = ctx.Attr<bool>("transpose_X");
bool transpose_y = ctx.Attr<bool>("transpose_Y"); bool transpose_y = ctx.Attr<bool>("transpose_Y");
auto stream = float alpha = static_cast<T>(ctx.Attr<float>("alpha"));
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
if (x->dims().size() == 2) {
if (transpose_y) {
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
const auto& runner_dx =
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
{{"transpose_x1", false}, {"transpose_x2", false}});
runner_dx.Run(stream);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
const auto& runner_dy =
NpuOpRunner("MatMul", {*dout, *x}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream); std::vector<int64_t> x_dims = framework::vectorize(X->dims());
} std::vector<int64_t> y_dims = framework::vectorize(Y->dims());
std::vector<int64_t> out_dims = framework::vectorize(dOut->dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int out_ndim = out_dims.size();
} else { auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
const auto& runner_dx =
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
{{"transpose_x1", false}, {"transpose_x2", true}});
runner_dx.Run(stream); // Case 1: [K] x [K] = [1]
} if (x_ndim == 1 && y_ndim == 1) {
if (dy) { Tensor dout_temp(dOut->type());
dy->mutable_data<T>(ctx.GetPlace()); dout_temp.Resize(X->dims());
const auto& runner_dy = dout_temp.mutable_data<T>(ctx.GetPlace());
NpuOpRunner("MatMul", {*x, *dout}, {*dy}, NpuOpRunner runner;
{{"transpose_x1", true}, {"transpose_x2", false}}); runner.SetType("BroadcastTo")
.AddInput(*dOut)
.AddInput(std::move(x_dims))
.AddOutput(dout_temp)
.Run(stream);
if (dX) {
Mul<T>(ctx, stream, dout_temp, *Y, dX, alpha);
}
if (dY) {
Mul<T>(ctx, stream, dout_temp, *X, dY, alpha);
}
return;
}
// Resize dim 1 to 2
Tensor x_temp, y_temp, dout_temp;
x_temp.ShareDataWith(*X);
y_temp.ShareDataWith(*Y);
dout_temp.ShareDataWith(*dOut);
if (x_ndim == 1) {
x_dims.insert(x_dims.begin(), 1);
out_dims.insert(out_dims.end() - 1, 1);
x_temp.Resize(framework::make_ddim(x_dims));
dout_temp.Resize(framework::make_ddim(out_dims));
x_ndim = 2;
out_ndim += 1;
}
if (y_ndim == 1) {
y_dims.push_back(1);
out_dims.push_back(1);
y_temp.Resize(framework::make_ddim(y_dims));
dout_temp.Resize(framework::make_ddim(out_dims));
y_ndim = 2;
out_ndim += 1;
}
runner_dy.Run(stream); // Case 2: [M, K] x [K, N] = [M, N]
if (out_ndim == 2) {
if (dX) {
dX->Resize(framework::make_ddim(x_dims));
if (transpose_x) {
MatMul2D<T>(ctx, stream, y_temp, dout_temp, dX, transpose_y, true,
alpha);
} else {
MatMul2D<T>(ctx, stream, dout_temp, y_temp, dX, false, !transpose_y,
alpha);
} }
dX->Resize(X->dims());
} }
} else if (x->dims().size() > 2) { if (dY) {
if (transpose_y) { dY->Resize(framework::make_ddim(y_dims));
if (dx) { if (transpose_y) {
dx->mutable_data<T>(ctx.GetPlace()); MatMul2D<T>(ctx, stream, dout_temp, x_temp, dY, true, transpose_x,
const auto& runner_dx = alpha);
NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, } else {
{{"adj_x1", false}, {"adj_x2", false}}); MatMul2D<T>(ctx, stream, x_temp, dout_temp, dY, !transpose_x, false,
alpha);
runner_dx.Run(stream);
} }
if (dy) { dY->Resize(Y->dims());
dy->mutable_data<T>(ctx.GetPlace()); }
const auto& runner_dy = return;
NpuOpRunner("BatchMatMul", {*dout, *x}, {*dy}, }
{{"adj_x1", true}, {"adj_x2", false}});
const int K = transpose_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
const int N = transpose_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
runner_dy.Run(stream); // Case 3: [B, M, K] x [K, N] = [B, M, N], when transpose_x = false
// Equal: [B * M, K] x [K, N] = [B * M, N] => [B, M, N]
if (transpose_x == false && y_ndim == 2) {
std::vector<int64_t> x_vec_dim = {x_temp.numel() / K, K};
dout_temp.Resize(
framework::make_ddim(std::vector<int64_t>{dout_temp.numel() / N, N}));
if (dX) {
dX->Resize(framework::make_ddim(x_vec_dim));
MatMul2D<T>(ctx, stream, dout_temp, y_temp, dX, false, !transpose_y,
alpha);
dX->Resize(X->dims());
}
if (dY) {
x_temp.Resize(framework::make_ddim(x_vec_dim));
if (transpose_y) {
MatMul2D<T>(ctx, stream, dout_temp, x_temp, dY, true, false, alpha);
} else {
MatMul2D<T>(ctx, stream, x_temp, dout_temp, dY, true, false, alpha);
} }
} else { }
if (dx) { return;
dx->mutable_data<T>(ctx.GetPlace()); }
const auto& runner_dx =
NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx},
{{"adj_x1", false}, {"adj_x2", true}});
runner_dx.Run(stream); // Case 4: [B, M, K] x [B, K, N] = [B, M, N]
std::vector<int64_t> x_broadcast_dims(out_ndim, 1);
std::vector<int64_t> y_broadcast_dims(out_ndim, 1);
std::copy(out_dims.begin(), out_dims.end() - 2, x_broadcast_dims.begin());
std::copy(out_dims.begin(), out_dims.end() - 2, y_broadcast_dims.begin());
std::copy(x_dims.end() - 2, x_dims.end(), x_broadcast_dims.end() - 2);
std::copy(y_dims.end() - 2, y_dims.end(), y_broadcast_dims.end() - 2);
Tensor x_temp_brd(X->type());
if (x_dims == x_broadcast_dims) {
x_temp_brd.ShareDataWith(*X);
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
} else {
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
x_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(x_temp)
.AddInput(std::move(x_broadcast_dims))
.AddOutput(x_temp_brd)
.Run(stream);
}
Tensor y_temp_brd(Y->type());
if (y_dims == y_broadcast_dims) {
y_temp_brd.ShareDataWith(*Y);
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
} else {
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
y_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(y_temp)
.AddInput(std::move(y_broadcast_dims))
.AddOutput(y_temp_brd)
.Run(stream);
}
if (dX) {
if (x_dims == x_broadcast_dims) {
if (transpose_x) {
MatMulND<T>(ctx, stream, y_temp_brd, dout_temp, dX, transpose_y, true,
alpha);
} else {
MatMulND<T>(ctx, stream, dout_temp, y_temp_brd, dX, false,
!transpose_y, alpha);
}
} else {
Tensor dx_temp(X->type());
dx_temp.Resize(framework::make_ddim(x_broadcast_dims));
if (transpose_x) {
MatMulND<T>(ctx, stream, y_temp_brd, dout_temp, &dx_temp, transpose_y,
true, alpha);
} else {
MatMulND<T>(ctx, stream, dout_temp, y_temp_brd, &dx_temp, false,
!transpose_y, alpha);
} }
if (dy) { ReduceDims<T>(ctx, stream, x_dims, x_broadcast_dims, dx_temp, dX);
dy->mutable_data<T>(ctx.GetPlace()); }
if ((x->dims().size() == 3) && (dout->dims().size() == 3) && }
(dy->dims().size() == 2)) { if (dY) {
framework::Tensor dout_tmp; if (y_dims == y_broadcast_dims) {
dout_tmp.ShareDataWith(*dout); if (transpose_y) {
std::vector<int> vec_dim = MatMulND<T>(ctx, stream, dout_temp, x_temp_brd, dY, true, transpose_x,
framework::vectorize<int>(dout_tmp.dims()); alpha);
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; } else {
dout_tmp.Resize(framework::make_ddim(vec_dim_v)); MatMulND<T>(ctx, stream, x_temp_brd, dout_temp, dY, !transpose_x,
false, alpha);
framework::Tensor x_tmp; }
x_tmp.ShareDataWith(*x); } else {
std::vector<int> vec_dim_x = Tensor dy_temp(Y->type());
framework::vectorize<int>(x_tmp.dims()); dy_temp.Resize(framework::make_ddim(y_broadcast_dims));
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1], if (transpose_y) {
vec_dim_x[2]}; MatMulND<T>(ctx, stream, dout_temp, x_temp_brd, &dy_temp, true,
x_tmp.Resize(framework::make_ddim(vec_dim_x_v)); transpose_x, alpha);
const auto& runner_dy = } else {
NpuOpRunner("MatMul", {x_tmp, dout_tmp}, {*dy}, MatMulND<T>(ctx, stream, x_temp_brd, dout_temp, &dy_temp,
{{"transpose_x1", true}, {"transpose_x2", false}}); !transpose_x, false, alpha);
runner_dy.Run(stream);
} else {
const auto& runner_dy =
NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy},
{{"adj_x1", true}, {"adj_x2", false}});
runner_dy.Run(stream);
}
} }
ReduceDims<T>(ctx, stream, y_dims, y_broadcast_dims, dy_temp, dY);
} }
} }
} }
......
...@@ -249,5 +249,45 @@ class TestNPUCumSumWithFlatten2(TestNPUCumSumOp1): ...@@ -249,5 +249,45 @@ class TestNPUCumSumWithFlatten2(TestNPUCumSumOp1):
self.outputs = {'Out': self.inputs['X'].cumsum()} self.outputs = {'Out': self.inputs['X'].cumsum()}
#----------------Cumsum Int64----------------
class TestNPUCumSumOpInt64(TestNPUCumSumOp1):
def init_testcase(self):
self.attrs = {'axis': -1, 'reverse': True}
self.inputs = {
'X': np.random.randint(
1, 10000, size=(5, 6, 10)).astype(self.dtype)
}
self.outputs = {
'Out': np.flip(
np.flip(
self.inputs['X'], axis=2).cumsum(axis=2), axis=2)
}
def create_test_int64(parent):
class TestCumSumInt64(parent):
def init_dtype(self):
self.dtype = np.int64
cls_name = "{0}_{1}".format(parent.__name__, "Int64")
TestCumSumInt64.__name__ = cls_name
globals()[cls_name] = TestCumSumInt64
create_test_int64(TestNPUCumSumOp1)
create_test_int64(TestNPUCumSumOp2)
create_test_int64(TestNPUCumSumOp3)
create_test_int64(TestNPUCumSumOp4)
create_test_int64(TestNPUCumSumOp5)
create_test_int64(TestNPUCumSumOp7)
create_test_int64(TestNPUCumSumExclusive1)
create_test_int64(TestNPUCumSumExclusive2)
create_test_int64(TestNPUCumSumExclusive3)
create_test_int64(TestNPUCumSumExclusive4)
create_test_int64(TestNPUCumSumExclusive5)
create_test_int64(TestNPUCumSumReverseExclusive)
create_test_int64(TestNPUCumSumWithFlatten1)
create_test_int64(TestNPUCumSumWithFlatten2)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -95,6 +95,11 @@ class TestElementwiseSubOpInt32(TestElementwiseSubOp): ...@@ -95,6 +95,11 @@ class TestElementwiseSubOpInt32(TestElementwiseSubOp):
self.dtype = np.int32 self.dtype = np.int32
class TestElementwiseSubOpInt64(TestElementwiseSubOp):
def init_dtype(self):
self.dtype = np.int64
class TestSubtractAPI(unittest.TestCase): class TestSubtractAPI(unittest.TestCase):
def test_name(self): def test_name(self):
with paddle.static.program_guard(paddle.static.Program()): with paddle.static.program_guard(paddle.static.Program()):
......
...@@ -33,14 +33,15 @@ class TestLookupTableV2(OpTest): ...@@ -33,14 +33,15 @@ class TestLookupTableV2(OpTest):
self.place = paddle.NPUPlace(0) self.place = paddle.NPUPlace(0)
self.init_dtype() self.init_dtype()
self.init_dim() self.init_dims()
self.init_padding_idx()
np.random.seed(SEED) np.random.seed(SEED)
bsz = 6 w = np.random.random([self.vocab, self.dim]).astype(self.dtype)
seqlen = 8 x = np.random.randint(
vocab = 10 0, self.vocab, size=(self.bsz, self.seqlen)).astype(np.int32)
w = np.ones([vocab, self.dim]).astype(self.dtype) out = w[x]
x = np.random.randint(0, vocab, size=(bsz, seqlen)).astype(np.int32) if self.padding_idx != -1:
out = np.ones([bsz, seqlen, self.dim]).astype(self.dtype) out[np.squeeze(x == self.padding_idx)] = np.zeros(self.dim)
self.inputs = { self.inputs = {
'W': OpTest.np_dtype_to_fluid_dtype(w), 'W': OpTest.np_dtype_to_fluid_dtype(w),
...@@ -50,7 +51,7 @@ class TestLookupTableV2(OpTest): ...@@ -50,7 +51,7 @@ class TestLookupTableV2(OpTest):
'is_sparse': False, 'is_sparse': False,
'is_distributed': False, 'is_distributed': False,
'remote_prefetch': False, 'remote_prefetch': False,
'padding_idx': -1 'padding_idx': self.padding_idx
} }
self.outputs = {'Out': out} self.outputs = {'Out': out}
...@@ -60,10 +61,16 @@ class TestLookupTableV2(OpTest): ...@@ -60,10 +61,16 @@ class TestLookupTableV2(OpTest):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float32 self.dtype = np.float32
def init_dim(self): def init_dims(self):
self.bsz = 6
self.seqlen = 8
self.vocab = 10
# embedding_dim is not multiple of 32 # embedding_dim is not multiple of 32
self.dim = 20 self.dim = 20
def init_padding_idx(self):
self.padding_idx = -1
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place) self.check_output_with_place(self.place)
...@@ -85,7 +92,10 @@ class TestLookupTableV2FP16(TestLookupTableV2): ...@@ -85,7 +92,10 @@ class TestLookupTableV2FP16(TestLookupTableV2):
class TestLookupTableV2Dim32(TestLookupTableV2): class TestLookupTableV2Dim32(TestLookupTableV2):
def init_dim(self): def init_dims(self):
self.bsz = 6
self.seqlen = 8
self.vocab = 10
# embedding_dim is multiple of 32 # embedding_dim is multiple of 32
self.dim = 64 self.dim = 64
...@@ -96,7 +106,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2): ...@@ -96,7 +106,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2):
def init_dtype(self): def init_dtype(self):
self.dtype = np.float16 self.dtype = np.float16
def init_dim(self): def init_dims(self):
self.bsz = 6
self.seqlen = 8
self.vocab = 10
self.dim = 64 self.dim = 64
def set_npu(self): def set_npu(self):
...@@ -104,5 +117,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2): ...@@ -104,5 +117,10 @@ class TestLookupTableV2Dim32FP16(TestLookupTableV2):
self.__class__.no_need_check_grad = True self.__class__.no_need_check_grad = True
class TestLookupTableV2WithPadding(TestLookupTableV2):
def init_padding_idx(self):
self.padding_idx = np.random.randint(0, self.vocab)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import numpy as np
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import paddle
import paddle.fluid as fluid
paddle.enable_static()
SEED = 2021
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False, scale=1.0):
"""Reference forward implementation using np.matmul."""
# np.matmul does not support the transpose flags, so we manually
# transpose X and Y appropriately.
if transpose_X:
if X.ndim == 1:
X = X.reshape((X.size, ))
elif X.ndim == 2:
X = X.T
else:
dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim))
if transpose_Y:
if Y.ndim == 1:
Y = Y.reshape((Y.size, ))
else:
dim = [i for i in range(len(Y.shape))]
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1, ) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
if abs(scale - 1.0) > 1e-09:
Out = Out * scale
return Out
class TestMatMulOp(OpTest):
"""
basic case
"""
def setUp(self):
self.set_npu()
self.op_type = "matmul"
self.init_dtype()
self.init_alpha()
self.config()
X = np.random.random(self.x_shape).astype(self.dtype)
Y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1
X = -0.1 + 0.2 * X
Y = -0.1 + 0.2 * Y
Out = reference_matmul(X, Y, self.transpose_X, self.transpose_Y,
self.alpha)
Out = Out.astype(self.dtype)
self.inputs = {'X': X, 'Y': Y}
self.attrs = {
'transpose_X': self.transpose_X,
'transpose_Y': self.transpose_Y,
'alpha': self.alpha
}
self.outputs = {'Out': Out}
def set_npu(self):
self.__class__.use_npu = True
self.place = paddle.NPUPlace(0)
def config(self):
self.x_shape = (100, )
self.y_shape = (100, )
self.transpose_X = False
self.transpose_Y = False
def init_alpha(self):
self.alpha = 1.0
def init_dtype(self):
self.dtype = "float32"
def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-7)
def test_check_grad_normal(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
class TestMatMulOp1(TestMatMulOp):
"""
case x_ndim == 1, y_ndim != 1
"""
def config(self):
self.x_shape = (100, )
self.y_shape = (1, 3, 2, 100)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp2(TestMatMulOp):
"""
case x_ndim != 1, y_ndim == 1
"""
def config(self):
self.x_shape = (1, 2, 100, 1)
self.y_shape = (100, )
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp3(TestMatMulOp):
"""
case [M, K] x [K, N] = [M, N]
"""
def config(self):
self.x_shape = (2, 100)
self.y_shape = (100, 2)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp4(TestMatMulOp):
"""
case [M, K] x [K, N] = [M, N]
"""
def config(self):
self.x_shape = (2, 100)
self.y_shape = (2, 100)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp5(TestMatMulOp):
"""
case [M, K] x [K, N] = [M, N]
"""
def config(self):
self.x_shape = (100, 2)
self.y_shape = (100, 2)
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp6(TestMatMulOp):
"""
case [B, M, K] x [K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 2, 25)
self.y_shape = (25, 4)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp7(TestMatMulOp):
"""
case [B, M, K] x [K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (1, 2, 25)
self.y_shape = (4, 25)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp8(TestMatMulOp):
"""
case [B, M, K] x [K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (1, 25, 4)
self.y_shape = (25, 4)
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp9(TestMatMulOp):
"""
case [B, M, K] x [B, K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 5, 10)
self.y_shape = (2, 10, 5)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp10(TestMatMulOp):
"""
case [B, M, K] x [B, K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 10, 5)
self.y_shape = (2, 10, 5)
self.transpose_X = True
self.transpose_Y = False
class TestMatMulOp11(TestMatMulOp):
"""
case [B, M, K] x [B, K, N] = [B, M, N]
"""
def config(self):
self.x_shape = (2, 5, 10)
self.y_shape = (2, 5, 10)
self.transpose_X = False
self.transpose_Y = True
class TestMatMulOp12(TestMatMulOp):
"""
case to check the gradient for special case
"""
def config(self):
self.x_shape = (100)
self.y_shape = (1, 2, 2, 100, 2)
self.transpose_X = False
self.transpose_Y = False
class TestMatMulOp13(TestMatMulOp):
"""
case to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.transpose_X = False
self.transpose_Y = False
#--------------------test matmul alpha--------------------
def create_test_alpha_class(parent):
class TestMatMulOpAlphaCase(parent):
def init_alpha(self):
self.alpha = 0.125
cls_name = "{0}_{1}".format(parent.__name__, "Alpha")
TestMatMulOpAlphaCase.__name__ = cls_name
globals()[cls_name] = TestMatMulOpAlphaCase
create_test_alpha_class(TestMatMulOp)
create_test_alpha_class(TestMatMulOp1)
create_test_alpha_class(TestMatMulOp2)
create_test_alpha_class(TestMatMulOp3)
create_test_alpha_class(TestMatMulOp4)
create_test_alpha_class(TestMatMulOp5)
create_test_alpha_class(TestMatMulOp6)
create_test_alpha_class(TestMatMulOp9)
create_test_alpha_class(TestMatMulOp10)
create_test_alpha_class(TestMatMulOp11)
create_test_alpha_class(TestMatMulOp12)
create_test_alpha_class(TestMatMulOp13)
#--------------------test matmul fp16--------------------
def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5):
class TestMatMulOpFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=atol)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X', 'Y'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestMatMulOpFp16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpFp16Case
create_test_fp16_class(TestMatMulOp)
create_test_fp16_class(TestMatMulOp1)
create_test_fp16_class(TestMatMulOp2)
create_test_fp16_class(TestMatMulOp3)
create_test_fp16_class(TestMatMulOp4)
create_test_fp16_class(TestMatMulOp5)
create_test_fp16_class(TestMatMulOp6)
create_test_fp16_class(TestMatMulOp9)
create_test_fp16_class(TestMatMulOp10)
create_test_fp16_class(TestMatMulOp11)
create_test_fp16_class(TestMatMulOp12)
create_test_fp16_class(TestMatMulOp13)
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册