提交 49c31feb 编写于 作者: T tensor-tang

fix typo and op test

上级 02909335
...@@ -63,7 +63,7 @@ void FusionSeqExpandConcatFCOp::InferShape( ...@@ -63,7 +63,7 @@ void FusionSeqExpandConcatFCOp::InferShape(
framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType( framework::OpKernelType FusionSeqExpandConcatFCOp::GetExpectedKernelType(
const framework::ExecutionContext& ctx) const { const framework::ExecutionContext& ctx) const {
return framework::OpKernelType( return framework::OpKernelType(
framework::ToDataType(ctx.Input<framework::LoDTensor>("X")->type()), framework::ToDataType(ctx.MultiInput<LoDTensor>("X")[0]->type()),
ctx.device_context()); ctx.device_context());
} }
...@@ -113,7 +113,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -113,7 +113,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
auto* w = ctx.Input<Tensor>("FCWeight"); auto* w = ctx.Input<Tensor>("FCWeight");
auto* b = ctx.Input<Tensor>("FCBias"); auto* b = ctx.Input<Tensor>("FCBias");
auto* out = ctx.Output<LoDTensor>("Out"); auto* out = ctx.Output<LoDTensor>("Out");
auto* fc_out = ctx.Output<Tensor>("FCOUT"); auto* fc_out = ctx.Output<Tensor>("FCOut");
auto* ref_in = ins[0]; auto* ref_in = ins[0];
auto ref_lod = ref_in->lod(); auto ref_lod = ref_in->lod();
...@@ -164,7 +164,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -164,7 +164,7 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data, math::FCCompute<DeviceContext, T>(blas, total_T, D, M0, ref_in_data, w_data,
out_data, b ? b->data<T>() : NULL); out_data, b ? b->data<T>() : NULL);
w_data = w_data + M0 * D; w_data = w_data + M0 * D;
// first one use write on // first write on
blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data); blas.MatMul(N, D, M1, in1_data, w_data, fc_out_data);
w_data = w_data + M1 * D; w_data = w_data + M1 * D;
for (size_t i = 2; i < ins.size(); ++i) { for (size_t i = 2; i < ins.size(); ++i) {
...@@ -175,16 +175,15 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> { ...@@ -175,16 +175,15 @@ class FusionSeqExpandConcatFCOpKernel : public framework::OpKernel<T> {
K, w_data, D, static_cast<T>(1), fc_out_data, D); K, w_data, D, static_cast<T>(1), fc_out_data, D);
w_data = w_data + K * D; w_data = w_data + K * D;
} }
T* cur_out_data = out_data;
for (int i = 0; i < N; ++i) { for (int i = 0; i < N; ++i) {
int seq_len = ref_lod[0][i + 1] - ref_lod[0][i]; int seq_len = ref_lod[0][i + 1] - ref_lod[0][i];
T* src = fc_out_data + i * D; T* src = fc_out_data + i * D;
for (int step = 0; step < seq_len; ++step) { for (int step = 0; step < seq_len; ++step) {
blas.VADD(D, out_data, src, out_data); blas.VADD(D, cur_out_data, src, cur_out_data);
out_data = out_data + D; cur_out_data = cur_out_data + D;
} }
} }
fc_act(total_T * D, out_data, out_data); fc_act(total_T * D, out_data, out_data);
} }
}; };
......
...@@ -80,16 +80,16 @@ class TestFusionSeqExpandConcatFCOp(OpTest): ...@@ -80,16 +80,16 @@ class TestFusionSeqExpandConcatFCOp(OpTest):
out = fusion_seqexpand_concat_fc(xs, self.lod, w, b, out = fusion_seqexpand_concat_fc(xs, self.lod, w, b,
ACTIVATION[self.fc_act]) ACTIVATION[self.fc_act])
self.inputs = {'X': [(x0, self.lod)], 'FCWeight': w} self.inputs = {'X': [('x0', (x0, self.lod))], 'FCWeight': w}
normal_lod = [i for i in range(bs + 1)] normal_lod = [[1] * bs]
for i in range(num_inputs - 1): for i in range(num_inputs - 1):
self.inputs['X'].append((xs[i + 1], normal_lod)) self.inputs['X'].append(('x%d' % (i + 1), (xs[i + 1], normal_lod)))
if self.with_bias: if self.with_bias:
self.inputs['FCBias'] = b self.inputs['FCBias'] = b
self.outputs = {'Out': (out, self.lod)} self.outputs = {'Out': (out, self.lod)}
self.attrs = {'fc_activation': self.fc_act, } self.attrs = {'fc_activation': self.fc_act}
def test_check_output(self): def test_check_output(self):
self.check_output() self.check_output()
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册