提交 49c31feb 编写于 作者: T tensor-tang

fix typo and op test

上级 02909335
......@@ -63,7 +63,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const {
return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()),
framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
ctx.device_context());
}
......@@ -113,7 +113,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
auto* w = ctx.Input<Tensor>("FCWeight");
auto* b = ctx.Input<Tensor>("FCBias");
auto* out = ctx.Output<LoDTensor>("Out");
auto* fc_out = ctx.Output<Tensor>("FCOUT");
auto* fc_out = ctx.Output<Tensor>("FCOut");
auto* ref_in = ins[0];
auto ref_lod = ref_in->lod();
......@@ -164,7 +164,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL);
w_data = w_data + M0 * D;
// first one use write on
// first write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
w_data = w_data + M1 * D;
for (size_t i = 2; i < ins.size(); ++i) {
......@@ -175,16 +175,15 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
K, w_data, D, static_cast<T>(1), fc_out_data, D);
w_data = w_data + K * D;
}
T* cur_out_data = out_data;
for (int i = 0; i < N; ++i) {
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
T* src = fc_out_data + i * D;
for (int step = 0; step < seq_len; ++step) {
blas.VADD(D, out_data, src, out_data);
out_data = out_data + D;
blas.VADD(D, cur_out_data, src, cur_out_data);
cur_out_data = cur_out_data + D;
}
}
fc_act(total_T * D, out_data, out_data);
}
};
......
......@@ -80,16 +80,16 @@ class TestFusionSeqExpandConcatFCOp(OpTest):
out = fusion_seqexpand_concat_fc(xs, self.lod, w, b,
ACTIVATION[self.fc_act])
self.inputs = {'X': [(x0, self.lod)], 'FCWeight': w}
normal_lod = [i for i in range(bs + 1)]
self.inputs = {'X': [('x0', (x0, self.lod))], 'FCWeight': w}
normal_lod = [[1] * bs]
for i in range(num_inputs - 1):
self.inputs['X'].append((xs[i + 1], normal_lod))
self.inputs['X'].append(('x%d' % (i + 1), (xs[i + 1], normal_lod)))
if self.with_bias:
self.inputs['FCBias'] = b
self.outputs = {'Out': (out, self.lod)}
self.attrs = {'fc_activation': self.fc_act, }
self.attrs = {'fc_activation': self.fc_act}
def test_check_output(self):
self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册