Skip to content

  • 体验新版
    • 正在加载...
  • 登录
  • PaddlePaddle
  • Paddle
  • Issue
  • #12572

P
Paddle
  • 项目概览

PaddlePaddle / Paddle
大约 2 年 前同步成功

通知 2325
Star 20933
Fork 5424
  • 代码
    • 文件
    • 提交
    • 分支
    • Tags
    • 贡献者
    • 分支图
    • Diff
  • Issue 1423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
  • Wiki 0
    • Wiki
  • 分析
    • 仓库
    • DevOps
  • 项目成员
  • Pages
P
Paddle
  • 项目概览
    • 项目概览
    • 详情
    • 发布
  • 仓库
    • 仓库
    • 文件
    • 提交
    • 分支
    • 标签
    • 贡献者
    • 分支图
    • 比较
  • Issue 1,423
    • Issue 1,423
    • 列表
    • 看板
    • 标记
    • 里程碑
  • 合并请求 543
    • 合并请求 543
  • Pages
  • 分析
    • 分析
    • 仓库分析
    • DevOps
  • Wiki 0
    • Wiki
  • 成员
    • 成员
  • 收起侧边栏
  • 动态
  • 分支图
  • 创建新Issue
  • 提交
  • Issue看板
已关闭
开放中
Opened 8月 07, 2018 by saxon_zh@saxon_zhGuest

Cannot run operator on place CUDAPlace(0) at [/paddle/paddle/fluid/framework/operator.cc:104]

Created by: Angus07

训练脚本

import os
import gzip
import paddle
import paddle.fluid as fluid
import read_short as reader
import utils
from functools import partial
import network_conf_short5 as network
os.environ["CUDA_VISIBLE_DEVICES"] = '8'
GPU_CNT = 1
BUFFER_SIZE = 12800
step = 0


def load_initial_model(model_path, parameters):
    """ Initalize parameters in the network from a trained model.
    This is useful in resuming the training from previously saved models.
    Arguments:
        - model_path:    The path of a trained model.
        - parameters:    The parameters in a network which will be initialized
                         from the specified model.
    """
    with gzip.open(model_path, "rb") as f:
        parameters.init_from_tar(f)


def train(topology,
          train_data_dir=None,
          test_data_dir=None,
          word_dict_path=None,
          l1_dict_path=None,
          label_dict_path=None,
          model_save_dir="models",
          batch_size=128,
          num_passes=20):
    """
    train dnn model
    """
    place = fluid.CUDAPlace(0)
    if not os.path.exists(model_save_dir):
        os.mkdir(model_save_dir)
    if word_dict_path is None or not os.path.exists(word_dict_path):
        utils.logger.info(("word dictionary is not given, the dictionary "
                           "is automatically built from the training data."))

        utils.build_dict(
            data_dir=train_data_dir,
            save_path=word_dict_path,
            use_col=0,
            cutoff_fre=10,
            insert_extra_words=["<UNK>"])
    if not os.path.exists(label_dict_path):
        utils.logger.info(("label dictionary is not given, the dictionary "
                           "is automatically built from the training data."))
        utils.build_dict(
            data_dir=train_data_dir, save_path=label_dict_path, use_col=2)
    word_dict = utils.load_dict(word_dict_path)
    #l1_dict is a feature of topic tagger task
    l1_dict = utils.load_dict(l1_dict_path)
    lbl_dict = utils.load_dict(label_dict_path)
    class_num = len(lbl_dict)
    utils.logger.info("class number is : %d." % (len(lbl_dict)))
    train_reader = paddle.batch(
        paddle.reader.shuffle(paddle.reader.buffered(
            reader.train_reader(train_data_dir, word_dict, l1_dict, lbl_dict),
            BUFFER_SIZE), buf_size=BUFFER_SIZE),
        batch_size=batch_size)
    if test_data_dir is not None:
        test_reader = paddle.batch(
            paddle.reader.shuffle(
                reader.train_reader(test_data_dir, word_dict, l1_dict, lbl_dict),
                buf_size=BUFFER_SIZE),
            batch_size=batch_size)
    else:
        test_reader = None
    dict_dim = len(word_dict)
    utils.logger.info("length of word dictionary is : %d." % (dict_dim))
    def optimizer_func():
        return fluid.optimizer.Adam(learning_rate=1e-3, regularization=fluid.regularizer.L2DecayRegularizer(8e-4))
    # create trainer
    trainer = fluid.Trainer(
        train_func=partial(topology, dict_dim, class_num),
        place=place,
        optimizer_func=optimizer_func)
    feed_order = ["title1", "title2", "title3", "l1" , "label"]
    def _event_handler(event):
        """
        Define end batch and end pass event handler
        """
        global step
        if isinstance(event,  fluid.EndStepEvent):
            if step % 100 == 0:
                utils.logger.info("Pass %d, Batch %d, Step %d, Cost %f\n" %
                    (event.pass_id, event.batch_id, step, event.cost))

            if step % 10000 == 0:
                if test_reader is not None:
                    result = trainer.test(reader=test_reader, feed_order =feed_order)
                    utils.logger.info("Test at Pass %d, Step %d, Cost %f\n" %
                                (event.pass_id, step, result.cost))
                with gzip.open(
                        os.path.join(model_save_dir, "dnn_params_step_%d.tar.gz" %
                                     step), "w") as f:
                    trainer.save_params(model_save_dir)
            step += 1

    trainer.train(
        num_passes,
        reader=train_reader,
        event_handler=_event_handler,
        feed_order=feed_order,
        )
    utils.logger.info("Training has finished.")

