提交 c213c27f 编写于 作者: L liuqi

Update the logic : output type equals op type.

上级 7f43b237
......@@ -57,9 +57,8 @@ void CalImage2DShape(const std::vector<index_t> &shape, /* NHWC */
std::string DataTypeToCLType(const DataType dt) {
switch (dt) {
case DT_FLOAT:
return "float";
case DT_HALF:
return "half";
return "float";
case DT_UINT8:
return "uchar";
case DT_INT8:
......@@ -85,9 +84,8 @@ std::string DataTypeToCLType(const DataType dt) {
std::string DataTypeToOPENCLCMDDataType(const DataType dt) {
switch (dt) {
case DT_FLOAT:
return "f";
case DT_HALF:
return "h";
return "f";
default:
LOG(FATAL) << "Not supported data type for opencl cmd data type";
return "";
......
......@@ -194,16 +194,12 @@ void TestNHWCSimple3x3SAME() {
TEST_F(Conv2dOpTest, CPUSimple) {
TestNHWCSimple3x3VALID<DeviceType::CPU, float>();
TestNHWCSimple3x3VALID<DeviceType::CPU, half>();
TestNHWCSimple3x3SAME<DeviceType::CPU, float>();
TestNHWCSimple3x3SAME<DeviceType::CPU, half>();
}
TEST_F(Conv2dOpTest, OPENCLSimple) {
TestNHWCSimple3x3VALID<DeviceType::OPENCL, float>();
TestNHWCSimple3x3VALID<DeviceType::OPENCL, half>();
TestNHWCSimple3x3SAME<DeviceType::OPENCL, float>();
TestNHWCSimple3x3SAME<DeviceType::OPENCL, half>();
}
template<DeviceType D>
......@@ -294,12 +290,10 @@ void TestNHWCSimple3x3WithoutBias() {
TEST_F(Conv2dOpTest, CPUWithoutBias) {
TestNHWCSimple3x3WithoutBias<DeviceType::CPU, float>();
TestNHWCSimple3x3WithoutBias<DeviceType::CPU, half>();
}
TEST_F(Conv2dOpTest, OPENCLWithoutBias) {
TestNHWCSimple3x3WithoutBias<DeviceType::OPENCL, float>();
TestNHWCSimple3x3WithoutBias<DeviceType::OPENCL, half>();
}
template<DeviceType D>
......@@ -408,12 +402,10 @@ static void TestNHWCCombined3x3() {
TEST_F(Conv2dOpTest, CPUStride2) {
TestNHWCCombined3x3<DeviceType::CPU, float>();
TestNHWCCombined3x3<DeviceType::CPU, half>();
}
TEST_F(Conv2dOpTest, OPENCLStride2) {
TestNHWCCombined3x3<DeviceType::OPENCL, float>();
TestNHWCCombined3x3<DeviceType::OPENCL, half>();
}
template<DeviceType D>
......@@ -608,19 +600,10 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
Tensor expected;
expected.Copy(*net.GetOutput("Output"));
std::vector<half> input_data(float_input_data.begin(), float_input_data.end());
std::vector<half> filter_data(float_filter_data.begin(), float_filter_data.end());
std::vector<half> bias_data(float_bias_data.begin(), float_bias_data.end());
net.AddInputFromArray<D, half>("InputHalf", {batch, height, width, input_channels}, input_data);
net.AddInputFromArray<D, half>(
"FilterHalf", {kernel_h, kernel_w, input_channels, output_channels}, filter_data);
net.AddInputFromArray<D, half>("BiasHalf", {output_channels}, bias_data);
// run on gpu
BufferToImage<D, half>(net, "InputHalf", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D, half>(net, "FilterHalf", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D, half>(net, "BiasHalf", "BiasImage", kernels::BufferType::ARGUMENT);
BufferToImage<D, half>(net, "Input", "InputImage", kernels::BufferType::IN_OUT);
BufferToImage<D, half>(net, "Filter", "FilterImage", kernels::BufferType::FILTER);
BufferToImage<D, half>(net, "Bias", "BiasImage", kernels::BufferType::ARGUMENT);
OpDefBuilder("Conv2D", "Conv2dTest")
.Input("InputImage")
......@@ -630,6 +613,7 @@ static void TestHalfComplexConvNxNS12(const std::vector<index_t> &shape) {
.AddIntsArg("strides", {stride_h, stride_w})
.AddIntArg("padding", type)
.AddIntsArg("dilations", {1, 1})
.AddIntArg("T", static_cast<int>(DataType::DT_HALF))
.Finalize(net.NewOperatorDef());
// Run on device
net.RunOp(D);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册