提交 8560ce69 编写于 作者: X xutianbing

Daoyuan's comments about SequenceArg.

上级 9edfd200
......@@ -74,7 +74,7 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
/**
* Paddle Function for Context Projection Forward.
* Calculate the value for the output layer with context projection.
* Calculate the output sequence after context projection.
*
* What is Context Projection?
* For example, assumed input (x) has 4 words and the dimension of each word
......@@ -92,10 +92,12 @@ void ContextProjectionForward<DEVICE_TYPE_CPU>(CpuMatrix& out_mat,
* c1, c2, d1, d2, 0, 0]
* @endcode
*
* \param outputs[0] output value.
* \param inputs[0] input value.
* \param inputs[1] input weight.
* \param inputs[2] input sequence.
* \param outputs[0].matrix output value, n * (d * l)
* \param outputs[0].vector input sequence, n * 1
* \param inputs[0].matrix input value, n * d
* \param inputs[0].vector input sequence, n * 1
* \param inputs[1].matrix input weight, pad * d
* \param inputs[1].vector input sequence, n * 1
*/
template <DeviceType Device>
class ContextProjectionForwardFunc : public FunctionBase {
......@@ -107,28 +109,40 @@ public:
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ((size_t)3, inputs.size());
CHECK(1 == inputs.size() || 2 == inputs.size());
CHECK_EQ((size_t)1, outputs.size());
CHECK(outputs[0].data() && inputs[0].data() && inputs[2].data());
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[1].shape().ndims(), (size_t)2);
CHECK_EQ(inputs[2].shape().ndims(), (size_t)1);
const auto val_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
const auto w_seqs = inputs.size() <= 1
? nullptr
: dynamic_cast<const SequenceArg*>(&inputs[1]);
auto out_seqs = dynamic_cast<const SequenceArg&>(outputs[0]);
CHECK(out_seqs.data() && val_seqs.data() &&
val_seqs.getSequenceIds().data());
CHECK_EQ(out_seqs.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.shape().ndims(), (size_t)2);
CHECK_EQ(val_seqs.getSequenceIds().shape().ndims(), (size_t)1);
if (w_seqs) {
CHECK_EQ(w_seqs->shape().ndims(), (size_t)2);
CHECK_EQ(w_seqs->getSequenceIds().shape().ndims(), (size_t)1);
}
/// dim of output = dim of input * context_length
CHECK_EQ(outputs[0].shape()[1], inputs[0].shape()[1] * context_length_);
/// dim of input == dim of weight
CHECK_EQ(inputs[0].shape()[1], inputs[1].shape()[1]);
CHECK_EQ(out_seqs.shape()[1], val_seqs.shape()[1] * context_length_);
/// input and output has the same batch_size
CHECK_EQ(inputs[0].shape()[0], outputs[0].shape()[0]);
CHECK_EQ(val_seqs.shape()[0], out_seqs.shape()[0]);
/// dim of input == dim of weight
if (w_seqs) {
CHECK_EQ(val_seqs.shape()[1], w_seqs->shape()[1]);
}
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
auto out_mat = outputs[0].matrix<Device>();
const auto in_mat = inputs[0].matrix<Device>();
CHECK_EQ(out_seqs.getArgType(), ADD_TO);
auto out_mat = out_seqs.matrix<Device>();
const auto in_mat = val_seqs.matrix<Device>();
const auto w_mat =
!inputs[1].data() ? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
: inputs[1].matrix<Device>();
const auto seq_vec = inputs[2].vector<int, Device>();
w_seqs ? w_seqs->matrix<Device>()
: typename Tensor<real, Device>::Matrix(nullptr, 0, 0);
const auto seq_vec = val_seqs.getSequenceIds().vector<int, Device>();
ContextProjectionForward<Device>(out_mat,
in_mat,
w_mat,
......@@ -227,25 +241,25 @@ public:
CHECK_EQ((size_t)1, inputs.size());
CHECK_EQ((size_t)2, outputs.size());
const auto seqArg = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(seqArg.data() && inputs[0].data());
CHECK_EQ(seqArg.shape().ndims(), (size_t)2);
CHECK_EQ(seqArg.getSequenceIds().shape().ndims(), (size_t)1);
const auto seq_arg = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(seq_arg.data() && inputs[0].data());
CHECK_EQ(seq_arg.shape().ndims(), (size_t)2);
CHECK_EQ(seq_arg.getSequenceIds().shape().ndims(), (size_t)1);
CHECK_EQ(outputs[0].shape().ndims(), (size_t)2);
CHECK_EQ(outputs[1].shape().ndims(), (size_t)2);
/// dim of input grad == dim of weight
CHECK_EQ(outputs[0].shape()[1], outputs[1].shape()[1]);
/// input and output grad has the same batch_size
CHECK_EQ(outputs[0].shape()[0], seqArg.shape()[0]);
CHECK_EQ(outputs[0].shape()[0], seq_arg.shape()[0]);
/// dim of output val = dim of input grad * context_length
CHECK_EQ(seqArg.shape()[1], outputs[0].shape()[1] * context_length_);
CHECK_EQ(seq_arg.shape()[1], outputs[0].shape()[1] * context_length_);
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK_EQ(outputs[1].getArgType(), ADD_TO);
const auto seq_vec = seqArg.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = seqArg.matrix<Device>();
const auto seq_vec = seq_arg.getSequenceIds().vector<int, Device>();
const auto out_grad_mat = seq_arg.matrix<Device>();
auto in_grad_mat =
!outputs[0].data()
? typename Tensor<real, Device>::Matrix(nullptr, 0, 0)
......@@ -272,6 +286,91 @@ private:
size_t total_pad_;
};
/**
* \param inputs[0].matrix input grad, n*d
* \param inputs[0].vector input sequence, n*1
* \param outputs[0] output grad, n*(d*l)
*/
template <DeviceType Device>
class ContextProjectionBackwardDataFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seqs.data() && outputs[0].data() &&
in_seqs.getSequenceIds().data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1);
CHECK_EQ(outputs[0].shape().ndims(),
in_seqs.shape().ndims() * context_length_);
/// input and output has the same batch_size
CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]);
const auto out_grad_mat = outputs[0].matrix<Device>();
auto in_grad_mat = in_seqs.matrix<Device>();
const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();
ContextProjectionBackwardData<Device>(
out_grad_mat, in_grad_mat, seq_vec, context_length_, context_start_);
}
private:
size_t context_length_;
int context_start_;
};
/**
* \param inputs[0].matrix weight grad, pad * d
* \param inputs[0].vecotr input sequence, n * 1
* \param outputs[0] output grad, n * (d * l)
*/
template <DeviceType Device>
class ContextProjectionBackwardWeightFunc : public FunctionBase {
public:
void init(const FuncConfig& config) override {
context_length_ = config.get<size_t>("context_length");
context_start_ = config.get<int>("context_start");
begin_pad_ = config.get<size_t>("begin_pad");
total_pad_ = config.get<size_t>("total_pad");
}
void calc(const BufferArgs& inputs, const BufferArgs& outputs) override {
CHECK_EQ(1, static_cast<int>(inputs.size()));
CHECK_EQ(1, static_cast<int>(outputs.size()));
const auto in_seqs = dynamic_cast<const SequenceArg&>(inputs[0]);
CHECK(in_seqs.data() && in_seqs.getSequenceIds().data() &&
outputs[0].data());
CHECK_EQ(static_cast<int>(outputs[0].shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.shape().ndims()), 2);
CHECK_EQ(static_cast<int>(in_seqs.getSequenceIds().shape().ndims()), 1);
CHECK_EQ(in_seqs.shape()[0], outputs[0].shape()[0]);
CHECK_EQ(outputs[0].shape()[1], in_seqs.shape()[1] * context_length_);
const auto out_grad_mat = outputs[0].matrix<Device>();
auto w_grad_mat = inputs[0].matrix<Device>();
const auto seq_vec = in_seqs.getSequenceIds().vector<int, Device>();
ContextProjectionBackwardWeight<Device>(out_grad_mat,
w_grad_mat,
seq_vec,
context_length_,
context_start_,
total_pad_,
begin_pad_);
}
private:
size_t context_length_;
int context_start_;
size_t begin_pad_;
size_t total_pad_;
};
REGISTER_TYPED_FUNC(ContextProjectionForward,
CPU,
ContextProjectionForwardFunc);
......@@ -285,5 +384,11 @@ REGISTER_TYPED_FUNC(ContextProjectionForward,
REGISTER_TYPED_FUNC(ContextProjectionBackward,
GPU,
ContextProjectionBackwardFunc);
REGISTER_TYPED_FUNC(ContextProjectionBackwardData,
GPU,
ContextProjectionBackwardDataFunc);
REGISTER_TYPED_FUNC(ContextProjectionBackwardWeight,
GPU,
ContextProjectionBackwardWeightFunc);
#endif
} // namespace paddle
......@@ -58,21 +58,21 @@ void testMatrixProjectionForward(int context_start,
BufferArgs cpu_inputs;
BufferArgs cpu_outputs;
cpu_inputs.addArg(cpu_in);
cpu_inputs.addArg(cpu_weight ? *cpu_weight
: CpuMatrix(nullptr, 0, input_dim));
cpu_inputs.addArg(*cpu_seq);
cpu_outputs.addArg(cpu_out, ADD_TO);
cpu_inputs.addArg(cpu_in, *cpu_seq);
if (cpu_weight) {
cpu_inputs.addArg(*cpu_weight, *cpu_seq);
}
cpu_outputs.addArg(cpu_out, *cpu_seq, ADD_TO);
compare.getCpuFunction()->calc(cpu_inputs, cpu_outputs);
BufferArgs gpu_inputs;
BufferArgs gpu_outputs;
gpu_inputs.addArg(gpu_in);
gpu_inputs.addArg(gpu_weight ? *gpu_weight
: GpuMatrix(nullptr, 0, input_dim));
gpu_inputs.addArg(*gpu_seq);
gpu_outputs.addArg(gpu_out, ADD_TO);
gpu_inputs.addArg(gpu_in, *gpu_seq);
if (gpu_weight) {
gpu_inputs.addArg(*gpu_weight, *gpu_seq);
}
gpu_outputs.addArg(gpu_out, *gpu_seq, ADD_TO);
compare.getGpuFunction()->calc(gpu_inputs, gpu_outputs);
......
......@@ -118,16 +118,15 @@ void ContextProjection::forward() {
/// first use state_, otherwise use weight_(padding false === w nullptr)
auto w_ptr =
state_ ? state_.get() : is_padding ? weight_->getW().get() : nullptr;
auto start_pos = in_->sequenceStartPositions;
const auto start_pos = in_->sequenceStartPositions->getVector(useGpu_);
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*in_->value);
inputs.addArg(CpuMatrix(w_ptr ? w_ptr->getData() : nullptr,
w_ptr ? w_ptr->getHeight() : 0,
input_dim));
inputs.addArg(*in_->sequenceStartPositions->getVector(useGpu_));
outputs.addArg(*out_->value, ADD_TO);
inputs.addArg(*in_->value, *start_pos);
if (w_ptr) {
inputs.addArg(CpuMatrix(w_ptr->getData(), w_ptr->getHeight(), input_dim),
*start_pos);
}
outputs.addArg(*out_->value, *start_pos, ADD_TO);
forward_[0]->calc(inputs, outputs);
if (state_ && config_.context_start() < 0) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册