提交 95a7bc01 编写于 作者: H hedaoyuan

follow comments

上级 784e2184
......@@ -20,9 +20,9 @@ limitations under the License. */
namespace paddle {
enum TestType {
FORWARD_TEST = 0,
BACKWARD_INPUT_TEST = 1,
BACKWARD_FILTER_TEST = 2,
kForwardTest = 0,
kBackwardInputTest = 1,
kBackwardFilterTest = 2,
};
template <DeviceType DType1, DeviceType DType2>
......@@ -43,16 +43,16 @@ public:
if (padding >= filterSize) break;
size_t outputSize =
(inputSize - filterSize + 2 * padding + stride) / stride;
LOG(INFO) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
VLOG(3) << " batchSize=" << batchSize
<< " inputChannels=" << inputChannels
<< " inputHeight=" << inputSize
<< " inputWidth=" << inputSize
<< " outputChannels=" << outputChannels
<< " filterHeight=" << filterSize
<< " filterWidth=" << filterSize
<< " outputHeight=" << outputSize
<< " outputWidth=" << outputSize
<< " stride=" << stride << " padding=" << padding;
std::vector<size_t> paddings = {padding, padding};
std::vector<size_t> strides = {stride, stride};
......@@ -72,17 +72,17 @@ public:
TensorShape output{
batchSize, outputChannels, outputSize, outputSize};
if (type == FORWARD_TEST) {
if (type == kForwardTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.run();
} else if (type == BACKWARD_INPUT_TEST) {
} else if (type == kBackwardInputTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, filter));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, input), ADD_TO);
test.run();
} else if (type == BACKWARD_FILTER_TEST) {
} else if (type == kBackwardFilterTest) {
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, output));
test.addInputs(BufferArg(VALUE_TYPE_FLOAT, input));
test.addOutputs(BufferArg(VALUE_TYPE_FLOAT, filter));
......@@ -100,23 +100,23 @@ public:
TEST(Forward, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_CPU> test(
"NaiveConv-CPU", "GemmConv-CPU", FORWARD_TEST);
"NaiveConv-CPU", "GemmConv-CPU", kForwardTest);
}
#ifndef PADDLE_ONLY_CPU
TEST(Forward, GEMM2) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConv-CPU", "GemmConv-GPU", FORWARD_TEST);
"GemmConv-CPU", "GemmConv-GPU", kForwardTest);
}
TEST(BackwardInput, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", BACKWARD_INPUT_TEST);
"GemmConvGradInput-CPU", "GemmConvGradInput-GPU", kBackwardInputTest);
}
TEST(BackwardFilter, GEMM) {
ConvolutionTest<DEVICE_TYPE_CPU, DEVICE_TYPE_GPU> test(
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", BACKWARD_FILTER_TEST);
"GemmConvGradFilter-CPU", "GemmConvGradFilter-GPU", kBackwardFilterTest);
}
#endif
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册