def main(args):
    """
    main
    """
    if args.nn_type == "cnn":
        topology = network.convolution_net
    train(
        topology=topology,
        train_data_dir=args.train_data_dir,
        test_data_dir=args.test_data_dir,
        word_dict_path=args.word_dict,
        label_dict_path=args.label_dict,
        l1_dict_path="../conf/level_1_tag_25_raw",
        batch_size=args.batch_size,
        num_passes=args.num_passes,
        model_save_dir=args.model_save_dir)


if __name__ == "__main__":
    args = utils.parse_train_cmd()
    if args.train_data_dir is not None:
        assert args.word_dict and args.label_dict, (
            "the parameter train_data_dir, word_dict_path, and label_dict_path "
            "should be set at the same time.")
    main(args)

错误信息: Traceback (most recent call last): File "train_short5.py", line 193, in main(args) File "train_short5.py", line 184, in main model_save_dir=args.model_save_dir) File "train_short5.py", line 104, in train optimizer_func=optimizer_func) File "/home/du/chenliangyu/python27-gcc482/lib/python2.7/site-packages/paddle/fluid/trainer.py", line 284, in init exe.run(self.startup_program) File "/home/du/chenliangyu/python27-gcc482/lib/python2.7/site-packages/paddle/fluid/executor.py", line 443, in run self.executor.run(program.desc, scope, 0, True, True) paddle.fluid.core.EnforceNotMet: Cannot run operator on place CUDAPlace(0) at [/paddle/paddle/fluid/framework/operator.cc:104]

网络结构:

"""
Define the network
"""

import paddle.fluid as fluid
import sys

__all__ = ["fc_net", "convolution_net"]

LAYER1_SIZE = 896
LAYER2_SIZE = 448
LAYER3_SIZE = 224


