From c213c27fa2142cfbc0ad0c882feddf1006466086 Mon Sep 17 00:00:00 2001 From: liuqi Date: Fri, 1 Dec 2017 10:32:29 +0800 Subject: [PATCH] Update the logic : output type equals op type. --- mace/kernels/opencl/helper.cc | 6 ++---- mace/ops/conv_2d_test.cc | 24 ++++-------------------- 2 files changed, 6 insertions(+), 24 deletions(-) diff --git a/mace/kernels/opencl/helper.cc b/mace/kernels/opencl/helper.cc index 05221e55..4f4d1c56 100644 --- a/mace/kernels/opencl/helper.cc +++ b/mace/kernels/opencl/helper.cc @@ -57,9 +57,8 @@ void CalImage2DShape(const std::vector &shape, /* NHWC */ std::string DataTypeToCLType(const DataType dt) { switch (dt) { case DT_FLOAT: - return "float"; case DT_HALF: - return "half"; + return "float"; case DT_UINT8: return "uchar"; case DT_INT8: @@ -85,9 +84,8 @@ std::string DataTypeToCLType(const DataType dt) { std::string DataTypeToOPENCLCMDDataType(const DataType dt) { switch (dt) { case DT_FLOAT: - return "f"; case DT_HALF: - return "h"; + return "f"; default: LOG(FATAL) << "Not supported data type for opencl cmd data type"; return ""; diff --git a/mace/ops/conv_2d_test.cc b/mace/ops/conv_2d_test.cc index faaf508c..b4fd374b 100644 --- a/mace/ops/conv_2d_test.cc +++ b/mace/ops/conv_2d_test.cc @@ -194,16 +194,12 @@ void TestNHWCSimple3x3SAME() { TEST_F(Conv2dOpTest, CPUSimple) { TestNHWCSimple3x3VALID(); - TestNHWCSimple3x3VALID(); TestNHWCSimple3x3SAME(); - TestNHWCSimple3x3SAME(); } TEST_F(Conv2dOpTest, OPENCLSimple) { TestNHWCSimple3x3VALID(); - TestNHWCSimple3x3VALID(); TestNHWCSimple3x3SAME(); - TestNHWCSimple3x3SAME(); } template @@ -294,12 +290,10 @@ void TestNHWCSimple3x3WithoutBias() { TEST_F(Conv2dOpTest, CPUWithoutBias) { TestNHWCSimple3x3WithoutBias(); - TestNHWCSimple3x3WithoutBias(); } TEST_F(Conv2dOpTest, OPENCLWithoutBias) { TestNHWCSimple3x3WithoutBias(); - TestNHWCSimple3x3WithoutBias(); } template @@ -408,12 +402,10 @@ static void TestNHWCCombined3x3() { TEST_F(Conv2dOpTest, CPUStride2) { TestNHWCCombined3x3(); - TestNHWCCombined3x3(); } TEST_F(Conv2dOpTest, OPENCLStride2) { TestNHWCCombined3x3(); - TestNHWCCombined3x3(); } template @@ -608,19 +600,10 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { Tensor expected; expected.Copy(*net.GetOutput("Output")); - - std::vector input_data(float_input_data.begin(), float_input_data.end()); - std::vector filter_data(float_filter_data.begin(), float_filter_data.end()); - std::vector bias_data(float_bias_data.begin(), float_bias_data.end()); - net.AddInputFromArray("InputHalf", {batch, height, width, input_channels}, input_data); - net.AddInputFromArray( - "FilterHalf", {kernel_h, kernel_w, input_channels, output_channels}, filter_data); - net.AddInputFromArray("BiasHalf", {output_channels}, bias_data); - // run on gpu - BufferToImage(net, "InputHalf", "InputImage", kernels::BufferType::IN_OUT); - BufferToImage(net, "FilterHalf", "FilterImage", kernels::BufferType::FILTER); - BufferToImage(net, "BiasHalf", "BiasImage", kernels::BufferType::ARGUMENT); + BufferToImage(net, "Input", "InputImage", kernels::BufferType::IN_OUT); + BufferToImage(net, "Filter", "FilterImage", kernels::BufferType::FILTER); + BufferToImage(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT); OpDefBuilder("Conv2D", "Conv2dTest") .Input("InputImage") @@ -630,6 +613,7 @@ static void TestHalfComplexConvNxNS12(const std::vector &shape) { .AddIntsArg("strides", {stride_h, stride_w}) .AddIntArg("padding", type) .AddIntsArg("dilations", {1, 1}) + .AddIntArg("T", static_cast(DataType::DT_HALF)) .Finalize(net.NewOperatorDef()); // Run on device net.RunOp(D); -- GitLab