提交 53b6ee19 编写于 作者: D dangqingqing

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into ssd_target_assign

......@@ -10,8 +10,7 @@ The following example shows the usage of `fluid.switch`.
a = fluid.Var(10)
b = fluid.Var(0)
switch = fluid.switch()
with switch.block():
with switch() as switch:
with switch.case(fluid.less_equal(a, 10)):
fluid.print("Case 1")
with switch.case(fluid.larger(a, 0)):
......
......@@ -11,9 +11,15 @@ cc_test(test_inference_image_classification_resnet
SRCS test_inference_image_classification.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/image_classification_resnet.inference.model)
cc_test(test_inference_label_semantic_roles
SRCS test_inference_label_semantic_roles.cc
DEPS ARCHIVE_START paddle_fluid ARCHIVE_END
ARGS --dirname=${PYTHON_TESTS_DIR}/book/label_semantic_roles.inference.model)
set_tests_properties(test_inference_recognize_digits_mlp
PROPERTIES DEPENDS test_recognize_digits)
set_tests_properties(test_inference_image_classification_vgg
PROPERTIES DEPENDS test_image_classification_train)
set_tests_properties(test_inference_image_classification_resnet
PROPERTIES DEPENDS test_image_classification_train)
set_tests_properties(test_inference_label_semantic_roles
PROPERTIES DEPENDS test_label_semantic_roles)
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
template <typename T>
void SetupTensor(paddle::framework::LoDTensor& input,
paddle::framework::DDim dims,
T lower,
T upper) {
srand(time(0));
T* input_ptr = input.mutable_data<T>(dims, paddle::platform::CPUPlace());
for (int i = 0; i < input.numel(); ++i) {
input_ptr[i] =
(static_cast<T>(rand()) / static_cast<T>(RAND_MAX)) * (upper - lower) +
lower;
}
}
template <typename T>
void SetupLoDTensor(paddle::framework::LoDTensor& input,
paddle::framework::LoD& lod,
T lower,
T upper) {
input.set_lod(lod);
int dim = lod[0][lod[0].size() - 1];
SetupTensor(input, {dim, 1}, lower, upper);
}
template <typename T>
void CheckError(paddle::framework::LoDTensor& output1,
paddle::framework::LoDTensor& output2) {
// Check lod information
EXPECT_EQ(output1.lod(), output2.lod());
EXPECT_EQ(output1.dims(), output2.dims());
EXPECT_EQ(output1.numel(), output2.numel());
T err = static_cast<T>(0);
if (typeid(T) == typeid(float)) {
err = 1E-3;
} else if (typeid(T) == typeid(double)) {
err = 1E-6;
} else {
err = 0;
}
size_t count = 0;
for (int64_t i = 0; i < output1.numel(); ++i) {
if (fabs(output1.data<T>()[i] - output2.data<T>()[i]) > err) {
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
template <typename Place, typename T>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor and scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file
auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete scope;
}
/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <gtest/gtest.h>
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
TEST(inference, label_semantic_roles) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
}
LOG(INFO) << "FLAGS_dirname: " << FLAGS_dirname << std::endl;
std::string dirname = FLAGS_dirname;
// 0. Call `paddle::framework::InitDevices()` initialize all the devices
// In unittests, this is done in paddle/testing/paddle_gtest_main.cc
paddle::framework::LoDTensor word, predicate, ctx_n2, ctx_n1, ctx_0, ctx_p1,
ctx_p2, mark;
paddle::framework::LoD lod{{0, 4, 10}};
SetupLoDTensor(word, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(
predicate, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_n2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_n1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_0, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_p1, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(ctx_p2, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
SetupLoDTensor(mark, lod, static_cast<int64_t>(0), static_cast<int64_t>(1));
std::vector<paddle::framework::LoDTensor*> cpu_feeds;
cpu_feeds.push_back(&word);
cpu_feeds.push_back(&predicate);
cpu_feeds.push_back(&ctx_n2);
cpu_feeds.push_back(&ctx_n1);
cpu_feeds.push_back(&ctx_0);
cpu_feeds.push_back(&ctx_p1);
cpu_feeds.push_back(&ctx_p2);
cpu_feeds.push_back(&mark);
paddle::framework::LoDTensor output1;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs1;
cpu_fetchs1.push_back(&output1);
// Run inference on CPU
TestInference<paddle::platform::CPUPlace, float>(
dirname, cpu_feeds, cpu_fetchs1);
LOG(INFO) << output1.lod();
LOG(INFO) << output1.dims();
#ifdef PADDLE_WITH_CUDA
paddle::framework::LoDTensor output2;
std::vector<paddle::framework::LoDTensor*> cpu_fetchs2;
cpu_fetchs2.push_back(&output2);
// Run inference on CUDA GPU
TestInference<paddle::platform::CUDAPlace, float>(
dirname, cpu_feeds, cpu_fetchs2);
LOG(INFO) << output2.lod();
LOG(INFO) << output2.dims();
CheckError<float>(output1, output2);
#endif
}
......@@ -16,89 +16,10 @@ limitations under the License. */
#include <time.h>
#include <sstream>
#include "gflags/gflags.h"
#include "paddle/framework/lod_tensor.h"
#include "paddle/inference/io.h"
#include "test_helper.h"
DEFINE_string(dirname, "", "Directory of the inference model.");
template <typename Place, typename T>
void TestInference(const std::string& dirname,
const std::vector<paddle::framework::LoDTensor*>& cpu_feeds,
std::vector<paddle::framework::LoDTensor*>& cpu_fetchs) {
// 1. Define place, executor and scope
auto place = Place();
auto executor = paddle::framework::Executor(place);
auto* scope = new paddle::framework::Scope();
// 2. Initialize the inference_program and load all parameters from file
auto inference_program = paddle::inference::Load(executor, *scope, dirname);
// 3. Get the feed_target_names and fetch_target_names
const std::vector<std::string>& feed_target_names =
inference_program->GetFeedTargetNames();
const std::vector<std::string>& fetch_target_names =
inference_program->GetFetchTargetNames();
// 4. Prepare inputs: set up maps for feed targets
std::map<std::string, const paddle::framework::LoDTensor*> feed_targets;
for (size_t i = 0; i < feed_target_names.size(); ++i) {
// Please make sure that cpu_feeds[i] is right for feed_target_names[i]
feed_targets[feed_target_names[i]] = cpu_feeds[i];
}
// 5. Define Tensor to get the outputs: set up maps for fetch targets
std::map<std::string, paddle::framework::LoDTensor*> fetch_targets;
for (size_t i = 0; i < fetch_target_names.size(); ++i) {
fetch_targets[fetch_target_names[i]] = cpu_fetchs[i];
}
// 6. Run the inference program
executor.Run(*inference_program, scope, feed_targets, fetch_targets);
delete scope;
}
template <typename T>
void SetupTensor(paddle::framework::LoDTensor& input,
paddle::framework::DDim dims,
T lower,
T upper) {
srand(time(0));
float* input_ptr = input.mutable_data<T>(dims, paddle::platform::CPUPlace());
for (int i = 0; i < input.numel(); ++i) {
input_ptr[i] =
(static_cast<T>(rand()) / static_cast<T>(RAND_MAX)) * (upper - lower) +
lower;
}
}
template <typename T>
void CheckError(paddle::framework::LoDTensor& output1,
paddle::framework::LoDTensor& output2) {
// Check lod information
EXPECT_EQ(output1.lod(), output2.lod());
EXPECT_EQ(output1.dims(), output2.dims());
EXPECT_EQ(output1.numel(), output2.numel());
T err = static_cast<T>(0);
if (typeid(T) == typeid(float)) {
err = 1E-3;
} else if (typeid(T) == typeid(double)) {
err = 1E-6;
} else {
err = 0;
}
size_t count = 0;
for (int64_t i = 0; i < output1.numel(); ++i) {
if (fabs(output1.data<T>()[i] - output2.data<T>()[i]) > err) {
count++;
}
}
EXPECT_EQ(count, 0) << "There are " << count << " different elements.";
}
TEST(inference, recognize_digits) {
if (FLAGS_dirname.empty()) {
LOG(FATAL) << "Usage: ./example --dirname=path/to/your/model";
......
......@@ -41,6 +41,21 @@ class ConditionalOp : public framework::OperatorBase {
});
return retv;
}
bool ScalarCondition(
const std::vector<const framework::LoDTensor *> &ips) const {
if (!(ips.size() == 1UL && ips[0]->IsInitialized())) {
PADDLE_THROW("should have one initialized input as condition");
}
if (!(ips[0]->type().hash_code() == typeid(bool).hash_code() &&
ips[0]->numel() == 1)) {
PADDLE_THROW(
"condition input's data type should be bool, "
"numel should be 1, actual numel is %d",
ips[0]->numel());
}
return ips[0]->data<bool>()[0];
}
};
class ConditionalBlockOp : public ConditionalOp {
......@@ -53,9 +68,15 @@ class ConditionalBlockOp : public ConditionalOp {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
need_run = ScalarCondition(xs);
} else {
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}
if (need_run) {
auto *scope_var = scope.FindVar(Output("Scope"));
......@@ -88,6 +109,10 @@ class ConditionalBlockOpProtoMaker : public framework::OpProtoAndCheckerMaker {
"scope is std::vector<Scope*>");
AddAttr<framework::BlockDesc *>(
"sub_block", "The step block of conditional block operator");
AddAttr<bool>("is_scalar_condition",
"the input X is used as scalar "
"condition")
.SetDefault(false);
AddComment(R"DOC(Conditional block operator
Run the sub-block if X is not empty. Params is the other inputs and Out is the
......@@ -106,9 +131,15 @@ class ConditionalBlockGradOp : public ConditionalOp {
void Run(const framework::Scope &scope,
const platform::Place &dev_place) const override {
auto xs = this->InputTensors(scope);
bool need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
bool need_run;
if (Attr<bool>("is_scalar_condition")) {
need_run = ScalarCondition(xs);
} else {
need_run = std::all_of(
xs.begin(), xs.end(),
[](const framework::LoDTensor *t) { return t->numel() != 0; });
}
if (need_run) {
auto *scope_var = scope.FindVar(Input("Scope"));
......@@ -182,6 +213,7 @@ class ConditionalBlockGradMaker : public framework::SingleGradOpDescMaker {
grad_op->SetOutput(framework::GradVarName("Params"),
InputGrad("Params", false));
grad_op->SetBlockAttr("sub_block", *this->grad_block_[0]);
grad_op->SetAttr("is_scalar_condition", GetAttr("is_scalar_condition"));
return std::unique_ptr<framework::OpDesc>(grad_op);
}
};
......
......@@ -18,6 +18,7 @@ from tensor import assign, fill_constant
from .. import core
from ..framework import Program, Variable, Operator
from ..layer_helper import LayerHelper, unique_name
from ops import logical_and, logical_not, logical_or
__all__ = [
'split_lod_tensor',
......@@ -27,6 +28,7 @@ __all__ = [
'StaticRNNMemoryLink',
'WhileGuard',
'While',
'Switch',
'lod_rank_table',
'max_sequence_len',
'topk',
......@@ -1063,11 +1065,12 @@ class ConditionalBlockGuard(BlockGuard):
class ConditionalBlock(object):
def __init__(self, inputs, name=None):
def __init__(self, inputs, is_scalar_condition=False, name=None):
for each_input in inputs:
if not isinstance(each_input, Variable):
raise TypeError("Each input should be variable")
self.inputs = inputs
self.is_scalar_condition = is_scalar_condition
self.helper = LayerHelper('conditional_block', name=name)
def block(self):
......@@ -1112,7 +1115,66 @@ class ConditionalBlock(object):
},
outputs={'Out': out_list,
'Scope': [step_scope]},
attrs={'sub_block': inside_block})
attrs={
'sub_block': inside_block,
'is_scalar_condition': self.is_scalar_condition
})
class Switch(object):
def __init__(self, name=None):
self.helper = LayerHelper('switch', name=name)
self.inside_scope = False
self.pre_not_conditions = []
def case(self, condition):
"""create a new block for this condition
"""
if not self.inside_scope:
raise ValueError("case should be called inside with")
if len(self.pre_not_conditions) == 0:
cond_block = ConditionalBlock([condition], is_scalar_condition=True)
not_cond = logical_not(x=condition)
self.pre_not_conditions.append(not_cond)
else:
pre_cond_num = len(self.pre_not_conditions)
pre_not_cond = self.pre_not_conditions[pre_cond_num - 1]
new_not_cond = logical_and(
x=pre_not_cond, y=logical_not(x=condition))
self.pre_not_conditions.append(new_not_cond)
cond_block = ConditionalBlock(
[logical_and(
x=pre_not_cond, y=condition)],
is_scalar_condition=True)
return ConditionalBlockGuard(cond_block)
def default(self):
"""create a default case for this switch
"""
pre_cond_num = len(self.pre_not_conditions)
if pre_cond_num == 0:
raise ValueError("there should be at least one condition")
cond_block = ConditionalBlock(
[self.pre_not_conditions[pre_cond_num - 1]],
is_scalar_condition=True)
return ConditionalBlockGuard(cond_block)
def __enter__(self):
"""
set flag that now is inside switch.block {}
:return:
"""
self.inside_scope = True
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.inside_scope = False
if exc_type is not None:
return False # re-raise exception
return True
class IfElseBlockGuard(object):
......
......@@ -61,6 +61,10 @@ __all__ = [
'clip_by_norm',
'softmax',
'sequence_softmax',
'logical_and',
'logical_or',
'logical_xor',
'logical_not',
] + __activations__
for _OP in set(__all__):
......
......@@ -18,7 +18,9 @@ import numpy as np
import paddle.v2 as paddle
import paddle.v2.dataset.conll05 as conll05
import paddle.v2.fluid as fluid
import contextlib
import time
import unittest
word_dict, verb_dict, label_dict = conll05.get_dict()
word_dict_len = len(word_dict)
......@@ -127,7 +129,15 @@ def to_lodtensor(data, place):
return res
def main():
def create_random_lodtensor(lod, place, low, high):
data = np.random.random_integers(low, high, [lod[-1], 1]).astype("int64")
res = fluid.LoDTensor()
res.set(data, place)
res.set_lod([lod])
return res
def train(use_cuda, save_dirname=None):
# define network topology
word = fluid.layers.data(
name='word_data', shape=[1], dtype='int64', lod_level=1)
......@@ -175,8 +185,8 @@ def main():
paddle.reader.shuffle(
paddle.dataset.conll05.test(), buf_size=8192),
batch_size=BATCH_SIZE)
# place = fluid.CPUPlace()
place = fluid.CUDAPlace(0)
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
feeder = fluid.DataFeeder(
feed_list=[
word, ctx_n2, ctx_n1, ctx_0, ctx_p1, ctx_p2, predicate, mark, target
......@@ -211,12 +221,102 @@ def main():
if batch_id != 0:
print("second per batch: " + str((time.time() - start_time)
/ batch_id))
# exit early for CI
exit(0)
# Set the threshold low to speed up the CI test
if float(pass_precision) > 0.05:
if save_dirname is not None:
fluid.io.save_inference_model(save_dirname, [
'word_data', 'verb_data', 'ctx_n2_data',
'ctx_n1_data', 'ctx_0_data', 'ctx_p1_data',
'ctx_p2_data', 'mark_data'
], [feature_out], exe)
return
batch_id = batch_id + 1
def infer(use_cuda, save_dirname=None):
if save_dirname is None:
return
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
exe = fluid.Executor(place)
# Use fluid.io.load_inference_model to obtain the inference program desc,
# the feed_target_names (the names of variables that will be feeded
# data using feed operators), and the fetch_targets (variables that
# we want to obtain data from using fetch operators).
[inference_program, feed_target_names,
fetch_targets] = fluid.io.load_inference_model(save_dirname, exe)
lod = [0, 4, 10]
ts_word = create_random_lodtensor(lod, place, low=0, high=1)
ts_pred = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_n2 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_n1 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_0 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_p1 = create_random_lodtensor(lod, place, low=0, high=1)
ts_ctx_p2 = create_random_lodtensor(lod, place, low=0, high=1)
ts_mark = create_random_lodtensor(lod, place, low=0, high=1)
# Construct feed as a dictionary of {feed_target_name: feed_target_data}
# and results will contain a list of data corresponding to fetch_targets.
assert feed_target_names[0] == 'word_data'
assert feed_target_names[1] == 'verb_data'
assert feed_target_names[2] == 'ctx_n2_data'
assert feed_target_names[3] == 'ctx_n1_data'
assert feed_target_names[4] == 'ctx_0_data'
assert feed_target_names[5] == 'ctx_p1_data'
assert feed_target_names[6] == 'ctx_p2_data'
assert feed_target_names[7] == 'mark_data'
results = exe.run(inference_program,
feed={
feed_target_names[0]: ts_word,
feed_target_names[1]: ts_pred,
feed_target_names[2]: ts_ctx_n2,
feed_target_names[3]: ts_ctx_n1,
feed_target_names[4]: ts_ctx_0,
feed_target_names[5]: ts_ctx_p1,
feed_target_names[6]: ts_ctx_p2,
feed_target_names[7]: ts_mark
},
fetch_list=fetch_targets,
return_numpy=False)
print(results[0].lod())
np_data = np.array(results[0])
print("Inference Shape: ", np_data.shape)
print("Inference results: ", np_data)
def main(use_cuda):
if use_cuda and not fluid.core.is_compiled_with_cuda():
return
# Directory for saving the trained model
save_dirname = "label_semantic_roles.inference.model"
train(use_cuda, save_dirname)
infer(use_cuda, save_dirname)
class TestLabelSemanticRoles(unittest.TestCase):
def test_cuda(self):
with self.scope_prog_guard():
main(use_cuda=True)
def test_cpu(self):
with self.scope_prog_guard():
main(use_cuda=False)
@contextlib.contextmanager
def scope_prog_guard(self):
prog = fluid.Program()
startup_prog = fluid.Program()
scope = fluid.core.Scope()
with fluid.scope_guard(scope):
with fluid.program_guard(prog, startup_prog):
yield
if __name__ == '__main__':
main()
unittest.main()
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import unittest
import paddle.v2.fluid.core as core
import paddle.v2.fluid.layers as layers
import paddle.v2.fluid.framework as framework
from paddle.v2.fluid.executor import Executor
from paddle.v2.fluid.framework import default_startup_program
class TestSwitch(unittest.TestCase):
def check_switch(self, value):
x = layers.fill_constant(shape=[1], dtype='float32', value=value)
zero_var = layers.fill_constant(shape=[1], dtype='float32', value=0.0)
one_var = layers.fill_constant(shape=[1], dtype='float32', value=1.0)
two_var = layers.fill_constant(shape=[1], dtype='float32', value=2.0)
three_var = layers.fill_constant(shape=[1], dtype='float32', value=3.0)
result = layers.create_global_var(
shape=[1], value=-1.0, dtype='float32', persistable=True)
with layers.Switch() as switch:
with switch.case(layers.less_than(x, zero_var)):
layers.assign(zero_var, result)
with switch.case(layers.less_than(x, one_var)):
layers.assign(one_var, result)
with switch.case(layers.less_than(x, two_var)):
layers.assign(two_var, result)
with switch.default():
layers.assign(three_var, result)
cpu = core.CPUPlace()
exe = Executor(cpu)
exe.run(default_startup_program())
out = exe.run(feed={}, fetch_list=[result])[0][0]
return out
def test_switch(self):
test_data = {(-0.1, 0), (0.1, 1), (1.1, 2), (2.1, 3)}
for x, expected_result in test_data:
main_program = framework.Program()
startup_program = framework.Program()
with framework.program_guard(main_program, startup_program):
result = self.check_switch(x)
self.assertEqual(result, expected_result)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册