From 8753a2bf6dd6df80ef08ef59fa42e3341b35f9d5 Mon Sep 17 00:00:00 2001 From: houj04 <35131887+houj04@users.noreply.github.com> Date: Wed, 20 Jul 2022 11:47:11 +0800 Subject: [PATCH] [XPU][NPU] (1) add device_guard. (2) add support for LoDTensorArray of sum op. (#44367) * device_guard support xpu. test=kunlun * sum op of xpu support LoDTensorArray. add test for while op of xpu. test=kunlun. --- .../new_executor/interpretercore_util.cc | 40 +++- paddle/fluid/framework/operator.cc | 88 +++++++- paddle/fluid/framework/operator.h | 4 + .../operators/controlflow/while_op_helper.cc | 11 +- paddle/fluid/operators/sum_op_xpu.cc | 116 +++++++--- python/paddle/fluid/framework.py | 4 +- .../unittests/xpu/test_device_guard_xpu.py | 209 ++++++++++++++++++ .../tests/unittests/xpu/test_while_op_xpu.py | 139 ++++++++++++ 8 files changed, 565 insertions(+), 46 deletions(-) create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py create mode 100644 python/paddle/fluid/tests/unittests/xpu/test_while_op_xpu.py diff --git a/paddle/fluid/framework/new_executor/interpretercore_util.cc b/paddle/fluid/framework/new_executor/interpretercore_util.cc index 35f02189ca9..b58a74a659c 100644 --- a/paddle/fluid/framework/new_executor/interpretercore_util.cc +++ b/paddle/fluid/framework/new_executor/interpretercore_util.cc @@ -330,14 +330,12 @@ void apply_device_guard(const OperatorBase* op_base, VLOG(3) << "Switch into CPUPlace by device_guard."; expected_kernel_key->place_ = platform::CPUPlace(); } else if (op_device.find("gpu") != std::string::npos && - (platform::is_gpu_place(place) || - platform::is_npu_place(place))) { - // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel - // will be executed and a warning will be given at the same time. + platform::is_gpu_place(place)) { + // when the Op that does not have GPUKernel is assigned to GPU, the + // CPUKernel will be executed and a warning will be given at the same + // time. if (op_base->SupportGPU()) { expected_kernel_key->place_ = place; - } else if (op_base->SupportNPU()) { - expected_kernel_key->place_ = place; } else { expected_kernel_key->place_ = platform::CPUPlace(); LOG_FIRST_N(WARNING, 1) @@ -346,6 +344,36 @@ void apply_device_guard(const OperatorBase* op_base, } VLOG(3) << "Switch into " << expected_kernel_key->place_ << " by device_guard."; + } else if (op_device.find("npu") != std::string::npos && + platform::is_npu_place(place)) { + // when the Op that does not have NPUKernel is assigned to NPU, the + // CPUKernel will be executed and a warning will be given at the same + // time. + if (op_base->SupportNPU()) { + expected_kernel_key->place_ = place; + } else { + expected_kernel_key->place_ = platform::CPUPlace(); + LOG_FIRST_N(WARNING, 1) + << "Op(" << op_base->Type() + << ") has no NPU implementation. It will be assigned to CPUPlace."; + } + VLOG(3) << "Switch into " << expected_kernel_key->place_ + << " by device_guard."; + } else if (op_device.find("xpu") != std::string::npos && + platform::is_xpu_place(place)) { + // when the Op that does not have XPUKernel is assigned to XPU, the + // CPUKernel will be executed and a warning will be given at the same + // time. + if (op_base->SupportXPU()) { + expected_kernel_key->place_ = place; + } else { + expected_kernel_key->place_ = platform::CPUPlace(); + LOG_FIRST_N(WARNING, 1) + << "Op(" << op_base->Type() + << ") has no XPU implementation. It will be assigned to CPUPlace."; + } + VLOG(3) << "Switch into " << expected_kernel_key->place_ + << " by device_guard."; } else { PADDLE_THROW( platform::errors::Fatal("Unsupported current place %s", op_device)); diff --git a/paddle/fluid/framework/operator.cc b/paddle/fluid/framework/operator.cc index 0632af59bfd..fb8dbe603c3 100644 --- a/paddle/fluid/framework/operator.cc +++ b/paddle/fluid/framework/operator.cc @@ -1274,6 +1274,43 @@ bool OperatorWithKernel::SupportNPU() const { } } +bool OperatorWithKernel::SupportXPU() const { +#ifdef PADDLE_WITH_XPU + auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( + phi::TransToPhiKernelName(type_)); + auto has_phi_kernel = + std::any_of(phi_kernels.begin(), + phi_kernels.end(), + [](phi::KernelKeyMap::const_reference kern_pair) { + return kern_pair.first.backend() == phi::Backend::XPU; + }); + if (has_phi_kernel) { + return true; + } else { + auto kernel_iter = OperatorWithKernel::AllOpKernels().find(type_); + if (kernel_iter == OperatorWithKernel::AllOpKernels().end()) { + return false; + } else { + auto& op_kernels = kernel_iter->second; + return std::any_of( + op_kernels.begin(), + op_kernels.end(), + [this](OpKernelMap::const_reference kern_pair) { + return platform::is_xpu_place(kern_pair.first.place_) && + paddle::platform::is_xpu_support_op(type_, + kern_pair.first) && + !paddle::platform::is_in_xpu_black_list(type_); + }); + } + } +#else + PADDLE_THROW(platform::errors::PreconditionNotMet( + "should not call OperatorWithKernel::SupportXPU() when not compiled with " + "XPU support.")); + return false; +#endif +} + bool OperatorWithKernel::SupportsMKLDNN( const proto::VarType::Type data_type) const { auto phi_kernels = phi::KernelFactory::Instance().SelectKernelMap( @@ -1733,8 +1770,9 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( << "Device index is only supported under pipeline parallelism, " << "so it will be ignored."; } - // when the Op that only has CPUKernel is assigned to GPU, the CPUKernel - // will be executed and a warning will be given at the same time. + // when the Op that does not have GPUKernel is assigned to GPU, the + // CPUKernel will be executed and a warning will be given at the same + // time. expected_kernel_key.place_ = platform::CPUPlace(); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) if (SupportGPU()) { @@ -1742,6 +1780,25 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( expected_kernel_key.place_ = dev_ctx.GetPlace(); } #endif + if (platform::is_cpu_place(expected_kernel_key.place_)) { + LOG_FIRST_N(WARNING, 1) + << "Op(" << type_ + << ") has no CUDA implementation. It will be assigned to CPUPlace."; + } + } else if (Attr("op_device").find("npu") != + std::string::npos) { + auto device = Attr("op_device"); + size_t pos = device.find(':'); + if (pos != std::string::npos) { + device = device.substr(0, pos); + LOG_FIRST_N(WARNING, 1) + << "Device index is only supported under pipeline parallelism, " + << "so it will be ignored."; + } + // when the Op that does not have NPUKernel is assigned to NPU, the + // CPUKernel will be executed and a warning will be given at the same + // time. + expected_kernel_key.place_ = platform::CPUPlace(); #ifdef PADDLE_WITH_ASCEND_CL if (SupportNPU()) { auto& dev_ctx = ctx.device_context(); @@ -1751,7 +1808,32 @@ OpKernelType OperatorWithKernel::InnerGetExpectedKernelType( if (platform::is_cpu_place(expected_kernel_key.place_)) { LOG_FIRST_N(WARNING, 1) << "Op(" << type_ - << ") has no CUDA implementation. It will be assigned to CPUPlace."; + << ") has no NPU implementation. It will be assigned to CPUPlace."; + } + } else if (Attr("op_device").find("xpu") != + std::string::npos) { + auto device = Attr("op_device"); + size_t pos = device.find(':'); + if (pos != std::string::npos) { + device = device.substr(0, pos); + LOG_FIRST_N(WARNING, 1) + << "Device index is only supported under pipeline parallelism, " + << "so it will be ignored."; + } + // when the Op that does not have XPUKernel is assigned to XPU, the + // CPUKernel will be executed and a warning will be given at the same + // time. + expected_kernel_key.place_ = platform::CPUPlace(); +#ifdef PADDLE_WITH_XPU + if (SupportXPU()) { + auto& dev_ctx = ctx.device_context(); + expected_kernel_key.place_ = dev_ctx.GetPlace(); + } +#endif + if (platform::is_cpu_place(expected_kernel_key.place_)) { + LOG_FIRST_N(WARNING, 1) + << "Op(" << type_ + << ") has no XPU implementation. It will be assigned to CPUPlace."; } } } diff --git a/paddle/fluid/framework/operator.h b/paddle/fluid/framework/operator.h index be3259c7b1d..2568a459f31 100644 --- a/paddle/fluid/framework/operator.h +++ b/paddle/fluid/framework/operator.h @@ -173,6 +173,7 @@ class OperatorBase { virtual bool SupportGPU() const { return false; } virtual bool SupportNPU() const { return false; } virtual bool SupportMLU() const { return false; } + virtual bool SupportXPU() const { return false; } const std::string& Type() const { return type_; } @@ -596,6 +597,9 @@ class OperatorWithKernel : public OperatorBase { return platform::is_mlu_place(kern_pair.first.place_); }); } + + bool SupportXPU() const override; + bool SupportsMKLDNN(proto::VarType::Type data_type) const; bool SupportsKernelType(const OpKernelType& kernel_type) const; diff --git a/paddle/fluid/operators/controlflow/while_op_helper.cc b/paddle/fluid/operators/controlflow/while_op_helper.cc index a671273eae4..a4c8c23438b 100644 --- a/paddle/fluid/operators/controlflow/while_op_helper.cc +++ b/paddle/fluid/operators/controlflow/while_op_helper.cc @@ -225,16 +225,17 @@ bool GetCondData(const framework::LoDTensor &cond) { return cond.data()[0]; } // when platform::is_gpu_place(cond.place()) or - // platform::is_npu_place(cond.place()) is true + // platform::is_npu_place(cond.place()) or + // platform::is_xpu_place(cond.place()) is true std::unique_ptr cpu_cond{new framework::LoDTensor()}; #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP) || \ - defined(PADDLE_WITH_ASCEND_CL) + defined(PADDLE_WITH_ASCEND_CL) || defined(PADDLE_WITH_XPU) framework::TensorCopySync(cond, platform::CPUPlace(), cpu_cond.get()); #else PADDLE_THROW(platform::errors::PreconditionNotMet( - "This version of PaddlePaddle does NOT support GPU/NPU but got GPU/NPU " - "tensor " - "Cond in WhileOp. Please compile WITH_GPU or WITH_ASCEND_CL option.")); + "This version of PaddlePaddle does NOT support GPU/NPU/XPU but got " + "GPU/NPU/XPU tensor Cond in WhileOp. Please compile WITH_GPU or " + "WITH_ASCEND_CL or WITH_XPU option.")); #endif return cpu_cond->data()[0]; } diff --git a/paddle/fluid/operators/sum_op_xpu.cc b/paddle/fluid/operators/sum_op_xpu.cc index b3df08095fa..b73677b59ce 100644 --- a/paddle/fluid/operators/sum_op_xpu.cc +++ b/paddle/fluid/operators/sum_op_xpu.cc @@ -14,6 +14,7 @@ limitations under the License. */ #include #include "paddle/fluid/operators/sum_op.h" +#include "paddle/fluid/platform/device/device_wrapper.h" #include "paddle/fluid/platform/device/xpu/xpu_header.h" namespace paddle { @@ -28,38 +29,93 @@ class SumXPUKernel : public framework::OpKernel { void Compute(const framework::ExecutionContext &context) const override { auto in_vars = context.MultiInputVar("X"); auto out_var = context.OutputVar("Out"); - auto *out = context.Output("Out"); - bool in_place = out_var == in_vars[0]; - int N = in_vars.size(); - PADDLE_ENFORCE_EQ( - out_var->IsType(), - true, - platform::errors::InvalidArgument("XPU only surpport LodTensor")); - if (!in_place) { - out->mutable_data(context.GetPlace()); - } - auto &dev_ctx = context.template device_context(); - std::vector ptrs; - for (int i = 0; i < N; ++i) { - PADDLE_ENFORCE_EQ( - in_vars[i]->IsType(), - true, - platform::errors::InvalidArgument("XPU only surpport LodTensor")); - auto &in_t = in_vars[i]->Get(); - if (in_t.numel() == 0) { - continue; + + if (out_var->IsType()) { + auto *out = context.Output("Out"); + bool in_place = out_var == in_vars[0]; + int N = in_vars.size(); + + if (!in_place) { + out->mutable_data(context.GetPlace()); + } + auto &dev_ctx = context.template device_context(); + std::vector ptrs; + for (int i = 0; i < N; ++i) { + PADDLE_ENFORCE_EQ( + in_vars[i]->IsType(), + true, + platform::errors::InvalidArgument("XPU only support LodTensor")); + auto &in_t = in_vars[i]->Get(); + if (in_t.numel() == 0) { + continue; + } + ptrs.push_back(reinterpret_cast(in_t.data())); + } + int r = xpu::sum(dev_ctx.x_context(), + ptrs, + reinterpret_cast(out->data()), + out->numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum"); + } else if (out_var->IsType()) { + bool in_place = out_var == in_vars[0]; + auto &out_array = *out_var->GetMutable(); + + for (size_t i = in_place ? 1 : 0; i < in_vars.size(); ++i) { + PADDLE_ENFORCE_EQ(in_vars[i]->IsType(), + true, + platform::errors::InvalidArgument( + "Only support all inputs are TensorArray, " + "but inputs[%d] is not TensorArray.", + i)); + auto &in_array = in_vars[i]->Get(); + + for (size_t i = 0; i < in_array.size(); ++i) { + if (in_array[i].IsInitialized() && (in_array[i].numel() != 0)) { + if (i >= out_array.size()) { + out_array.resize(i + 1); + } + if (!out_array[i].IsInitialized() || (out_array[i].numel() == 0)) { + framework::TensorCopy(in_array[i], + in_array[i].place(), + context.device_context(), + &out_array[i]); + out_array[i].set_lod(in_array[i].lod()); + } else { + PADDLE_ENFORCE_EQ( + out_array[i].lod(), + in_array[i].lod(), + platform::errors::InvalidArgument( + "The lod message between inputs[%d] and" + " outputs[%d] must be same, but now is not same.", + i, + i)); + + std::vector ptrs; + ptrs.push_back( + reinterpret_cast(in_array[i].data())); + ptrs.push_back( + reinterpret_cast(out_array[i].data())); + + auto &dev_ctx = context.template device_context(); + // int sum(Context* ctx, const std::vector& x_list, T* + // y, int len); + int r = + xpu::sum(dev_ctx.x_context(), + ptrs, + reinterpret_cast(out_array[i].data()), + out_array[i].numel()); + PADDLE_ENFORCE_XDNN_SUCCESS(r, "sum"); + } + } + } } - ptrs.push_back(reinterpret_cast(in_t.data())); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "Expected type of Output(out) must be Tensor or " + "LoDTensorArray. But got " + "unsupport type: %s.", + framework::ToTypeName(out_var->Type()))); } - int r = xpu::sum(dev_ctx.x_context(), - ptrs, - reinterpret_cast(out->data()), - out->numel()); - PADDLE_ENFORCE_EQ( - r, - XPU_SUCCESS, - platform::errors::External( - "XPU sum kernel return wrong value[%d %s]", r, XPUAPIErrorMsg[r])); } }; diff --git a/python/paddle/fluid/framework.py b/python/paddle/fluid/framework.py index d6e4af58669..4ce4801d32b 100644 --- a/python/paddle/fluid/framework.py +++ b/python/paddle/fluid/framework.py @@ -7078,9 +7078,9 @@ def device_guard(device=None): device, index = device.split(':') if device == 'cpu': raise ValueError("Should not set device id for cpu.") - if device not in ['cpu', 'gpu', 'npu', '', None]: + if device not in ['cpu', 'gpu', 'npu', 'xpu', '', None]: raise ValueError( - "The Attr(device) should be 'cpu' 'npu' or 'gpu', and it can also be empty string or None " + "The Attr(device) should be 'cpu' 'npu' 'xpu' or 'gpu', and it can also be empty string or None " "when there is no need to specify device. But received %s" % device) if index: device = ":".join([device, index]) diff --git a/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py new file mode 100644 index 00000000000..06880f74817 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_device_guard_xpu.py @@ -0,0 +1,209 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import sys + +sys.path.append("..") +from op_test import OpTest + +import numpy as np +import paddle +import paddle.fluid as fluid +import paddle.fluid.core as core +import warnings + +paddle.enable_static() + + +def execute(main_program, startup_program): + if paddle.is_compiled_with_xpu(): + place = paddle.XPUPlace(0) + else: + place = paddle.CPUPlace() + exe = paddle.static.Executor(place) + exe.run(startup_program) + exe.run(main_program) + + +def get_vaild_warning_num(warning, w): + num = 0 + for i in range(len(w)): + if warning in str(w[i].message): + num += 1 + return num + + +class TestDeviceGuard(unittest.TestCase): + + def test_device_guard(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + data1 = paddle.full(shape=[1, 3, 8, 8], + fill_value=0.5, + dtype='float32') + data2 = paddle.full(shape=[1, 3, 5, 5], + fill_value=0.5, + dtype='float32') + shape = paddle.shape(data2) + with paddle.static.device_guard("cpu"): + shape = paddle.slice(shape, axes=[0], starts=[0], ends=[4]) + with paddle.static.device_guard("xpu"): + out = fluid.layers.crop_tensor(data1, shape=shape) + # check if the device attr is set correctly + all_ops = main_program.global_block().ops + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + for op in all_ops: + if op.type == 'slice': + self.assertEqual(op.desc.attr(device_attr_name), "cpu") + if op.type == 'crop_tensor': + self.assertEqual(op.desc.attr(device_attr_name), "xpu") + + execute(main_program, startup_program) + + def test_device_guard_with_id(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + data1 = paddle.full(shape=[1, 3, 8, 8], + fill_value=0.5, + dtype='float32') + data2 = paddle.full(shape=[1, 3, 5, 5], + fill_value=0.5, + dtype='float32') + shape = paddle.shape(data2) + with paddle.static.device_guard("cpu"): + shape = paddle.slice(shape, axes=[0], starts=[0], ends=[4]) + with paddle.static.device_guard("xpu:1"): + out = fluid.layers.crop_tensor(data1, shape=shape) + # check if the device attr is set correctly + all_ops = main_program.global_block().ops + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + for op in all_ops: + if op.type == 'slice': + self.assertEqual(op.desc.attr(device_attr_name), "cpu") + if op.type == 'crop_tensor': + self.assertEqual(op.desc.attr(device_attr_name), "xpu:1") + + execute(main_program, startup_program) + + def test_cpu_only_op(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + x = paddle.full(shape=[2, 255, 13, 13], + fill_value=0.3, + dtype='float32') + gt_box = paddle.full(shape=[2, 6, 4], + fill_value=0.5, + dtype='float32') + gt_label = paddle.full(shape=[2, 6], fill_value=1.0, dtype='int32') + gt_score = paddle.full(shape=[2, 6], + fill_value=0.5, + dtype='float32') + anchors = [ + 10, 13, 16, 30, 33, 23, 30, 61, 62, 45, 59, 119, 116, 90, 156, + 198, 373, 326 + ] + anchor_mask = [0, 1, 2] + with paddle.static.device_guard("xpu"): + # yolov3_loss only has cpu kernel, so its cpu kernel will be executed + loss = fluid.layers.yolov3_loss(x=x, + gt_box=gt_box, + gt_label=gt_label, + gt_score=gt_score, + anchors=anchors, + anchor_mask=anchor_mask, + class_num=80, + ignore_thresh=0.7, + downsample_ratio=32) + + execute(main_program, startup_program) + + def test_without_kernel_op(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + i = paddle.full(shape=[1], dtype='int64', fill_value=0) + loop_len = paddle.full(shape=[1], dtype='int64', fill_value=10) + cond = paddle.less_than(x=i, y=loop_len) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + with paddle.static.device_guard("cpu"): + while_op = fluid.layers.While(cond=cond) + with while_op.block(): + i = paddle.increment(x=i, value=1) + fluid.layers.less_than(x=i, y=loop_len, cond=cond) + + warning = "The Op(while) is not support to set device." + warning_num = get_vaild_warning_num(warning, w) + assert warning_num == 1 + + all_ops = main_program.global_block().ops + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + for op in all_ops: + if op.type == 'while': + self.assertEqual(op.desc.attr(device_attr_name), "") + + execute(main_program, startup_program) + + def test_error(self): + + def device_attr(): + with paddle.static.device_guard("cpu1"): + out = paddle.full(shape=[1], fill_value=0.2, dtype='float32') + + def device_attr2(): + with paddle.static.device_guard("cpu:1"): + out = paddle.full(shape=[1], fill_value=0.2, dtype='float32') + + self.assertRaises(ValueError, device_attr) + self.assertRaises(ValueError, device_attr2) + + # check if op_descs have op_device attr + def test_op_descs_device_attr(self): + main_program = paddle.static.Program() + startup_program = paddle.static.Program() + with paddle.static.program_guard(main_program, startup_program): + data1 = paddle.static.data(name="data_1", + shape=[4, 2], + dtype="float32") + label = paddle.static.data(name="label", + shape=[4, 1], + dtype="int64") + fc1 = paddle.static.nn.fc(x=data1, size=10) + fc2 = paddle.static.nn.fc(x=fc1, size=10) + with paddle.static.device_guard("xpu"): + out = paddle.nn.functional.softmax_with_cross_entropy( + logits=fc1 + fc2, label=label) + loss = paddle.mean(out) + opt = paddle.optimizer.SGD(0.1) + opt.minimize(loss) + + all_ops = main_program.global_block().ops + device_attr_name = core.op_proto_and_checker_maker.kOpDeviceAttrName() + for op in all_ops: + self.assertEqual(True, op.desc.has_attr(device_attr_name)) + # fill_constant(backward op) is append to mean op, which should have + # the same op_device value as mean op + if op.desc == 'fill_constant': + self.assertEqual(op.desc.attr(device_attr_name), "xpu") + + +if __name__ == '__main__': + unittest.main() diff --git a/python/paddle/fluid/tests/unittests/xpu/test_while_op_xpu.py b/python/paddle/fluid/tests/unittests/xpu/test_while_op_xpu.py new file mode 100644 index 00000000000..3265e11a574 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/xpu/test_while_op_xpu.py @@ -0,0 +1,139 @@ +# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import unittest +import paddle +import paddle.fluid.layers as layers +from paddle.fluid.executor import Executor +import paddle.fluid.core as core +import paddle.fluid as fluid +from paddle.fluid.backward import append_backward +import numpy +from paddle.fluid import compiler, Program, program_guard + +paddle.enable_static() + + +class TestWhileOp(unittest.TestCase): + + def simple_net(self): + d0 = layers.data("d0", + shape=[10], + append_batch_size=False, + dtype='float32') + d1 = layers.data("d1", + shape=[10], + append_batch_size=False, + dtype='float32') + d2 = layers.data("d2", + shape=[10], + append_batch_size=False, + dtype='float32') + i = layers.zeros(shape=[1], dtype='int64') + i.stop_gradient = True + init = layers.zeros(shape=[10], dtype='float32') + mem_array = layers.array_write(x=init, i=i) + data_array = layers.array_write(x=d0, i=i) + i = layers.increment(i) + layers.array_write(d1, i, array=data_array) + i = layers.increment(i) + layers.array_write(d2, i, array=data_array) + i = layers.zeros(shape=[1], dtype='int64') + i.stop_gradient = True + array_len = layers.fill_constant(shape=[1], dtype='int64', value=1) + array_len.stop_gradient = True + cond = layers.less_than(x=i, y=array_len) + j = layers.fill_constant(shape=[1], dtype='int64', value=1) + j.stop_gradient = True + array_len2 = layers.fill_constant(shape=[1], dtype='int64', value=3) + array_len2.stop_gradient = True + cond2 = layers.less_than(x=j, y=array_len2) + while_op = layers.While(cond=cond) + while_op2 = layers.While(cond=cond2) + with while_op.block(): + d = layers.array_read(array=data_array, i=i) + prev = layers.array_read(array=mem_array, i=i) + result = layers.sums(input=[d, prev]) + + i = layers.increment(x=i, in_place=True) + layers.array_write(result, i=i, array=mem_array) + layers.less_than(x=i, y=array_len, cond=cond) + + with while_op2.block(): + d2 = layers.array_read(array=data_array, i=j) + prev2 = layers.array_read(array=mem_array, i=j) + result2 = layers.sums(input=[d2, prev2]) + + j = layers.increment(x=j, in_place=True) + layers.array_write(result2, i=j, array=mem_array) + layers.less_than(x=j, y=array_len2, cond=cond2) + sum_result = layers.array_read(array=mem_array, i=j) + loss = paddle.mean(sum_result) + return loss, sum_result + + def test_simple_net(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + loss, sum_result = self.simple_net() + + append_backward(loss) + + xpu_place = paddle.XPUPlace(0) + exe = Executor(xpu_place) + d = [] + + for i in range(3): + d.append(numpy.random.random(size=[10]).astype('float32')) + + outs = exe.run(feed={ + 'd0': d[0], + 'd1': d[1], + 'd2': d[2] + }, + fetch_list=[sum_result]) + self.assertAlmostEqual(numpy.sum(d), numpy.sum(outs[0]), delta=0.01) + + def test_simple_net_forward(self): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + self.simple_net() + binary = fluid.compiler.CompiledProgram(main_program) + + xpu_place = paddle.XPUPlace(0) + exe = Executor(xpu_place) + d = [] + + for i in range(3): + d.append(numpy.random.random(size=[10]).astype('float32')) + + for _ in range(2): + exe.run(binary, feed={'d0': d[0], 'd1': d[1], 'd2': d[2]}) + + def test_exceptions(self): + i = layers.zeros(shape=[2], dtype='int64') + array_len = layers.fill_constant(shape=[2], dtype='int64', value=1) + cond = layers.less_than(x=i, y=array_len) + with self.assertRaises(TypeError): + layers.While(cond=cond) + cond = layers.cast(cond, dtype='float64') + with self.assertRaises(TypeError): + layers.While(cond=cond) + + +if __name__ == '__main__': + unittest.main() -- GitLab