未验证 提交 aa0cb2e9 编写于 作者: P pkpk 提交者: GitHub

Refine api of emotion detection and dialogue domain classification (#4308)

* Update README.md (#4267)

* test=develop (#4269)

* 3d use new api (#4275)

* PointNet++ and PointRCNN use new API

* Update Readme of Dygraph BERT (#4277)

Fix some typos.

* Update run_classifier_multi_gpu.sh (#4279)

remove the CUDA_VISIBLE_DEVICES

* Update README.md (#4280)

* 17 update api (#4294)

* update1.7 save/load & fluid.data

* update datafeed to dataloader

* Update resnet_acnet.py (#4297)

Bias attr of square conv should be "False" rather than None during training mode.

* test=develop
Co-authored-by: NKaipeng Deng <dengkaipeng@baidu.com>
Co-authored-by: Nzhang wenhui <frankwhzhang@126.com>
Co-authored-by: Nparap1uie-s <parap1uie-s@users.noreply.github.com>
上级 872f494e
...@@ -172,7 +172,7 @@ class ResNetACNet(object): ...@@ -172,7 +172,7 @@ class ResNetACNet(object):
act=act if self.deploy else None, act=act if self.deploy else None,
param_attr=ParamAttr(name=name + "_acsquare_weights"), param_attr=ParamAttr(name=name + "_acsquare_weights"),
bias_attr=ParamAttr(name=name + "_acsquare_bias") bias_attr=ParamAttr(name=name + "_acsquare_bias")
if self.deploy else None, if self.deploy else False,
name=name + '.acsquare.conv2d.output.1') name=name + '.acsquare.conv2d.output.1')
if self.deploy: if self.deploy:
......
...@@ -32,14 +32,13 @@ try: ...@@ -32,14 +32,13 @@ try:
except ImportError: except ImportError:
import ConfigParser as cp import ConfigParser as cp
random_seed = 7 random_seed = 7
logger = logging.getLogger() logger = logging.getLogger()
format = "%(asctime)s - %(name)s - %(levelname)s -%(filename)s-%(lineno)4d -%(message)s" format = "%(asctime)s - %(name)s - %(levelname)s -%(filename)s-%(lineno)4d -%(message)s"
# format = "%(levelname)8s: %(asctime)s: %(filename)s:%(lineno)4d %(message)s" # format = "%(levelname)8s: %(asctime)s: %(filename)s:%(lineno)4d %(message)s"
logging.basicConfig(format=format) logging.basicConfig(format=format)
logger.setLevel(logging.INFO) logger.setLevel(logging.INFO)
logger = logging.getLogger('Paddle-DDC') logger = logging.getLogger('Paddle-DDC')
def str2bool(v): def str2bool(v):
...@@ -77,6 +76,7 @@ class ArgumentGroup(object): ...@@ -77,6 +76,7 @@ class ArgumentGroup(object):
Arguments: Arguments:
object {[type]} -- [description] object {[type]} -- [description]
""" """
def __init__(self, parser, title, des): def __init__(self, parser, title, des):
self._group = parser.add_argument_group(title=title, description=des) self._group = parser.add_argument_group(title=title, description=des)
...@@ -107,6 +107,7 @@ class DataReader(object): ...@@ -107,6 +107,7 @@ class DataReader(object):
Returns: Returns:
[type] -- [description] [type] -- [description]
""" """
def __init__(self, char_vocab, intent_dict, max_len): def __init__(self, char_vocab, intent_dict, max_len):
self._char_vocab = char_vocab self._char_vocab = char_vocab
self._intent_dict = intent_dict self._intent_dict = intent_dict
...@@ -115,10 +116,10 @@ class DataReader(object): ...@@ -115,10 +116,10 @@ class DataReader(object):
self.all_data = [] self.all_data = []
self.max_len = max_len self.max_len = max_len
self.padding_id = 0 self.padding_id = 0
def _get_num_examples(self): def _get_num_examples(self):
return len(self.all_data) return len(self.all_data)
def prepare_data(self, data_path, batch_size, mode): def prepare_data(self, data_path, batch_size, mode):
""" """
prepare data prepare data
...@@ -128,12 +129,17 @@ class DataReader(object): ...@@ -128,12 +129,17 @@ class DataReader(object):
# word_dict_path), "The given word dictionary dose not exist." # word_dict_path), "The given word dictionary dose not exist."
assert os.path.exists(data_path), "The given data file does not exist." assert os.path.exists(data_path), "The given data file does not exist."
if mode == "train": if mode == "train":
train_reader = fluid.io.batch(paddle.reader.shuffle(self.data_reader(data_path, self.max_len, shuffle=True), train_reader = fluid.io.batch(
buf_size=batch_size * 100), batch_size) paddle.reader.shuffle(
self.data_reader(
data_path, self.max_len, shuffle=True),
buf_size=batch_size * 100),
batch_size)
# train_reader = fluid.io.batch(self.data_reader(data_path), batch_size) # train_reader = fluid.io.batch(self.data_reader(data_path), batch_size)
return train_reader return train_reader
else: else:
test_reader = fluid.io.batch(self.data_reader(data_path, self.max_len), batch_size) test_reader = fluid.io.batch(
self.data_reader(data_path, self.max_len), batch_size)
return test_reader return test_reader
def data_reader(self, file_path, max_len, shuffle=False): def data_reader(self, file_path, max_len, shuffle=False):
...@@ -141,7 +147,7 @@ class DataReader(object): ...@@ -141,7 +147,7 @@ class DataReader(object):
Convert query into id list Convert query into id list
use fixed voc use fixed voc
""" """
for line in codecs.open(file_path, "r", encoding="utf8"): for line in codecs.open(file_path, "r", encoding="utf8"):
line = line.strip() line = line.strip()
if isinstance(line, six.binary_type): if isinstance(line, six.binary_type):
...@@ -150,7 +156,8 @@ class DataReader(object): ...@@ -150,7 +156,8 @@ class DataReader(object):
char_id_list = list(map(lambda x: 0 if x not in self._char_vocab else int(self._char_vocab[x]), \ char_id_list = list(map(lambda x: 0 if x not in self._char_vocab else int(self._char_vocab[x]), \
list(query))) list(query)))
if len(char_id_list) < max_len: if len(char_id_list) < max_len:
char_id_list.extend([self.padding_id] * (max_len - len(char_id_list))) char_id_list.extend([self.padding_id] *
(max_len - len(char_id_list)))
char_id_list = char_id_list[:max_len] char_id_list = char_id_list[:max_len]
intent_id_list = [self.padding_id] * self.intent_size intent_id_list = [self.padding_id] * self.intent_size
for item in intent.split('\2'): for item in intent.split('\2'):
...@@ -159,6 +166,7 @@ class DataReader(object): ...@@ -159,6 +166,7 @@ class DataReader(object):
if shuffle: if shuffle:
random.seed(random_seed) random.seed(random_seed)
random.shuffle(self.all_data) random.shuffle(self.all_data)
def reader(): def reader():
""" """
reader reader
...@@ -166,6 +174,7 @@ class DataReader(object): ...@@ -166,6 +174,7 @@ class DataReader(object):
for char_id_list, intent_id_list in self.all_data: for char_id_list, intent_id_list in self.all_data:
# print char_id_list, intent_id # print char_id_list, intent_id
yield char_id_list, intent_id_list yield char_id_list, intent_id_list
return reader return reader
...@@ -178,6 +187,7 @@ class DataProcesser(object): ...@@ -178,6 +187,7 @@ class DataProcesser(object):
Returns: Returns:
[type] -- [description] [type] -- [description]
""" """
@staticmethod @staticmethod
def read_dict(filename): def read_dict(filename):
""" """
...@@ -211,7 +221,7 @@ class DataProcesser(object): ...@@ -211,7 +221,7 @@ class DataProcesser(object):
char_dict = {} char_dict = {}
intent_dict = {} intent_dict = {}
# readfile # readfile
for line in codecs.open(filename): for line in codecs.open(filename):
line = line.strip() line = line.strip()
if isinstance(line, six.binary_type): if isinstance(line, six.binary_type):
line = line.strip().decode("utf8", errors="ignore") line = line.strip().decode("utf8", errors="ignore")
...@@ -227,7 +237,8 @@ class DataProcesser(object): ...@@ -227,7 +237,8 @@ class DataProcesser(object):
intent_dict[intent] = 0 intent_dict[intent] = 0
intent_dict[intent] += 1 intent_dict[intent] += 1
# save char dict # save char dict
with codecs.open("%s/char.dict" % save_dir, "w", encoding="utf8") as f_out: with codecs.open(
"%s/char.dict" % save_dir, "w", encoding="utf8") as f_out:
f_out.write("PAD\0020\n") f_out.write("PAD\0020\n")
f_out.write("OOV\0021\n") f_out.write("OOV\0021\n")
char_id = 2 char_id = 2
...@@ -238,7 +249,8 @@ class DataProcesser(object): ...@@ -238,7 +249,8 @@ class DataProcesser(object):
f_out.write("%s\002%d\n" % (key, char_id)) f_out.write("%s\002%d\n" % (key, char_id))
char_id += 1 char_id += 1
# save intent dict # save intent dict
with codecs.open("%s/domain.dict" % save_dir, "w", encoding="utf8") as f_out: with codecs.open(
"%s/domain.dict" % save_dir, "w", encoding="utf8") as f_out:
f_out.write("SYS_OTHER\0020\n") f_out.write("SYS_OTHER\0020\n")
intent_id = 1 intent_id = 1
for key, value in intent_dict.items(): for key, value in intent_dict.items():
...@@ -247,7 +259,6 @@ class DataProcesser(object): ...@@ -247,7 +259,6 @@ class DataProcesser(object):
key = key.encode("utf8") key = key.encode("utf8")
f_out.write("%s\002%d\n" % (key, intent_id)) f_out.write("%s\002%d\n" % (key, intent_id))
intent_id += 1 intent_id += 1
class ConfigReader(object): class ConfigReader(object):
...@@ -282,49 +293,13 @@ class ConfigReader(object): ...@@ -282,49 +293,13 @@ class ConfigReader(object):
return flow_data return flow_data
def init_pretraining_params(exe,
pretraining_params_path,
main_program,
use_fp16=False):
"""load params of pretrained model, NOT including moment, learning_rate"""
assert os.path.exists(pretraining_params_path
), "[%s] cann't be found." % pretraining_params_path
def _existed_params(var):
if not isinstance(var, fluid.framework.Parameter):
return False
return os.path.exists(os.path.join(pretraining_params_path, var.name))
fluid.io.load_vars(
exe,
pretraining_params_path,
main_program=main_program,
predicate=_existed_params)
print("Load pretraining parameters from {}.".format(
pretraining_params_path))
def init_checkpoint(exe, init_checkpoint_path, main_program): def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
Init CheckPoint Init CheckPoint
""" """
assert os.path.exists( fluid.load(main_program, init_checkpoint_path, exe)
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path print("Load model from {}".format(init_checkpoint_path))
def existed_persitables(var):
"""
If existed presitabels
"""
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars(
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print ("Load model from {}".format(init_checkpoint_path))
def print_arguments(args): def print_arguments(args):
""" """
...@@ -350,5 +325,3 @@ def check_version(version='1.6.0'): ...@@ -350,5 +325,3 @@ def check_version(version='1.6.0'):
except Exception as e: except Exception as e:
logger.error(err) logger.error(err)
sys.exit(1) sys.exit(1)
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Emotion Detection Task Emotion Detection Task
""" """
...@@ -38,9 +37,7 @@ import reader ...@@ -38,9 +37,7 @@ import reader
import utils import utils
def create_model(args, def create_model(args, num_labels, is_prediction=False):
num_labels,
is_prediction=False):
""" """
Create Model for Emotion Detection Create Model for Emotion Detection
""" """
...@@ -77,10 +74,17 @@ def create_model(args, ...@@ -77,10 +74,17 @@ def create_model(args,
raise ValueError("Unknown network type!") raise ValueError("Unknown network type!")
if is_prediction: if is_prediction:
probs = network(data, seq_len, None, args.vocab_size, class_dim=num_labels, is_prediction=True) probs = network(
data,
seq_len,
None,
args.vocab_size,
class_dim=num_labels,
is_prediction=True)
return loader, probs, [data.name, seq_len.name] return loader, probs, [data.name, seq_len.name]
avg_loss, probs = network(data, seq_len, label, args.vocab_size, class_dim=num_labels) avg_loss, probs = network(
data, seq_len, label, args.vocab_size, class_dim=num_labels)
num_seqs = fluid.layers.create_tensor(dtype='int64') num_seqs = fluid.layers.create_tensor(dtype='int64')
accuracy = fluid.layers.accuracy(input=probs, label=label, total=num_seqs) accuracy = fluid.layers.accuracy(input=probs, label=label, total=num_seqs)
return loader, avg_loss, accuracy, num_seqs return loader, avg_loss, accuracy, num_seqs
...@@ -142,9 +146,10 @@ def main(args): ...@@ -142,9 +146,10 @@ def main(args):
exe = fluid.Executor(place) exe = fluid.Executor(place)
task_name = args.task_name.lower() task_name = args.task_name.lower()
processor = reader.EmoTectProcessor(data_dir=args.data_dir, processor = reader.EmoTectProcessor(
vocab_path=args.vocab_path, data_dir=args.data_dir,
random_seed=args.random_seed) vocab_path=args.vocab_path,
random_seed=args.random_seed)
#num_labels = len(processor.get_labels()) #num_labels = len(processor.get_labels())
num_labels = args.num_labels num_labels = args.num_labels
...@@ -173,9 +178,7 @@ def main(args): ...@@ -173,9 +178,7 @@ def main(args):
with fluid.program_guard(train_program, startup_prog): with fluid.program_guard(train_program, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
train_loader, loss, accuracy, num_seqs = create_model( train_loader, loss, accuracy, num_seqs = create_model(
args, args, num_labels=num_labels, is_prediction=False)
num_labels=num_labels,
is_prediction=False)
sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr) sgd_optimizer = fluid.optimizer.Adagrad(learning_rate=args.lr)
sgd_optimizer.minimize(loss) sgd_optimizer.minimize(loss)
...@@ -189,37 +192,27 @@ def main(args): ...@@ -189,37 +192,27 @@ def main(args):
if args.do_val: if args.do_val:
if args.do_train: if args.do_train:
test_data_generator = processor.data_generator( test_data_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size, phase='dev', epoch=1)
phase='dev',
epoch=1)
else: else:
test_data_generator = processor.data_generator( test_data_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size, phase='test', epoch=1)
phase='test',
epoch=1)
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
test_loader, loss, accuracy, num_seqs = create_model( test_loader, loss, accuracy, num_seqs = create_model(
args, args, num_labels=num_labels, is_prediction=False)
num_labels=num_labels,
is_prediction=False)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
if args.do_infer: if args.do_infer:
infer_data_generator = processor.data_generator( infer_data_generator = processor.data_generator(
batch_size=args.batch_size, batch_size=args.batch_size, phase='infer', epoch=1)
phase='infer',
epoch=1)
test_prog = fluid.Program() test_prog = fluid.Program()
with fluid.program_guard(test_prog, startup_prog): with fluid.program_guard(test_prog, startup_prog):
with fluid.unique_name.guard(): with fluid.unique_name.guard():
infer_loader, probs, _ = create_model( infer_loader, probs, _ = create_model(
args, args, num_labels=num_labels, is_prediction=True)
num_labels=num_labels,
is_prediction=True)
test_prog = test_prog.clone(for_test=True) test_prog = test_prog.clone(for_test=True)
exe.run(startup_prog) exe.run(startup_prog)
...@@ -292,8 +285,9 @@ def main(args): ...@@ -292,8 +285,9 @@ def main(args):
time_begin = time.time() time_begin = time.time()
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.save_checkpoint_dir, "step_" + str(steps)) save_path = os.path.join(args.save_checkpoint_dir,
fluid.io.save_persistables(exe, save_path, train_program) "step_" + str(steps))
fluid.save(train_program, save_path)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
# evaluate on dev set # evaluate on dev set
...@@ -306,11 +300,11 @@ def main(args): ...@@ -306,11 +300,11 @@ def main(args):
print("final step: %d " % steps) print("final step: %d " % steps)
if args.do_val: if args.do_val:
evaluate(test_exe, test_prog, test_loader, evaluate(test_exe, test_prog, test_loader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name], "dev")
"dev")
save_path = os.path.join(args.save_checkpoint_dir, "step_" + str(steps)) save_path = os.path.join(args.save_checkpoint_dir,
fluid.io.save_persistables(exe, save_path, train_program) "step_" + str(steps))
fluid.save(train_program, save_path)
train_loader.reset() train_loader.reset()
break break
...@@ -334,15 +328,12 @@ def main(args): ...@@ -334,15 +328,12 @@ def main(args):
if not args.do_train and args.do_val: if not args.do_train and args.do_val:
print("Final test result:") print("Final test result:")
evaluate(test_exe, test_prog, test_loader, evaluate(test_exe, test_prog, test_loader,
[loss.name, accuracy.name, num_seqs.name], [loss.name, accuracy.name, num_seqs.name], "test")
"test")
# infer # infer
if args.do_infer: if args.do_infer:
print("Final infer result:") print("Final infer result:")
infer(test_exe, test_prog, infer_loader, infer(test_exe, test_prog, infer_loader, [probs.name], "infer")
[probs.name],
"infer")
def get_cards(): def get_cards():
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
Emotion Detection Task, based on ERNIE Emotion Detection Task, based on ERNIE
""" """
...@@ -350,7 +349,7 @@ def main(args): ...@@ -350,7 +349,7 @@ def main(args):
if steps % args.save_steps == 0: if steps % args.save_steps == 0:
save_path = os.path.join(args.save_checkpoint_dir, "step_" + str(steps)) save_path = os.path.join(args.save_checkpoint_dir, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(train_program, save_path)
if steps % args.validation_steps == 0: if steps % args.validation_steps == 0:
# evaluate dev set # evaluate dev set
...@@ -369,7 +368,7 @@ def main(args): ...@@ -369,7 +368,7 @@ def main(args):
except fluid.core.EOFException: except fluid.core.EOFException:
save_path = os.path.join(args.save_checkpoint_dir, "step_" + str(steps)) save_path = os.path.join(args.save_checkpoint_dir, "step_" + str(steps))
fluid.io.save_persistables(exe, save_path, train_program) fluid.save(train_program, save_path)
train_pyreader.reset() train_pyreader.reset()
break break
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
""" """
EmoTect utilities. EmoTect utilities.
""" """
...@@ -29,27 +28,13 @@ import paddle ...@@ -29,27 +28,13 @@ import paddle
import paddle.fluid as fluid import paddle.fluid as fluid
import numpy as np import numpy as np
def init_checkpoint(exe, init_checkpoint_path, main_program): def init_checkpoint(exe, init_checkpoint_path, main_program):
""" """
Init CheckPoint Init CheckPoint
""" """
assert os.path.exists(
init_checkpoint_path), "[%s] cann't be found." % init_checkpoint_path
def existed_persitables(var):
"""
If existed presitabels
"""
if not fluid.io.is_persistable(var):
return False
return os.path.exists(os.path.join(init_checkpoint_path, var.name))
fluid.io.load_vars( fluid.load(main_program, init_checkpoint_path, exe)
exe,
init_checkpoint_path,
main_program=main_program,
predicate=existed_persitables)
print("Load model from {}".format(init_checkpoint_path))
def word2id(word_dict, query): def word2id(word_dict, query):
...@@ -57,8 +42,10 @@ def word2id(word_dict, query): ...@@ -57,8 +42,10 @@ def word2id(word_dict, query):
Convert word sequence into id list Convert word sequence into id list
""" """
unk_id = len(word_dict) unk_id = len(word_dict)
wids = [word_dict[w] if w in word_dict else unk_id wids = [
for w in query.strip().split(" ")] word_dict[w] if w in word_dict else unk_id
for w in query.strip().split(" ")
]
return wids return wids
......
...@@ -114,7 +114,6 @@ loss, data_list = model(dict_dim, emb_dim) ...@@ -114,7 +114,6 @@ loss, data_list = model(dict_dim, emb_dim)
sgd = fluid.optimizer.SGD(learning_rate=args.base_lr) sgd = fluid.optimizer.SGD(learning_rate=args.base_lr)
sgd.minimize(loss) sgd.minimize(loss)
place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace() place = fluid.CUDAPlace(0) if args.use_gpu else fluid.CPUPlace()
feeder = fluid.DataFeeder(feed_list=data_list, place=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
for batch_id in range(100): for batch_id in range(100):
......
...@@ -136,17 +136,19 @@ def start_train(args): ...@@ -136,17 +136,19 @@ def start_train(args):
startup_program = fluid.default_startup_program() startup_program = fluid.default_startup_program()
loop_program = fluid.default_main_program() loop_program = fluid.default_main_program()
feeder = fluid.DataFeeder(feed_list=all_slots, place=place)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(startup_program) exe.run(startup_program)
loader = fluid.io.DataLoader.from_generator(
feed_list=all_slots, capacity=10000, iterable=True)
loader.set_sample_list_generator(train_reader, places=place)
total_time = 0 total_time = 0
ce_info = [] ce_info = []
for pass_id in range(args.epochs): for pass_id in range(args.epochs):
start_time = time.time() start_time = time.time()
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(loader()):
loss_val, correct_val = exe.run(loop_program, loss_val, correct_val = exe.run(loop_program,
feed=feeder.feed(data), feed=data,
fetch_list=[avg_cost, correct]) fetch_list=[avg_cost, correct])
logger.info("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}" logger.info("TRAIN --> pass: {} batch_id: {} avg_cost: {}, acc: {}"
.format(pass_id, batch_id, loss_val, .format(pass_id, batch_id, loss_val,
......
...@@ -87,9 +87,12 @@ def train(args): ...@@ -87,9 +87,12 @@ def train(args):
optimizer.minimize(avg_cost) optimizer.minimize(avg_cost)
data_list = [var.name for var in train_input_data] data_list = [var.name for var in train_input_data]
feeder = fluid.DataFeeder(feed_list=data_list, place=place) print(data_list)
exe = fluid.Executor(place) exe = fluid.Executor(place)
exe.run(fluid.default_startup_program()) exe.run(fluid.default_startup_program())
loader = fluid.io.DataLoader.from_generator(
feed_list=train_input_data, capacity=10000, iterable=True)
loader.set_sample_list_generator(train_reader, places=place)
if parallel: if parallel:
train_exe = fluid.ParallelExecutor( train_exe = fluid.ParallelExecutor(
use_cuda=use_cuda, loss_name=avg_cost.name) use_cuda=use_cuda, loss_name=avg_cost.name)
...@@ -103,10 +106,10 @@ def train(args): ...@@ -103,10 +106,10 @@ def train(args):
print("epoch_%d start" % epoch_idx) print("epoch_%d start" % epoch_idx)
t0 = time.time() t0 = time.time()
i = 0 i = 0
for batch_id, data in enumerate(train_reader()): for batch_id, data in enumerate(loader()):
i += 1 i += 1
loss_val, correct_val = train_exe.run( loss_val, correct_val = train_exe.run(
feed=feeder.feed(data), fetch_list=[avg_cost.name, acc.name]) feed=data, fetch_list=[avg_cost.name, acc.name])
ce_info.append(float(np.mean(correct_val)) / args.batch_size) ce_info.append(float(np.mean(correct_val)) / args.batch_size)
if i % args.print_batch == 0: if i % args.print_batch == 0:
logger.info( logger.info(
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册