提交 01d52ebf 编写于 作者: H hedaoyuan

Fix RowConvOpTest use CpuGpuFuncCompare.

上级 1e0cc741
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
FunctionCompare test("RowConv", FuncConfig()); CpuGpuFuncCompare test("RowConv", FuncConfig());
test.addSequence(SequenceIdArg(TensorShape{batchSize})); test.addSequence(SequenceIdArg(TensorShape{batchSize}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
...@@ -31,7 +31,7 @@ void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) { ...@@ -31,7 +31,7 @@ void testRowConvFw(size_t batchSize, size_t dim, size_t contextLength) {
} }
void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) { void testRowConvBw(size_t batchSize, size_t dim, size_t contextLength) {
FunctionCompare test("RowConvGrad", FuncConfig()); CpuGpuFuncCompare test("RowConvGrad", FuncConfig());
test.addSequence(SequenceIdArg(TensorShape{batchSize})); test.addSequence(SequenceIdArg(TensorShape{batchSize}));
test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim})); test.addInputs(SequenceArg(VALUE_TYPE_FLOAT, TensorShape{batchSize, dim}));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册