未验证 提交 8753a2bf 编写于 作者: H houj04 提交者: GitHub

[XPU][NPU] (1) add device_guard. (2) add support for LoDTensorArray of sum op. (#44367)

* device_guard support xpu. test=kunlun

* sum op of xpu support LoDTensorArray. add test for while op of xpu. test=kunlun.
上级 28cb0067
......@@ -330,14 +330,12 @@ void apply_device_guard(const OperatorBase* op_base,
VLOG(3) << "Switch into CPUPlace by device_guard.";
expected_kernel_key->place_ = platform::CPUPlace();
} else if (op_device.find("gpu") != std::string::npos &&
(platform::is_gpu_place(place) ||
platform::is_npu_place(place))) {
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
platform::is_gpu_place(place)) {
// when the Op that does not have GPUKernel is assigned to GPU, the
// CPUKernel will be executed and a warning will be given at the same
// time.
if (op_base->SupportGPU()) {
expected_kernel_key->place_ = place;
} else if (op_base->SupportNPU()) {
expected_kernel_key->place_ = place;
} else {
expected_kernel_key->place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
......@@ -346,6 +344,36 @@ void apply_device_guard(const OperatorBase* op_base,
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else if (op_device.find("npu") != std::string::npos &&
platform::is_npu_place(place)) {
// when the Op that does not have NPUKernel is assigned to NPU, the
// CPUKernel will be executed and a warning will be given at the same
// time.
if (op_base->SupportNPU()) {
expected_kernel_key->place_ = place;
} else {
expected_kernel_key->place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
<< "Op(" << op_base->Type()
<< ") has no NPU implementation. It will be assigned to CPUPlace.";
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else if (op_device.find("xpu") != std::string::npos &&
platform::is_xpu_place(place)) {
// when the Op that does not have XPUKernel is assigned to XPU, the
// CPUKernel will be executed and a warning will be given at the same
// time.
if (op_base->SupportXPU()) {
expected_kernel_key->place_ = place;
} else {
expected_kernel_key->place_ = platform::CPUPlace();
LOG_FIRST_N(WARNING, 1)
<< "Op(" << op_base->Type()
<< ") has no XPU implementation. It will be assigned to CPUPlace.";
}
VLOG(3) << "Switch into " << expected_kernel_key->place_
<< " by device_guard.";
} else {
PADDLE_THROW(
platform::errors::Fatal("Unsupported current place %s", op_device));
......
......@@ -1274,6 +1274,43 @@ bool OperatorWithKernel::SupportNPU() const {
}
}
bool OperatorWithKernel::SupportXPU() const {
#ifdef PADDLE_WITH_XPU
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
phi::TransToPhiKernelName(type_));
auto has_phi_kernel =
std::any_of(phi_kernels.begin(),
phi_kernels.end(),
[](phi::KernelKeyMap::const_reference kern_pair) {
return kern_pair.first.backend() == phi::Backend::XPU;
});
if (has_phi_kernel) {
return true;
} else {
auto kernel_iter = OperatorWithKernel::AllOpKernels().find(type_);
if (kernel_iter == OperatorWithKernel::AllOpKernels().end()) {
return false;
} else {
auto& op_kernels = kernel_iter->second;
return std::any_of(
op_kernels.begin(),
op_kernels.end(),
[this](OpKernelMap::const_reference kern_pair) {
return platform::is_xpu_place(kern_pair.first.place_) &&
paddle::platform::is_xpu_support_op(type_,
kern_pair.first) &&
!paddle::platform::is_in_xpu_black_list(type_);
});
}
}
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"should not call OperatorWithKernel::SupportXPU() when not compiled with "
"XPU support."));
return false;
#endif
}
bool OperatorWithKernel::SupportsMKLDNN(
const proto::VarType::Type data_type) const {
auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap(
......@@ -1733,8 +1770,9 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
<< "Device index is only supported under pipeline parallelism, "
<< "so it will be ignored.";
}
// when the Op that only has CPUKernel is assigned to GPU, the CPUKernel
// will be executed and a warning will be given at the same time.
// when the Op that does not have GPUKernel is assigned to GPU, the
// CPUKernel will be executed and a warning will be given at the same
// time.
expected_kernel_key.place_ = platform::CPUPlace();
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (SupportGPU()) {
......@@ -1742,6 +1780,25 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
expected_kernel_key.place_ = dev_ctx.GetPlace();
}
#endif
if (platform::is_cpu_place(expected_kernel_key.place_)) {
LOG_FIRST_N(WARNING, 1)
<< "Op(" << type_
<< ") has no CUDA implementation. It will be assigned to CPUPlace.";
}
} else if (Attr<std::string>("op_device").find("npu") !=
std::string::npos) {
auto device = Attr<std::string>("op_device");
size_t pos = device.find(':');
if (pos != std::string::npos) {
device = device.substr(0, pos);
LOG_FIRST_N(WARNING, 1)
<< "Device index is only supported under pipeline parallelism, "
<< "so it will be ignored.";
}
// when the Op that does not have NPUKernel is assigned to NPU, the
// CPUKernel will be executed and a warning will be given at the same
// time.
expected_kernel_key.place_ = platform::CPUPlace();
#ifdef PADDLE_WITH_ASCEND_CL
if (SupportNPU()) {
auto& dev_ctx = ctx.device_context();
......@@ -1751,7 +1808,32 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType(
if (platform::is_cpu_place(expected_kernel_key.place_)) {
LOG_FIRST_N(WARNING, 1)
<< "Op(" << type_
<< ") has no CUDA implementation. It will be assigned to CPUPlace.";
<< ") has no NPU implementation. It will be assigned to CPUPlace.";
}
} else if (Attr<std::string>("op_device").find("xpu") !=
std::string::npos) {
auto device = Attr<std::string>("op_device");
size_t pos = device.find(':');
if (pos != std::string::npos) {
device = device.substr(0, pos);
LOG_FIRST_N(WARNING, 1)
<< "Device index is only supported under pipeline parallelism, "
<< "so it will be ignored.";
}
// when the Op that does not have XPUKernel is assigned to XPU, the
// CPUKernel will be executed and a warning will be given at the same
// time.
expected_kernel_key.place_ = platform::CPUPlace();
#ifdef PADDLE_WITH_XPU
if (SupportXPU()) {
auto& dev_ctx = ctx.device_context();
expected_kernel_key.place_ = dev_ctx.GetPlace();
}
#endif
if (platform::is_cpu_place(expected_kernel_key.place_)) {
LOG_FIRST_N(WARNING, 1)
<< "Op(" << type_
<< ") has no XPU implementation. It will be assigned to CPUPlace.";
}
}
}
......
......@@ -173,6 +173,7 @@ class OperatorBase {
virtual bool SupportGPU() const { return false; }
virtual bool SupportNPU() const { return false; }
virtual bool SupportMLU() const { return false; }
virtual bool SupportXPU() const { return false; }
const std::string& Type() const { return type_; }
......@@ -596,6 +597,9 @@ class OperatorWithKernel : public OperatorBase {
return platform::is_mlu_place(kern_pair.first.place_);
});
}
bool SupportXPU() const override;
bool SupportsMKLDNN(proto::VarType::Type data_type) const;
bool SupportsKernelType(const OpKernelType& kernel_type) const;
......
......@@ -225,16 +225,17 @@ bool GetCondData(const framework::LoDTensor &cond) {
return cond.data<bool>()[0];
}
// when platform::is_gpu_place(cond.place()) or
// platform::is_npu_place(cond.place()) is true
// platform::is_npu_place(cond.place()) or
// platform::is_xpu_place(cond.place()) is true
std::unique_ptr<framework::LoDTensor> cpu_cond{new framework::LoDTensor()};
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \
defined(PADDLE_WITH_ASCEND_CL)
defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU)
framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get());
#else
PADDLE_THROW(platform::errors::PreconditionNotMet(
"This version of PaddlePaddle does NOT support GPU/NPU but got GPU/NPU "
"tensor "
"Cond in WhileOp. Please compile WITH_GPU or WITH_ASCEND_CL option."));
"This version of PaddlePaddle does NOT support GPU/NPU/XPU but got "
"GPU/NPU/XPU tensor Cond in WhileOp. Please compile WITH_GPU or "
"WITH_ASCEND_CL or WITH_XPU option."));
#endif
return cpu_cond->data<bool>()[0];
}
......
......@@ -14,6 +14,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/operators/sum_op.h"
#include "paddle/fluid/platform/device/device_wrapper.h"
#include "paddle/fluid/platform/device/xpu/xpu_header.h"
namespace paddle {
......@@ -28,38 +29,93 @@ class SumXPUKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext &context) const override {
auto in_vars = context.MultiInputVar("X");
auto out_var = context.OutputVar("Out");
auto *out = context.Output<LoDTensor>("Out");
bool in_place = out_var == in_vars[0];
int N = in_vars.size();
PADDLE_ENFORCE_EQ(
out_var->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument("XPU only surpport LodTensor"));
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
}
auto &dev_ctx = context.template device_context<DeviceContext>();
std::vector<const XPUType *> ptrs;
for (int i = 0; i < N; ++i) {
PADDLE_ENFORCE_EQ(
in_vars[i]->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument("XPU only surpport LodTensor"));
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) {
continue;
if (out_var->IsType<framework::LoDTensor>()) {
auto *out = context.Output<LoDTensor>("Out");
bool in_place = out_var == in_vars[0];
int N = in_vars.size();
if (!in_place) {
out->mutable_data<T>(context.GetPlace());
}
auto &dev_ctx = context.template device_context<DeviceContext>();
std::vector<const XPUType *> ptrs;
for (int i = 0; i < N; ++i) {
PADDLE_ENFORCE_EQ(
in_vars[i]->IsType<framework::LoDTensor>(),
true,
platform::errors::InvalidArgument("XPU only support LodTensor"));
auto &in_t = in_vars[i]->Get<framework::LoDTensor>();
if (in_t.numel() == 0) {
continue;
}
ptrs.push_back(reinterpret_cast<const XPUType *>(in_t.data<T>()));
}
int r = xpu::sum(dev_ctx.x_context(),
ptrs,
reinterpret_cast<XPUType *>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum");
} else if (out_var->IsType<framework::LoDTensorArray>()) {
bool in_place = out_var == in_vars[0];
auto &out_array = *out_var->GetMutable<framework::LoDTensorArray>();
for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) {
PADDLE_ENFORCE_EQ(in_vars[i]->IsType<framework::LoDTensorArray>(),
true,
platform::errors::InvalidArgument(
"Only support all inputs are TensorArray, "
"but inputs[%d] is not TensorArray.",
i));
auto &in_array = in_vars[i]->Get<framework::LoDTensorArray>();
for (size_t i = 0; i < in_array.size(); ++i) {
if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) {
if (i >= out_array.size()) {
out_array.resize(i + 1);
}
if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) {
framework::TensorCopy(in_array[i],
in_array[i].place(),
context.device_context(),
&out_array[i]);
out_array[i].set_lod(in_array[i].lod());
} else {
PADDLE_ENFORCE_EQ(
out_array[i].lod(),
in_array[i].lod(),
platform::errors::InvalidArgument(
"The lod message between inputs[%d] and"
" outputs[%d] must be same, but now is not same.",
i,
i));
std::vector<const XPUType *> ptrs;
ptrs.push_back(
reinterpret_cast<const XPUType *>(in_array[i].data<T>()));
ptrs.push_back(
reinterpret_cast<const XPUType *>(out_array[i].data<T>()));
auto &dev_ctx = context.template device_context<DeviceContext>();
// int sum(Context* ctx, const std::vector<const T*>& x_list, T*
// y, int len);
int r =
xpu::sum(dev_ctx.x_context(),
ptrs,
reinterpret_cast<XPUType *>(out_array[i].data<T>()),
out_array[i].numel());
PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum");
}
}
}
}
ptrs.push_back(reinterpret_cast<const XPUType *>(in_t.data<T>()));
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Expected type of Output(out) must be Tensor or "
"LoDTensorArray. But got "
"unsupport type: %s.",
framework::ToTypeName(out_var->Type())));
}
int r = xpu::sum(dev_ctx.x_context(),
ptrs,
reinterpret_cast<XPUType *>(out->data<T>()),
out->numel());
PADDLE_ENFORCE_EQ(
r,
XPU_SUCCESS,
platform::errors::External(
"XPU sum kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r]));
}
};
......
......@@ -7078,9 +7078,9 @@ def device_guard(device=None):
device, index = device.split(':')
if device == 'cpu':
raise ValueError("Should not set device id for cpu.")
if device not in ['cpu', 'gpu', 'npu', '', None]:
if device not in ['cpu', 'gpu', 'npu', 'xpu', '', None]:
raise ValueError(
"The Attr(device) should be 'cpu' 'npu' or 'gpu', and it can also be empty string or None "
"The Attr(device) should be 'cpu' 'npu' 'xpu' or 'gpu', and it can also be empty string or None "
"when there is no need to specify device. But received %s" % device)
if index:
device = ":".join([device, index])
......
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import sys
sys.path.append("..")
from op_test import OpTest
import numpy as np
import paddle
import paddle.fluid as fluid
import paddle.fluid.core as core
import warnings
paddle.enable_static()
def execute(main_program, startup_program):
if paddle.is_compiled_with_xpu():
place = paddle.XPUPlace(0)
else:
place = paddle.CPUPlace()
exe = paddle.static.Executor(place)
exe.run(startup_program)
exe.run(main_program)
def get_vaild_warning_num(warning, w):
num = 0
for i in range(len(w)):
if warning in str(w[i].message):
num += 1
return num
class TestDeviceGuard(unittest.TestCase):
def test_device_guard(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data1 = paddle.full(shape=[1, 3, 8, 8],
fill_value=0.5,
dtype='float32')
data2 = paddle.full(shape=[1, 3, 5, 5],
fill_value=0.5,
dtype='float32')
shape = paddle.shape(data2)
with paddle.static.device_guard("cpu"):
shape = paddle.slice(shape, axes=[0], starts=[0], ends=[4])
with paddle.static.device_guard("xpu"):
out = fluid.layers.crop_tensor(data1, shape=shape)
# check if the device attr is set correctly
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
if op.type == 'slice':
self.assertEqual(op.desc.attr(device_attr_name), "cpu")
if op.type == 'crop_tensor':
self.assertEqual(op.desc.attr(device_attr_name), "xpu")
execute(main_program, startup_program)
def test_device_guard_with_id(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data1 = paddle.full(shape=[1, 3, 8, 8],
fill_value=0.5,
dtype='float32')
data2 = paddle.full(shape=[1, 3, 5, 5],
fill_value=0.5,
dtype='float32')
shape = paddle.shape(data2)
with paddle.static.device_guard("cpu"):
shape = paddle.slice(shape, axes=[0], starts=[0], ends=[4])
with paddle.static.device_guard("xpu:1"):
out = fluid.layers.crop_tensor(data1, shape=shape)
# check if the device attr is set correctly
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
if op.type == 'slice':
self.assertEqual(op.desc.attr(device_attr_name), "cpu")
if op.type == 'crop_tensor':
self.assertEqual(op.desc.attr(device_attr_name), "xpu:1")
execute(main_program, startup_program)
def test_cpu_only_op(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
x = paddle.full(shape=[2, 255, 13, 13],
fill_value=0.3,
dtype='float32')
gt_box = paddle.full(shape=[2, 6, 4],
fill_value=0.5,
dtype='float32')
gt_label = paddle.full(shape=[2, 6], fill_value=1.0, dtype='int32')
gt_score = paddle.full(shape=[2, 6],
fill_value=0.5,
dtype='float32')
anchors = [
10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156,
198, 373, 326
]
anchor_mask = [0, 1, 2]
with paddle.static.device_guard("xpu"):
# yolov3_loss only has cpu kernel, so its cpu kernel will be executed
loss = fluid.layers.yolov3_loss(x=x,
gt_box=gt_box,
gt_label=gt_label,
gt_score=gt_score,
anchors=anchors,
anchor_mask=anchor_mask,
class_num=80,
ignore_thresh=0.7,
downsample_ratio=32)
execute(main_program, startup_program)
def test_without_kernel_op(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
i = paddle.full(shape=[1], dtype='int64', fill_value=0)
loop_len = paddle.full(shape=[1], dtype='int64', fill_value=10)
cond = paddle.less_than(x=i, y=loop_len)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
with paddle.static.device_guard("cpu"):
while_op = fluid.layers.While(cond=cond)
with while_op.block():
i = paddle.increment(x=i, value=1)
fluid.layers.less_than(x=i, y=loop_len, cond=cond)
warning = "The Op(while) is not support to set device."
warning_num = get_vaild_warning_num(warning, w)
assert warning_num == 1
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
if op.type == 'while':
self.assertEqual(op.desc.attr(device_attr_name), "")
execute(main_program, startup_program)
def test_error(self):
def device_attr():
with paddle.static.device_guard("cpu1"):
out = paddle.full(shape=[1], fill_value=0.2, dtype='float32')
def device_attr2():
with paddle.static.device_guard("cpu:1"):
out = paddle.full(shape=[1], fill_value=0.2, dtype='float32')
self.assertRaises(ValueError, device_attr)
self.assertRaises(ValueError, device_attr2)
# check if op_descs have op_device attr
def test_op_descs_device_attr(self):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
with paddle.static.program_guard(main_program, startup_program):
data1 = paddle.static.data(name="data_1",
shape=[4, 2],
dtype="float32")
label = paddle.static.data(name="label",
shape=[4, 1],
dtype="int64")
fc1 = paddle.static.nn.fc(x=data1, size=10)
fc2 = paddle.static.nn.fc(x=fc1, size=10)
with paddle.static.device_guard("xpu"):
out = paddle.nn.functional.softmax_with_cross_entropy(
logits=fc1 + fc2, label=label)
loss = paddle.mean(out)
opt = paddle.optimizer.SGD(0.1)
opt.minimize(loss)
all_ops = main_program.global_block().ops
device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName()
for op in all_ops:
self.assertEqual(True, op.desc.has_attr(device_attr_name))
# fill_constant(backward op) is append to mean op, which should have
# the same op_device value as mean op
if op.desc == 'fill_constant':
self.assertEqual(op.desc.attr(device_attr_name), "xpu")
if __name__ == '__main__':
unittest.main()
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import unittest
import paddle
import paddle.fluid.layers as layers
from paddle.fluid.executor import Executor
import paddle.fluid.core as core
import paddle.fluid as fluid
from paddle.fluid.backward import append_backward
import numpy
from paddle.fluid import compiler, Program, program_guard
paddle.enable_static()
class TestWhileOp(unittest.TestCase):
def simple_net(self):
d0 = layers.data("d0",
shape=[10],
append_batch_size=False,
dtype='float32')
d1 = layers.data("d1",
shape=[10],
append_batch_size=False,
dtype='float32')
d2 = layers.data("d2",
shape=[10],
append_batch_size=False,
dtype='float32')
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
init = layers.zeros(shape=[10], dtype='float32')
mem_array = layers.array_write(x=init, i=i)
data_array = layers.array_write(x=d0, i=i)
i = layers.increment(i)
layers.array_write(d1, i, array=data_array)
i = layers.increment(i)
layers.array_write(d2, i, array=data_array)
i = layers.zeros(shape=[1], dtype='int64')
i.stop_gradient = True
array_len = layers.fill_constant(shape=[1], dtype='int64', value=1)
array_len.stop_gradient = True
cond = layers.less_than(x=i, y=array_len)
j = layers.fill_constant(shape=[1], dtype='int64', value=1)
j.stop_gradient = True
array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3)
array_len2.stop_gradient = True
cond2 = layers.less_than(x=j, y=array_len2)
while_op = layers.While(cond=cond)
while_op2 = layers.While(cond=cond2)
with while_op.block():
d = layers.array_read(array=data_array, i=i)
prev = layers.array_read(array=mem_array, i=i)
result = layers.sums(input=[d, prev])
i = layers.increment(x=i, in_place=True)
layers.array_write(result, i=i, array=mem_array)
layers.less_than(x=i, y=array_len, cond=cond)
with while_op2.block():
d2 = layers.array_read(array=data_array, i=j)
prev2 = layers.array_read(array=mem_array, i=j)
result2 = layers.sums(input=[d2, prev2])
j = layers.increment(x=j, in_place=True)
layers.array_write(result2, i=j, array=mem_array)
layers.less_than(x=j, y=array_len2, cond=cond2)
sum_result = layers.array_read(array=mem_array, i=j)
loss = paddle.mean(sum_result)
return loss, sum_result
def test_simple_net(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
loss, sum_result = self.simple_net()
append_backward(loss)
xpu_place = paddle.XPUPlace(0)
exe = Executor(xpu_place)
d = []
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))
outs = exe.run(feed={
'd0': d[0],
'd1': d[1],
'd2': d[2]
},
fetch_list=[sum_result])
self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01)
def test_simple_net_forward(self):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
self.simple_net()
binary = fluid.compiler.CompiledProgram(main_program)
xpu_place = paddle.XPUPlace(0)
exe = Executor(xpu_place)
d = []
for i in range(3):
d.append(numpy.random.random(size=[10]).astype('float32'))
for _ in range(2):
exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]})
def test_exceptions(self):
i = layers.zeros(shape=[2], dtype='int64')
array_len = layers.fill_constant(shape=[2], dtype='int64', value=1)
cond = layers.less_than(x=i, y=array_len)
with self.assertRaises(TypeError):
layers.While(cond=cond)
cond = layers.cast(cond, dtype='float64')
with self.assertRaises(TypeError):
layers.While(cond=cond)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册