提交 6bc92775 编写于 作者: Q qingqing01 提交者: GitHub

Merge branch 'develop' into row_conv

......@@ -453,6 +453,14 @@ eos
.. autoclass:: paddle.v2.layer.eos
:noindex:
Miscs
=====
dropout
--------------
.. autoclass:: paddle.v2.layer.dropout
:noindex:
Activation with learnable parameter
===================================
......@@ -460,4 +468,3 @@ prelu
--------
.. autoclass:: paddle.v2.layer.prelu
:noindex:
......@@ -125,11 +125,3 @@ simple_attention
:members: simple_attention
:noindex:
Miscs
=====
dropout_layer
--------------
.. automodule:: paddle.v2.networks
:members: dropout_layer
:noindex:
......@@ -8,6 +8,7 @@ add_subdirectory(gserver)
add_subdirectory(pserver)
add_subdirectory(trainer)
add_subdirectory(scripts)
add_subdirectory(strings)
# Do not build go directory until go cmake is working smoothly.
# if(CMAKE_Go_COMPILER)
......
......@@ -41,6 +41,7 @@ SET(SWIG_MODULE_swig_paddle_EXTRA_DEPS
paddle_network
paddle_proto
${external_project_dependencies}
${RDMA_LIBS}
)
IF(APPLE)
......@@ -73,6 +74,7 @@ SWIG_LINK_LIBRARIES(swig_paddle
${CMAKE_DL_LIBS}
${EXTERNAL_LIBS}
${CMAKE_THREAD_LIBS_INIT}
${RDMA_LD_FLAGS}
${START_END}
)
......
cc_library(stringpiece SRCS stringpiece.cc)
cc_test(stringpiece_test SRCS stringpiece_test.cc DEPS stringpiece glog gflags)
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/strings/stringpiece.h"
#include <string.h>
#include <algorithm>
#include <iosfwd>
#include <stdexcept>
namespace paddle {
StringPiece::StringPiece() : data_(NULL), size_(0) {}
StringPiece::StringPiece(const char* d, size_t n) : data_(d), size_(n) {
if (d == NULL && n != 0)
throw std::invalid_argument(
"StringPiece requires len to be 0 for NULL data");
}
StringPiece::StringPiece(const char* s) : data_(s) {
size_ = (s == NULL) ? 0 : strlen(s);
}
StringPiece::StringPiece(const std::string& s)
: data_(s.data()), size_(s.size()) {}
char StringPiece::operator[](size_t n) const {
if (n >= len())
throw std::invalid_argument("index out of StringPiece length");
return data_[n];
}
int Compare(StringPiece a, StringPiece b) {
const size_t min_len = (a.len() < b.len()) ? a.len() : b.len();
int r = memcmp(a.data(), b.data(), min_len);
if (r == 0) {
if (a.len() < b.len())
return -1;
else if (a.len() > b.len())
return 1;
}
return r;
}
bool operator==(StringPiece x, StringPiece y) {
return ((x.len() == y.len()) &&
(x.data() == y.data() || memcmp(x.data(), y.data(), x.len()) == 0));
}
bool operator!=(StringPiece x, StringPiece y) { return !(x == y); }
bool operator<(StringPiece x, StringPiece y) { return Compare(x, y) < 0; }
bool operator>(StringPiece x, StringPiece y) { return Compare(x, y) > 0; }
bool operator<=(StringPiece x, StringPiece y) { return Compare(x, y) <= 0; }
bool operator>=(StringPiece x, StringPiece y) { return Compare(x, y) >= 0; }
bool HasPrefix(StringPiece s, StringPiece x) {
return ((s.len() >= x.len()) && (memcmp(s.data(), x.data(), x.len()) == 0));
}
bool HasSuffix(StringPiece s, StringPiece x) {
return ((s.len() >= x.len()) &&
(memcmp(s.data() + (s.len() - x.len()), x.data(), x.len()) == 0));
}
StringPiece SkipPrefix(StringPiece s, size_t n) {
if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length");
return StringPiece(s.data() + n, s.len() - n);
}
StringPiece SkipSuffix(StringPiece s, size_t n) {
if (n > s.len())
throw std::invalid_argument("Skip distance larger than StringPiece length");
return StringPiece(s.data(), s.len() - n);
}
StringPiece TrimPrefix(StringPiece s, StringPiece x) {
return HasPrefix(s, x) ? SkipPrefix(s, x.len()) : s;
}
StringPiece TrimSuffix(StringPiece s, StringPiece x) {
return HasSuffix(s, x) ? SkipSuffix(s, x.len()) : s;
}
bool Contains(StringPiece s, StringPiece sub) {
return std::search(s.begin(), s.end(), sub.begin(), sub.end()) != s.end();
}
size_t Index(StringPiece s, StringPiece sub) {
auto e = std::search(s.begin(), s.end(), sub.begin(), sub.end());
return e != s.end() ? e - s.data() : StringPiece::npos;
}
size_t Find(StringPiece s, char c, size_t pos) {
if (pos >= s.len()) {
return StringPiece::npos;
}
const char* result =
reinterpret_cast<const char*>(memchr(s.data() + pos, c, s.len() - pos));
return result != nullptr ? result - s.data() : StringPiece::npos;
}
size_t RFind(StringPiece s, char c, size_t pos) {
if (s.len() == 0) return StringPiece::npos;
for (const char* p = s.data() + std::min(pos, s.len() - 1); p >= s.data();
p--) {
if (*p == c) {
return p - s.data();
}
}
return StringPiece::npos;
}
StringPiece SubStr(StringPiece s, size_t pos, size_t n) {
if (pos > s.len()) pos = s.len();
if (n > s.len() - pos) n = s.len() - pos;
return StringPiece(s.data() + pos, n);
}
std::ostream& operator<<(std::ostream& o, StringPiece piece) {
return o << piece.ToString();
}
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#pragma once
#include <string>
namespace paddle {
// StringPiece points into a std::string object but doesn't own the
// string. It is for efficient access to strings. Like Go's string
// type. Not that StringPiece doesn't mutate the underlying string,
// so it is thread-safe given that the underlying string doesn't
// change. Because StringPiece contains a little data members, and
// its syntax is simple as it doesn't own/manage the string, it is
// cheap to construct StringPieces and pass them around.
class StringPiece {
public:
static const size_t npos = static_cast<size_t>(-1);
// We provide non-explicit singleton constructors so users can
// pass in a "const char*" or a "string" wherever a "StringPiece"
// is expected. These contructors ensure that if data_ is NULL,
// size_ is 0.
StringPiece();
StringPiece(const char* d, size_t n);
StringPiece(const char* d);
StringPiece(const std::string& s);
const char* data() const { return data_; }
size_t len() const { return size_; }
char operator[](size_t n) const;
// StringPiece doesn't own the string, so both iterator and const
// iterator are const char* indeed.
typedef const char* const_iterator;
typedef const char* iterator;
iterator begin() const { return data_; }
iterator end() const { return data_ + size_; }
// Return a string that contains the copy of the referenced data.
std::string ToString() const { return std::string(data_, size_); }
private:
const char* data_;
size_t size_;
// Intentionally copyable
};
int Compare(StringPiece a, StringPiece b);
bool operator==(StringPiece x, StringPiece y);
bool operator!=(StringPiece x, StringPiece y);
bool operator<(StringPiece x, StringPiece y);
bool operator>(StringPiece x, StringPiece y);
bool operator<=(StringPiece x, StringPiece y);
bool operator>=(StringPiece x, StringPiece y);
bool HasPrefix(StringPiece s, StringPiece prefix);
bool HasSuffix(StringPiece s, StringPiece suffix);
StringPiece SkipPrefix(StringPiece s, size_t n);
StringPiece SkipSuffix(StringPiece s, size_t n);
// Skip the prefix (or suffix) if it matches with the string.
StringPiece TrimPrefix(StringPiece s, StringPiece prefix);
StringPiece TrimSuffix(StringPiece s, StringPiece suffix);
// Returns if s contains sub. Any s except for empty s contains an
// empty sub.
bool Contains(StringPiece s, StringPiece sub);
// Return the first occurrence of sub in s, or npos. If both s and
// sub is empty, it returns npos; otherwise, if only sub is empty, it
// returns 0.
size_t Index(StringPiece s, StringPiece sub);
// Return the first occurrence of c in s[pos:end], or npos.
size_t Find(StringPiece s, char c, size_t pos);
// Search range is [0..pos] inclusive. If pos == npos, search everything.
size_t RFind(StringPiece s, char c, size_t pos);
StringPiece SubStr(StringPiece s, size_t pos, size_t n);
// allow StringPiece to be logged
std::ostream& operator<<(std::ostream& o, StringPiece piece);
} // namespace paddle
/*
Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
#include "paddle/strings/stringpiece.h"
#include <sstream>
#include "gtest/gtest.h"
TEST(StringPiece, Construct) {
{
paddle::StringPiece s;
EXPECT_EQ(NULL, s.data());
EXPECT_EQ(0U, s.len());
}
{ EXPECT_THROW(paddle::StringPiece s(NULL, 10000U), std::invalid_argument); }
{
paddle::StringPiece s(NULL);
EXPECT_EQ(0U, s.len());
}
{
std::string a;
EXPECT_EQ(0U, a.size());
paddle::StringPiece s(a);
EXPECT_EQ(0U, s.len());
}
}
TEST(StringPiece, CopyAndAssign) {
paddle::StringPiece empty;
EXPECT_EQ(0U, empty.len());
paddle::StringPiece a("hello");
paddle::StringPiece b = a;
EXPECT_EQ(b.len(), strlen("hello"));
EXPECT_EQ(a, b);
std::string storage("hello");
paddle::StringPiece c(storage);
EXPECT_EQ(a, c);
EXPECT_NE(a.data(), c.data());
}
TEST(StringPiece, Compare) {
{
paddle::StringPiece a("hello");
paddle::StringPiece b("world");
EXPECT_TRUE(a != b);
EXPECT_FALSE(a == b);
EXPECT_TRUE(a < b);
EXPECT_TRUE(a <= b);
EXPECT_FALSE(a > b);
EXPECT_FALSE(a >= b);
EXPECT_LT(Compare(a, b), 0);
EXPECT_GT(Compare(b, a), 0);
}
{
paddle::StringPiece a, b;
EXPECT_TRUE(a == b);
EXPECT_FALSE(a != b);
EXPECT_FALSE(a < b);
EXPECT_FALSE(a > b);
EXPECT_TRUE(a <= b);
EXPECT_TRUE(a >= b);
EXPECT_EQ(0, Compare(a, b));
EXPECT_EQ(0, Compare(b, a));
}
}
TEST(StringPiece, ToString) {
{
paddle::StringPiece s;
EXPECT_EQ(std::string(""), s.ToString());
}
{
paddle::StringPiece s(NULL);
EXPECT_EQ(std::string(""), s.ToString());
}
{
paddle::StringPiece s("hello");
EXPECT_EQ(std::string("hello"), s.ToString());
}
}
TEST(StringPiece, HasPrefixSuffix) {
using paddle::HasPrefix;
using paddle::HasSuffix;
{
paddle::StringPiece s;
EXPECT_FALSE(HasPrefix(s, "something"));
EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_FALSE(HasSuffix(s, "something"));
EXPECT_TRUE(HasSuffix(s, ""));
}
{
paddle::StringPiece s("app");
EXPECT_TRUE(HasPrefix(s, ""));
EXPECT_TRUE(HasPrefix(s, "a"));
EXPECT_TRUE(HasPrefix(s, "ap"));
EXPECT_TRUE(HasPrefix(s, "app"));
EXPECT_TRUE(HasSuffix(s, ""));
EXPECT_TRUE(HasSuffix(s, "p"));
EXPECT_TRUE(HasSuffix(s, "pp"));
EXPECT_TRUE(HasSuffix(s, "app"));
}
}
TEST(StringPiece, SkipPrefixSuffix) {
using paddle::SkipPrefix;
using paddle::SkipSuffix;
{
paddle::StringPiece s;
EXPECT_EQ("", SkipPrefix(s, 0));
EXPECT_THROW(SkipPrefix(s, 1), std::invalid_argument);
EXPECT_EQ("", SkipSuffix(s, 0));
EXPECT_THROW(SkipSuffix(s, 1), std::invalid_argument);
}
{
paddle::StringPiece s("app");
EXPECT_EQ("app", SkipPrefix(s, 0));
EXPECT_EQ("pp", SkipPrefix(s, 1));
EXPECT_EQ("p", SkipPrefix(s, 2));
EXPECT_EQ("", SkipPrefix(s, 3));
EXPECT_THROW(SkipPrefix(s, 4), std::invalid_argument);
EXPECT_EQ("app", SkipSuffix(s, 0));
EXPECT_EQ("ap", SkipSuffix(s, 1));
EXPECT_EQ("a", SkipSuffix(s, 2));
EXPECT_EQ("", SkipSuffix(s, 3));
EXPECT_THROW(SkipSuffix(s, 4), std::invalid_argument);
}
}
TEST(StringPiece, TrimPrefixSuffix) {
using paddle::TrimPrefix;
using paddle::TrimSuffix;
{
paddle::StringPiece s;
EXPECT_EQ("", TrimPrefix(s, ""));
EXPECT_EQ("", TrimPrefix(s, "something"));
EXPECT_EQ("", TrimSuffix(s, ""));
EXPECT_EQ("", TrimSuffix(s, "something"));
}
{
paddle::StringPiece s("app");
EXPECT_EQ("app", TrimPrefix(s, ""));
EXPECT_EQ("pp", TrimPrefix(s, "a"));
EXPECT_EQ("p", TrimPrefix(s, "ap"));
EXPECT_EQ("", TrimPrefix(s, "app"));
EXPECT_EQ("app", TrimPrefix(s, "something"));
EXPECT_EQ("app", TrimSuffix(s, ""));
EXPECT_EQ("ap", TrimSuffix(s, "p"));
EXPECT_EQ("a", TrimSuffix(s, "pp"));
EXPECT_EQ("", TrimSuffix(s, "app"));
EXPECT_EQ("app", TrimSuffix(s, "something"));
}
}
TEST(StringPiece, Contains) {
using paddle::Contains;
{
paddle::StringPiece s;
EXPECT_FALSE(Contains(s, ""));
EXPECT_FALSE(Contains(s, "something"));
}
{
paddle::StringPiece s("app");
EXPECT_TRUE(Contains(s, ""));
EXPECT_TRUE(Contains(s, "a"));
EXPECT_TRUE(Contains(s, "p"));
EXPECT_TRUE(Contains(s, "ap"));
EXPECT_TRUE(Contains(s, "pp"));
EXPECT_TRUE(Contains(s, "app"));
EXPECT_FALSE(Contains(s, "something"));
}
}
TEST(StringPiece, Index) {
using paddle::Index;
auto npos = paddle::StringPiece::npos;
{
paddle::StringPiece s;
EXPECT_EQ(npos, Index(s, ""));
EXPECT_EQ(npos, Index(s, "something"));
}
{
paddle::StringPiece s("app");
EXPECT_EQ(0U, Index(s, ""));
EXPECT_EQ(0U, Index(s, "a"));
EXPECT_EQ(1U, Index(s, "p"));
EXPECT_EQ(0U, Index(s, "ap"));
EXPECT_EQ(1U, Index(s, "pp"));
EXPECT_EQ(0U, Index(s, "app"));
EXPECT_EQ(npos, Index(s, "something"));
}
}
TEST(StringPiece, Find) {
using paddle::Find;
auto npos = paddle::StringPiece::npos;
{
paddle::StringPiece s;
EXPECT_EQ(npos, Find(s, 'a', 0U));
}
{
paddle::StringPiece s("app");
EXPECT_EQ(0U, Find(s, 'a', 0U));
EXPECT_EQ(1U, Find(s, 'p', 0U));
EXPECT_EQ(1U, Find(s, 'p', 1U));
EXPECT_EQ(2U, Find(s, 'p', 2U));
EXPECT_EQ(npos, Find(s, 'z', 2U));
}
}
TEST(StringPiece, RFind) {
using paddle::RFind;
auto npos = paddle::StringPiece::npos;
{
paddle::StringPiece s;
EXPECT_EQ(npos, RFind(s, 'a', 0U));
}
{
paddle::StringPiece s("app");
EXPECT_EQ(2U, RFind(s, 'p', 2U));
EXPECT_EQ(0U, RFind(s, 'a', 2U));
EXPECT_EQ(1U, RFind(s, 'p', 1U));
EXPECT_EQ(0U, RFind(s, 'a', 0));
EXPECT_EQ(npos, RFind(s, 'z', 2U));
}
}
TEST(StringPiece, SubStr) {
using paddle::SubStr;
{
paddle::StringPiece s;
EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 0, 1));
EXPECT_EQ("", SubStr(s, 1, 0));
}
{
paddle::StringPiece s("app");
EXPECT_EQ("", SubStr(s, 0, 0));
EXPECT_EQ("", SubStr(s, 1, 0));
EXPECT_EQ("", SubStr(s, 2, 0));
EXPECT_EQ("", SubStr(s, 3, 0));
EXPECT_EQ("a", SubStr(s, 0, 1));
EXPECT_EQ("p", SubStr(s, 1, 1));
EXPECT_EQ("p", SubStr(s, 2, 1));
EXPECT_EQ("", SubStr(s, 3, 1));
EXPECT_EQ("ap", SubStr(s, 0, 2));
EXPECT_EQ("pp", SubStr(s, 1, 2));
EXPECT_EQ("p", SubStr(s, 2, 2));
EXPECT_EQ("", SubStr(s, 3, 2));
EXPECT_EQ("app", SubStr(s, 0, 3));
EXPECT_EQ("pp", SubStr(s, 1, 3));
EXPECT_EQ("p", SubStr(s, 2, 3));
EXPECT_EQ("", SubStr(s, 3, 3));
}
}
TEST(StringPiece, StreamOutput) {
using paddle::StringPiece;
std::stringstream o;
o << StringPiece();
EXPECT_EQ("", o.str());
o << StringPiece("hello");
EXPECT_EQ("hello", o.str());
o << StringPiece();
EXPECT_EQ("hello", o.str());
}
......@@ -3563,11 +3563,7 @@ def update_g_config():
return g_config
def begin_parse(config_arg_str=''):
'''
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
def begin_parse():
init_config_environment()
for hook in _parse_config_hooks:
hook()
......@@ -3585,8 +3581,12 @@ def begin_parse(config_arg_str=''):
def parse_config(trainer_config, config_arg_str):
begin_parse(config_arg_str)
'''
@param config_arg_str: a string of the form var1=val1,var2=val2. It will be
passed to config script as a dictionary CONFIG_ARGS
'''
begin_parse()
config_args = {}
if config_arg_str:
......
......@@ -122,6 +122,7 @@ __all__ = [
'layer_support',
'multiplex_layer',
'row_conv_layer',
'dropout_layer',
'prelu_layer',
]
......@@ -3773,7 +3774,6 @@ def beam_search(step,
assert generated_input_index != -1
gipt = input[generated_input_index]
assert isinstance(gipt, BaseGeneratedInput)
gipt.bos_id = bos_id
gipt.eos_id = eos_id
......@@ -3793,7 +3793,6 @@ def beam_search(step,
predict = gipt.after_real_step(step(*args))
eos_layer(input=predict, eos_id=eos_id, name=eos_name)
return predict
tmp = recurrent_group(
......@@ -5569,6 +5568,24 @@ def multiplex_layer(input, name=None, layer_attr=None):
size=l.config.size)
@wrap_name_default("dropout")
def dropout_layer(input, dropout_rate, name=None):
"""
@TODO(yuyang18): Add comments.
:param name:
:param input:
:param dropout_rate:
:return:
"""
return addto_layer(
name=name,
input=input,
act=LinearActivation(),
bias_attr=False,
layer_attr=ExtraAttr(drop_rate=dropout_rate))
@wrap_name_default()
@wrap_act_default(act=LinearActivation())
@wrap_param_attr_default()
......
......@@ -26,10 +26,10 @@ from paddle.trainer.config_parser import *
__all__ = [
'sequence_conv_pool', 'simple_lstm', "simple_img_conv_pool",
"img_conv_bn_pool", 'dropout_layer', 'lstmemory_group', 'lstmemory_unit',
'small_vgg', 'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group',
'simple_gru', 'simple_attention', 'simple_gru2', 'bidirectional_gru',
'text_conv_pool', 'bidirectional_lstm', 'inputs', 'outputs'
"img_conv_bn_pool", 'lstmemory_group', 'lstmemory_unit', 'small_vgg',
'img_conv_group', 'vgg_16_network', 'gru_unit', 'gru_group', 'simple_gru',
'simple_attention', 'simple_gru2', 'bidirectional_gru', 'text_conv_pool',
'bidirectional_lstm', 'inputs', 'outputs'
]
######################################################
......@@ -1366,29 +1366,6 @@ def simple_attention(encoded_sequence,
input=scaled, pooling_type=SumPooling(), name="%s_pooling" % name)
############################################################################
# Miscs #
############################################################################
@wrap_name_default("dropout")
def dropout_layer(input, dropout_rate, name=None):
"""
@TODO(yuyang18): Add comments.
:param name:
:param input:
:param dropout_rate:
:return:
"""
return addto_layer(
name=name,
input=input,
act=LinearActivation(),
bias_attr=False,
layer_attr=ExtraAttr(drop_rate=dropout_rate))
def inputs(layers, *args):
"""
Declare the inputs of network. The order of input should be as same as
......
......@@ -13,7 +13,7 @@
# limitations under the License.
"""
`paddle.v2.layer` is a part of model config packages in paddle.v2. In API v2,
we want to make Paddle a plain Python package. The model config package defined
we want to make Paddle a plain Python package. The model config package defines
the way how to configure a neural network topology in Paddle Python code.
The primary usage shows below.
......@@ -30,7 +30,6 @@ The primary usage shows below.
# use prediction instance where needed.
parameters = paddle.parameters.create(cost)
"""
import collections
import copy
import re
......@@ -44,9 +43,10 @@ __all__ = ['data', 'parse_network']
def __need_to_keep__(name):
if name in ['StaticInput', 'LayerType', 'layer_support']:
return False
return True
return name in [
'StaticInput', 'SubsequenceInput', 'GeneratedInput', 'LayerType',
'layer_support'
]
def __need_to_wrap__(name):
......@@ -54,6 +54,8 @@ def __need_to_wrap__(name):
def __convert_name__(inname):
if __need_to_keep__(inname):
return inname
if inname == 'maxid_layer':
return 'max_id'
elif inname.endswith('memory') or inname.endswith(
......@@ -74,8 +76,6 @@ def __convert_name__(inname):
for name in v1_layers.__all__:
obj = getattr(v1_layers, name)
if not __need_to_keep__(name):
continue
new_name = __convert_name__(name)
if callable(obj) and __need_to_wrap__(name):
globals()[new_name] = __convert_to_v2__(obj, new_name, __name__)
......@@ -107,7 +107,7 @@ __data_layer__.__doc__ = __map_data_docstr__(v1_layers.data_layer.__doc__)
data = __convert_to_v2__(__data_layer__, 'name', __name__)
def __get_used_layers__(output_layers, extra_layers=None):
def __get_used_layers__(output_layers):
layer_names = set()
parents = {}
......@@ -132,6 +132,13 @@ def __get_used_layers__(output_layers, extra_layers=None):
add_parent(mem.layer_name, mem.boot_layer_name)
add_parent(mem.link_name, mem.layer_name)
if sub_model.HasField('generator'):
# according to the implementation of text generation
# in recurrent layer group, the generated word must be
# the first out link
add_parent(sub_model.out_links[0].layer_name,
sub_model.generator.eos_layer_name)
def dfs_travel(layer_name):
if layer_name in layer_names:
return
......@@ -247,9 +254,9 @@ def __trim_submodel__(old_submodel, layer_names, input_layer_names,
def parse_network(output_layers, extra_layers=None):
if not isinstance(output_layers, collections.Sequence):
output_layers = [output_layers]
if extra_layers is not None and not isinstance(extra_layers,
collections.Sequence):
extra_layers = [extra_layers]
if extra_layers is not None:
if not isinstance(extra_layers, collections.Sequence):
extra_layers = [extra_layers]
else:
extra_layers = []
......@@ -262,18 +269,29 @@ def parse_network(output_layers, extra_layers=None):
model_config = ModelConfig()
model_config.type = cp.g_config.model_config.type
for layer in output_layers:
model_config.output_layer_names.append(layer.full_name)
output_layer_names.add(layer.full_name)
for l in cp.g_config.model_config.layers:
if l.name not in layer_names:
continue
model_config.layers.extend([l])
if l.type == 'data':
if l.name in model_config.output_layer_names:
"""
In text generation, the outlink to save the generated word
indices is a data_layer defined in recurrent_group. This
data_layer is sure to be the output of the network in text
generation task, so this statement excludes such a special
data_layer from being inputs of the network, otherwise an error
will occur during data feeding.
"""
continue
model_config.input_layer_names.append(l.name)
input_layer_names.add(l.name)
for layer in output_layers:
model_config.output_layer_names.append(layer.full_name)
output_layer_names.add(layer.full_name)
for e in cp.g_config.model_config.evaluators:
if e.name in evaluator_names:
model_config.evaluators.extend([e])
......
......@@ -31,7 +31,6 @@ class Topology(object):
def __init__(self, layers, extra_layers=None):
def __check__(layers):
if not isinstance(layers, collections.Sequence):
__check_layer_type__(layers)
layers = [layers]
for layer in layers:
__check_layer_type__(layer)
......@@ -91,6 +90,7 @@ class Topology(object):
[('image', dense_vector(768)), ('label', integer_value(10))]
"""
data_layers = self.data_layers()
return [(nm, data_layers[nm].data_type)
for nm in self.proto().input_layer_names]
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册