未验证 提交 8a1a2af8 编写于 作者: H Huihuang Zheng 提交者: GitHub

Add Assert Op (#24280)

1. To make ProgramTranslator to support `assert` grammar, this PR adds `assert` python API and C++ code. 

2. Fix a bug: graph_pattern_detector.h #include <gtest/gtest_prod.h> but didn't declared dependency at CMakeLists, which can cause single build failure.

3. Refactoring `Formatter` in print_op to make it reusable and reuse the formatter to print in assert op.
上级 8c296dea
...@@ -40,7 +40,13 @@ cc_library(graph SRCS graph.cc DEPS node pretty_log) ...@@ -40,7 +40,13 @@ cc_library(graph SRCS graph.cc DEPS node pretty_log)
cc_library(graph_helper SRCS graph_helper.cc DEPS graph) cc_library(graph_helper SRCS graph_helper.cc DEPS graph)
cc_library(pass SRCS pass.cc DEPS graph node graph_helper) cc_library(pass SRCS pass.cc DEPS graph node graph_helper)
cc_library(graph_traits SRCS graph_traits.cc DEPS graph) cc_library(graph_traits SRCS graph_traits.cc DEPS graph)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS graph graph_helper graph_traits)
SET(GRAPH_PATTERN_DETECTOR_DEPS graph graph_helper graph_traits)
if (WITH_TESTING)
SET(GRAPH_PATTERN_DETECTOR_DEPS ${GRAPH_PATTERN_DETECTOR_DEPS} gtest)
endif(WITH_TESTING)
cc_library(graph_pattern_detector SRCS graph_pattern_detector.cc DEPS ${GRAPH_PATTERN_DETECTOR_DEPS})
cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor) cc_library(subgraph_detector SRCS subgraph_detector.cc DEPS graph_pattern_detector executor)
cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass) cc_library(fuse_pass_base SRCS fuse_pass_base.cc DEPS pass)
cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass) cc_library(placement_pass_base SRCS placement_pass_base.cc DEPS pass)
......
...@@ -92,6 +92,7 @@ if (WITH_GPU) ...@@ -92,6 +92,7 @@ if (WITH_GPU)
endif() endif()
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} device_memory_aligment)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer) set(COMMON_OP_DEPS ${COMMON_OP_DEPS} layer)
set(COMMON_OP_DEPS ${COMMON_OP_DEPS} tensor_formatter)
# FIXME(typhoonzero): operator deps may not needed. # FIXME(typhoonzero): operator deps may not needed.
# op_library(lod_tensor_to_array_op DEPS lod_rank_table_op) # op_library(lod_tensor_to_array_op DEPS lod_rank_table_op)
...@@ -119,6 +120,7 @@ else() ...@@ -119,6 +120,7 @@ else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3) cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif() endif()
cc_library(tensor_formatter SRCS tensor_formatter.cc DEPS ${OP_HEADER_DEPS})
if (WITH_PYTHON) if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind) cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
endif() endif()
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/controlflow/while_op_helper.h"
#include "paddle/fluid/operators/tensor_formatter.h"
const char kCond[] = "Cond";
const char kData[] = "Data";
const char kSummarize[] = "summarize";
namespace paddle {
namespace operators {
using framework::LoDTensor;
class AssertOp : public framework::OperatorBase {
public:
AssertOp(const std::string &type, const framework::VariableNameMap &inputs,
const framework::VariableNameMap &outputs,
const framework::AttributeMap &attrs)
: OperatorBase(type, inputs, outputs, attrs) {}
private:
void RunImpl(const framework::Scope &scope,
const platform::Place &dev_place) const override {
const framework::Variable *cond_var_ptr = scope.FindVar(Input(kCond));
PADDLE_ENFORCE_NOT_NULL(cond_var_ptr,
platform::errors::NotFound(
"Input(Condition) of AssertOp is not found."));
const LoDTensor &cond = cond_var_ptr->Get<LoDTensor>();
PADDLE_ENFORCE_EQ(
cond.dims(), paddle::framework::make_ddim({1}),
platform::errors::InvalidArgument(
"The numel of Input(Condition) of AssertOp must be 1. But now "
"the Condition's shape is %s.",
cond.dims().to_str()));
bool cond_data = GetCondData(cond);
if (cond_data) {
return;
}
TensorFormatter formatter;
formatter.SetSummarize(Attr<int64_t>(kSummarize));
const std::vector<std::string> &x_names = Inputs(kData);
for (const std::string &name : x_names) {
const framework::Variable *x_var_ptr = scope.FindVar(name);
const framework::LoDTensor &x_tensor = x_var_ptr->Get<LoDTensor>();
formatter.Print(x_tensor, name);
}
PADDLE_THROW(platform::errors::InvalidArgument(
"The condition variable '%s' of AssertOp must be "
"true, but received false",
Input(kCond)));
}
};
class AssertOpProtoMaker : public framework::OpProtoAndCheckerMaker {
public:
void Make() override {
AddInput(
kCond,
"The boolean scalar condition tensor which is asserted to be true.");
AddInput(kData,
"The tensors to print when the assert condition is not true.")
.AsDuplicable();
AddAttr<int64_t>(
kSummarize,
"The number of entries of each tensor to print when the "
"assert condition is not true. -1 means print all entries. If "
"the number of entries of a tensor is less then "
"summarize_num, this OP will print all entries of the tensor.")
.SetDefault(-1);
AddComment(
R"DOC(Assert the input Condition Tensor is true and print Tensors if the Condition Tensor is false.)DOC");
}
};
class AssertOpInferShape : public framework::InferShapeBase {
public:
void operator()(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInputs(kCond), "Input", "Condition", "AssertOp");
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OPERATOR(
assert, ops::AssertOp, ops::AssertOpProtoMaker, ops::AssertOpInferShape,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/var_type.h" #include "paddle/fluid/framework/var_type.h"
#include "paddle/fluid/operators/assign_op.h" #include "paddle/fluid/operators/assign_op.h"
#include "paddle/fluid/operators/tensor_formatter.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -28,133 +29,6 @@ const char kForward[] = "FORWARD"; ...@@ -28,133 +29,6 @@ const char kForward[] = "FORWARD";
const char kBackward[] = "BACKWARD"; const char kBackward[] = "BACKWARD";
const char kBoth[] = "BOTH"; const char kBoth[] = "BOTH";
class LogGuard {
public:
inline LogGuard() { LogMutex().lock(); }
inline ~LogGuard() { LogMutex().unlock(); }
private:
static std::mutex &LogMutex() {
static std::mutex mtx;
return mtx;
}
};
struct Formater {
std::string message;
std::string name;
std::string dims;
std::type_index dtype{typeid(const char)};
std::string layout;
framework::LoD lod;
int summarize;
void *data{nullptr};
platform::Place place;
std::stringstream logs;
void operator()(size_t size) {
PrintName();
PrintMessage();
PrintLod();
PrintPlace();
PrintDims();
PrintLayout();
PrintDtype();
PrintData(size);
LogGuard guard;
CLOG << logs.str();
}
private:
void PrintPlace() { logs << " - place: " << place << std::endl; }
void PrintMessage() {
if (!message.empty()) {
logs << " - message: " << message << std::endl;
}
}
void PrintName() {
if (!name.empty()) {
logs << "Variable: " << name << std::endl;
}
}
void PrintDims() {
if (!dims.empty()) {
logs << " - shape: " << dims << std::endl;
}
}
void PrintDtype() {
if (!framework::IsType<const char>(dtype)) {
logs << " - dtype: " << platform::demangle(dtype.name()) << std::endl;
}
}
void PrintLayout() {
if (!layout.empty()) {
logs << " - layout: " << layout << std::endl;
}
}
void PrintLod() {
if (!lod.empty()) {
logs << " - lod: {";
for (auto level : lod) {
logs << "{";
bool is_first = true;
for (auto i : level) {
if (is_first) {
logs << i;
is_first = false;
} else {
logs << ", " << i;
}
}
logs << "}";
}
logs << "}" << std::endl;
}
}
void PrintData(size_t size) {
PADDLE_ENFORCE_NOT_NULL(data);
// print float
if (framework::IsType<const float>(dtype)) {
Display<float>(size);
} else if (framework::IsType<const double>(dtype)) {
Display<double>(size);
} else if (framework::IsType<const int>(dtype)) {
Display<int>(size);
} else if (framework::IsType<const int64_t>(dtype)) {
Display<int64_t>(size);
} else if (framework::IsType<const bool>(dtype)) {
Display<bool>(size);
} else {
logs << " - data: unprintable type: " << dtype.name() << std::endl;
}
}
template <typename T>
void Display(size_t size) {
auto *d = reinterpret_cast<T *>(data);
logs << " - data: [";
if (summarize != -1) {
summarize = std::min(size, (size_t)summarize);
if (summarize > 0) {
logs << d[0];
for (int i = 1; i < summarize; ++i) {
logs << " " << d[i];
}
}
} else {
if (size > 0) {
logs << d[0];
for (size_t i = 1; i < size; ++i) {
logs << " " << d[i];
}
}
}
logs << "]" << std::endl;
}
};
// TODO(ChunweiYan) there should be some other printers for TensorArray // TODO(ChunweiYan) there should be some other printers for TensorArray
class PrintOp : public framework::OperatorBase { class PrintOp : public framework::OperatorBase {
public: public:
...@@ -211,27 +85,15 @@ class PrintOp : public framework::OperatorBase { ...@@ -211,27 +85,15 @@ class PrintOp : public framework::OperatorBase {
TensorCopy(in_tensor, place, &printed_tensor); TensorCopy(in_tensor, place, &printed_tensor);
} }
Formater formater; TensorFormatter formatter;
formater.place = place; const std::string &name =
formater.message = Attr<std::string>("message"); Attr<bool>("print_tensor_name") ? printed_var_name : "";
if (Attr<bool>("print_tensor_name")) { formatter.SetPrintTensorType(Attr<bool>("print_tensor_type"));
formater.name = printed_var_name; formatter.SetPrintTensorShape(Attr<bool>("print_tensor_shape"));
} formatter.SetPrintTensorLod(Attr<bool>("print_tensor_lod"));
if (Attr<bool>("print_tensor_type")) { formatter.SetPrintTensorLayout(Attr<bool>("print_tensor_layout"));
formater.dtype = framework::ToTypeIndex(printed_tensor.type()); formatter.SetSummarize(static_cast<int64_t>(Attr<int>("summarize")));
} formatter.Print(printed_tensor, name, Attr<std::string>("message"));
if (Attr<bool>("print_tensor_shape")) {
formater.dims = printed_tensor.dims().to_str();
}
if (Attr<bool>("print_tensor_lod")) {
formater.lod = printed_tensor.lod();
}
if (Attr<bool>("print_tensor_layout")) {
formater.layout = framework::DataLayoutToString(printed_tensor.layout());
}
formater.summarize = Attr<int>("summarize");
formater.data = reinterpret_cast<void *>(printed_tensor.data<void>());
formater(printed_tensor.numel());
} }
private: private:
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <algorithm>
#include <string>
#include "paddle/fluid/operators/tensor_formatter.h"
namespace paddle {
namespace operators {
void TensorFormatter::SetPrintTensorType(bool print_tensor_type) {
print_tensor_type_ = print_tensor_type;
}
void TensorFormatter::SetPrintTensorShape(bool print_tensor_shape) {
print_tensor_shape_ = print_tensor_shape;
}
void TensorFormatter::SetPrintTensorLod(bool print_tensor_lod) {
print_tensor_lod_ = print_tensor_lod;
}
void TensorFormatter::SetPrintTensorLayout(bool print_tensor_layout) {
print_tensor_layout_ = print_tensor_layout;
}
void TensorFormatter::SetSummarize(int64_t summarize) {
summarize_ = summarize;
}
void TensorFormatter::Print(const framework::LoDTensor& print_tensor,
const std::string& tensor_name,
const std::string& message) {
static std::mutex mutex;
std::lock_guard<std::mutex> lock(mutex);
std::cout << Format(print_tensor, tensor_name, message);
}
std::string TensorFormatter::Format(const framework::LoDTensor& print_tensor,
const std::string& tensor_name,
const std::string& message) {
std::stringstream log_stream;
if (!tensor_name.empty()) {
log_stream << "Variable: " << tensor_name << std::endl;
}
if (!message.empty()) {
log_stream << " - message: " << message << std::endl;
}
if (print_tensor_lod_) {
log_stream << " - lod: {";
const framework::LoD& lod = print_tensor.lod();
for (auto level : lod) {
log_stream << "{";
bool is_first = true;
for (auto i : level) {
if (is_first) {
log_stream << i;
is_first = false;
} else {
log_stream << ", " << i;
}
}
log_stream << "}";
}
log_stream << "}" << std::endl;
}
log_stream << " - place: " << print_tensor.place() << std::endl;
if (print_tensor_shape_) {
log_stream << " - shape: " << print_tensor.dims().to_str() << std::endl;
}
if (print_tensor_layout_) {
log_stream << " - layout: "
<< framework::DataLayoutToString(print_tensor.layout())
<< std::endl;
}
std::type_index dtype = framework::ToTypeIndex(print_tensor.type());
if (print_tensor_type_) {
log_stream << " - dtype: " << platform::demangle(dtype.name())
<< std::endl;
}
if (framework::IsType<const float>(dtype)) {
FormatData<float>(print_tensor, log_stream);
} else if (framework::IsType<const double>(dtype)) {
FormatData<double>(print_tensor, log_stream);
} else if (framework::IsType<const int>(dtype)) {
FormatData<int>(print_tensor, log_stream);
} else if (framework::IsType<const int64_t>(dtype)) {
FormatData<int64_t>(print_tensor, log_stream);
} else if (framework::IsType<const bool>(dtype)) {
FormatData<bool>(print_tensor, log_stream);
} else {
log_stream << " - data: unprintable type: " << dtype.name() << std::endl;
}
return log_stream.str();
}
template <typename T>
void TensorFormatter::FormatData(const framework::LoDTensor& print_tensor,
std::stringstream& log_stream) {
int64_t print_size = summarize_ == -1
? print_tensor.numel()
: std::min(summarize_, print_tensor.numel());
const T* data = nullptr;
if (is_cpu_place(print_tensor.place())) {
data = print_tensor.data<T>();
} else {
framework::LoDTensor cpu_tensor;
platform::CPUPlace cpu_place;
TensorCopy(print_tensor, cpu_place, &cpu_tensor);
data = cpu_tensor.data<T>();
}
log_stream << " - data: [";
if (print_size > 0) {
log_stream << data[0];
for (int64_t i = 1; i < print_size; ++i) {
log_stream << " " << data[i];
}
}
log_stream << "]" << std::endl;
}
template void TensorFormatter::FormatData<bool>(
const framework::LoDTensor& print_tensor, std::stringstream& log_stream);
template void TensorFormatter::FormatData<float>(
const framework::LoDTensor& print_tensor, std::stringstream& log_stream);
template void TensorFormatter::FormatData<double>(
const framework::LoDTensor& print_tensor, std::stringstream& log_stream);
template void TensorFormatter::FormatData<int>(
const framework::LoDTensor& print_tensor, std::stringstream& log_stream);
template void TensorFormatter::FormatData<int64_t>(
const framework::LoDTensor& print_tensor, std::stringstream& log_stream);
} // namespace operators
} // namespace paddle
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <string>
#include "paddle/fluid/framework/data_layout.h"
#include "paddle/fluid/framework/var_type.h"
namespace paddle {
namespace operators {
class TensorFormatter {
public:
TensorFormatter() {}
std::string Format(const framework::LoDTensor& print_tensor,
const std::string& tensor_name = "",
const std::string& message = "");
void Print(const framework::LoDTensor& print_tensor,
const std::string& tensor_name = "",
const std::string& message = "");
void SetPrintTensorType(bool print_tensor_type);
void SetPrintTensorShape(bool print_tensor_shape);
void SetPrintTensorLod(bool print_tensor_lod);
void SetPrintTensorLayout(bool print_tensor_layout);
void SetSummarize(int64_t summarize);
private:
template <typename T>
void FormatData(const framework::LoDTensor& print_tensor,
std::stringstream& log_stream);
int64_t summarize_ = -1;
bool print_tensor_type_ = true;
bool print_tensor_shape_ = true;
bool print_tensor_lod_ = true;
bool print_tensor_layout_ = true;
};
} // namespace operators
} // namespace paddle
...@@ -34,8 +34,8 @@ __all__ = [ ...@@ -34,8 +34,8 @@ __all__ = [
'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than', 'While', 'Switch', 'increment', 'array_write', 'create_array', 'less_than',
'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal', 'less_equal', 'greater_than', 'greater_equal', 'equal', 'not_equal',
'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN', 'array_read', 'array_length', 'cond', 'IfElse', 'DynamicRNN', 'StaticRNN',
'reorder_lod_tensor_by_rank', 'Print', 'is_empty', 'case', 'switch_case', 'reorder_lod_tensor_by_rank', 'Print', 'Assert', 'is_empty', 'case',
'while_loop' 'switch_case', 'while_loop'
] ]
...@@ -300,6 +300,78 @@ def Print(input, ...@@ -300,6 +300,78 @@ def Print(input,
return output return output
def Assert(cond, data=None, summarize=20, name=None):
'''
This API creates an op that asserts the given condition is true. If the
condition is false, prints the tensors in data. ``summarize`` specifies the
number of the elements in the tensors to print.
Args:
cond (Variable): The boolean condition tensor whose numel should be 1.
data (list|tuple, optional): list or tuple of tensors to print when
condition is not true. If it's ``None``, no tensor will be printed.
The default value is ``None``.
summarize (int, optional): Number of elements in the tensor to be
printed. If its value is -1, then all elements in the tensor will
be printed. The default value is 20.
name (str, optional): The default value is ``None`` . Normally users
don't have to set this parameter. For more information, please
refer to :ref:`api_guide_Name` .
Returns:
Operator: the created operation.
Raises:
TypeError: If ``cond`` is not boolean Variable.
TypeError: If ``data`` is not a list or tuple or ``None``.
TypeError: If ``summarize`` is not int.
TypeError: If ``name`` is not a string or ``None`` .
fluid.core.EnforceNotMet: If the condition is False in running time.
Examples:
.. code-block:: python
import paddle.fluid as fluid
import paddle.fluid.layers as layers
x = layers.fill_constant(shape=[2, 3], dtype='float32', value=2.0)
condition = layers.reduce_max(x) < 1.0 # False
layers.Assert(condition, [x], 10, "example_assert_layer")
exe = fluid.Executor()
try:
exe.run(fluid.default_main_program())
# Print x and throws paddle.fluid.core.EnforceNotMet exception
# Example printed message for x:
#
# Variable: fill_constant_0.tmp_0
# - lod: {}
# - place: CPUPlace()
# - shape: [2, 3]
# - layout: NCHW
# - dtype: float
# - data: [2 2 2 2 2 2]
except fluid.core.EnforceNotMet as e:
print("Assert Exception Example")
'''
check_variable_and_dtype(cond, "cond", ["bool"], "fluid.layers.Assert")
check_type(data, "data", (list, tuple, type(None)), "fluid.layers.Assert")
check_type(summarize, "summarize", int, "fluid.layers.Assert")
check_type(name, "name", (str, type(None)), "fluid.layers.Assert")
layer_name = name if name else ('assert_' + cond.name)
helper = LayerHelper(layer_name, **locals())
op = helper.append_op(
type="assert",
inputs={"Cond": cond,
"Data": [] if data is None else list(data)},
attrs={"summarize": summarize})
return op
class BlockGuard(object): class BlockGuard(object):
""" """
BlockGuard class. BlockGuard class.
......
# Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import print_function
import paddle.fluid as fluid
import paddle.fluid.layers as layers
import unittest
class TestAssertOp(unittest.TestCase):
def run_network(self, net_func):
main_program = fluid.Program()
startup_program = fluid.Program()
with fluid.program_guard(main_program, startup_program):
net_func()
exe = fluid.Executor()
exe.run(main_program)
def test_assert_true(self):
def net_func():
condition = layers.fill_constant(
shape=[1], dtype='bool', value=True)
layers.Assert(condition, [])
self.run_network(net_func)
def test_assert_false(self):
def net_func():
condition = layers.fill_constant(
shape=[1], dtype='bool', value=False)
layers.Assert(condition)
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_assert_cond_numel_error(self):
def net_func():
condition = layers.fill_constant(
shape=[1, 2], dtype='bool', value=True)
layers.Assert(condition, [])
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_assert_print_data(self):
def net_func():
zero = layers.fill_constant(shape=[1], dtype='int64', value=0)
one = layers.fill_constant(shape=[1], dtype='int64', value=1)
condition = layers.less_than(one, zero) # False
layers.Assert(condition, [zero, one])
print("test_assert_print_data")
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_assert_summary(self):
def net_func():
x = layers.fill_constant(shape=[10], dtype='float32', value=2.0)
condition = layers.reduce_max(x) < 1.0
layers.Assert(condition, (x, ), 5)
print("test_assert_summary")
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
def test_assert_summary_greater_than_size(self):
def net_func():
x = layers.fill_constant(shape=[2, 3], dtype='float32', value=2.0)
condition = layers.reduce_max(x) < 1.0
layers.Assert(condition, [x], 10, name="test")
print("test_assert_summary_greater_than_size")
with self.assertRaises(fluid.core.EnforceNotMet):
self.run_network(net_func)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册