提交 92f18fc6 编写于 作者: Y yejianwu

support tf basic lstm on cpu

上级 b77d694f
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_FILL_H_
#define MACE_KERNELS_FILL_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
struct FillBase {
explicit FillBase(float value) : value_(value) {}
int value_;
};
template <DeviceType D, class T>
struct FillFunctor;
template <>
struct FillFunctor<DeviceType::CPU, float> : FillBase {
explicit FillFunctor(float value) : FillBase(value) {}
MaceStatus operator()(const Tensor *shape,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(shape->dim_size() == 1) << "Shape must be 1-D";
const index_t num_dims = shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
const int32_t *shape_data = shape->data<int32_t>();
std::vector<index_t> output_shape;
for (index_t i = 0; i < num_dims; ++i) {
MACE_CHECK(shape_data[i] > 0) << "Shape must be non-negative: "
<< shape_data[i];
output_shape.push_back(shape_data[i]);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
float *output_data = output->mutable_data<float>();
std::fill(output_data, output_data + output->size(), value_);
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_FILL_H_
#include <common.h>
__kernel void slice(KERNEL_ERROR_PARAMS
__kernel void split(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input,
__private const int chan_blk_offset,
......
......@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/kernels/slice.h"
#include "mace/kernels/split.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h"
......@@ -21,7 +21,7 @@ namespace mace {
namespace kernels {
template <typename T>
MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
MaceStatus SplitFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input,
const std::vector<Tensor *> &output_list,
StatsFuture *future) {
......@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count;
MACE_CHECK(output_channels % 4 == 0)
<< "output channels of slice op must be divisible by 4";
<< "output channels of split op must be divisible by 4";
std::vector<index_t> output_shape(
{input->dim(0), input->dim(1), input->dim(2), output_channels});
......@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
std::set<std::string> built_options;
OUT_OF_RANGE_CONFIG(kernel_error_);
NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice");
built_options.emplace("-Dslice=" + kernel_name);
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("split");
built_options.emplace("-Dsplit=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" +
DtToCLCMDDt(DataTypeToEnum<T>::value));
MACE_RETURN_IF_ERROR(runtime->BuildKernel("slice",
MACE_RETURN_IF_ERROR(runtime->BuildKernel("split",
kernel_name,
built_options,
&kernel_));
......@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
return MACE_SUCCESS;
}
template struct SliceFunctor<DeviceType::GPU, float>;
template struct SliceFunctor<DeviceType::GPU, half>;
template struct SplitFunctor<DeviceType::GPU, float>;
template struct SplitFunctor<DeviceType::GPU, half>;
} // namespace kernels
} // namespace mace
......@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_SLICE_H_
#define MACE_KERNELS_SLICE_H_
#ifndef MACE_KERNELS_SPLIT_H_
#define MACE_KERNELS_SPLIT_H_
#include <memory>
#include <functional>
......@@ -31,15 +31,15 @@
namespace mace {
namespace kernels {
struct SliceFunctorBase {
explicit SliceFunctorBase(const int32_t axis) : axis_(axis) {}
struct SplitFunctorBase {
explicit SplitFunctorBase(const int32_t axis) : axis_(axis) {}
int32_t axis_;
};
template<DeviceType D, typename T>
struct SliceFunctor : SliceFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {}
struct SplitFunctor : SplitFunctorBase {
explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {}
MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
......@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase {
#ifdef MACE_ENABLE_OPENCL
template<typename T>
struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {}
struct SplitFunctor<DeviceType::GPU, T> : SplitFunctorBase {
explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {}
MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list,
......@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_SLICE_H_
#endif // MACE_KERNELS_SPLIT_H_
......@@ -169,7 +169,6 @@ struct StridedSliceFunctor {
i += strides_data[0]) {
*output_data++ = input_data[i];
}
} else if (input->dim_size() == 2) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
......@@ -179,7 +178,24 @@ struct StridedSliceFunctor {
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
*output_data++ = input_data[i * dim_stride[0] + j];
*output_data++ = input_data[i * input->dim(1) + j];
}
}
} else if (input->dim_size() == 3) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
*output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k];
}
}
}
} else {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/fill.h"
namespace mace {
namespace ops {
void Register_Fill(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Fill")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FillOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_FILL_H_
#define MACE_OPS_FILL_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/fill.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class FillOp : public Operator<D, T> {
public:
FillOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetOptionalArg<float>("value", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *shape = this->Input(SHAPE);
Tensor *output = this->Output(OUTPUT);
return functor_(shape, output, future);
}
private:
kernels::FillFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(SHAPE);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_FILL_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class FillTest : public OpsTestBase {};
namespace {
void TestFill(const std::vector<int32_t> &shape,
const float &value) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Fill", "FillTest")
.Input("Shape")
.AddFloatArg("value", static_cast<float>(value))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"Shape",
{static_cast<index_t>(shape.size())},
shape);
// Run
net.RunOp();
auto output = net.GetTensor("Output");
for (index_t i = 0; i < output->dim_size(); ++i) {
ASSERT_EQ(output->dim(i), shape[i]);
}
const float *output_ptr = output->data<float>();
const index_t size = output->size();
for (index_t i = 0; i < size; ++i) {
ASSERT_EQ(output_ptr[i], value);
}
}
} // namespace
TEST_F(FillTest, Simple) {
TestFill({3, 2, 1}, 5.0f);
TestFill({1, 3}, -1.0f);
}
} // namespace test
} // namespace ops
} // namespace mace
......@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_Fill(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
extern void Register_Gather(OperatorRegistryBase *op_registry);
......@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Slice(OperatorRegistryBase *op_registry);
extern void Register_Split(OperatorRegistryBase *op_registry);
extern void Register_Softmax(OperatorRegistryBase *op_registry);
extern void Register_Stack(OperatorRegistryBase *op_registry);
extern void Register_StridedSlice(OperatorRegistryBase *op_registry);
......@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this);
ops::Register_Eltwise(this);
ops::Register_Fill(this);
ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this);
ops::Register_Gather(this);
......@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this);
ops::Register_Shape(this);
ops::Register_Slice(this);
ops::Register_Split(this);
ops::Register_Softmax(this);
ops::Register_Stack(this);
ops::Register_StridedSlice(this);
......
......@@ -12,30 +12,30 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/slice.h"
#include "mace/ops/split.h"
namespace mace {
namespace ops {
void Register_Slice(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
void Register_Split(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
SliceOp<DeviceType::CPU, float>);
SplitOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::GPU)
.TypeConstraint<float>("T")
.Build(),
SliceOp<DeviceType::GPU, float>);
SplitOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice")
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::GPU)
.TypeConstraint<half>("T")
.Build(),
SliceOp<DeviceType::GPU, half>);
SplitOp<DeviceType::GPU, half>);
#endif // MACE_ENABLE_OPENCL
}
......
......@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_SLICE_H_
#define MACE_OPS_SLICE_H_
#ifndef MACE_OPS_SPLIT_H_
#define MACE_OPS_SPLIT_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/slice.h"
#include "mace/kernels/split.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class SliceOp : public Operator<D, T> {
class SplitOp : public Operator<D, T> {
public:
SliceOp(const OperatorDef &op_def, Workspace *ws)
SplitOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
......@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> {
<< "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs();
const int32_t slice_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0)
const int32_t split_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
MACE_CHECK((input->dim(split_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally.";
return functor_(input, output_list, future);
}
private:
kernels::SliceFunctor<D, T> functor_;
kernels::SplitFunctor<D, T> functor_;
private:
MACE_OP_INPUT_TAGS(INPUT);
......@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> {
} // namespace ops
} // namespace mace
#endif // MACE_OPS_SLICE_H_
#endif // MACE_OPS_SPLIT_H_
......@@ -22,7 +22,7 @@ namespace test {
namespace {
template<DeviceType D, typename T>
void BMSliceHelper(int iters,
void BMSplitHelper(int iters,
const std::vector<index_t> &input_shape,
const index_t num_outputs) {
mace::testing::StopTiming();
......@@ -42,7 +42,7 @@ void BMSliceHelper(int iters,
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest");
auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i));
......@@ -51,7 +51,7 @@ void BMSliceHelper(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
auto builder = OpDefBuilder("Slice", "SliceTest");
auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i));
......@@ -73,28 +73,28 @@ void BMSliceHelper(int iters,
}
} // namespace
#define MACE_BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
#define MACE_BM_SPLIT_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
static void \
MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \
MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSliceHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
BMSplitHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
} \
MACE_BENCHMARK( \
MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
#define MACE_BM_SLICE(N, H, W, C, NO) \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, GPU); \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, half, GPU);
#define MACE_BM_SPLIT(N, H, W, C, NO) \
MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, CPU); \
MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, GPU); \
MACE_BM_SPLIT_MACRO(N, H, W, C, NO, half, GPU);
MACE_BM_SLICE(1, 32, 32, 32, 2);
MACE_BM_SLICE(1, 32, 32, 128, 2);
MACE_BM_SLICE(1, 32, 32, 256, 2);
MACE_BM_SLICE(1, 128, 128, 32, 2);
MACE_BM_SLICE(1, 128, 128, 128, 2);
MACE_BM_SPLIT(1, 32, 32, 32, 2);
MACE_BM_SPLIT(1, 32, 32, 128, 2);
MACE_BM_SPLIT(1, 32, 32, 256, 2);
MACE_BM_SPLIT(1, 128, 128, 32, 2);
MACE_BM_SPLIT(1, 128, 128, 128, 2);
} // namespace test
} // namespace ops
......
......@@ -17,13 +17,13 @@
#include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h"
#include "mace/ops/slice.h"
#include "mace/ops/split.h"
namespace mace {
namespace ops {
namespace test {
class SliceOpTest : public OpsTestBase {};
class SplitOpTest : public OpsTestBase {};
namespace {
template <DeviceType D, typename T>
......@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) {
BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest");
auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i));
......@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) {
builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef());
} else {
auto builder = OpDefBuilder("Slice", "SliceTest").AddIntArg("axis", axis);
auto builder = OpDefBuilder("Split", "SplitTest").AddIntArg("axis", axis);
builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i));
......@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) {
}
} // namespace
TEST_F(SliceOpTest, CPU) {
TEST_F(SplitOpTest, CPU) {
RandomTest<DeviceType::CPU, float>(2, 3);
RandomTest<DeviceType::CPU, float>(4, 3);
RandomTest<DeviceType::CPU, float>(11, 3);
}
TEST_F(SliceOpTest, CPUAxis1) {
TEST_F(SplitOpTest, CPUAxis1) {
RandomTest<DeviceType::CPU, float>(2, 1);
RandomTest<DeviceType::CPU, float>(4, 1);
RandomTest<DeviceType::CPU, float>(11, 1);
}
TEST_F(SliceOpTest, OPENCLFloat) {
TEST_F(SplitOpTest, OPENCLFloat) {
RandomTest<DeviceType::GPU, float>(2, 3);
RandomTest<DeviceType::GPU, float>(4, 3);
RandomTest<DeviceType::GPU, float>(11, 3);
}
TEST_F(SliceOpTest, OPENCLHalf) {
TEST_F(SplitOpTest, OPENCLHalf) {
RandomTest<DeviceType::GPU, half>(2, 3);
RandomTest<DeviceType::GPU, half>(4, 3);
RandomTest<DeviceType::GPU, half>(11, 3);
......
......@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
0, 3, {}, {6});
}
TEST_F(StridedSliceOpTest, TestStridedSliceRank3) {
TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{0, 0, 0}, {2, 3, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, 2, 2},
{1, 2, 5, 6, 7, 8, 11, 12});
TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6,
6, 6}, {1, 0, 0}, {2, 1, 3}, {1, 1, 1}, 0, 0, 0, 0, 0, {1,
1, 3}, {3, 3, 3});
TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6,
6, 6}, {0, 0, 0}, {2, 2, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2,
1, 2}, {1, 1, 3, 3});
}
TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3},
{1, 2, 3, 4, 5, 6});
......
......@@ -88,6 +88,7 @@ MaceSupportedOps = [
'Dequantize',
'Eltwise',
'FoldedBatchNorm',
'Fill',
'FullyConnected',
'Gather',
'Identity',
......@@ -101,6 +102,7 @@ MaceSupportedOps = [
'Reshape',
'ResizeBilinear',
'Slice',
'Split',
'Shape',
'Squeeze',
'Stack',
......@@ -146,6 +148,7 @@ class MaceKeyword(object):
mace_constant_value_str = 'constant_value'
mace_dims_str = 'dims'
mace_axis_str = 'axis'
mace_num_split_str = 'num_split'
mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed'
......
......@@ -68,6 +68,7 @@ TFSupportedOps = [
'Relu6',
'Tanh',
'Sigmoid',
'Fill',
'FusedBatchNorm',
'AvgPool',
'MaxPool',
......@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Relu6.name: self.convert_activation,
TFOpType.Tanh.name: self.convert_activation,
TFOpType.Sigmoid.name: self.convert_activation,
TFOpType.Fill.name: self.convert_fill,
TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
TFOpType.AvgPool.name: self.convert_pooling,
TFOpType.MaxPool.name: self.convert_pooling,
......@@ -458,6 +460,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0
def convert_fill(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Fill.name
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_value_str
value_arg.f = tf_op.inputs[1].eval()
def convert_fused_batchnorm(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.FoldedBatchNorm.name
......@@ -763,19 +773,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output_type.extend([mace_pb2.DT_INT32])
def convert_split(self, tf_op):
# inputs: [dim, input]
axis = tf_op.inputs[0].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
mace_check(axis == 3, 'Split with %d axis only support' % axis)
input_shape = self.infer_tensor_shape(tf_op.inputs[1])
mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0),
"The input's 4th dimension should be a multiple of 4")
op = self.convert_general_op(tf_op)
op.type = MaceOp.Slice.name
op.type = MaceOp.Split.name
del op.input[0]
axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = tf_op.get_attr('num_split')
self._skip_tensor.add(tf_op.inputs[0].name)
......@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface):
"only support concat at "
"channel dimension")
arg.i = 3
producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \
......
......@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/slice.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册