提交 f18d83f3 编写于 作者: D dangqingqing

follow comments

上级 37015fad
......@@ -136,6 +136,7 @@ public:
// check
CHECK_EQ(2UL, inputs.size());
CHECK_EQ(1UL, outputs.size());
// TODO(qingqing): support ASSIGN_TO.
CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here.";
......@@ -144,9 +145,7 @@ public:
auto w = inputs[1];
CHECK(in.data() && out.data() && in.getSequenceId().data());
CHECK_EQ(in.shape().ndims(), 2UL);
CHECK_EQ(out.shape().ndims(), 2UL);
CHECK_EQ(in.shape()[1], out.shape()[1]);
CHECK_EQ(in.shape()[0], out.shape()[0]);
CHECK(in.shape() == out.shape());
CHECK_EQ(w.shape()[1], in.shape()[1]);
auto outMat = out.matrix<Device>();
......@@ -176,6 +175,7 @@ public:
template <DeviceType Device>
class RowConvGradFunc : public FunctionBase {
// TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc
public:
void init(const FuncConfig& config) override {}
......@@ -196,9 +196,8 @@ public:
auto wGrad = outputs[1];
CHECK_EQ(in.shape().ndims(), 2UL);
CHECK_EQ(outGrad.shape().ndims(), 2UL);
CHECK_EQ(in.shape()[1], outGrad.shape()[1]);
CHECK_EQ(in.shape()[0], outGrad.shape()[0]);
CHECK(in.shape() == inGrad.shape());
CHECK(in.shape() == outGrad.shape());
CHECK_EQ(wGrad.shape()[1], in.shape()[1]);
const auto outGMat = outGrad.matrix<Device>();
......
......@@ -43,13 +43,14 @@ void RowConvLayer::forward(PassType passType) {
resetOutput(height, width);
const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_);
wDims_ = TensorShape({contexLength_, width});
MatrixPtr w = weight_->getW();
wDims_ = TensorShape({w->getHeight(), w->getWidth()});
MatrixPtr outV = getOutputValue();
BufferArgs inputs;
BufferArgs outputs;
inputs.addArg(*getInputValue(0), *startPos);
inputs.addArg(*weight_->getW(), wDims_);
inputs.addArg(*w, wDims_);
outputs.addArg(*getOutputValue(), *startPos, ADD_TO);
{
......
......@@ -191,6 +191,14 @@ class LayerType(object):
PAD_LAYER = "pad"
MULTIPLEX_LAYER = "multiplex"
ROW_CONV_LAYER = "row_conv"
PRINT_LAYER = 'print'
PRIORBOX_LAYER = 'priorbox'
CTC_LAYER = 'ctc'
WARP_CTC_LAYER = 'warp_ctc'
CRF_LAYER = 'crf'
CRF_DECODING_LAYER = 'crf_decoding'
NCE_LAYER = 'nce'
RANK_COST = 'rank-cost'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册