未验证 提交 7850f7ce 编写于 作者: Q Qi Li 提交者: GitHub

[NPU] fix matmul_v2 and utils.run_check, test=develop (#36164)

* [NPU] fix matmul_v2 and utils.run_check, test=develop

* remove debug files, test=develop

* fix install_check, test=develop

* fix doc, test=develop

* fix review comments, test=develop
上级 83541fd4
...@@ -21,166 +21,387 @@ limitations under the License. */ ...@@ -21,166 +21,387 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> using Tensor = framework::Tensor;
using NPUDeviceContext = platform::NPUDeviceContext;
template <typename T>
static void MatMul2D(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& X,
const Tensor& Y, Tensor* Out, const bool trans_x,
const bool trans_y) {
Out->mutable_data<T>(ctx.GetPlace());
const auto& runner =
NpuOpRunner("MatMul", {X, Y}, {*Out},
{{"transpose_x1", trans_x}, {"transpose_x2", trans_y}});
runner.Run(stream);
}
template <typename T>
static void MatMulND(const framework::ExecutionContext& ctx,
const aclrtStream& stream, const Tensor& X,
const Tensor& Y, Tensor* Out, const bool trans_x,
const bool trans_y) {
Out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("BatchMatMul", {X, Y}, {*Out},
{{"adj_x1", trans_x}, {"adj_x2", trans_y}});
runner.Run(stream);
}
template <typename T>
static void ReduceDims(const framework::ExecutionContext& ctx,
const aclrtStream& stream,
const std::vector<int64_t>& dims,
const std::vector<int64_t>& brd_dims, const Tensor& in,
Tensor* out) {
std::vector<int64_t> axes;
int64_t size = brd_dims.size();
int64_t diff = brd_dims.size() - dims.size();
for (int64_t i = 0; i < size; ++i) {
if (i < diff) {
axes.push_back(i);
continue;
}
if (brd_dims[i] > dims[i - diff]) {
axes.push_back(i);
}
}
out->mutable_data<T>(ctx.GetPlace());
const auto& runner = NpuOpRunner("ReduceSumD", {in}, {*out},
{{"axes", axes}, {"keep_dims", false}});
runner.Run(stream);
}
template <typename T>
class MatMulV2NPUKernel : public framework::OpKernel<T> { class MatMulV2NPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X"); auto* X = ctx.Input<Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto* Y = ctx.Input<Tensor>("Y");
auto* out = ctx.Output<framework::Tensor>("Out"); auto* Out = ctx.Output<Tensor>("Out");
bool transpose_x = ctx.Attr<bool>("trans_x"); const bool trans_x = ctx.Attr<bool>("trans_x");
bool transpose_y = ctx.Attr<bool>("trans_y"); const bool trans_y = ctx.Attr<bool>("trans_y");
if (x->dims().size() == 2) { std::vector<int64_t> x_dims = framework::vectorize(X->dims());
out->mutable_data<T>(ctx.GetPlace()); std::vector<int64_t> y_dims = framework::vectorize(Y->dims());
std::vector<int64_t> out_dims = framework::vectorize(Out->dims());
const auto& runner = NpuOpRunner( int x_ndim = x_dims.size();
"MatMul", {*x, *y}, {*out}, int y_ndim = y_dims.size();
{{"transpose_x1", transpose_x}, {"transpose_x2", transpose_y}}); int out_ndim = out_dims.size();
auto stream =
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream);
} else if (x->dims().size() > 2) { auto stream = ctx.template device_context<NPUDeviceContext>().stream();
out->mutable_data<T>(ctx.GetPlace());
const auto& runner = // Case 1: [K] x [K] = [1]
NpuOpRunner("BatchMatMul", {*x, *y}, {*out}, if (x_ndim == 1 && y_ndim == 1) {
{{"adj_x1", transpose_x}, {"adj_x2", transpose_y}}); PADDLE_ENFORCE_EQ(
X->numel(), Y->numel(),
platform::errors::InvalidArgument(
"X's numbers must be equal to Y's numbers,"
"when X/Y's dims =1. But received X has [%d] elements,"
"received Y has [%d] elements",
X->numel(), Y->numel()));
Out->Resize({1});
Out->mutable_data<T>(ctx.GetPlace());
auto stream = const auto& runner = NpuOpRunner("Dot", {*X, *Y}, {*Out});
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
runner.Run(stream); runner.Run(stream);
return;
}
// Resize dim 1 to 2
Tensor x_temp, y_temp;
x_temp.ShareDataWith(*X);
y_temp.ShareDataWith(*Y);
if (x_ndim == 1) {
x_dims.insert(x_dims.begin(), 1);
out_dims.insert(out_dims.end() - 1, 1);
x_temp.Resize(framework::make_ddim(x_dims));
x_ndim = 2;
out_ndim += 1;
}
if (y_ndim == 1) {
y_dims.push_back(1);
out_dims.push_back(1);
y_temp.Resize(framework::make_ddim(y_dims));
y_ndim = 2;
out_ndim += 1;
}
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
if (trans_y) {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 1], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 1, K, y_ndim - 1, y_dims[y_ndim - 1]));
} else {
PADDLE_ENFORCE_EQ(y_dims[y_ndim - 2], K,
platform::errors::InvalidArgument(
"Input(Y) has error dim."
"Y'dims[%d] must be equal to %d"
"But received Y'dims[%d] is %d",
y_ndim - 2, K, y_ndim - 2, y_dims[y_ndim - 2]));
} }
// Case 2: [M, K] x [K, N] = [M, N]
if (x_ndim == 2 && y_ndim == 2) {
MatMul2D<T>(ctx, stream, x_temp, y_temp, Out, trans_x, trans_y);
return;
}
// Case 3: [B, M, K] x [K, N] = [B, M, N], when trans_x = false
// Equal: [B * M, K] x [K, N] = [B * M, N] => [B, M, N]
if (trans_x == false && y_ndim == 2) {
std::vector<int64_t> vec_dim = {x_temp.numel() / K, K};
x_temp.Resize(framework::make_ddim(vec_dim));
MatMul2D<T>(ctx, stream, x_temp, y_temp, Out, trans_x, trans_y);
return;
}
// Case 4: [B, M, K] x [B, K, N] = [B, M, N]
std::vector<int64_t> x_broadcast_dims(out_ndim, 1);
std::vector<int64_t> y_broadcast_dims(out_ndim, 1);
std::copy(out_dims.begin(), out_dims.end() - 2, x_broadcast_dims.begin());
std::copy(out_dims.begin(), out_dims.end() - 2, y_broadcast_dims.begin());
std::copy(x_dims.end() - 2, x_dims.end(), x_broadcast_dims.end() - 2);
std::copy(y_dims.end() - 2, y_dims.end(), y_broadcast_dims.end() - 2);
Tensor x_temp_brd(X->type());
if (x_dims == x_broadcast_dims) {
x_temp_brd.ShareDataWith(*X);
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
} else {
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
x_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(x_temp)
.AddInput(std::move(x_broadcast_dims))
.AddOutput(x_temp_brd)
.Run(stream);
}
Tensor y_temp_brd(Y->type());
if (y_dims == y_broadcast_dims) {
y_temp_brd.ShareDataWith(*Y);
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
} else {
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
y_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(y_temp)
.AddInput(std::move(y_broadcast_dims))
.AddOutput(y_temp_brd)
.Run(stream);
}
MatMulND<T>(ctx, stream, x_temp_brd, y_temp_brd, Out, trans_x, trans_y);
} }
}; };
template <typename DeviceContext, typename T> template <typename T>
class MatMulV2GradNPUKernel : public framework::OpKernel<T> { class MatMulV2GradNPUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
auto* x = ctx.Input<framework::Tensor>("X"); auto* X = ctx.Input<Tensor>("X");
auto* y = ctx.Input<framework::Tensor>("Y"); auto* Y = ctx.Input<Tensor>("Y");
auto* dout = ctx.Input<framework::Tensor>(framework::GradVarName("Out")); auto* dOut = ctx.Input<Tensor>(framework::GradVarName("Out"));
auto* dx = ctx.Output<framework::Tensor>(framework::GradVarName("X")); auto* dX = ctx.Output<Tensor>(framework::GradVarName("X"));
auto* dy = ctx.Output<framework::Tensor>(framework::GradVarName("Y")); auto* dY = ctx.Output<Tensor>(framework::GradVarName("Y"));
bool transpose_y = ctx.Attr<bool>("trans_y"); const bool trans_x = ctx.Attr<bool>("trans_x");
auto stream = const bool trans_y = ctx.Attr<bool>("trans_y");
ctx.template device_context<paddle::platform::NPUDeviceContext>()
.stream();
if (x->dims().size() == 2) {
if (transpose_y) {
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
const auto& runner_dx =
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
{{"transpose_x1", false}, {"transpose_x2", false}});
runner_dx.Run(stream);
}
if (dy) {
dy->mutable_data<T>(ctx.GetPlace());
const auto& runner_dy =
NpuOpRunner("MatMul", {*dout, *x}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream); std::vector<int64_t> x_dims = framework::vectorize(X->dims());
} std::vector<int64_t> y_dims = framework::vectorize(Y->dims());
std::vector<int64_t> out_dims = framework::vectorize(dOut->dims());
int x_ndim = x_dims.size();
int y_ndim = y_dims.size();
int out_ndim = out_dims.size();
} else { auto stream = ctx.template device_context<NPUDeviceContext>().stream();
if (dx) {
dx->mutable_data<T>(ctx.GetPlace());
const auto& runner_dx =
NpuOpRunner("MatMul", {*dout, *y}, {*dx},
{{"transpose_x1", false}, {"transpose_x2", true}});
runner_dx.Run(stream); // Case 1: [K] x [K] = [1]
} if (x_ndim == 1 && y_ndim == 1) {
if (dy) { Tensor dout_temp(dOut->type());
dy->mutable_data<T>(ctx.GetPlace()); dout_temp.Resize(X->dims());
const auto& runner_dy = dout_temp.mutable_data<T>(ctx.GetPlace());
NpuOpRunner("MatMul", {*x, *dout}, {*dy}, NpuOpRunner runner;
{{"transpose_x1", true}, {"transpose_x2", false}}); runner.SetType("BroadcastTo")
.AddInput(*dOut)
.AddInput(std::move(x_dims))
.AddOutput(dout_temp)
.Run(stream);
runner_dy.Run(stream); if (dX) {
dX->mutable_data<T>(ctx.GetPlace());
const auto& runner_dx = NpuOpRunner("Mul", {dout_temp, *Y}, {*dX}, {});
runner_dx.Run(stream);
}
if (dY) {
dY->mutable_data<T>(ctx.GetPlace());
const auto& runner_dy = NpuOpRunner("Mul", {dout_temp, *X}, {*dY}, {});
runner_dy.Run(stream);
}
return;
}
// Resize dim 1 to 2
Tensor x_temp, y_temp, dout_temp;
x_temp.ShareDataWith(*X);
y_temp.ShareDataWith(*Y);
dout_temp.ShareDataWith(*dOut);
if (x_ndim == 1) {
x_dims.insert(x_dims.begin(), 1);
out_dims.insert(out_dims.end() - 1, 1);
x_temp.Resize(framework::make_ddim(x_dims));
dout_temp.Resize(framework::make_ddim(out_dims));
x_ndim = 2;
out_ndim += 1;
}
if (y_ndim == 1) {
y_dims.push_back(1);
out_dims.push_back(1);
y_temp.Resize(framework::make_ddim(y_dims));
dout_temp.Resize(framework::make_ddim(out_dims));
y_ndim = 2;
out_ndim += 1;
}
// Case 2: [M, K] x [K, N] = [M, N]
if (out_ndim == 2) {
if (dX) {
dX->Resize(framework::make_ddim(x_dims));
if (trans_x) {
MatMul2D<T>(ctx, stream, y_temp, dout_temp, dX, trans_y, true);
} else {
MatMul2D<T>(ctx, stream, dout_temp, y_temp, dX, false, !trans_y);
} }
dX->Resize(X->dims());
} }
} else if (x->dims().size() > 2) { if (dY) {
if (transpose_y) { dY->Resize(framework::make_ddim(y_dims));
if (dx) { if (trans_y) {
dx->mutable_data<T>(ctx.GetPlace()); MatMul2D<T>(ctx, stream, dout_temp, x_temp, dY, true, trans_x);
const auto& runner_dx = } else {
NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, MatMul2D<T>(ctx, stream, x_temp, dout_temp, dY, !trans_x, false);
{{"adj_x1", false}, {"adj_x2", false}});
runner_dx.Run(stream);
} }
if (dy) { dY->Resize(Y->dims());
dy->mutable_data<T>(ctx.GetPlace()); }
const auto& runner_dy = return;
NpuOpRunner("BatchMatMul", {*dout, *x}, {*dy}, }
{{"adj_x1", true}, {"adj_x2", false}});
const int K = trans_x ? x_dims[x_ndim - 2] : x_dims[x_ndim - 1];
const int N = trans_y ? y_dims[y_ndim - 2] : y_dims[y_ndim - 1];
runner_dy.Run(stream); // Case 3: [B, M, K] x [K, N] = [B, M, N], when trans_x = false
// Equal: [B * M, K] x [K, N] = [B * M, N] => [B, M, N]
if (trans_x == false && y_ndim == 2) {
std::vector<int64_t> x_vec_dim = {x_temp.numel() / K, K};
dout_temp.Resize(
framework::make_ddim(std::vector<int64_t>{dout_temp.numel() / N, N}));
if (dX) {
dX->Resize(framework::make_ddim(x_vec_dim));
MatMul2D<T>(ctx, stream, dout_temp, y_temp, dX, false, !trans_y);
dX->Resize(X->dims());
}
if (dY) {
x_temp.Resize(framework::make_ddim(x_vec_dim));
if (trans_y) {
MatMul2D<T>(ctx, stream, dout_temp, x_temp, dY, true, false);
} else {
MatMul2D<T>(ctx, stream, x_temp, dout_temp, dY, true, false);
} }
} else { }
if (dx) { return;
dx->mutable_data<T>(ctx.GetPlace()); }
const auto& runner_dx =
NpuOpRunner("BatchMatMul", {*dout, *y}, {*dx}, // Case 4: [B, M, K] x [B, K, N] = [B, M, N]
{{"adj_x1", false}, {"adj_x2", true}}); std::vector<int64_t> x_broadcast_dims(out_ndim, 1);
std::vector<int64_t> y_broadcast_dims(out_ndim, 1);
std::copy(out_dims.begin(), out_dims.end() - 2, x_broadcast_dims.begin());
std::copy(out_dims.begin(), out_dims.end() - 2, y_broadcast_dims.begin());
std::copy(x_dims.end() - 2, x_dims.end(), x_broadcast_dims.end() - 2);
std::copy(y_dims.end() - 2, y_dims.end(), y_broadcast_dims.end() - 2);
Tensor x_temp_brd(X->type());
if (x_dims == x_broadcast_dims) {
x_temp_brd.ShareDataWith(*X);
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
} else {
x_temp_brd.Resize(framework::make_ddim(x_broadcast_dims));
x_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(x_temp)
.AddInput(std::move(x_broadcast_dims))
.AddOutput(x_temp_brd)
.Run(stream);
}
runner_dx.Run(stream); Tensor y_temp_brd(Y->type());
if (y_dims == y_broadcast_dims) {
y_temp_brd.ShareDataWith(*Y);
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
} else {
y_temp_brd.Resize(framework::make_ddim(y_broadcast_dims));
y_temp_brd.mutable_data<T>(ctx.GetPlace());
NpuOpRunner runner_brd;
runner_brd.SetType("BroadcastTo")
.AddInput(y_temp)
.AddInput(std::move(y_broadcast_dims))
.AddOutput(y_temp_brd)
.Run(stream);
}
if (dX) {
if (x_dims == x_broadcast_dims) {
if (trans_x) {
MatMulND<T>(ctx, stream, y_temp_brd, dout_temp, dX, trans_y, true);
} else {
MatMulND<T>(ctx, stream, dout_temp, y_temp_brd, dX, false, !trans_y);
} }
if (dy) { } else {
dy->mutable_data<T>(ctx.GetPlace()); Tensor dx_temp(X->type());
if ((x->dims().size() == 3) && (dout->dims().size() == 3) && dx_temp.Resize(framework::make_ddim(x_broadcast_dims));
(dy->dims().size() == 2)) { if (trans_x) {
framework::Tensor dout_tmp; MatMulND<T>(ctx, stream, y_temp_brd, dout_temp, &dx_temp, trans_y,
dout_tmp.ShareDataWith(*dout); true);
std::vector<int> vec_dim = } else {
framework::vectorize<int>(dout_tmp.dims()); MatMulND<T>(ctx, stream, dout_temp, y_temp_brd, &dx_temp, false,
std::vector<int> vec_dim_v{vec_dim[0] * vec_dim[1], vec_dim[2]}; !trans_y);
dout_tmp.Resize(framework::make_ddim(vec_dim_v));
framework::Tensor x_tmp;
x_tmp.ShareDataWith(*x);
std::vector<int> vec_dim_x =
framework::vectorize<int>(x_tmp.dims());
std::vector<int> vec_dim_x_v{vec_dim_x[0] * vec_dim_x[1],
vec_dim_x[2]};
x_tmp.Resize(framework::make_ddim(vec_dim_x_v));
const auto& runner_dy =
NpuOpRunner("MatMul", {x_tmp, dout_tmp}, {*dy},
{{"transpose_x1", true}, {"transpose_x2", false}});
runner_dy.Run(stream);
} else {
const auto& runner_dy =
NpuOpRunner("BatchMatMul", {*x, *dout}, {*dy},
{{"adj_x1", true}, {"adj_x2", false}});
runner_dy.Run(stream);
}
} }
ReduceDims<T>(ctx, stream, x_dims, x_broadcast_dims, dx_temp, dX);
}
}
if (dY) {
if (y_dims == y_broadcast_dims) {
if (trans_y) {
MatMulND<T>(ctx, stream, dout_temp, x_temp_brd, dY, true, trans_x);
} else {
MatMulND<T>(ctx, stream, x_temp_brd, dout_temp, dY, !trans_x, false);
}
} else {
Tensor dy_temp(Y->type());
dy_temp.Resize(framework::make_ddim(y_broadcast_dims));
if (trans_y) {
MatMulND<T>(ctx, stream, dout_temp, x_temp_brd, &dy_temp, true,
trans_x);
} else {
MatMulND<T>(ctx, stream, x_temp_brd, dout_temp, &dy_temp, !trans_x,
false);
}
ReduceDims<T>(ctx, stream, y_dims, y_broadcast_dims, dy_temp, dY);
} }
} }
} }
}; };
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_NPU_KERNEL( REGISTER_OP_NPU_KERNEL(matmul_v2, ops::MatMulV2NPUKernel<float>,
matmul_v2, ops::MatMulV2NPUKernel<paddle::platform::float16>);
ops::MatMulV2NPUKernel<paddle::platform::NPUDeviceContext, float>, REGISTER_OP_NPU_KERNEL(matmul_v2_grad, ops::MatMulV2GradNPUKernel<float>,
ops::MatMulV2NPUKernel<paddle::platform::NPUDeviceContext, ops::MatMulV2GradNPUKernel<paddle::platform::float16>);
paddle::platform::float16>);
REGISTER_OP_NPU_KERNEL(
matmul_v2_grad,
ops::MatMulV2GradNPUKernel<paddle::platform::NPUDeviceContext, float>,
ops::MatMulV2GradNPUKernel<paddle::platform::NPUDeviceContext,
paddle::platform::float16>);
...@@ -55,6 +55,7 @@ __all__ = [ ...@@ -55,6 +55,7 @@ __all__ = [
'is_compiled_with_cuda', 'is_compiled_with_cuda',
'is_compiled_with_rocm', 'is_compiled_with_rocm',
'is_compiled_with_xpu', 'is_compiled_with_xpu',
'is_compiled_with_npu',
'Variable', 'Variable',
'require_version', 'require_version',
'device_guard', 'device_guard',
...@@ -380,6 +381,15 @@ def _xpu_ids(): ...@@ -380,6 +381,15 @@ def _xpu_ids():
return device_ids return device_ids
def _npu_ids():
npus_env = os.getenv("FLAGS_selected_npus")
if npus_env:
device_ids = [int(s) for s in npus_env.split(",")]
else:
device_ids = six.moves.range(core.get_npu_device_count())
return device_ids
def is_compiled_with_xpu(): def is_compiled_with_xpu():
""" """
Whether this whl package can be used to run the model on XPU. Whether this whl package can be used to run the model on XPU.
...@@ -395,6 +405,21 @@ def is_compiled_with_xpu(): ...@@ -395,6 +405,21 @@ def is_compiled_with_xpu():
return core.is_compiled_with_xpu() return core.is_compiled_with_xpu()
def is_compiled_with_npu():
"""
Whether this whl package can be used to run the model on NPU.
Returns (bool): support npu or not.
Examples:
.. code-block:: python
import paddle.fluid as fluid
support_npu = fluid.is_compiled_with_npu()
"""
return core.is_compiled_with_npu()
def disable_signal_handler(): def disable_signal_handler():
""" """
Reset signal handler registered by Paddle. Reset signal handler registered by Paddle.
...@@ -538,6 +563,47 @@ def xpu_places(device_ids=None): ...@@ -538,6 +563,47 @@ def xpu_places(device_ids=None):
return [core.XPUPlace(dev_id) for dev_id in device_ids] return [core.XPUPlace(dev_id) for dev_id in device_ids]
def npu_places(device_ids=None):
"""
**Note**:
For multi-card tasks, please use `FLAGS_selected_npus` environment variable to set the visible NPU device.
This function creates a list of :code:`paddle.NPUPlace` objects.
If :code:`device_ids` is None, environment variable of
:code:`FLAGS_selected_npus` would be checked first. For example, if
:code:`FLAGS_selected_npus=0,1,2`, the returned list would
be [paddle.NPUPlace(0), paddle.NPUPlace(1), paddle.NPUPlace(2)].
If :code:`FLAGS_selected_npus` is not set, all visible
npu places would be returned.
If :code:`device_ids` is not None, it should be the device
ids of NPUs. For example, if :code:`device_ids=[0,1,2]`,
the returned list would be
[paddle.NPUPlace(0), paddle.NPUPlace(1), paddle.NPUPlace(2)].
Parameters:
device_ids (list or tuple of int, optional): list of NPU device ids.
Returns:
list of paddle.NPUPlace: Created NPU place list.
Examples:
.. code-block:: python
# required: npu
import paddle
import paddle.static as static
paddle.enable_static()
npu_places = static.npu_places()
"""
assert core.is_compiled_with_npu(), \
"Not compiled with NPU"
if device_ids is None:
device_ids = _npu_ids()
elif not isinstance(device_ids, (list, tuple)):
device_ids = [device_ids]
return [core.NPUPlace(dev_id) for dev_id in device_ids]
def cpu_places(device_count=None): def cpu_places(device_count=None):
""" """
This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list. This function creates a list of :code:`paddle.CPUPlace` objects, and returns the created list.
...@@ -1927,6 +1993,10 @@ class Variable(object): ...@@ -1927,6 +1993,10 @@ class Variable(object):
p = core.Place() p = core.Place()
p.set_place(t._place()) p.set_place(t._place())
place = core.XPUPlace(p.xpu_device_id()) place = core.XPUPlace(p.xpu_device_id())
elif p.is_npu_place():
p = core.Place()
p.set_place(t._place())
place = core.NPUPlace(p.npu_device_id())
else: else:
p = core.Place() p = core.Place()
p.set_place(t._place()) p.set_place(t._place())
......
...@@ -20,4 +20,5 @@ if (WITH_ASCEND_CL) ...@@ -20,4 +20,5 @@ if (WITH_ASCEND_CL)
set_tests_properties(test_stack_op_npu PROPERTIES TIMEOUT 300) set_tests_properties(test_stack_op_npu PROPERTIES TIMEOUT 300)
set_tests_properties(test_conv2d_transpose_op_npu PROPERTIES TIMEOUT 200) set_tests_properties(test_conv2d_transpose_op_npu PROPERTIES TIMEOUT 200)
set_tests_properties(test_conv2d_op_npu PROPERTIES TIMEOUT 300) set_tests_properties(test_conv2d_op_npu PROPERTIES TIMEOUT 300)
set_tests_properties(test_matmulv2_op_npu PROPERTIES TIMEOUT 300)
endif() endif()
...@@ -21,56 +21,35 @@ sys.path.append("..") ...@@ -21,56 +21,35 @@ sys.path.append("..")
from op_test import OpTest from op_test import OpTest
import paddle import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
from test_matmul_v2_op import reference_matmul
paddle.enable_static() paddle.enable_static()
SEED = 2021 SEED = 2021
def reference_matmul(X, Y, transpose_X=False, transpose_Y=False): class TestMatMulV2Op(OpTest):
"""Reference forward implementation using np.matmul.""" """
# np.matmul does not support the transpose flags, so we manually case 1
# transpose X and Y appropriately. """
if transpose_X:
if X.ndim == 1: def set_npu(self):
X = X.reshape((X.size)) self.__class__.use_npu = True
elif X.ndim == 2: self.place = paddle.NPUPlace(0)
X = X.T
else:
dim = [i for i in range(len(X.shape))]
dim[-1], dim[len(X.shape) - 2] = dim[len(X.shape) - 2], dim[-1]
X = np.transpose(X, tuple(dim))
if transpose_Y:
if Y.ndim == 1:
Y = Y.reshape((Y.size))
else:
dim = [i for i in range(len(Y.shape))]
dim[-1], dim[len(Y.shape) - 2] = dim[len(Y.shape) - 2], dim[-1]
Y = np.transpose(Y, tuple(dim))
Out = np.matmul(X, Y)
if not Out.shape:
# We do not support 0-dimensional Tensors (scalars). So where
# np.matmul outputs a scalar, we must convert to a Tensor of
# shape (1) instead.
# Everywhere else, we are compatible with np.matmul.
Out = np.array([Out], dtype="float64")
return Out
class TestMatMul(OpTest):
def config(self): def config(self):
self.x_shape = (100, 24) self.x_shape = (100, )
self.y_shape = (24, 100) self.y_shape = (100, )
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
def init_kernel_type(self):
self.dtype = "float32"
def setUp(self): def setUp(self):
self.set_npu() self.set_npu()
self.op_type = "matmul_v2" self.init_kernel_type()
self.place = paddle.NPUPlace(0)
self.init_dtype()
self.config() self.config()
np.random.seed(SEED) self.op_type = "matmul_v2"
x = np.random.random(self.x_shape).astype(self.dtype) x = np.random.random(self.x_shape).astype(self.dtype)
y = np.random.random(self.y_shape).astype(self.dtype) y = np.random.random(self.y_shape).astype(self.dtype)
# -0.1 ~ 0.1 # -0.1 ~ 0.1
...@@ -85,201 +64,314 @@ class TestMatMul(OpTest): ...@@ -85,201 +64,314 @@ class TestMatMul(OpTest):
self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y} self.attrs = {'trans_x': self.trans_x, 'trans_y': self.trans_y}
self.outputs = {'Out': result} self.outputs = {'Out': result}
def set_npu(self):
self.__class__.use_npu = True
self.__class__.no_need_check_grad = True
def init_dtype(self):
self.dtype = np.float32
def test_check_output(self): def test_check_output(self):
self.check_output_with_place(self.place, atol=1e-5) self.check_output_with_place(self.place, atol=1e-7)
def test_check_grad(self):
self.check_grad_with_place(self.place, ['X', 'Y'], 'Out')
# TODO(ascendrc): Add grad test class TestMatMuklOp2(TestMatMulV2Op):
# def test_check_grad(self):
# if self.dtype == np.float16:
# return
# self.check_grad(['X'], 'Out')
#
class TestMatMul2(TestMatMul):
""" """
case 2 case 2
""" """
def config(self): def config(self):
self.x_shape = (32, 24) self.x_shape = (100, )
self.y_shape = (32, 24) self.y_shape = (1, 3, 2, 100)
self.trans_x = False self.trans_x = False
self.trans_y = True self.trans_y = True
class TestMatMul3(TestMatMul): class TestMatMuklOp3(TestMatMulV2Op):
""" """
case 3 case 3
""" """
def init_dtype(self): def config(self):
self.dtype = np.float16 self.x_shape = (100, )
self.y_shape = (1, 1, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMul4(TestMatMul): class TestMatMuklOp4(TestMatMulV2Op):
""" """
case 4 dim=3 case 4
""" """
def config(self): def config(self):
self.x_shape = (2, 3, 4) self.x_shape = (100, )
self.y_shape = (2, 4, 3) self.y_shape = (1, 2, 100, 2)
self.trans_x = False self.trans_x = False
self.trans_y = False self.trans_y = False
class TestMatMulNet(unittest.TestCase): class TestMatMuklOp5(TestMatMulV2Op):
def _test(self, run_npu=True): """
main_prog = paddle.static.Program() case 5
startup_prog = paddle.static.Program() """
main_prog.random_seed = SEED
startup_prog.random_seed = SEED def config(self):
np.random.seed(SEED) self.x_shape = (1, 1, 100, 1)
self.y_shape = (100, )
a_np = np.random.random(size=(2, 3)).astype('float32') self.trans_x = True
b_np = np.random.random(size=(2, 3)).astype('float32') self.trans_y = False
c_np = np.random.random(size=(3, 2)).astype('float32')
d_np = np.random.random(size=(3, 2)).astype('float32')
label_np = np.random.randint(2, size=(2, 1)).astype('int64') class TestMatMuklOp6(TestMatMulV2Op):
"""
with paddle.static.program_guard(main_prog, startup_prog): case 6
a = paddle.static.data(name="a", shape=[2, 3], dtype='float32') """
b = paddle.static.data(name="b", shape=[2, 3], dtype='float32')
c = paddle.static.data(name="c", shape=[3, 2], dtype='float32') def config(self):
d = paddle.static.data(name="d", shape=[3, 2], dtype='float32') self.x_shape = (1, 2, 102, 1)
label = paddle.static.data( self.y_shape = (102, )
name="label", shape=[2, 1], dtype='int64') self.trans_x = True
self.trans_y = False
sum_1 = paddle.add(a, b)
sum_2 = paddle.add(c, d)
result = paddle.matmul(sum_1, sum_2) class TestMatMuklOp7(TestMatMulV2Op):
"""
fc_1 = fluid.layers.fc(input=result, size=8) case 7
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax') """
cost = fluid.layers.cross_entropy(input=prediction, label=label) def config(self):
loss = fluid.layers.reduce_mean(cost) self.x_shape = (1, 2, 1, 100)
sgd = fluid.optimizer.SGD(learning_rate=0.01) self.y_shape = (100, )
sgd.minimize(loss) self.trans_x = False
self.trans_y = False
if run_npu:
place = paddle.NPUPlace(0)
else: class TestMatMuklOp8(TestMatMulV2Op):
place = paddle.CPUPlace() """
exe = paddle.static.Executor(place) case 8
exe.run(startup_prog) """
print("Start run on {}".format(place)) def config(self):
for epoch in range(100): self.x_shape = (1, 1, 2, 100)
self.y_shape = (1, 1, 100, 2)
pred_res, loss_res = exe.run(main_prog, self.trans_x = False
feed={ self.trans_y = False
"a": a_np,
"b": b_np,
"c": c_np, class TestMatMuklOp9(TestMatMulV2Op):
"d": d_np, """
"label": label_np case 9
}, """
fetch_list=[prediction, loss])
if epoch % 10 == 0: def config(self):
print("Epoch {} | Prediction[0]: {}, Loss: {}".format( self.x_shape = (1, 1, 1, 100)
epoch, pred_res[0], loss_res)) self.y_shape = (2, 1, 2, 100)
self.trans_x = False
return pred_res, loss_res self.trans_y = True
def test_npu(self):
cpu_pred, cpu_loss = self._test(False) class TestMatMuklOp10(TestMatMulV2Op):
npu_pred, npu_loss = self._test(True) """
case 10
self.assertTrue(np.allclose(npu_pred, cpu_pred)) """
self.assertTrue(np.allclose(npu_loss, cpu_loss))
def config(self):
self.x_shape = (1, 1, 25, 4)
# The precision is aligned in NPU and GPU separately, which is only used for the usage method. self.y_shape = (1, 2, 4, 25)
self.trans_x = False
self.trans_y = False
class TestMatMulNet3_2(unittest.TestCase):
def _test(self, run_npu=True):
main_prog = paddle.static.Program() class TestMatMuklOp11(TestMatMulV2Op):
startup_prog = paddle.static.Program() """
main_prog.random_seed = SEED case 11
startup_prog.random_seed = SEED """
np.random.seed(SEED)
self._dtype = "float32" def config(self):
self.x_shape = (2, 1, 2, 100)
a_np = np.random.random(size=(2, 1, 3)).astype(self._dtype) self.y_shape = (1, 1, 100, 2)
b_np = np.random.random(size=(2, 1, 3)).astype(self._dtype) self.trans_x = False
c_np = np.random.random(size=(3, 2)).astype(self._dtype) self.trans_y = False
d_np = np.random.random(size=(3, 2)).astype(self._dtype)
label_np = np.random.randint(2, size=(2, 1)).astype('int64')
class TestMatMuklOp12(TestMatMulV2Op):
with paddle.static.program_guard(main_prog, startup_prog): """
a = paddle.static.data(name="a", shape=[2, 1, 3], dtype=self._dtype) case 12
b = paddle.static.data(name="b", shape=[2, 1, 3], dtype=self._dtype) """
c = paddle.static.data(name="c", shape=[3, 2], dtype=self._dtype)
d = paddle.static.data(name="d", shape=[3, 2], dtype=self._dtype) def config(self):
label = paddle.static.data( self.x_shape = (2, 1, 4, 25)
name="label", shape=[2, 1], dtype='int64') self.y_shape = (1, 1, 4, 25)
self.trans_x = True
sum_1 = paddle.add(a, b) self.trans_y = False
sum_2 = paddle.add(c, d)
sum_1 = paddle.cast(sum_1, 'float16')
sum_2 = paddle.cast(sum_2, 'float16') class TestMatMuklOp13(TestMatMulV2Op):
if not run_npu: """
sum_1 = paddle.cast(sum_1, 'float32') case 13
sum_2 = paddle.cast(sum_2, 'float32') """
result = paddle.matmul(sum_1, sum_2) def config(self):
if run_npu: self.x_shape = (2, 2, 10, 10)
result = paddle.cast(result, 'float32') self.y_shape = (2, 2, 10, 10)
self.trans_x = True
result = paddle.reshape(result, shape=[2, 2]) self.trans_y = False
fc_1 = fluid.layers.fc(input=result, size=8)
prediction = fluid.layers.fc(input=fc_1, size=2, act='softmax')
class TestMatMuklOp14(TestMatMulV2Op):
cost = fluid.layers.cross_entropy(input=prediction, label=label) """
loss = fluid.layers.reduce_mean(cost) case 14_1
sgd = fluid.optimizer.SGD(learning_rate=0.01) """
sgd.minimize(loss)
def config(self):
if run_npu: self.x_shape = (3, 1, 6, 6)
self.y_shape = (1, 2, 6, 9)
self.trans_x = True
self.trans_y = False
class TestMatMuklOp15(TestMatMulV2Op):
"""
case 14_2
"""
def config(self):
self.x_shape = (3, 1, 6, 6)
self.y_shape = (1, 2, 6, 9)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp16(TestMatMulV2Op):
"""
case 16 : to check the gradient for special case
"""
def config(self):
self.x_shape = (100)
self.y_shape = (1, 2, 2, 100, 2)
self.trans_x = False
self.trans_y = False
class TestMatMuklOp17(TestMatMulV2Op):
"""
case 17 : to check the gradient for special case
"""
def config(self):
self.x_shape = (2, 1, 100)
self.y_shape = (100)
self.trans_x = False
self.trans_y = False
class TestMatMuklOpBroadcast1(TestMatMulV2Op):
"""
case 14_3
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = True
self.trans_y = True
class TestMatMuklOpBroadcast2(TestMatMulV2Op):
"""
case 14_4
"""
def config(self):
self.x_shape = (3, 1, 10, 10)
self.y_shape = (1, 2, 10, 10)
self.trans_x = False
self.trans_y = True
#--------------------test matmul fp16--------------------
def create_test_fp16_class(parent, atol=0.001, max_relative_error=2.5):
class TestMatMulOpFp16Case(parent):
def init_kernel_type(self):
self.dtype = np.float16
def test_check_output(self):
self.check_output_with_place(self.place, atol=atol)
def test_check_grad(self):
self.check_grad_with_place(
self.place, ['X', 'Y'],
'Out',
max_relative_error=max_relative_error)
cls_name = "{0}_{1}".format(parent.__name__, "Fp16")
TestMatMulOpFp16Case.__name__ = cls_name
globals()[cls_name] = TestMatMulOpFp16Case
create_test_fp16_class(TestMatMulV2Op)
create_test_fp16_class(TestMatMuklOp2)
create_test_fp16_class(TestMatMuklOp3)
create_test_fp16_class(TestMatMuklOp4)
create_test_fp16_class(TestMatMuklOp5)
create_test_fp16_class(TestMatMuklOp6)
create_test_fp16_class(TestMatMuklOp7)
create_test_fp16_class(TestMatMuklOp8)
create_test_fp16_class(TestMatMuklOp9)
create_test_fp16_class(TestMatMuklOp10)
create_test_fp16_class(TestMatMuklOp11)
create_test_fp16_class(TestMatMuklOp12)
create_test_fp16_class(TestMatMuklOp13)
create_test_fp16_class(TestMatMuklOp14)
create_test_fp16_class(TestMatMuklOp15)
create_test_fp16_class(TestMatMuklOp16)
create_test_fp16_class(TestMatMuklOp17)
class TestMatMulV2API(unittest.TestCase):
def setUp(self):
self.places = [paddle.CPUPlace()]
if paddle.is_compiled_with_npu():
self.places.append(paddle.NPUPlace(0))
def check_static_result(self, place):
with fluid.program_guard(fluid.Program(), fluid.Program()):
input_x = fluid.data(name="input_x", shape=[4, 3], dtype="float32")
input_y = fluid.data(name="input_y", shape=[3, 4], dtype="float32")
result = paddle.matmul(input_x, input_y)
x_np = np.random.random([4, 3]).astype("float32")
y_np = np.random.random([3, 4]).astype("float32")
exe = fluid.Executor(place)
fetches = exe.run(fluid.default_main_program(),
feed={"input_x": x_np,
"input_y": y_np},
fetch_list=[result])
def test_static(self):
for place in self.places:
self.check_static_result(place=place)
def test_dygraph(self):
for place in self.places:
with fluid.dygraph.guard(place):
input_x = np.random.random([4, 3]).astype("float32")
input_y = np.random.random([3, 4]).astype("float32")
x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
result = paddle.matmul(x, y)
def test_dygraph_fp16(self):
if paddle.is_compiled_with_npu():
place = paddle.NPUPlace(0) place = paddle.NPUPlace(0)
else: with fluid.dygraph.guard(place):
place = paddle.CPUPlace() input_x = np.random.random([4, 3]).astype("float16")
exe = paddle.static.Executor(place) input_y = np.random.random([3, 4]).astype("float16")
exe.run(startup_prog) x = paddle.to_tensor(input_x)
y = paddle.to_tensor(input_y)
print("Start run on {}".format(place)) result = paddle.matmul(x, y)
for epoch in range(100):
pred_res, loss_res = exe.run(main_prog,
feed={
"a": a_np,
"b": b_np,
"c": c_np,
"d": d_np,
"label": label_np
},
fetch_list=[prediction, loss])
if epoch % 10 == 0:
print("Epoch {} | Prediction[0]: {}, Loss: {}".format(
epoch, pred_res[0], loss_res))
return pred_res, loss_res
def test_npu(self):
cpu_pred, cpu_loss = self._test(False)
npu_pred, npu_loss = self._test(True)
self.assertTrue(np.allclose(npu_pred, cpu_pred, atol=1e-4))
self.assertTrue(np.allclose(npu_loss, cpu_loss, atol=1e-4))
if __name__ == '__main__': if __name__ == '__main__':
......
...@@ -43,6 +43,7 @@ from ..fluid.framework import program_guard # noqa: F401 ...@@ -43,6 +43,7 @@ from ..fluid.framework import program_guard # noqa: F401
from ..fluid.framework import cpu_places # noqa: F401 from ..fluid.framework import cpu_places # noqa: F401
from ..fluid.framework import cuda_places # noqa: F401 from ..fluid.framework import cuda_places # noqa: F401
from ..fluid.framework import xpu_places # noqa: F401 from ..fluid.framework import xpu_places # noqa: F401
from ..fluid.framework import npu_places # noqa: F401
from ..fluid.framework import Variable # noqa: F401 from ..fluid.framework import Variable # noqa: F401
from ..fluid.layers.control_flow import Print # noqa: F401 from ..fluid.layers.control_flow import Print # noqa: F401
from ..fluid.layers.nn import py_func # noqa: F401 from ..fluid.layers.nn import py_func # noqa: F401
...@@ -99,6 +100,7 @@ __all__ = [ #noqa ...@@ -99,6 +100,7 @@ __all__ = [ #noqa
'cpu_places', 'cpu_places',
'cuda_places', 'cuda_places',
'xpu_places', 'xpu_places',
'npu_places',
'Variable', 'Variable',
'create_global_var', 'create_global_var',
'accuracy', 'accuracy',
......
...@@ -74,7 +74,22 @@ def _is_cuda_available(): ...@@ -74,7 +74,22 @@ def _is_cuda_available():
return False return False
def _run_dygraph_single(use_cuda): def _is_npu_available():
"""
Check whether NPU is avaiable.
"""
try:
assert len(paddle.static.npu_places()) > 0
return True
except Exception as e:
logging.warning(
"You are using NPU version PaddlePaddle, but there is no NPU "
"detected on your machine. Maybe NPU devices is not set properly."
"\n Original Error is {}".format(e))
return False
def _run_dygraph_single(use_cuda, use_npu):
""" """
Testing the simple network in dygraph mode using one CPU/GPU. Testing the simple network in dygraph mode using one CPU/GPU.
...@@ -84,6 +99,8 @@ def _run_dygraph_single(use_cuda): ...@@ -84,6 +99,8 @@ def _run_dygraph_single(use_cuda):
paddle.disable_static() paddle.disable_static()
if use_cuda: if use_cuda:
paddle.set_device('gpu') paddle.set_device('gpu')
elif use_npu:
paddle.set_device('npu')
else: else:
paddle.set_device('cpu') paddle.set_device('cpu')
weight_attr = paddle.ParamAttr( weight_attr = paddle.ParamAttr(
...@@ -102,7 +119,7 @@ def _run_dygraph_single(use_cuda): ...@@ -102,7 +119,7 @@ def _run_dygraph_single(use_cuda):
opt.step() opt.step()
def _run_static_single(use_cuda): def _run_static_single(use_cuda, use_npu):
""" """
Testing the simple network with executor running directly, using one CPU/GPU. Testing the simple network with executor running directly, using one CPU/GPU.
...@@ -119,8 +136,14 @@ def _run_static_single(use_cuda): ...@@ -119,8 +136,14 @@ def _run_static_single(use_cuda):
param_grads = paddle.static.append_backward( param_grads = paddle.static.append_backward(
out, parameter_list=[weight.name])[0] out, parameter_list=[weight.name])[0]
exe = paddle.static.Executor( if use_cuda:
paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()) place = paddle.CUDAPlace(0)
elif use_npu:
place = paddle.NPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
exe.run(train_prog, exe.run(train_prog,
feed={input.name: _prepare_data(1)}, feed={input.name: _prepare_data(1)},
...@@ -128,7 +151,7 @@ def _run_static_single(use_cuda): ...@@ -128,7 +151,7 @@ def _run_static_single(use_cuda):
paddle.disable_static() paddle.disable_static()
def _run_static_parallel(use_cuda, device_list): def _run_static_parallel(use_cuda, use_npu, device_list):
""" """
Testing the simple network in data parallel mode, using multiple CPU/GPU. Testing the simple network in data parallel mode, using multiple CPU/GPU.
...@@ -150,8 +173,15 @@ def _run_static_parallel(use_cuda, device_list): ...@@ -150,8 +173,15 @@ def _run_static_parallel(use_cuda, device_list):
train_prog).with_data_parallel( train_prog).with_data_parallel(
loss_name=loss.name, places=device_list) loss_name=loss.name, places=device_list)
exe = paddle.static.Executor( if use_cuda:
paddle.CUDAPlace(0) if use_cuda else paddle.CPUPlace()) place = paddle.CUDAPlace(0)
elif use_npu:
place = paddle.NPUPlace(0)
compiled_prog = train_prog
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_prog) exe.run(startup_prog)
exe.run(compiled_prog, exe.run(compiled_prog,
feed={input.name: _prepare_data(len(device_list))}, feed={input.name: _prepare_data(len(device_list))},
...@@ -182,23 +212,31 @@ def run_check(): ...@@ -182,23 +212,31 @@ def run_check():
if paddle.is_compiled_with_cuda(): if paddle.is_compiled_with_cuda():
use_cuda = _is_cuda_available() use_cuda = _is_cuda_available()
use_npu = False
elif paddle.is_compiled_with_npu():
use_npu = _is_npu_available()
use_cuda = False
else: else:
use_npu = False
use_cuda = False use_cuda = False
if use_cuda: if use_cuda:
device_str = "GPU" device_str = "GPU"
device_list = paddle.static.cuda_places() device_list = paddle.static.cuda_places()
elif use_npu:
device_str = "NPU"
device_list = paddle.static.npu_places()
else: else:
device_str = "CPU" device_str = "CPU"
device_list = paddle.static.cpu_places(device_count=2) device_list = paddle.static.cpu_places(device_count=2)
device_count = len(device_list) device_count = len(device_list)
_run_static_single(use_cuda) _run_static_single(use_cuda, use_npu)
_run_dygraph_single(use_cuda) _run_dygraph_single(use_cuda, use_npu)
print("PaddlePaddle works well on 1 {}.".format(device_str)) print("PaddlePaddle works well on 1 {}.".format(device_str))
try: try:
_run_static_parallel(use_cuda, device_list) _run_static_parallel(use_cuda, use_npu, device_list)
print("PaddlePaddle works well on {} {}s.".format(device_count, print("PaddlePaddle works well on {} {}s.".format(device_count,
device_str)) device_str))
print( print(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册