提交 86743817 编写于 作者: L liutuo

remove cwise

上级 5085a24f
......@@ -83,7 +83,6 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry);
extern void Register_ChannelShuffle(OperatorRegistry *op_registry);
extern void Register_Concat(OperatorRegistry *op_registry);
extern void Register_Conv2D(OperatorRegistry *op_registry);
extern void Register_CWise(OperatorRegistry *op_registry);
extern void Register_DepthToSpace(OperatorRegistry *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry);
extern void Register_Dequantize(OperatorRegistry *op_registry);
......@@ -123,7 +122,6 @@ OperatorRegistry::OperatorRegistry() {
ops::Register_ChannelShuffle(this);
ops::Register_Concat(this);
ops::Register_Conv2D(this);
ops::Register_CWise(this);
ops::Register_DepthToSpace(this);
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_CWISE_H_
#define MACE_KERNELS_CWISE_H_
#include <algorithm>
#include <memory>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/runtime/opencl/cl2_header.h"
#include "mace/core/tensor.h"
namespace mace {
namespace kernels {
enum CWiseType {
MUL = 0,
ADD = 1,
MAX = 2,
MIN = 3,
SUB = 4,
DIV = 5,
NEG = 6,
ABS = 7,
SQR_DIFF = 8,
};
struct CWiseFunctorBase {
CWiseFunctorBase(const CWiseType type, const float coeff)
: type_(type), coeff_(coeff) {}
CWiseType type_;
float coeff_;
};
template <DeviceType D, typename T>
struct CWiseFunctor : CWiseFunctorBase {
CWiseFunctor(const CWiseType type, const float coeff)
: CWiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
Tensor::MappingGuard input_guard(input);
Tensor::MappingGuard output_guard(output);
const T *input_ptr = input->data<T>();
T *output_ptr = output->mutable_data<T>();
const index_t size = input->size();
switch (type_) {
case MUL:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = coeff_ * input_ptr[i];
}
break;
case ADD:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = coeff_ + input_ptr[i];
}
break;
case MAX:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::max<T>(input_ptr[i], coeff_);
}
break;
case MIN:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = std::min<T>(input_ptr[i], coeff_);
}
break;
case SUB:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input_ptr[i] - coeff_;
}
break;
case DIV:
MACE_CHECK(fabs(coeff_) > 1e-6, "cannot divided by 0.");
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = input_ptr[i] / coeff_;
}
break;
case NEG:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
output_ptr[i] = 0 - input_ptr[i];
}
break;
case ABS:
#pragma omp parallel for
for (index_t i = 0; i < size; ++i) {
T val = input_ptr[i];
output_ptr[i] = (val > 0)? val : 0 - val;
}
break;
default:
LOG(FATAL) << "CWise op not support type " << type_;
}
}
};
template <typename T>
struct CWiseFunctor<DeviceType::OPENCL, T> : CWiseFunctorBase {
CWiseFunctor(const CWiseType type, const float coeff)
: CWiseFunctorBase(type, coeff) {}
void operator()(const Tensor *input,
Tensor *output,
StatsFuture *future);
cl::Kernel kernel_;
uint32_t kwg_size_;
std::unique_ptr<BufferBase> kernel_error_;
std::vector<index_t> input_shape_;
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_CWISE_H_
#include <common.h>
__kernel void cwise(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM2
__read_only image2d_t input, /* [c%4 * w * c/4, h * b] */
__private const int width,
__private const int channel,
__private const float value,
__write_only image2d_t output) {
const int w = get_global_id(0);
const int hb = get_global_id(1);
#ifndef NON_UNIFORM_WORK_GROUP
if (w >= global_size_dim0 || hb >= global_size_dim1) return;
#endif
const int remain_chan = channel - mul24((w / width), 4);
DATA_TYPE4 in0 = READ_IMAGET(input, SAMPLER, (int2)(w, hb));
DATA_TYPE4 in1 = (DATA_TYPE4){value, value, value, value};
DATA_TYPE4 out;
#if CWISE_TYPE == 0
out = in0 * in1;
#elif CWISE_TYPE == 1
out = in0 + in1;
#elif CWISE_TYPE == 2
out = fmax(in0, in1);
#elif CWISE_TYPE == 3
out = fmin(in0, in1);
#elif CWISE_TYPE == 4
out = in0 - in1;
#elif CWISE_TYPE == 5
if (fabs(in1.x) > 0.000001f)
out.x = in0.x / in1.x;
if (fabs(in1.y) > 0.000001f)
out.y = in0.y / in1.y;
if (fabs(in1.z) > 0.000001f)
out.z = in0.z / in1.z;
if (fabs(in1.w) > 0.000001f)
out.w = in0.w / in1.w;
#elif CWISE_TYPE == 6
in1 = (DATA_TYPE4)(0, 0, 0, 0);
out = in1 - in0;
#elif CWISE_TYPE == 7
out = fabs(in0);
#endif
#if CWISE_TYPE == 1 || CWISE_TYPE == 2 || CWISE_TYPE == 3 || CWISE_TYPE == 4
if (remain_chan < 4) {
switch (remain_chan) {
case 1:
out.y = 0;
case 2:
out.z = 0;
case 3:
out.w = 0;
}
}
#endif
WRITE_IMAGET(output, (int2)(w, hb), out);
}
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/cwise.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
namespace mace {
namespace kernels {
template <typename T>
void CWiseFunctor<DeviceType::OPENCL, T>::operator()(const Tensor *input,
Tensor *output,
StatsFuture *future) {
const index_t batch = input->dim(0);
const index_t height = input->dim(1);
const index_t width = input->dim(2);
const index_t channels = input->dim(3);
const index_t channel_blocks = RoundUpDiv4(channels);
const index_t width_pixels = channel_blocks * width;
const index_t batch_height_pixels = batch * height;
auto runtime = OpenCLRuntime::Global();
const uint32_t gws[2] = {static_cast<uint32_t>(width_pixels),
static_cast<uint32_t>(batch_height_pixels)};
if (kernel_.get() == nullptr) {
std::set<std::string> built_options;
auto dt = DataTypeToEnum<T>::value;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("cwise");
built_options.emplace("-Dcwise=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToUpstreamCLDt(dt));
built_options.emplace("-DCMD_DATA_TYPE=" + DtToUpstreamCLCMDDt(dt));
built_options.emplace(MakeString("-DCWISE_TYPE=", type_));
if (runtime->IsOutOfRangeCheckEnabled()) {
built_options.emplace("-DOUT_OF_RANGE_CHECK");
kernel_error_ = std::move(std::unique_ptr<Buffer>(
new Buffer(GetDeviceAllocator(DeviceType::OPENCL), 1)));
kernel_error_->Map(nullptr);
*(kernel_error_->mutable_data<char>()) = 0;
kernel_error_->UnMap();
}
if (runtime->IsNonUniformWorkgroupsSupported()) {
built_options.emplace("-DNON_UNIFORM_WORK_GROUP");
}
kernel_ = runtime->BuildKernel("cwise", kernel_name, built_options);
kwg_size_ =
static_cast<uint32_t>(runtime->GetKernelMaxWorkGroupSize(kernel_));
}
if (!IsVecEqual(input_shape_, input->shape())) {
uint32_t idx = 0;
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_.setArg(idx++,
*(static_cast<cl::Buffer *>(kernel_error_->buffer())));
}
if (!runtime->IsNonUniformWorkgroupsSupported()) {
kernel_.setArg(idx++, gws[0]);
kernel_.setArg(idx++, gws[1]);
}
kernel_.setArg(idx++, *(input->opencl_image()));
kernel_.setArg(idx++, static_cast<int32_t>(width));
kernel_.setArg(idx++, static_cast<int32_t>(channels));
kernel_.setArg(idx++, static_cast<float>(coeff_));
kernel_.setArg(idx++, *(output->opencl_image()));
input_shape_ = input->shape();
}
const std::vector<uint32_t> lws = {kwg_size_ / 16, 16, 0};
std::stringstream ss;
ss << "cwise_opencl_kernel_" << output->dim(0) << "_" << output->dim(1)
<< "_" << output->dim(2) << "_" << output->dim(3);
TuningOrRun2DKernel(kernel_, ss.str(), gws, lws, future);
if (runtime->IsOutOfRangeCheckEnabled()) {
kernel_error_->Map(nullptr);
char *kerror_code = kernel_error_->mutable_data<char>();
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;
kernel_error_->UnMap();
}
}
template struct CWiseFunctor<DeviceType::OPENCL, float>;
template struct CWiseFunctor<DeviceType::OPENCL, half>;
} // namespace kernels
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/cwise.h"
namespace mace {
namespace ops {
void Register_CWise(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
CWiseOp<DeviceType::CPU, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
CWiseOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("CWise")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
CWiseOp<DeviceType::OPENCL, half>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_CWISE_H_
#define MACE_OPS_CWISE_H_
#include <string>
#include "mace/core/operator.h"
#include "mace/kernels/cwise.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class CWiseOp : public Operator<D, T> {
public:
CWiseOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
x_(OperatorBase::GetSingleArgument<float>("x", 1.0)),
functor_(static_cast<kernels::CWiseType>(
OperatorBase::GetSingleArgument<int>(
"type", static_cast<int>(
kernels::CWiseType::ADD))),
this->x_) {}
bool Run(StatsFuture *future) override {
const Tensor *input_tensor = this->Input(INPUT);
Tensor *output_tensor = this->Output(OUTPUT);
output_tensor->ResizeLike(input_tensor);
functor_(input_tensor, output_tensor, future);
return true;
}
protected:
const float x_;
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
private:
kernels::CWiseFunctor<D, T> functor_;
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_CWISE_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void CWise(int iters, int batch, int channels,
int height, int width, float x, int type) {
mace::testing::StopTiming();
OpsTestNet net;
// Add input data
net.AddRandomInput<D, T>("Input", {batch, height, width, channels});
if (D == DeviceType::OPENCL) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("CWise", "CWiseBM")
.Input("InputImage")
.Output("Output")
.AddIntArg("type", type)
.AddFloatArg("x", x)
.Finalize(net.NewOperatorDef());
} else {
OpDefBuilder("CWise", "CWiseBM")
.Input("Input")
.Output("Output")
.AddIntArg("type", type)
.AddFloatArg("x", x)
.Finalize(net.NewOperatorDef());
}
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define BM_CWISE_MACRO(N, C, H, W, X, G, TYPE, DEVICE) \
static void \
BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
CWise<DEVICE, TYPE>(iters, N, C, H, W, X, G); \
} \
BENCHMARK( \
BM_CWISE_##N##_##C##_##H##_##W##_##X##_##G##_##TYPE##_##DEVICE)
#define BM_CWISE(N, C, H, W, X, G) \
BM_CWISE_MACRO(N, C, H, W, X, G, float, CPU); \
BM_CWISE_MACRO(N, C, H, W, X, G, float, OPENCL); \
BM_CWISE_MACRO(N, C, H, W, X, G, half, OPENCL);
BM_CWISE(1, 1, 512, 512, 2, 0);
BM_CWISE(1, 3, 128, 128, 2, 1);
BM_CWISE(1, 3, 512, 512, 2, 4);
BM_CWISE(1, 32, 112, 112, 2, 5);
BM_CWISE(1, 32, 112, 112, 2, 6);
BM_CWISE(1, 32, 112, 112, 2, 7);
BM_CWISE(1, 64, 256, 256, 3, 0);
BM_CWISE(1, 64, 512, 512, 3, 1);
BM_CWISE(1, 128, 56, 56, 3, 4);
BM_CWISE(1, 128, 256, 256, 3, 5);
BM_CWISE(1, 64, 512, 512, 3, 6);
BM_CWISE(1, 64, 512, 512, 3, 7);
BM_CWISE(1, 256, 14, 14, 3, 0);
BM_CWISE(1, 512, 14, 14, 3, 1);
BM_CWISE(1, 1024, 7, 7, 3, 4);
BM_CWISE(32, 1, 256, 256, 3, 5);
BM_CWISE(32, 1, 256, 256, 3, 6);
BM_CWISE(32, 1, 256, 256, 3, 7);
} // namespace test
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
#include "../kernels/cwise.h"
namespace mace {
namespace ops {
namespace test {
class CWiseOpTest : public OpsTestBase {};
namespace {
template <DeviceType D>
void Simple(const kernels::CWiseType type,
const std::vector<index_t> &shape,
const std::vector<float> &input0,
const float x,
const std::vector<float> &output) {
// Construct graph
OpsTestNet net;
// Add input data
net.AddInputFromArray<D, float>("Input1", shape, input0);
if (D == DeviceType::CPU) {
OpDefBuilder("CWise", "CWiseTest")
.Input("Input1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
} else {
BufferToImage<D, half>(&net, "Input1", "InputImg1",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("CWise", "CWiseTest")
.Input("InputImg1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", x)
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "Output",
kernels::BufferType::IN_OUT_CHANNEL);
}
auto expected = CreateTensor<float>(shape, output);
ExpectTensorNear<float>(*expected, *net.GetOutput("Output"), 1e-5, 1e-3);
}
} // namespace
TEST_F(CWiseOpTest, CPUSimple) {
Simple<DeviceType::CPU>(kernels::CWiseType::MUL, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6});
Simple<DeviceType::CPU>(kernels::CWiseType::ADD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8});
Simple<DeviceType::CPU>(kernels::CWiseType::DIV, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60});
Simple<DeviceType::CPU>(kernels::CWiseType::SUB, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4});
Simple<DeviceType::CPU>(kernels::CWiseType::NEG, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6});
Simple<DeviceType::CPU>(kernels::CWiseType::ABS, {1, 1, 2, 3},
{1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6});
}
TEST_F(CWiseOpTest, GPUSimple) {
Simple<DeviceType::OPENCL>(kernels::CWiseType::MUL, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {0.1, 0.2, .3, .4, .5, .6});
Simple<DeviceType::OPENCL>(kernels::CWiseType::ADD, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {3, 4, 5, 6, 7, 8});
Simple<DeviceType::OPENCL>(kernels::CWiseType::DIV, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 0.1, {10, 20, 30, 40, 50, 60});
Simple<DeviceType::OPENCL>(kernels::CWiseType::SUB, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, 0, 1, 2, 3, 4});
Simple<DeviceType::OPENCL>(kernels::CWiseType::NEG, {1, 1, 2, 3},
{1, 2, 3, 4, 5, 6}, 2.0, {-1, -2, -3, -4, -5, -6});
Simple<DeviceType::OPENCL>(kernels::CWiseType::ABS, {1, 1, 2, 3},
{1, -2, -0.0001, 4, 5, 6}, 2.0, {1, 2, 0.0001, 4, 5, 6});
}
namespace {
template <DeviceType D, typename T>
void RandomTest(const kernels::CWiseType type,
const std::vector<index_t> &shape) {
testing::internal::LogToStderr();
srand(time(NULL));
// Construct graph
OpsTestNet net;
// Add input data
net.AddRandomInput<D, float>("Input1", shape);
OpDefBuilder("CWise", "CWiseTest")
.Input("Input1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 1.2)
.Output("Output")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp();
BufferToImage<D, T>(&net, "Input1", "InputImg1",
kernels::BufferType::IN_OUT_CHANNEL);
OpDefBuilder("CWise", "CWiseTest")
.Input("InputImg1")
.AddIntArg("type", static_cast<int>(type))
.AddFloatArg("x", 1.2)
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Output("OutputImg")
.Finalize(net.NewOperatorDef());
// Run
net.RunOp(D);
ImageToBuffer<D, float>(&net, "OutputImg", "OPENCLOutput",
kernels::BufferType::IN_OUT_CHANNEL);
if (DataTypeToEnum<T>::value == DT_FLOAT) {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-5, 1e-4);
} else {
ExpectTensorNear<float>(*net.GetTensor("Output"),
*net.GetOutput("OPENCLOutput"), 1e-2, 1e-2);
}
}
} // namespace
TEST_F(CWiseOpTest, OPENCLRandomFloat) {
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::MUL,
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::ADD,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::SUB,
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::DIV,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, float>(kernels::CWiseType::NEG,
{13, 32, 32, 64});
}
TEST_F(CWiseOpTest, OPENCLRandomHalf) {
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::MUL,
{3, 23, 37, 19});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::ADD,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::SUB,
{3, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::DIV,
{13, 32, 32, 64});
RandomTest<DeviceType::OPENCL, half>(kernels::CWiseType::NEG,
{13, 32, 32, 64});
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -829,7 +829,7 @@ class TFConverter(object):
self.resolved_ops[op.name] = 1
self.unused_tensor.add(get_input_tensor(op, 1).name)
def convert_math(self, op, math_type):
def convert_eltwise(self, op, math_type):
op_def = self.net_def.op.add()
arg = op_def.arg.add()
arg.name = 'T'
......@@ -1144,11 +1144,11 @@ class TFConverter(object):
elif op.type == 'SpaceToDepth':
self.convert_depth_to_space(op, False)
elif op.type in ['Neg', 'neg', 'Negative', 'negative']:
self.convert_math(op, 'NEG')
self.convert_eltwise(op, 'NEG')
elif op.type == 'Mul':
self.convert_math(op, 'MUL')
self.convert_eltwise(op, 'MUL')
elif op.type == 'Sub':
self.convert_math(op, 'SUB')
self.convert_eltwise(op, 'SUB')
elif self.is_softmax(op):
self.convert_softmax(op)
elif op.type in ['Relu', 'Sigmoid', 'Tanh']:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册