From 8a1a2af82e0b2750bc332e0959492914ab16fbee Mon Sep 17 00:00:00 2001 From: Huihuang Zheng Date: Fri, 8 May 2020 14:39:01 +0800 Subject: [PATCH] Add Assert Op (#24280) 1. To make ProgramTranslator to support `assert` grammar, this PR adds `assert` python API and C++ code. 2. Fix a bug: graph_pattern_detector.h #include but didn't declared dependency at CMakeLists, which can cause single build failure. 3. Refactoring `Formatter` in print_op to make it reusable and reuse the formatter to print in assert op. --- paddle/fluid/framework/ir/CMakeLists.txt | 8 +- paddle/fluid/operators/CMakeLists.txt | 2 + paddle/fluid/operators/assert_op.cc | 108 ++++++++++++ paddle/fluid/operators/print_op.cc | 158 ++---------------- paddle/fluid/operators/tensor_formatter.cc | 154 +++++++++++++++++ paddle/fluid/operators/tensor_formatter.h | 55 ++++++ python/paddle/fluid/layers/control_flow.py | 76 ++++++++- .../fluid/tests/unittests/test_assert_op.py | 90 ++++++++++ 8 files changed, 500 insertions(+), 151 deletions(-) create mode 100644 paddle/fluid/operators/assert_op.cc create mode 100644 paddle/fluid/operators/tensor_formatter.cc create mode 100644 paddle/fluid/operators/tensor_formatter.h create mode 100644 python/paddle/fluid/tests/unittests/test_assert_op.py diff --git a/paddle/fluid/framework/ir/CMakeLists.txt b/paddle/fluid/framework/ir/CMakeLists.txt index 7c49bc1dcd0..5e6da1b3349 100644 --- a/paddle/fluid/framework/ir/CMakeLists.txt +++ b/paddle/fluid/framework/ir/CMakeLists.txt @@ -40,7 +40,13 @@ cc_library(graph SRCS graph.cc DEPS node pretty_log) cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(graph_traits SRCS graph_traits.cc DEPS graph) -cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits) + +SET(GRAPH_PATTERN_DETECTOR_DEPS graph graph_helper graph_traits) +if (WITH_TESTING) + SET(GRAPH_PATTERN_DETECTOR_DEPS ${GRAPH_PATTERN_DETECTOR_DEPS} gtest) +endif(WITH_TESTING) +cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS}) + cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) diff --git a/paddle/fluid/operators/CMakeLists.txt b/paddle/fluid/operators/CMakeLists.txt index 8ea27b86383..4e0f0af397c 100644 --- a/paddle/fluid/operators/CMakeLists.txt +++ b/paddle/fluid/operators/CMakeLists.txt @@ -92,6 +92,7 @@ if (WITH_GPU) endif() set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) +set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter) # FIXME(typhoonzero): operator deps may not needed. # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) @@ -119,6 +120,7 @@ else() cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) endif() +cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS}) if (WITH_PYTHON) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) endif() diff --git a/paddle/fluid/operators/assert_op.cc b/paddle/fluid/operators/assert_op.cc new file mode 100644 index 00000000000..da0e5fda636 --- /dev/null +++ b/paddle/fluid/operators/assert_op.cc @@ -0,0 +1,108 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/framework/var_type.h" +#include "paddle/fluid/operators/controlflow/while_op_helper.h" +#include "paddle/fluid/operators/tensor_formatter.h" + +const char kCond[] = "Cond"; +const char kData[] = "Data"; +const char kSummarize[] = "summarize"; + +namespace paddle { +namespace operators { + +using framework::LoDTensor; + +class AssertOp : public framework::OperatorBase { + public: + AssertOp(const std::string &type, const framework::VariableNameMap &inputs, + const framework::VariableNameMap &outputs, + const framework::AttributeMap &attrs) + : OperatorBase(type, inputs, outputs, attrs) {} + + private: + void RunImpl(const framework::Scope &scope, + const platform::Place &dev_place) const override { + const framework::Variable *cond_var_ptr = scope.FindVar(Input(kCond)); + PADDLE_ENFORCE_NOT_NULL(cond_var_ptr, + platform::errors::NotFound( + "Input(Condition) of AssertOp is not found.")); + const LoDTensor &cond = cond_var_ptr->Get(); + PADDLE_ENFORCE_EQ( + cond.dims(), paddle::framework::make_ddim({1}), + platform::errors::InvalidArgument( + "The numel of Input(Condition) of AssertOp must be 1. But now " + "the Condition's shape is %s.", + cond.dims().to_str())); + + bool cond_data = GetCondData(cond); + if (cond_data) { + return; + } + + TensorFormatter formatter; + formatter.SetSummarize(Attr(kSummarize)); + + const std::vector &x_names = Inputs(kData); + for (const std::string &name : x_names) { + const framework::Variable *x_var_ptr = scope.FindVar(name); + const framework::LoDTensor &x_tensor = x_var_ptr->Get(); + formatter.Print(x_tensor, name); + } + + PADDLE_THROW(platform::errors::InvalidArgument( + "The condition variable '%s' of AssertOp must be " + "true, but received false", + Input(kCond))); + } +}; + +class AssertOpProtoMaker : public framework::OpProtoAndCheckerMaker { + public: + void Make() override { + AddInput( + kCond, + "The boolean scalar condition tensor which is asserted to be true."); + AddInput(kData, + "The tensors to print when the assert condition is not true.") + .AsDuplicable(); + AddAttr( + kSummarize, + "The number of entries of each tensor to print when the " + "assert condition is not true. -1 means print all entries. If " + "the number of entries of a tensor is less then " + "summarize_num, this OP will print all entries of the tensor.") + .SetDefault(-1); + AddComment( + R"DOC(Assert the input Condition Tensor is true and print Tensors if the Condition Tensor is false.)DOC"); + } +}; + +class AssertOpInferShape : public framework::InferShapeBase { + public: + void operator()(framework::InferShapeContext *context) const override { + OP_INOUT_CHECK(context->HasInputs(kCond), "Input", "Condition", "AssertOp"); + } +}; + +} // namespace operators +} // namespace paddle + +namespace ops = paddle::operators; +REGISTER_OPERATOR( + assert, ops::AssertOp, ops::AssertOpProtoMaker, ops::AssertOpInferShape, + paddle::framework::EmptyGradOpMaker, + paddle::framework::EmptyGradOpMaker); diff --git a/paddle/fluid/operators/print_op.cc b/paddle/fluid/operators/print_op.cc index 238d8218a27..2afd16110d5 100644 --- a/paddle/fluid/operators/print_op.cc +++ b/paddle/fluid/operators/print_op.cc @@ -17,6 +17,7 @@ #include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/operators/assign_op.h" +#include "paddle/fluid/operators/tensor_formatter.h" namespace paddle { namespace operators { @@ -28,133 +29,6 @@ const char kForward[] = "FORWARD"; const char kBackward[] = "BACKWARD"; const char kBoth[] = "BOTH"; -class LogGuard { - public: - inline LogGuard() { LogMutex().lock(); } - - inline ~LogGuard() { LogMutex().unlock(); } - - private: - static std::mutex &LogMutex() { - static std::mutex mtx; - return mtx; - } -}; - -struct Formater { - std::string message; - std::string name; - std::string dims; - std::type_index dtype{typeid(const char)}; - std::string layout; - framework::LoD lod; - int summarize; - void *data{nullptr}; - platform::Place place; - std::stringstream logs; - - void operator()(size_t size) { - PrintName(); - PrintMessage(); - PrintLod(); - PrintPlace(); - PrintDims(); - PrintLayout(); - PrintDtype(); - PrintData(size); - LogGuard guard; - CLOG << logs.str(); - } - - private: - void PrintPlace() { logs << " - place: " << place << std::endl; } - void PrintMessage() { - if (!message.empty()) { - logs << " - message: " << message << std::endl; - } - } - void PrintName() { - if (!name.empty()) { - logs << "Variable: " << name << std::endl; - } - } - void PrintDims() { - if (!dims.empty()) { - logs << " - shape: " << dims << std::endl; - } - } - void PrintDtype() { - if (!framework::IsType(dtype)) { - logs << " - dtype: " << platform::demangle(dtype.name()) << std::endl; - } - } - void PrintLayout() { - if (!layout.empty()) { - logs << " - layout: " << layout << std::endl; - } - } - void PrintLod() { - if (!lod.empty()) { - logs << " - lod: {"; - for (auto level : lod) { - logs << "{"; - bool is_first = true; - for (auto i : level) { - if (is_first) { - logs << i; - is_first = false; - } else { - logs << ", " << i; - } - } - logs << "}"; - } - logs << "}" << std::endl; - } - } - - void PrintData(size_t size) { - PADDLE_ENFORCE_NOT_NULL(data); - // print float - if (framework::IsType(dtype)) { - Display(size); - } else if (framework::IsType(dtype)) { - Display(size); - } else if (framework::IsType(dtype)) { - Display(size); - } else if (framework::IsType(dtype)) { - Display(size); - } else if (framework::IsType(dtype)) { - Display(size); - } else { - logs << " - data: unprintable type: " << dtype.name() << std::endl; - } - } - - template - void Display(size_t size) { - auto *d = reinterpret_cast(data); - logs << " - data: ["; - if (summarize != -1) { - summarize = std::min(size, (size_t)summarize); - if (summarize > 0) { - logs << d[0]; - for (int i = 1; i < summarize; ++i) { - logs << " " << d[i]; - } - } - } else { - if (size > 0) { - logs << d[0]; - for (size_t i = 1; i < size; ++i) { - logs << " " << d[i]; - } - } - } - logs << "]" << std::endl; - } -}; - // TODO(ChunweiYan) there should be some other printers for TensorArray class PrintOp : public framework::OperatorBase { public: @@ -211,27 +85,15 @@ class PrintOp : public framework::OperatorBase { TensorCopy(in_tensor, place, &printed_tensor); } - Formater formater; - formater.place = place; - formater.message = Attr("message"); - if (Attr("print_tensor_name")) { - formater.name = printed_var_name; - } - if (Attr("print_tensor_type")) { - formater.dtype = framework::ToTypeIndex(printed_tensor.type()); - } - if (Attr("print_tensor_shape")) { - formater.dims = printed_tensor.dims().to_str(); - } - if (Attr("print_tensor_lod")) { - formater.lod = printed_tensor.lod(); - } - if (Attr("print_tensor_layout")) { - formater.layout = framework::DataLayoutToString(printed_tensor.layout()); - } - formater.summarize = Attr("summarize"); - formater.data = reinterpret_cast(printed_tensor.data()); - formater(printed_tensor.numel()); + TensorFormatter formatter; + const std::string &name = + Attr("print_tensor_name") ? printed_var_name : ""; + formatter.SetPrintTensorType(Attr("print_tensor_type")); + formatter.SetPrintTensorShape(Attr("print_tensor_shape")); + formatter.SetPrintTensorLod(Attr("print_tensor_lod")); + formatter.SetPrintTensorLayout(Attr("print_tensor_layout")); + formatter.SetSummarize(static_cast(Attr("summarize"))); + formatter.Print(printed_tensor, name, Attr("message")); } private: diff --git a/paddle/fluid/operators/tensor_formatter.cc b/paddle/fluid/operators/tensor_formatter.cc new file mode 100644 index 00000000000..7b8b484a11e --- /dev/null +++ b/paddle/fluid/operators/tensor_formatter.cc @@ -0,0 +1,154 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#include +#include + +#include "paddle/fluid/operators/tensor_formatter.h" + +namespace paddle { +namespace operators { + +void TensorFormatter::SetPrintTensorType(bool print_tensor_type) { + print_tensor_type_ = print_tensor_type; +} + +void TensorFormatter::SetPrintTensorShape(bool print_tensor_shape) { + print_tensor_shape_ = print_tensor_shape; +} + +void TensorFormatter::SetPrintTensorLod(bool print_tensor_lod) { + print_tensor_lod_ = print_tensor_lod; +} + +void TensorFormatter::SetPrintTensorLayout(bool print_tensor_layout) { + print_tensor_layout_ = print_tensor_layout; +} + +void TensorFormatter::SetSummarize(int64_t summarize) { + summarize_ = summarize; +} + +void TensorFormatter::Print(const framework::LoDTensor& print_tensor, + const std::string& tensor_name, + const std::string& message) { + static std::mutex mutex; + std::lock_guard lock(mutex); + std::cout << Format(print_tensor, tensor_name, message); +} + +std::string TensorFormatter::Format(const framework::LoDTensor& print_tensor, + const std::string& tensor_name, + const std::string& message) { + std::stringstream log_stream; + if (!tensor_name.empty()) { + log_stream << "Variable: " << tensor_name << std::endl; + } + + if (!message.empty()) { + log_stream << " - message: " << message << std::endl; + } + + if (print_tensor_lod_) { + log_stream << " - lod: {"; + const framework::LoD& lod = print_tensor.lod(); + for (auto level : lod) { + log_stream << "{"; + bool is_first = true; + for (auto i : level) { + if (is_first) { + log_stream << i; + is_first = false; + } else { + log_stream << ", " << i; + } + } + log_stream << "}"; + } + log_stream << "}" << std::endl; + } + + log_stream << " - place: " << print_tensor.place() << std::endl; + + if (print_tensor_shape_) { + log_stream << " - shape: " << print_tensor.dims().to_str() << std::endl; + } + + if (print_tensor_layout_) { + log_stream << " - layout: " + << framework::DataLayoutToString(print_tensor.layout()) + << std::endl; + } + + std::type_index dtype = framework::ToTypeIndex(print_tensor.type()); + if (print_tensor_type_) { + log_stream << " - dtype: " << platform::demangle(dtype.name()) + << std::endl; + } + + if (framework::IsType(dtype)) { + FormatData(print_tensor, log_stream); + } else if (framework::IsType(dtype)) { + FormatData(print_tensor, log_stream); + } else if (framework::IsType(dtype)) { + FormatData(print_tensor, log_stream); + } else if (framework::IsType(dtype)) { + FormatData(print_tensor, log_stream); + } else if (framework::IsType(dtype)) { + FormatData(print_tensor, log_stream); + } else { + log_stream << " - data: unprintable type: " << dtype.name() << std::endl; + } + return log_stream.str(); +} + +template +void TensorFormatter::FormatData(const framework::LoDTensor& print_tensor, + std::stringstream& log_stream) { + int64_t print_size = summarize_ == -1 + ? print_tensor.numel() + : std::min(summarize_, print_tensor.numel()); + const T* data = nullptr; + if (is_cpu_place(print_tensor.place())) { + data = print_tensor.data(); + } else { + framework::LoDTensor cpu_tensor; + platform::CPUPlace cpu_place; + TensorCopy(print_tensor, cpu_place, &cpu_tensor); + data = cpu_tensor.data(); + } + + log_stream << " - data: ["; + if (print_size > 0) { + log_stream << data[0]; + for (int64_t i = 1; i < print_size; ++i) { + log_stream << " " << data[i]; + } + } + log_stream << "]" << std::endl; +} + +template void TensorFormatter::FormatData( + const framework::LoDTensor& print_tensor, std::stringstream& log_stream); +template void TensorFormatter::FormatData( + const framework::LoDTensor& print_tensor, std::stringstream& log_stream); +template void TensorFormatter::FormatData( + const framework::LoDTensor& print_tensor, std::stringstream& log_stream); +template void TensorFormatter::FormatData( + const framework::LoDTensor& print_tensor, std::stringstream& log_stream); +template void TensorFormatter::FormatData( + const framework::LoDTensor& print_tensor, std::stringstream& log_stream); + +} // namespace operators +} // namespace paddle diff --git a/paddle/fluid/operators/tensor_formatter.h b/paddle/fluid/operators/tensor_formatter.h new file mode 100644 index 00000000000..1731348479d --- /dev/null +++ b/paddle/fluid/operators/tensor_formatter.h @@ -0,0 +1,55 @@ +/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. */ + +#pragma once +#include + +#include "paddle/fluid/framework/data_layout.h" +#include "paddle/fluid/framework/var_type.h" + +namespace paddle { +namespace operators { + +class TensorFormatter { + public: + TensorFormatter() {} + + std::string Format(const framework::LoDTensor& print_tensor, + const std::string& tensor_name = "", + const std::string& message = ""); + + void Print(const framework::LoDTensor& print_tensor, + const std::string& tensor_name = "", + const std::string& message = ""); + + void SetPrintTensorType(bool print_tensor_type); + void SetPrintTensorShape(bool print_tensor_shape); + void SetPrintTensorLod(bool print_tensor_lod); + void SetPrintTensorLayout(bool print_tensor_layout); + void SetSummarize(int64_t summarize); + + private: + template + void FormatData(const framework::LoDTensor& print_tensor, + std::stringstream& log_stream); + + int64_t summarize_ = -1; + bool print_tensor_type_ = true; + bool print_tensor_shape_ = true; + bool print_tensor_lod_ = true; + bool print_tensor_layout_ = true; +}; + +} // namespace operators +} // namespace paddle diff --git a/python/paddle/fluid/layers/control_flow.py b/python/paddle/fluid/layers/control_flow.py index e35576056e9..8b10d3a438e 100755 --- a/python/paddle/fluid/layers/control_flow.py +++ b/python/paddle/fluid/layers/control_flow.py @@ -34,8 +34,8 @@ __all__ = [ 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than', 'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal', 'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN', - 'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'case', 'switch_case', - 'while_loop' + 'reorder_lod_tensor_by_rank', 'Print', 'Assert', 'is_empty', 'case', + 'switch_case', 'while_loop' ] @@ -300,6 +300,78 @@ def Print(input, return output +def Assert(cond, data=None, summarize=20, name=None): + ''' + This API creates an op that asserts the given condition is true. If the + condition is false, prints the tensors in data. ``summarize`` specifies the + number of the elements in the tensors to print. + + Args: + cond (Variable): The boolean condition tensor whose numel should be 1. + data (list|tuple, optional): list or tuple of tensors to print when + condition is not true. If it's ``None``, no tensor will be printed. + The default value is ``None``. + summarize (int, optional): Number of elements in the tensor to be + printed. If its value is -1, then all elements in the tensor will + be printed. The default value is 20. + name (str, optional): The default value is ``None`` . Normally users + don't have to set this parameter. For more information, please + refer to :ref:`api_guide_Name` . + + Returns: + Operator: the created operation. + + Raises: + TypeError: If ``cond`` is not boolean Variable. + TypeError: If ``data`` is not a list or tuple or ``None``. + TypeError: If ``summarize`` is not int. + TypeError: If ``name`` is not a string or ``None`` . + fluid.core.EnforceNotMet: If the condition is False in running time. + + Examples: + .. code-block:: python + + import paddle.fluid as fluid + import paddle.fluid.layers as layers + + x = layers.fill_constant(shape=[2, 3], dtype='float32', value=2.0) + condition = layers.reduce_max(x) < 1.0 # False + layers.Assert(condition, [x], 10, "example_assert_layer") + + exe = fluid.Executor() + try: + exe.run(fluid.default_main_program()) + # Print x and throws paddle.fluid.core.EnforceNotMet exception + # Example printed message for x: + # + # Variable: fill_constant_0.tmp_0 + # - lod: {} + # - place: CPUPlace() + # - shape: [2, 3] + # - layout: NCHW + # - dtype: float + # - data: [2 2 2 2 2 2] + except fluid.core.EnforceNotMet as e: + print("Assert Exception Example") + + ''' + check_variable_and_dtype(cond, "cond", ["bool"], "fluid.layers.Assert") + check_type(data, "data", (list, tuple, type(None)), "fluid.layers.Assert") + check_type(summarize, "summarize", int, "fluid.layers.Assert") + check_type(name, "name", (str, type(None)), "fluid.layers.Assert") + + layer_name = name if name else ('assert_' + cond.name) + helper = LayerHelper(layer_name, **locals()) + + op = helper.append_op( + type="assert", + inputs={"Cond": cond, + "Data": [] if data is None else list(data)}, + attrs={"summarize": summarize}) + + return op + + class BlockGuard(object): """ BlockGuard class. diff --git a/python/paddle/fluid/tests/unittests/test_assert_op.py b/python/paddle/fluid/tests/unittests/test_assert_op.py new file mode 100644 index 00000000000..47dbb1092c5 --- /dev/null +++ b/python/paddle/fluid/tests/unittests/test_assert_op.py @@ -0,0 +1,90 @@ +# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import print_function + +import paddle.fluid as fluid +import paddle.fluid.layers as layers +import unittest + + +class TestAssertOp(unittest.TestCase): + def run_network(self, net_func): + main_program = fluid.Program() + startup_program = fluid.Program() + with fluid.program_guard(main_program, startup_program): + net_func() + exe = fluid.Executor() + exe.run(main_program) + + def test_assert_true(self): + def net_func(): + condition = layers.fill_constant( + shape=[1], dtype='bool', value=True) + layers.Assert(condition, []) + + self.run_network(net_func) + + def test_assert_false(self): + def net_func(): + condition = layers.fill_constant( + shape=[1], dtype='bool', value=False) + layers.Assert(condition) + + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_assert_cond_numel_error(self): + def net_func(): + condition = layers.fill_constant( + shape=[1, 2], dtype='bool', value=True) + layers.Assert(condition, []) + + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_assert_print_data(self): + def net_func(): + zero = layers.fill_constant(shape=[1], dtype='int64', value=0) + one = layers.fill_constant(shape=[1], dtype='int64', value=1) + condition = layers.less_than(one, zero) # False + layers.Assert(condition, [zero, one]) + + print("test_assert_print_data") + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_assert_summary(self): + def net_func(): + x = layers.fill_constant(shape=[10], dtype='float32', value=2.0) + condition = layers.reduce_max(x) < 1.0 + layers.Assert(condition, (x, ), 5) + + print("test_assert_summary") + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + def test_assert_summary_greater_than_size(self): + def net_func(): + x = layers.fill_constant(shape=[2, 3], dtype='float32', value=2.0) + condition = layers.reduce_max(x) < 1.0 + layers.Assert(condition, [x], 10, name="test") + + print("test_assert_summary_greater_than_size") + with self.assertRaises(fluid.core.EnforceNotMet): + self.run_network(net_func) + + +if __name__ == '__main__': + unittest.main() -- GitLab