提交 d5b0c57d 编写于 作者: X xzl

Merge branch 'develop' of https://github.com/PaddlePaddle/Paddle into mobilenet_gpu

// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package connection package connection
import ( import (
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(master_test) go_test(master_test)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
go_library(paddle_master SHARED DEPS paddle_go_optimizer) go_library(paddle_master SHARED DEPS paddle_go_optimizer)
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master_test package master_test
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "sync" import "sync"
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package master package master
import "testing" import "testing"
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(pserver_test DEPS paddle_go_optimizer) go_test(pserver_test DEPS paddle_go_optimizer)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(pserver_client_test DEPS paddle_go_optimizer) go_test(pserver_client_test DEPS paddle_go_optimizer)
endif() endif()
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf) cc_library(paddle_go_optimizer DEPS paddle_optimizer paddle_proto glog gflags protobuf)
target_link_libraries(paddle_go_optimizer stdc++ m) target_link_libraries(paddle_go_optimizer stdc++ m)
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package main package main
/* /*
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer) cc_test(test_cclient SRCS test_cclient.c DEPS paddle_pserver_cclient paddle_go_optimizer)
add_style_check_target(test_cclient test_cclient.c) add_style_check_target(test_cclient test_cclient.c)
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include <stdio.h> #include <stdio.h>
#include <stdlib.h> #include <stdlib.h>
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client_test package client_test
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package client package client
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
// #cgo CFLAGS: -I ../../ // #cgo CFLAGS: -I ../../
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver package pserver
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pserver_test package pserver_test
import ( import (
......
# Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
if(WITH_TESTING) if(WITH_TESTING)
go_test(network_helper_test) go_test(network_helper_test)
endif() endif()
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import ( import (
......
// Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserve.
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
// http://www.apache.org/licenses/LICENSE-2.0
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package networkhelper package networkhelper
import "testing" import "testing"
......
...@@ -65,14 +65,15 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -65,14 +65,15 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> { struct EigenVector : public EigenTensor<T, 1, MajorType, IndexType> {
// Flatten is to reshape a Tensor into a one dimension EigenVector // Flatten is to reshape a Tensor into a one dimension EigenVector
static typename EigenTensor<T, 1>::Type Flatten(Tensor& tensor) { using Parent = EigenTensor<T, 1, MajorType, IndexType>;
return EigenTensor<T, 1>::From( static typename Parent::Type Flatten(Tensor& tensor) {
tensor, make_ddim({static_cast<int>(product(tensor.dims_))})); return Parent::From(tensor,
make_ddim({static_cast<int>(product(tensor.dims_))}));
} }
static typename EigenTensor<T, 1>::ConstType Flatten(const Tensor& tensor) { static typename Parent::ConstType Flatten(const Tensor& tensor) {
return EigenTensor<T, 1>::From( return Parent::From(tensor,
tensor, make_ddim({static_cast<int>(product(tensor.dims_))})); make_ddim({static_cast<int>(product(tensor.dims_))}));
} }
}; };
......
...@@ -3173,11 +3173,11 @@ def memory(name, ...@@ -3173,11 +3173,11 @@ def memory(name,
@wrap_bias_attr_default() @wrap_bias_attr_default()
@wrap_act_default( @wrap_act_default(param_names=['gate_act'], act=SigmoidActivation())
param_names=['gate_act', 'state_act'], act=SigmoidActivation()) @wrap_act_default(param_names=['state_act'], act=TanhActivation())
@wrap_act_default(act=TanhActivation()) @wrap_act_default(act=TanhActivation())
@wrap_name_default('lstm_step') @wrap_name_default('lstm_step')
@layer_support() @layer_support(ERROR_CLIPPING, DROPOUT)
def lstm_step_layer(input, def lstm_step_layer(input,
state, state,
size=None, size=None,
...@@ -3531,12 +3531,7 @@ def SubsequenceInput(input): ...@@ -3531,12 +3531,7 @@ def SubsequenceInput(input):
@wrap_name_default("recurrent_group") @wrap_name_default("recurrent_group")
def recurrent_group(step, def recurrent_group(step, input, reverse=False, name=None, targetInlink=None):
input,
reverse=False,
name=None,
targetInlink=None,
is_generating=False):
""" """
Recurrent layer group is an extremely flexible recurrent unit in Recurrent layer group is an extremely flexible recurrent unit in
PaddlePaddle. As long as the user defines the calculation done within a PaddlePaddle. As long as the user defines the calculation done within a
...@@ -3602,21 +3597,12 @@ def recurrent_group(step, ...@@ -3602,21 +3597,12 @@ def recurrent_group(step,
:type targetInlink: LayerOutput|SubsequenceInput :type targetInlink: LayerOutput|SubsequenceInput
:param is_generating: If is generating, none of input type should be LayerOutput;
else, for training or testing, one of the input type must
be LayerOutput.
:type is_generating: bool
:return: LayerOutput object. :return: LayerOutput object.
:rtype: LayerOutput :rtype: LayerOutput
""" """
model_type('recurrent_nn') model_type('recurrent_nn')
def is_single_input(x): if isinstance(input, LayerOutput) or isinstance(input, StaticInput):
return isinstance(x, LayerOutput) or isinstance(x, StaticInput)
if is_single_input(input):
input = [input] input = [input]
assert isinstance(input, collections.Sequence) assert isinstance(input, collections.Sequence)
...@@ -3630,13 +3616,8 @@ def recurrent_group(step, ...@@ -3630,13 +3616,8 @@ def recurrent_group(step,
in_links=map(lambda x: x.name, in_links), in_links=map(lambda x: x.name, in_links),
seq_reversed=reverse) seq_reversed=reverse)
in_args = [] in_args = []
has_LayerOutput = False
for each_input in input: for each_input in input:
assert is_single_input(each_input) if isinstance(each_input, StaticInput): # StaticInput
if isinstance(each_input, LayerOutput):
in_args.append(each_input)
has_LayerOutput = True
else: # StaticInput
mem_name = "__%s_memory__" % each_input.input.name mem_name = "__%s_memory__" % each_input.input.name
mem = memory( mem = memory(
name=None, name=None,
...@@ -3644,24 +3625,26 @@ def recurrent_group(step, ...@@ -3644,24 +3625,26 @@ def recurrent_group(step,
boot_layer=each_input.input) boot_layer=each_input.input)
mem.set_input(mem) mem.set_input(mem)
in_args.append(mem) in_args.append(mem)
else:
assert (is_generating != has_LayerOutput) in_args.append(each_input)
layer_outs = step(*in_args) layer_outs = step(*in_args)
if isinstance(layer_outs, LayerOutput): if isinstance(layer_outs, LayerOutput):
layer_outs = [layer_outs] layer_outs = [layer_outs]
for ot in layer_outs: for layer_out in layer_outs:
assert isinstance(ot, LayerOutput) assert isinstance(
ot.reverse = reverse layer_out, LayerOutput
RecurrentLayerGroupSetOutLink(ot.name) ), "Type of step function's return value must be LayerOutput."
layer_out.reverse = reverse
RecurrentLayerGroupSetOutLink(layer_out.name)
RecurrentLayerGroupEnd(name=name) RecurrentLayerGroupEnd(name=name)
for layer_out in layer_outs: for layer_out in layer_outs:
# Thee previous full_name is the name is the rnn group # The previous full_name is the name inside the recurrent group.
# We need a full_name outside the rnn group # We need a full_name outside the recurrent group.
layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name) layer_out.full_name = MakeLayerNameInSubmodel(layer_out.name)
if len(layer_outs) == 1: if len(layer_outs) == 1:
...@@ -3684,7 +3667,20 @@ class BaseGeneratedInput(object): ...@@ -3684,7 +3667,20 @@ class BaseGeneratedInput(object):
class GeneratedInput(BaseGeneratedInput): class GeneratedInput(BaseGeneratedInput):
def after_real_step(self, input): def after_real_step(self, input):
return maxid_layer(input=input, name='__beam_search_predict__') if isinstance(input, LayerOutput):
input = [input]
elif isinstance(input, collections.Sequence):
input = list(input)
if len(input) > 1:
logger.info(
("More than one layers inside the recurrent_group "
"are returned as outputs of the entire recurrent_group "
"PLEASE garantee the first output is probability of "
"the predicted next word."))
return [maxid_layer(
input=input[0], name='__beam_search_predict__')] + (
input[1:] if len(input) > 1 else [])
def before_real_step(self): def before_real_step(self):
predict_id = memory( predict_id = memory(
...@@ -3871,6 +3867,7 @@ def beam_search(step, ...@@ -3871,6 +3867,7 @@ def beam_search(step,
:type step: callable :type step: callable
:param input: Input data for the recurrent unit, which should include the :param input: Input data for the recurrent unit, which should include the
previously generated words as a GeneratedInput object. previously generated words as a GeneratedInput object.
In beam_search, none of the input's type should be LayerOutput.
:type input: list :type input: list
:param bos_id: Index of the start symbol in the dictionary. The start symbol :param bos_id: Index of the start symbol in the dictionary. The start symbol
is a special token for NLP task, which indicates the is a special token for NLP task, which indicates the
...@@ -3912,15 +3909,18 @@ def beam_search(step, ...@@ -3912,15 +3909,18 @@ def beam_search(step,
real_input = [] real_input = []
for i, each_input in enumerate(input): for i, each_input in enumerate(input):
assert isinstance(each_input, StaticInput) or isinstance( assert not isinstance(each_input, LayerOutput), (
each_input, BaseGeneratedInput) "in beam_search, "
"none of the input should has a type of LayerOutput.")
if isinstance(each_input, BaseGeneratedInput): if isinstance(each_input, BaseGeneratedInput):
assert generated_input_index == -1 assert generated_input_index == -1, ("recurrent_group accepts "
"only one GeneratedInput.")
generated_input_index = i generated_input_index = i
else: else:
real_input.append(each_input) real_input.append(each_input)
assert generated_input_index != -1 assert generated_input_index != -1, "No GeneratedInput is given."
gipt = input[generated_input_index] gipt = input[generated_input_index]
...@@ -3941,17 +3941,11 @@ def beam_search(step, ...@@ -3941,17 +3941,11 @@ def beam_search(step,
predict = gipt.after_real_step(step(*args)) predict = gipt.after_real_step(step(*args))
eos_layer(input=predict, eos_id=eos_id, name=eos_name) eos_layer(input=predict[0], eos_id=eos_id, name=eos_name)
return predict return predict
tmp = recurrent_group( return recurrent_group(
step=__real_step__, step=__real_step__, input=real_input, reverse=False, name=name)
input=real_input,
reverse=False,
name=name,
is_generating=True)
return tmp
def __cost_input__(input, label, weight=None): def __cost_input__(input, label, weight=None):
......
...@@ -614,18 +614,17 @@ def simple_lstm(input, ...@@ -614,18 +614,17 @@ def simple_lstm(input,
@wrap_name_default('lstm_unit') @wrap_name_default('lstm_unit')
def lstmemory_unit(input, def lstmemory_unit(input,
memory_boot=None, out_memory=None,
name=None, name=None,
size=None, size=None,
param_attr=None, param_attr=None,
act=None, act=None,
gate_act=None, gate_act=None,
state_act=None, state_act=None,
mixed_bias_attr=None, input_proj_bias_attr=None,
input_proj_layer_attr=None,
lstm_bias_attr=None, lstm_bias_attr=None,
mixed_layer_attr=None, lstm_layer_attr=None):
lstm_layer_attr=None,
get_output_layer_attr=None):
""" """
Define calculations that a LSTM unit performs during a single time step. Define calculations that a LSTM unit performs during a single time step.
This function itself is not a recurrent layer, so it can not be This function itself is not a recurrent layer, so it can not be
...@@ -662,8 +661,8 @@ def lstmemory_unit(input, ...@@ -662,8 +661,8 @@ def lstmemory_unit(input,
:param input: input layer name. :param input: input layer name.
:type input: LayerOutput :type input: LayerOutput
:param memory_boot: the initialization state of the LSTM cell. :param out_memory: output of previous time step
:type memory_boot: LayerOutput | None :type out_memory: LayerOutput | None
:param name: lstmemory unit name. :param name: lstmemory unit name.
:type name: basestring :type name: basestring
:param size: lstmemory unit size. :param size: lstmemory unit size.
...@@ -676,33 +675,35 @@ def lstmemory_unit(input, ...@@ -676,33 +675,35 @@ def lstmemory_unit(input,
:type gate_act: BaseActivation :type gate_act: BaseActivation
:param state_act: lstm state activiation type. :param state_act: lstm state activiation type.
:type state_act: BaseActivation :type state_act: BaseActivation
:param mixed_bias_attr: bias parameter attribute of mixed layer. :param input_proj_bias_attr: bias attribute for input-to-hidden projection.
False means no bias, None means default bias. False means no bias, None means default bias.
:type mixed_bias_attr: ParameterAttribute|False :type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_bias_attr: bias parameter attribute of lstm layer. :param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias. False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False :type lstm_bias_attr: ParameterAttribute|False
:param mixed_layer_attr: mixed layer's extra attribute.
:type mixed_layer_attr: ExtraLayerAttribute
:param lstm_layer_attr: lstm layer's extra attribute. :param lstm_layer_attr: lstm layer's extra attribute.
:type lstm_layer_attr: ExtraLayerAttribute :type lstm_layer_attr: ExtraLayerAttribute
:param get_output_layer_attr: get output layer's extra attribute.
:type get_output_layer_attr: ExtraLayerAttribute
:return: lstmemory unit name. :return: lstmemory unit name.
:rtype: LayerOutput :rtype: LayerOutput
""" """
if size is None: if size is None:
assert input.size % 4 == 0 assert input.size % 4 == 0
size = input.size / 4 size = input.size / 4
out_mem = memory(name=name, size=size) if out_memory is None:
state_mem = memory( out_mem = memory(name=name, size=size)
name="%s_state" % name, size=size, boot_layer=memory_boot) else:
out_mem = out_memory
state_mem = memory(name="%s_state" % name, size=size)
with mixed_layer( with mixed_layer(
name="%s_input_recurrent" % name, name="%s_input_recurrent" % name,
size=size * 4, size=size * 4,
bias_attr=mixed_bias_attr, bias_attr=input_proj_bias_attr,
layer_attr=mixed_layer_attr, layer_attr=input_proj_layer_attr,
act=IdentityActivation()) as m: act=IdentityActivation()) as m:
m += identity_projection(input=input) m += identity_projection(input=input)
m += full_matrix_projection(input=out_mem, param_attr=param_attr) m += full_matrix_projection(input=out_mem, param_attr=param_attr)
...@@ -717,11 +718,7 @@ def lstmemory_unit(input, ...@@ -717,11 +718,7 @@ def lstmemory_unit(input,
gate_act=gate_act, gate_act=gate_act,
state_act=state_act, state_act=state_act,
layer_attr=lstm_layer_attr) layer_attr=lstm_layer_attr)
get_output_layer( get_output_layer(name='%s_state' % name, input=lstm_out, arg_name='state')
name='%s_state' % name,
input=lstm_out,
arg_name='state',
layer_attr=get_output_layer_attr)
return lstm_out return lstm_out
...@@ -730,17 +727,16 @@ def lstmemory_unit(input, ...@@ -730,17 +727,16 @@ def lstmemory_unit(input,
def lstmemory_group(input, def lstmemory_group(input,
size=None, size=None,
name=None, name=None,
memory_boot=None, out_memory=None,
reverse=False, reverse=False,
param_attr=None, param_attr=None,
act=None, act=None,
gate_act=None, gate_act=None,
state_act=None, state_act=None,
mixed_bias_attr=None, input_proj_bias_attr=None,
input_proj_layer_attr=None,
lstm_bias_attr=None, lstm_bias_attr=None,
mixed_layer_attr=None, lstm_layer_attr=None):
lstm_layer_attr=None,
get_output_layer_attr=None):
""" """
lstm_group is a recurrent_group version of Long Short Term Memory. It lstm_group is a recurrent_group version of Long Short Term Memory. It
does exactly the same calculation as the lstmemory layer (see lstmemory in does exactly the same calculation as the lstmemory layer (see lstmemory in
...@@ -774,8 +770,8 @@ def lstmemory_group(input, ...@@ -774,8 +770,8 @@ def lstmemory_group(input,
:type size: int :type size: int
:param name: name of the lstmemory group. :param name: name of the lstmemory group.
:type name: basestring :type name: basestring
:param memory_boot: the initialization state of LSTM cell. :param out_memory: output of previous time step
:type memory_boot: LayerOutput | None :type out_memory: LayerOutput | None
:param reverse: is lstm reversed :param reverse: is lstm reversed
:type reverse: bool :type reverse: bool
:param param_attr: Parameter config, None if use default. :param param_attr: Parameter config, None if use default.
...@@ -786,18 +782,17 @@ def lstmemory_group(input, ...@@ -786,18 +782,17 @@ def lstmemory_group(input,
:type gate_act: BaseActivation :type gate_act: BaseActivation
:param state_act: lstm state activiation type. :param state_act: lstm state activiation type.
:type state_act: BaseActivation :type state_act: BaseActivation
:param mixed_bias_attr: bias parameter attribute of mixed layer.
False means no bias, None means default bias.
:type mixed_bias_attr: ParameterAttribute|False
:param lstm_bias_attr: bias parameter attribute of lstm layer. :param lstm_bias_attr: bias parameter attribute of lstm layer.
False means no bias, None means default bias. False means no bias, None means default bias.
:type lstm_bias_attr: ParameterAttribute|False :type lstm_bias_attr: ParameterAttribute|False
:param mixed_layer_attr: mixed layer's extra attribute. :param input_proj_bias_attr: bias attribute for input-to-hidden projection.
:type mixed_layer_attr: ExtraLayerAttribute False means no bias, None means default bias.
:type input_proj_bias_attr: ParameterAttribute|False|None
:param input_proj_layer_attr: extra layer attribute for input to hidden
projection of the LSTM unit, such as dropout, error clipping.
:type input_proj_layer_attr: ExtraLayerAttribute
:param lstm_layer_attr: lstm layer's extra attribute. :param lstm_layer_attr: lstm layer's extra attribute.
:type lstm_layer_attr: ExtraLayerAttribute :type lstm_layer_attr: ExtraLayerAttribute
:param get_output_layer_attr: get output layer's extra attribute.
:type get_output_layer_attr: ExtraLayerAttribute
:return: the lstmemory group. :return: the lstmemory group.
:rtype: LayerOutput :rtype: LayerOutput
""" """
...@@ -805,18 +800,17 @@ def lstmemory_group(input, ...@@ -805,18 +800,17 @@ def lstmemory_group(input,
def __lstm_step__(ipt): def __lstm_step__(ipt):
return lstmemory_unit( return lstmemory_unit(
input=ipt, input=ipt,
memory_boot=memory_boot,
name=name, name=name,
size=size, size=size,
mixed_bias_attr=mixed_bias_attr,
mixed_layer_attr=mixed_layer_attr,
param_attr=param_attr,
lstm_bias_attr=lstm_bias_attr,
act=act, act=act,
gate_act=gate_act, gate_act=gate_act,
state_act=state_act, state_act=state_act,
out_memory=out_memory,
input_proj_bias_attr=input_proj_bias_attr,
input_proj_layer_attr=input_proj_layer_attr,
param_attr=param_attr,
lstm_layer_attr=lstm_layer_attr, lstm_layer_attr=lstm_layer_attr,
get_output_layer_attr=get_output_layer_attr) lstm_bias_attr=lstm_bias_attr)
return recurrent_group( return recurrent_group(
name='%s_recurrent_group' % name, name='%s_recurrent_group' % name,
......
...@@ -104,7 +104,7 @@ layers { ...@@ -104,7 +104,7 @@ layers {
} }
bias_parameter_name: "lstm_bias" bias_parameter_name: "lstm_bias"
active_gate_type: "sigmoid" active_gate_type: "sigmoid"
active_state_type: "sigmoid" active_state_type: "tanh"
} }
layers { layers {
name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" name: "__lstm_group_0___state@__lstm_group_0___recurrent_group"
...@@ -183,7 +183,7 @@ layers { ...@@ -183,7 +183,7 @@ layers {
} }
bias_parameter_name: "lstm_bias" bias_parameter_name: "lstm_bias"
active_gate_type: "sigmoid" active_gate_type: "sigmoid"
active_state_type: "sigmoid" active_state_type: "tanh"
} }
layers { layers {
name: "__lstm_group_1___state@__lstm_group_1___recurrent_group" name: "__lstm_group_1___state@__lstm_group_1___recurrent_group"
......
...@@ -258,7 +258,7 @@ layers { ...@@ -258,7 +258,7 @@ layers {
} }
bias_parameter_name: "___lstm_group_0__@__lstm_group_0___recurrent_group.wbias" bias_parameter_name: "___lstm_group_0__@__lstm_group_0___recurrent_group.wbias"
active_gate_type: "sigmoid" active_gate_type: "sigmoid"
active_state_type: "sigmoid" active_state_type: "tanh"
} }
layers { layers {
name: "__lstm_group_0___state@__lstm_group_0___recurrent_group" name: "__lstm_group_0___state@__lstm_group_0___recurrent_group"
......
...@@ -20,12 +20,13 @@ lstm1 = lstmemory_group( ...@@ -20,12 +20,13 @@ lstm1 = lstmemory_group(
input=m1, input=m1,
param_attr=lstm_param, param_attr=lstm_param,
lstm_bias_attr=lstm_bias, lstm_bias_attr=lstm_bias,
mixed_bias_attr=False) input_proj_bias_attr=False)
lstm2 = lstmemory_group( lstm2 = lstmemory_group(
input=m2, input=m2,
param_attr=lstm_param, param_attr=lstm_param,
lstm_bias_attr=lstm_bias, lstm_bias_attr=lstm_bias,
mixed_bias_attr=False) input_proj_bias_attr=False)
softmax_param = ParamAttr(name='softmax_param') softmax_param = ParamAttr(name='softmax_param')
......
...@@ -116,7 +116,7 @@ def reader_creator(data_file, ...@@ -116,7 +116,7 @@ def reader_creator(data_file,
data = batch['data'] data = batch['data']
labels = batch['label'] labels = batch['label']
for sample, label in itertools.izip(data, batch['label']): for sample, label in itertools.izip(data, batch['label']):
yield sample, int(label) yield sample, int(label) - 1
if use_xmap: if use_xmap:
return xmap_readers(mapper, reader, cpu_count(), buffered_size) return xmap_readers(mapper, reader, cpu_count(), buffered_size)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册