def convolution_net(dict_dim, class_dim, emb_dim=128,
                    hid_dim=128, is_infer=False):
    data1 = fluid.layers.data(
        name="title1", shape=[1], dtype="int64", lod_level=1)
    data2 = fluid.layers.data(
        name="title2", shape=[1], dtype="int64", lod_level=1)
    data3 = fluid.layers.data(
        name="title3", shape=[1], dtype="int64", lod_level=1)
    data4 = fluid.layers.data(
        name="l1", shape=[1], dtype="int64", lod_level=1)
    if not is_infer:
        #lbl = paddle.layer.data("label", paddle.data_type.dense_vector(class_num, int))
        #label = fluid.layers.data(name="label", shape=[class_dim], dtype="int64")
        label = fluid.layers.data(name="label", shape=[class_dim], dtype="int64")

    # define the embedding layer
    emb1 = fluid.layers.embedding(
        input=data1, size=[dict_dim, emb_dim], is_sparse=True, param_attr=fluid.ParamAttr(name='emb'))
    emb2 = fluid.layers.embedding(
        input=data2, size=[dict_dim, emb_dim], is_sparse=True,param_attr=fluid.ParamAttr(name='emb'))
    emb3 = fluid.layers.embedding(
        input=data3, size=[dict_dim, emb_dim], is_sparse=True,param_attr=fluid.ParamAttr(name='emb'))
    emb4 = fluid.layers.embedding(
        input=data4, size=[27, emb_dim], is_sparse=True,param_attr=fluid.ParamAttr(name='l1_emb'))
    # max pooling to reduce the input sequence into a vector (non-sequence)
    seq_pool1  = fluid.layers.sequence_pool(input=emb1, pool_type='average')
    seq_pool2  = fluid.layers.sequence_pool(input=emb2, pool_type='average')
    seq_pool3  = fluid.layers.sequence_pool(input=emb3, pool_type='average')
    seq_pool4  = fluid.layers.sequence_pool(input=emb4, pool_type='average')

    # convolution layers with max pooling
    conv_1_win3 = fluid.nets.sequence_conv_pool(
        input=emb1,
        num_filters=hid_dim,
        filter_size=3,
        act="relu",
        pool_type="max")
    conv_1_win4 = fluid.nets.sequence_conv_pool(
        input=emb1,
        num_filters=hid_dim,
        filter_size=4,
        act="relu",
        pool_type="max")
    conv_1_win5 = fluid.nets.sequence_conv_pool(
        input=emb1,
        num_filters=hid_dim,
        filter_size=5,
        act="relu",
        pool_type="max")
    conv_2_win3 = fluid.nets.sequence_conv_pool(
        input=emb2,
        num_filters=hid_dim,
        filter_size=3,
        act="relu",
        pool_type="max")
    conv_2_win4 = fluid.nets.sequence_conv_pool(
        input=emb2,
        num_filters=hid_dim,
        filter_size=4,
        act="relu",
        pool_type="max")
    conv_2_win5 = fluid.nets.sequence_conv_pool(
        input=emb2,
        num_filters=hid_dim,
        filter_size=5,
        act="relu",
        pool_type="max")
    concat_vec = fluid.layers.concat(
        input=[conv_1_win3, conv_1_win4, conv_1_win5,conv_2_win3, conv_2_win4, conv_2_win5,
                                           seq_pool1, seq_pool2, seq_pool3
                                          , seq_pool4], axis=1)
    print >> sys.stderr, concat_vec
    bn = fluid.layers.batch_norm(input=concat_vec, act='relu')
    prediction = fluid.layers.fc(
        bn, size=2000, act="relu")
    #hidden = paddle.layer.dropout(input=hidden, dropout_rate=0.5)
    prob = fluid.layers.fc(
        input=prediction, size=class_dim,act='sigmoid')
    cost = fluid.layers.cross_entropy(input=prob, label=label, soft_label= True)
    cost = fluid.layers.mean(x=cost)
    if is_infer:
        return prob
    else:
        return [cost, prob, label]
指派人
分配到
无
里程碑
无
分配里程碑
工时统计
无
截止日期
无
标识: paddlepaddle/Paddle#12572
渝ICP备2023009037号

京公网安备11010502055752号

网络110报警服务 Powered by GitLab CE v13.7
开源知识
Git 入门 Pro Git 电子书 在线学 Git
Markdown 基础入门 IT 技术知识开源图谱
帮助
使用手册 反馈建议 博客
《GitCode 隐私声明》 《GitCode 服务条款》 关于GitCode
Powered by GitLab CE v13.7