提交 f18d83f3 编写于 作者: D dangqingqing

follow comments

上级 37015fad
...@@ -136,6 +136,7 @@ public: ...@@ -136,6 +136,7 @@ public:
// check // check
CHECK_EQ(2UL, inputs.size()); CHECK_EQ(2UL, inputs.size());
CHECK_EQ(1UL, outputs.size()); CHECK_EQ(1UL, outputs.size());
// TODO(qingqing): support ASSIGN_TO.
CHECK_EQ(outputs[0].getArgType(), ADD_TO); CHECK_EQ(outputs[0].getArgType(), ADD_TO);
CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg()) CHECK(inputs[0].isSequenceArg() && outputs[0].isSequenceArg())
<< "SequenceArg required here."; << "SequenceArg required here.";
...@@ -144,9 +145,7 @@ public: ...@@ -144,9 +145,7 @@ public:
auto w = inputs[1]; auto w = inputs[1];
CHECK(in.data() && out.data() && in.getSequenceId().data()); CHECK(in.data() && out.data() && in.getSequenceId().data());
CHECK_EQ(in.shape().ndims(), 2UL); CHECK_EQ(in.shape().ndims(), 2UL);
CHECK_EQ(out.shape().ndims(), 2UL); CHECK(in.shape() == out.shape());
CHECK_EQ(in.shape()[1], out.shape()[1]);
CHECK_EQ(in.shape()[0], out.shape()[0]);
CHECK_EQ(w.shape()[1], in.shape()[1]); CHECK_EQ(w.shape()[1], in.shape()[1]);
auto outMat = out.matrix<Device>(); auto outMat = out.matrix<Device>();
...@@ -176,6 +175,7 @@ public: ...@@ -176,6 +175,7 @@ public:
template <DeviceType Device> template <DeviceType Device>
class RowConvGradFunc : public FunctionBase { class RowConvGradFunc : public FunctionBase {
// TODO(qingqing): split into RowConvDataFunc and RowConvWeightFunc
public: public:
void init(const FuncConfig& config) override {} void init(const FuncConfig& config) override {}
...@@ -196,9 +196,8 @@ public: ...@@ -196,9 +196,8 @@ public:
auto wGrad = outputs[1]; auto wGrad = outputs[1];
CHECK_EQ(in.shape().ndims(), 2UL); CHECK_EQ(in.shape().ndims(), 2UL);
CHECK_EQ(outGrad.shape().ndims(), 2UL); CHECK(in.shape() == inGrad.shape());
CHECK_EQ(in.shape()[1], outGrad.shape()[1]); CHECK(in.shape() == outGrad.shape());
CHECK_EQ(in.shape()[0], outGrad.shape()[0]);
CHECK_EQ(wGrad.shape()[1], in.shape()[1]); CHECK_EQ(wGrad.shape()[1], in.shape()[1]);
const auto outGMat = outGrad.matrix<Device>(); const auto outGMat = outGrad.matrix<Device>();
......
...@@ -43,13 +43,14 @@ void RowConvLayer::forward(PassType passType) { ...@@ -43,13 +43,14 @@ void RowConvLayer::forward(PassType passType) {
resetOutput(height, width); resetOutput(height, width);
const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_); const auto startPos = getInput(0).sequenceStartPositions->getVector(useGpu_);
wDims_ = TensorShape({contexLength_, width}); MatrixPtr w = weight_->getW();
wDims_ = TensorShape({w->getHeight(), w->getWidth()});
MatrixPtr outV = getOutputValue(); MatrixPtr outV = getOutputValue();
BufferArgs inputs; BufferArgs inputs;
BufferArgs outputs; BufferArgs outputs;
inputs.addArg(*getInputValue(0), *startPos); inputs.addArg(*getInputValue(0), *startPos);
inputs.addArg(*weight_->getW(), wDims_); inputs.addArg(*w, wDims_);
outputs.addArg(*getOutputValue(), *startPos, ADD_TO); outputs.addArg(*getOutputValue(), *startPos, ADD_TO);
{ {
......
...@@ -191,6 +191,14 @@ class LayerType(object): ...@@ -191,6 +191,14 @@ class LayerType(object):
PAD_LAYER = "pad" PAD_LAYER = "pad"
MULTIPLEX_LAYER = "multiplex" MULTIPLEX_LAYER = "multiplex"
ROW_CONV_LAYER = "row_conv" ROW_CONV_LAYER = "row_conv"
PRINT_LAYER = 'print'
PRIORBOX_LAYER = 'priorbox'
CTC_LAYER = 'ctc'
WARP_CTC_LAYER = 'warp_ctc'
CRF_LAYER = 'crf'
CRF_DECODING_LAYER = 'crf_decoding'
NCE_LAYER = 'nce' NCE_LAYER = 'nce'
RANK_COST = 'rank-cost' RANK_COST = 'rank-cost'
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册