提交 b15a4783 编写于 作者: E emailweixu 提交者: luotao1

Correctly handling multiple inputs and integer inputs for recurrent_g… (#114)

* Correctly handling multiple inputs and integer inputs for recurrent_group

* Fix ScatterAgentLayer for generation

* Revert sequence_(nest)_rnn.conf
上级 ffc34167
...@@ -217,7 +217,7 @@ void hl_matrix_mul(real *A_d, hl_trans_op_t transa, ...@@ -217,7 +217,7 @@ void hl_matrix_mul(real *A_d, hl_trans_op_t transa,
} else { } else {
LOG(FATAL) << "parameter transa error!"; LOG(FATAL) << "parameter transa error!";
} }
CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS); CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
CHECK_SYNC("hl_matrix_mul failed"); CHECK_SYNC("hl_matrix_mul failed");
} }
...@@ -266,7 +266,7 @@ void hl_matrix_mul_vector(real *A_d, hl_trans_op_t trans, ...@@ -266,7 +266,7 @@ void hl_matrix_mul_vector(real *A_d, hl_trans_op_t trans,
LOG(FATAL) << "parameter transa error!"; LOG(FATAL) << "parameter transa error!";
} }
CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS); CHECK_EQ(stat, CUBLAS_STATUS_SUCCESS) << hl_cublas_get_error_string(stat);
CHECK_SYNC("hl_matrix_mul_vector"); CHECK_SYNC("hl_matrix_mul_vector");
} }
......
...@@ -497,20 +497,21 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs, ...@@ -497,20 +497,21 @@ void RecurrentGradientMachine::forward(const std::vector<Argument>& inArgs,
int idSize = 0; int idSize = 0;
// connect in_links // connect in_links
for (size_t j = 0; j < inFrameLines_.size(); ++j) { for (size_t j = 0; j < inFrameLines_.size(); ++j) {
Info& info = info_[shareInlinkInfo ? 0 : j];
// idSize denotes the sum number of tokens in each length i // idSize denotes the sum number of tokens in each length i
idSize = info_[j].idIndex[i + 1] - info_[j].idIndex[i]; idSize = info.idIndex[i + 1] - info.idIndex[i];
InFrameLine inFrameLine = inFrameLines_[j]; InFrameLine inFrameLine = inFrameLines_[j];
auto scatterAgent = auto scatterAgent =
dynamic_cast<ScatterAgentLayer*>(inFrameLine.agents[i].get()); dynamic_cast<ScatterAgentLayer*>(inFrameLine.agents[i].get());
scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer, scatterAgent->setRealLayerAndOutput(inFrameLine.inLayer,
inFrameLine.outArg, info_[j].allIds, inFrameLine.outArg, info.allIds,
info_[j].idIndex[i], idSize); info.idIndex[i], idSize);
if (hasSubseq) { if (hasSubseq) {
// size: the length of subsequence // size: the length of subsequence
int size = int size =
info_[j].seqStartPosIndex[i + 1] - info_[j].seqStartPosIndex[i]; info.seqStartPosIndex[i + 1] - info.seqStartPosIndex[i];
scatterAgent->setSequenceStartPositions(info_[j].sequenceStartPositions, scatterAgent->setSequenceStartPositions(info.sequenceStartPositions,
info_[j].seqStartPosIndex[i], info.seqStartPosIndex[i],
size); size);
} }
} }
...@@ -744,16 +745,24 @@ void RecurrentGradientMachine::selectRowsOneTime(LayerPtr layer, ...@@ -744,16 +745,24 @@ void RecurrentGradientMachine::selectRowsOneTime(LayerPtr layer,
const IVectorPtr& allIds, const IVectorPtr& allIds,
Argument* arg, Argument* arg,
PassType passType) { PassType passType) {
const MatrixPtr& realV = layer->getOutputValue(); Argument& src = layer->getOutput();
int height = realV->getHeight(); if (src.value) {
int width = realV->getWidth(); const MatrixPtr& realV = src.value;
Matrix::resizeOrCreate(arg->value, height, width, /* trans */ false, useGpu_); int height = realV->getHeight();
arg->value->zeroMem(); int width = realV->getWidth();
arg->value->selectRows(*realV, *allIds); Matrix::resizeOrCreate(
if (passType != PASS_TEST) { arg->value, height, width, /* trans */ false, useGpu_);
Matrix::resizeOrCreate(arg->grad, height, width, /* trans */ false, arg->value->zeroMem();
useGpu_); arg->value->selectRows(*realV, *allIds);
arg->grad->zeroMem(); if (passType != PASS_TEST) {
Matrix::resizeOrCreate(arg->grad, height, width, /* trans */ false,
useGpu_);
arg->grad->zeroMem();
}
}
if (src.ids) {
IVector::resizeOrCreate(arg->ids, src.ids->getSize(), useGpu_);
arg->ids->selectFrom(*src.ids, *allIds);
} }
} }
......
...@@ -139,15 +139,16 @@ void ScatterAgentLayer::forward(PassType passType) { ...@@ -139,15 +139,16 @@ void ScatterAgentLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
if (realLayer_->getOutput().ids) { // ids scatter int width = this->getSize();
IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_); if (realOutArg_.value || realOutArg_.ids) {
output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_); output_.subArgFrom(realOutArg_, /* offset */ idIndex_, idSize_,
} else { // value scatter width, useGpu_);
int width = this->getSize(); } else { // used in generation
if (realOutArg_.value) { if (realLayer_->getOutput().ids) {
output_.subArgFrom(realOutArg_, /* offset */ idIndex_ * width, idSize_, IVector::resizeOrCreate(output_.ids, ids_->getSize(), useGpu_);
width, useGpu_); output_.ids->selectFrom(*realLayer_->getOutput().ids, *ids_);
} else { // used in generation }
if (realLayer_->getOutput().value) {
int height = ids_->getSize(); int height = ids_->getSize();
resetOutput(height, width); resetOutput(height, width);
...@@ -213,18 +214,17 @@ void SequenceGatherAgentLayer::forward(PassType passType) { ...@@ -213,18 +214,17 @@ void SequenceGatherAgentLayer::forward(PassType passType) {
void SequenceScatterAgentLayer::forward(PassType passType) { void SequenceScatterAgentLayer::forward(PassType passType) {
Layer::forward(passType); Layer::forward(passType);
CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId()); CHECK_EQ(realLayer_->getDeviceId(), this->getDeviceId());
CHECK(!realLayer_->getOutput().ids) << "Not supported";
const Argument& input = realLayer_->getOutput(); const Argument& input = realLayer_->getOutput();
CHECK_EQ(input.value->getWidth(), this->getSize()); CHECK_EQ(realLayer_->getSize(), this->getSize());
int width = this->getSize(); int width = this->getSize();
AsyncGpuBlock asyncGpuBlock; AsyncGpuBlock asyncGpuBlock;
REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str()); REGISTER_TIMER_INFO("SequenceAgentLayerForward", getName().c_str());
if (realOutArg_.value) { if (realOutArg_.value || realOutArg_.ids) {
CHECK(realOutArg_.sequenceStartPositions); CHECK(realOutArg_.sequenceStartPositions);
output_.subArgFrom(realOutArg_, /* offset */ idIndex_ * width, idSize_, output_.subArgFrom(realOutArg_, /* offset */ idIndex_, idSize_,
width, useGpu_, /* trans */ false, /* seqFlag */ true, width, useGpu_, /* trans */ false, /* seqFlag */ true,
/* seqStart */ seqStartPosIndex_, /* seqStart */ seqStartPosIndex_,
/* seqSize */ numSequences_); /* seqSize */ numSequences_);
......
...@@ -56,7 +56,6 @@ add_test(NAME test_RecurrentGradientMachine ...@@ -56,7 +56,6 @@ add_test(NAME test_RecurrentGradientMachine
COMMAND .set_python_path.sh -d COMMAND .set_python_path.sh -d
${PROJ_ROOT}/python:${PROJ_ROOT}/paddle/gserver/tests ${PROJ_ROOT}/python:${PROJ_ROOT}/paddle/gserver/tests
${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine ${CMAKE_CURRENT_BINARY_DIR}/test_RecurrentGradientMachine
--use_gpu=false
WORKING_DIRECTORY ${PROJ_ROOT}/paddle) WORKING_DIRECTORY ${PROJ_ROOT}/paddle)
add_unittest_without_exec(test_NetworkCompare add_unittest_without_exec(test_NetworkCompare
......
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
######################## data source ################################
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
test_list=None,
module='rnn_data_provider',
obj='process_subseq')
settings(batch_size=2, learning_rate=0.01)
######################## network configure ################################
dict_dim = 10
word_dim = 8
hidden_dim = 8
label_dim = 3
data = data_layer(name="word", size=dict_dim)
emb = embedding_layer(input=data, size=word_dim)
# This hierachical RNN is designed to be equivalent to the simple RNN in
# sequence_rnn.conf
def outer_step(wid, x):
outer_mem = memory(name="outer_rnn_state", size=hidden_dim)
def inner_step(y, wid):
z = embedding_layer(input=wid, size=word_dim)
inner_mem = memory(name="inner_rnn_state",
size=hidden_dim,
boot_layer=outer_mem)
out = fc_layer(input=[y, z, inner_mem],
size=hidden_dim,
act=TanhActivation(),
bias_attr=True,
name="inner_rnn_state")
return out
inner_rnn_output = recurrent_group(
step=inner_step,
name="inner",
input=[x, wid])
last = last_seq(input=inner_rnn_output, name="outer_rnn_state")
# "return last" should also work. But currently RecurrentGradientMachine
# does not handle it correctly. Current implementation requires that
# all the out links are from sequences. However, it does not report error
# when the out links are not sequences.
return inner_rnn_output
out = recurrent_group(
name="outer",
step=outer_step,
input=[SubsequenceInput(data), SubsequenceInput(emb)])
rep = last_seq(input=out)
prob = fc_layer(size=label_dim,
input=rep,
act=SoftmaxActivation(),
bias_attr=True)
outputs(classification_cost(input=prob,
label=data_layer(name="label", size=label_dim)))
#edit-mode: -*- python -*-
# Copyright (c) 2016 Baidu, Inc. All Rights Reserved
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from paddle.trainer_config_helpers import *
######################## data source ################################
define_py_data_sources2(train_list='gserver/tests/Sequence/dummy.list',
test_list=None,
module='rnn_data_provider',
obj='process_seq')
settings(batch_size=2, learning_rate=0.01)
######################## network configure ################################
dict_dim = 10
word_dim = 8
hidden_dim = 8
label_dim = 3
data = data_layer(name="word", size=dict_dim)
emb = embedding_layer(input=data, size=word_dim)
def step(y, wid):
z = embedding_layer(input=wid, size=word_dim)
mem = memory(name="rnn_state", size=hidden_dim)
out = fc_layer(input=[y, z, mem],
size=hidden_dim,
act=TanhActivation(),
bias_attr=True,
name="rnn_state")
return out
out = recurrent_group(
name="rnn",
step=step,
input=[emb, data])
rep = last_seq(input=out)
prob = fc_layer(size=label_dim,
input=rep,
act=SoftmaxActivation(),
bias_attr=True)
outputs(classification_cost(input=prob,
label=data_layer(name="label", size=label_dim)))
...@@ -92,7 +92,11 @@ void CalCost(const string& conf, const string& dir, real* cost, ...@@ -92,7 +92,11 @@ void CalCost(const string& conf, const string& dir, real* cost,
rmDir(dir.c_str()); rmDir(dir.c_str());
} }
void test(const string& conf1, const string& conf2, double eps) { void test(const string& conf1, const string& conf2, double eps, bool useGpu) {
if (!paddle::version::isWithGpu() && useGpu) {
return;
}
FLAGS_use_gpu = useGpu;
int num_passes = 5; int num_passes = 5;
real* cost1 = new real[num_passes]; real* cost1 = new real[num_passes];
const string dir1 = "gserver/tests/t1"; const string dir1 = "gserver/tests/t1";
...@@ -113,17 +117,28 @@ void test(const string& conf1, const string& conf2, double eps) { ...@@ -113,17 +117,28 @@ void test(const string& conf1, const string& conf2, double eps) {
} }
TEST(RecurrentGradientMachine, HasSubSequence) { TEST(RecurrentGradientMachine, HasSubSequence) {
test("gserver/tests/sequence_layer_group.conf", for (bool useGpu : {false, true}) {
"gserver/tests/sequence_nest_layer_group.conf", test("gserver/tests/sequence_layer_group.conf",
1e-5); "gserver/tests/sequence_nest_layer_group.conf",
1e-5, useGpu);
}
} }
TEST(RecurrentGradientMachine, rnn) { TEST(RecurrentGradientMachine, rnn) {
test("gserver/tests/sequence_rnn.conf", for (bool useGpu : {false, true}) {
"gserver/tests/sequence_nest_rnn.conf", test("gserver/tests/sequence_rnn.conf",
0); "gserver/tests/sequence_nest_rnn.conf",
1e-6, useGpu);
}
} }
TEST(RecurrentGradientMachine, rnn_multi_input) {
for (bool useGpu : {false, true}) {
test("gserver/tests/sequence_rnn_multi_input.conf",
"gserver/tests/sequence_nest_rnn_multi_input.conf",
1e-6, useGpu);
}
}
int main(int argc, char** argv) { int main(int argc, char** argv) {
if (paddle::version::isWithPyDataProvider()) { if (paddle::version::isWithPyDataProvider()) {
......
...@@ -554,11 +554,16 @@ void Argument::degradeSequence(const Argument& input, bool useGpu) { ...@@ -554,11 +554,16 @@ void Argument::degradeSequence(const Argument& input, bool useGpu) {
void Argument::subArgFrom(const Argument& input, size_t offset, size_t height, void Argument::subArgFrom(const Argument& input, size_t offset, size_t height,
size_t width, bool useGpu, bool trans, bool seqFlag, size_t width, bool useGpu, bool trans, bool seqFlag,
size_t seqStart, size_t seqSize) { size_t seqStart, size_t seqSize) {
value = Matrix::create(input.value->getData() + offset, height, width, trans, if (input.value) {
useGpu); value = Matrix::create(input.value->getData() + offset * width,
height, width, trans, useGpu);
}
if (input.ids) {
ids = IVector::create(input.ids->getData() + offset, height, useGpu);
}
if (input.grad) { if (input.grad) {
grad = Matrix::create(input.grad->getData() + offset, height, width, trans, grad = Matrix::create(input.grad->getData() + offset * width,
useGpu); height, width, trans, useGpu);
} }
if (seqFlag) { if (seqFlag) {
sequenceStartPositions = std::make_shared<ICpuGpuVector>( sequenceStartPositions = std::make_shared<ICpuGpuVector>(
......
...@@ -177,11 +177,11 @@ struct Argument { ...@@ -177,11 +177,11 @@ struct Argument {
} }
/** /**
* @brief (value, grad, sequenceStartPositions) of output are subset of * @brief (value, ids, grad, sequenceStartPositions) of output are subset of
* input. Note that, output share the same memory of input. * input. Note that, output share the same memory of input.
* *
* @param input[in] input * @param input[in] input
* @param offset[in] offset of input.value * @param offset[in] offset in terms of rows
* @param height[in] height of output.value * @param height[in] height of output.value
* @param width[in] width of output.value * @param width[in] width of output.value
* @param useGpu[in] * @param useGpu[in]
......
...@@ -216,7 +216,7 @@ def check_input(input): ...@@ -216,7 +216,7 @@ def check_input(input):
""" """
if isinstance(input, LayerOutput): if isinstance(input, LayerOutput):
return [LayerOutput] return [input]
assert isinstance(input, list) assert isinstance(input, list)
for inp in input: for inp in input:
assert isinstance(inp, LayerOutput) assert isinstance(inp, LayerOutput)
...@@ -764,7 +764,7 @@ def print_layer(input, name=None): ...@@ -764,7 +764,7 @@ def print_layer(input, name=None):
:type input: LayerOutput|list|tuple :type input: LayerOutput|list|tuple
:return: No return :return: No return
""" """
check_input(input) input = check_input(input)
Layer( Layer(
name=name, name=name,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册