提交 95a7bc01 编写于 作者: H hedaoyuan

follow comments

上级 784e2184
...@@ -20,9 +20,9 @@ limitations under the License. */ ...@@ -20,9 +20,9 @@ limitations under the License. */
namespace paddle { namespace paddle {
enum TestType { enum TestType {
FORWARD_TEST = 0, kForwardTest = 0,
BACKWARD_INPUT_TEST = 1, kBackwardInputTest = 1,
BACKWARD_FILTER_TEST = 2, kBackwardFilterTest = 2,
}; };
template <DeviceType DType1, DeviceType DType2> template <DeviceType DType1, DeviceType DType2>
...@@ -43,7 +43,7 @@ public: ...@@ -43,7 +43,7 @@ public:
if (padding >= filterSize) break; if (padding >= filterSize) break;
size_t outputSize = size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride; (inputSize - filterSize + 2 * padding + stride) / stride;
LOG(INFO) << " batchSize=" << batchSize VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels << " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize << " inputHeight=" << inputSize
<< " inputWidth=" << inputSize << " inputWidth=" << inputSize
...@@ -72,17 +72,17 @@ public: ...@@ -72,17 +72,17 @@ public:
TensorShape output{ TensorShape output{
batchSize, outputChannels, outputSize, outputSize}; batchSize, outputChannels, outputSize, outputSize};
if (type == FORWARD_TEST) { if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run(); test.run();
} else if (type == BACKWARD_INPUT_TEST) { } else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run(); test.run();
} else if (type == BACKWARD_FILTER_TEST) { } else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input)); test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter)); test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
...@@ -100,23 +100,23 @@ public: ...@@ -100,23 +100,23 @@ public:
TEST(Forward, GEMM) { TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", FORWARD_TEST); "NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
} }
#ifndef PADDLE_ONLY_CPU #ifndef PADDLE_ONLY_CPU
TEST(Forward, GEMM2) { TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST); "GemmConv-CPU", "GemmConv-GPU", kForwardTest);
} }
TEST(BackwardInput, GEMM) { TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", BACKWARD_INPUT_TEST); "GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
} }
TEST(BackwardFilter, GEMM) { TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test( ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST); "GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
} }
#endif #endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册