提交 92f18fc6 编写于 作者: Y yejianwu

support tf basic lstm on cpu

上级 b77d694f
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_FILL_H_
#define MACE_KERNELS_FILL_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
struct FillBase {
explicit FillBase(float value) : value_(value) {}
int value_;
};
template <DeviceType D, class T>
struct FillFunctor;
template <>
struct FillFunctor<DeviceType::CPU, float> : FillBase {
explicit FillFunctor(float value) : FillBase(value) {}
MaceStatus operator()(const Tensor *shape,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(shape->dim_size() == 1) << "Shape must be 1-D";
const index_t num_dims = shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
const int32_t *shape_data = shape->data<int32_t>();
std::vector<index_t> output_shape;
for (index_t i = 0; i < num_dims; ++i) {
MACE_CHECK(shape_data[i] > 0) << "Shape must be non-negative: "
<< shape_data[i];
output_shape.push_back(shape_data[i]);
}
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
float *output_data = output->mutable_data<float>();
std::fill(output_data, output_data + output->size(), value_);
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_FILL_H_
#include <common.h> #include <common.h>
__kernel void slice(KERNEL_ERROR_PARAMS __kernel void split(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3 GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, __read_only image2d_t input,
__private const int chan_blk_offset, __private const int chan_blk_offset,
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/kernels/slice.h" #include "mace/kernels/split.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h" #include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h" #include "mace/utils/tuner.h"
...@@ -21,7 +21,7 @@ namespace mace { ...@@ -21,7 +21,7 @@ namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <typename T>
MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( MaceStatus SplitFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input, const Tensor *input,
const std::vector<Tensor *> &output_list, const std::vector<Tensor *> &output_list,
StatsFuture *future) { StatsFuture *future) {
...@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( ...@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
const size_t outputs_count = output_list.size(); const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count; const index_t output_channels = input_channels / outputs_count;
MACE_CHECK(output_channels % 4 == 0) MACE_CHECK(output_channels % 4 == 0)
<< "output channels of slice op must be divisible by 4"; << "output channels of split op must be divisible by 4";
std::vector<index_t> output_shape( std::vector<index_t> output_shape(
{input->dim(0), input->dim(1), input->dim(2), output_channels}); {input->dim(0), input->dim(1), input->dim(2), output_channels});
...@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( ...@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
std::set<std::string> built_options; std::set<std::string> built_options;
OUT_OF_RANGE_CONFIG(kernel_error_); OUT_OF_RANGE_CONFIG(kernel_error_);
NON_UNIFORM_WG_CONFIG; NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice"); std::string kernel_name = MACE_OBFUSCATE_SYMBOL("split");
built_options.emplace("-Dslice=" + kernel_name); built_options.emplace("-Dsplit=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value)); built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + built_options.emplace("-DCMD_DATA_TYPE=" +
DtToCLCMDDt(DataTypeToEnum<T>::value)); DtToCLCMDDt(DataTypeToEnum<T>::value));
MACE_RETURN_IF_ERROR(runtime->BuildKernel("slice", MACE_RETURN_IF_ERROR(runtime->BuildKernel("split",
kernel_name, kernel_name,
built_options, built_options,
&kernel_)); &kernel_));
...@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( ...@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
return MACE_SUCCESS; return MACE_SUCCESS;
} }
template struct SliceFunctor<DeviceType::GPU, float>; template struct SplitFunctor<DeviceType::GPU, float>;
template struct SliceFunctor<DeviceType::GPU, half>; template struct SplitFunctor<DeviceType::GPU, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_KERNELS_SLICE_H_ #ifndef MACE_KERNELS_SPLIT_H_
#define MACE_KERNELS_SLICE_H_ #define MACE_KERNELS_SPLIT_H_
#include <memory> #include <memory>
#include <functional> #include <functional>
...@@ -31,15 +31,15 @@ ...@@ -31,15 +31,15 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
struct SliceFunctorBase { struct SplitFunctorBase {
explicit SliceFunctorBase(const int32_t axis) : axis_(axis) {} explicit SplitFunctorBase(const int32_t axis) : axis_(axis) {}
int32_t axis_; int32_t axis_;
}; };
template<DeviceType D, typename T> template<DeviceType D, typename T>
struct SliceFunctor : SliceFunctorBase { struct SplitFunctor : SplitFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {}
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list, const std::vector<Tensor *> &output_list,
...@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase { ...@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase {
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
template<typename T> template<typename T>
struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase { struct SplitFunctor<DeviceType::GPU, T> : SplitFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {}
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list, const std::vector<Tensor *> &output_list,
...@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase { ...@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_SLICE_H_ #endif // MACE_KERNELS_SPLIT_H_
...@@ -169,7 +169,6 @@ struct StridedSliceFunctor { ...@@ -169,7 +169,6 @@ struct StridedSliceFunctor {
i += strides_data[0]) { i += strides_data[0]) {
*output_data++ = input_data[i]; *output_data++ = input_data[i];
} }
} else if (input->dim_size() == 2) { } else if (input->dim_size() == 2) {
for (index_t i = real_begin_indices[0]; for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0] strides_data[0] > 0 ? i < real_end_indices[0]
...@@ -179,7 +178,24 @@ struct StridedSliceFunctor { ...@@ -179,7 +178,24 @@ struct StridedSliceFunctor {
strides_data[1] > 0 ? j < real_end_indices[1] strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1]; : j > real_end_indices[1];
j += strides_data[1]) { j += strides_data[1]) {
*output_data++ = input_data[i * dim_stride[0] + j]; *output_data++ = input_data[i * input->dim(1) + j];
}
}
} else if (input->dim_size() == 3) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
*output_data++ = input_data[(i * input->dim(1) + j) * input->dim(2) + k];
}
} }
} }
} else { } else {
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/fill.h"
namespace mace {
namespace ops {
void Register_Fill(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Fill")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FillOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_FILL_H_
#define MACE_OPS_FILL_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/fill.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class FillOp : public Operator<D, T> {
public:
FillOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_(OperatorBase::GetOptionalArg<float>("value", 0.0f)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *shape = this->Input(SHAPE);
Tensor *output = this->Output(OUTPUT);
return functor_(shape, output, future);
}
private:
kernels::FillFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(SHAPE);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_FILL_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class FillTest : public OpsTestBase {};
namespace {
void TestFill(const std::vector<int32_t> &shape,
const float &value) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Fill", "FillTest")
.Input("Shape")
.AddFloatArg("value", static_cast<float>(value))
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"Shape",
{static_cast<index_t>(shape.size())},
shape);
// Run
net.RunOp();
auto output = net.GetTensor("Output");
for (index_t i = 0; i < output->dim_size(); ++i) {
ASSERT_EQ(output->dim(i), shape[i]);
}
const float *output_ptr = output->data<float>();
const index_t size = output->size();
for (index_t i = 0; i < size; ++i) {
ASSERT_EQ(output_ptr[i], value);
}
}
} // namespace
TEST_F(FillTest, Simple) {
TestFill({3, 2, 1}, 5.0f);
TestFill({1, 3}, -1.0f);
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); ...@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry); extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry); extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_Fill(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
extern void Register_Gather(OperatorRegistryBase *op_registry); extern void Register_Gather(OperatorRegistryBase *op_registry);
...@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry); ...@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Slice(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry);
extern void Register_Softmax(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry);
extern void Register_Stack(OperatorRegistryBase *op_registry); extern void Register_Stack(OperatorRegistryBase *op_registry);
extern void Register_StridedSlice(OperatorRegistryBase *op_registry); extern void Register_StridedSlice(OperatorRegistryBase *op_registry);
...@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_DepthwiseConv2d(this); ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this); ops::Register_Dequantize(this);
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
ops::Register_Fill(this);
ops::Register_FoldedBatchNorm(this); ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this); ops::Register_FullyConnected(this);
ops::Register_Gather(this); ops::Register_Gather(this);
...@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Reshape(this); ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this); ops::Register_ResizeBilinear(this);
ops::Register_Shape(this); ops::Register_Shape(this);
ops::Register_Slice(this); ops::Register_Split(this);
ops::Register_Softmax(this); ops::Register_Softmax(this);
ops::Register_Stack(this); ops::Register_Stack(this);
ops::Register_StridedSlice(this); ops::Register_StridedSlice(this);
......
...@@ -12,30 +12,30 @@ ...@@ -12,30 +12,30 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/slice.h" #include "mace/ops/split.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
void Register_Slice(OperatorRegistryBase *op_registry) { void Register_Split(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::CPU) .Device(DeviceType::CPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
SliceOp<DeviceType::CPU, float>); SplitOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::GPU) .Device(DeviceType::GPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
SliceOp<DeviceType::GPU, float>); SplitOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::GPU) .Device(DeviceType::GPU)
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
SliceOp<DeviceType::GPU, half>); SplitOp<DeviceType::GPU, half>);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} }
......
...@@ -12,21 +12,21 @@ ...@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_SLICE_H_ #ifndef MACE_OPS_SPLIT_H_
#define MACE_OPS_SLICE_H_ #define MACE_OPS_SPLIT_H_
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/slice.h" #include "mace/kernels/split.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, typename T> template <DeviceType D, typename T>
class SliceOp : public Operator<D, T> { class SplitOp : public Operator<D, T> {
public: public:
SliceOp(const OperatorDef &op_def, Workspace *ws) SplitOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {} functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
...@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> { ...@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> {
<< "There must be at least two outputs for slicing"; << "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs(); const std::vector<Tensor *> output_list = this->Outputs();
const int32_t slice_axis = OperatorBase::GetOptionalArg<int>("axis", 3); const int32_t split_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0) MACE_CHECK((input->dim(split_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally."; << "Outputs do not split input equally.";
return functor_(input, output_list, future); return functor_(input, output_list, future);
} }
private: private:
kernels::SliceFunctor<D, T> functor_; kernels::SplitFunctor<D, T> functor_;
private: private:
MACE_OP_INPUT_TAGS(INPUT); MACE_OP_INPUT_TAGS(INPUT);
...@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> { ...@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> {
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_SLICE_H_ #endif // MACE_OPS_SPLIT_H_
...@@ -22,7 +22,7 @@ namespace test { ...@@ -22,7 +22,7 @@ namespace test {
namespace { namespace {
template<DeviceType D, typename T> template<DeviceType D, typename T>
void BMSliceHelper(int iters, void BMSplitHelper(int iters,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
const index_t num_outputs) { const index_t num_outputs) {
mace::testing::StopTiming(); mace::testing::StopTiming();
...@@ -42,7 +42,7 @@ void BMSliceHelper(int iters, ...@@ -42,7 +42,7 @@ void BMSliceHelper(int iters,
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest"); auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("InputImage"); builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i)); builder = builder.Output(MakeString("OutputImage", i));
...@@ -51,7 +51,7 @@ void BMSliceHelper(int iters, ...@@ -51,7 +51,7 @@ void BMSliceHelper(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
auto builder = OpDefBuilder("Slice", "SliceTest"); auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("Input"); builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i)); builder = builder.Output(MakeString("Output", i));
...@@ -73,28 +73,28 @@ void BMSliceHelper(int iters, ...@@ -73,28 +73,28 @@ void BMSliceHelper(int iters,
} }
} // namespace } // namespace
#define MACE_BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ #define MACE_BM_SPLIT_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
static void \ static void \
MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \ MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \ const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::MaccProcessed(tot); \ mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSliceHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \ BMSplitHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
} \ } \
MACE_BENCHMARK( \ MACE_BENCHMARK( \
MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
#define MACE_BM_SLICE(N, H, W, C, NO) \ #define MACE_BM_SPLIT(N, H, W, C, NO) \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \ MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, CPU); \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, GPU); \ MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, GPU); \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, half, GPU); MACE_BM_SPLIT_MACRO(N, H, W, C, NO, half, GPU);
MACE_BM_SLICE(1, 32, 32, 32, 2); MACE_BM_SPLIT(1, 32, 32, 32, 2);
MACE_BM_SLICE(1, 32, 32, 128, 2); MACE_BM_SPLIT(1, 32, 32, 128, 2);
MACE_BM_SLICE(1, 32, 32, 256, 2); MACE_BM_SPLIT(1, 32, 32, 256, 2);
MACE_BM_SLICE(1, 128, 128, 32, 2); MACE_BM_SPLIT(1, 128, 128, 32, 2);
MACE_BM_SLICE(1, 128, 128, 128, 2); MACE_BM_SPLIT(1, 128, 128, 128, 2);
} // namespace test } // namespace test
} // namespace ops } // namespace ops
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/ops/slice.h" #include "mace/ops/split.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace test { namespace test {
class SliceOpTest : public OpsTestBase {}; class SplitOpTest : public OpsTestBase {};
namespace { namespace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
...@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) { ...@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest"); auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("InputImage"); builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i)); builder = builder.Output(MakeString("OutputImage", i));
...@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) { ...@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) {
builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
auto builder = OpDefBuilder("Slice", "SliceTest").AddIntArg("axis", axis); auto builder = OpDefBuilder("Split", "SplitTest").AddIntArg("axis", axis);
builder.Input("Input"); builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i)); builder = builder.Output(MakeString("Output", i));
...@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) { ...@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) {
} }
} // namespace } // namespace
TEST_F(SliceOpTest, CPU) { TEST_F(SplitOpTest, CPU) {
RandomTest<DeviceType::CPU, float>(2, 3); RandomTest<DeviceType::CPU, float>(2, 3);
RandomTest<DeviceType::CPU, float>(4, 3); RandomTest<DeviceType::CPU, float>(4, 3);
RandomTest<DeviceType::CPU, float>(11, 3); RandomTest<DeviceType::CPU, float>(11, 3);
} }
TEST_F(SliceOpTest, CPUAxis1) { TEST_F(SplitOpTest, CPUAxis1) {
RandomTest<DeviceType::CPU, float>(2, 1); RandomTest<DeviceType::CPU, float>(2, 1);
RandomTest<DeviceType::CPU, float>(4, 1); RandomTest<DeviceType::CPU, float>(4, 1);
RandomTest<DeviceType::CPU, float>(11, 1); RandomTest<DeviceType::CPU, float>(11, 1);
} }
TEST_F(SliceOpTest, OPENCLFloat) { TEST_F(SplitOpTest, OPENCLFloat) {
RandomTest<DeviceType::GPU, float>(2, 3); RandomTest<DeviceType::GPU, float>(2, 3);
RandomTest<DeviceType::GPU, float>(4, 3); RandomTest<DeviceType::GPU, float>(4, 3);
RandomTest<DeviceType::GPU, float>(11, 3); RandomTest<DeviceType::GPU, float>(11, 3);
} }
TEST_F(SliceOpTest, OPENCLHalf) { TEST_F(SplitOpTest, OPENCLHalf) {
RandomTest<DeviceType::GPU, half>(2, 3); RandomTest<DeviceType::GPU, half>(2, 3);
RandomTest<DeviceType::GPU, half>(4, 3); RandomTest<DeviceType::GPU, half>(4, 3);
RandomTest<DeviceType::GPU, half>(11, 3); RandomTest<DeviceType::GPU, half>(11, 3);
......
...@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { ...@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
0, 3, {}, {6}); 0, 3, {}, {6});
} }
TEST_F(StridedSliceOpTest, TestStridedSliceRank3) {
TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{0, 0, 0}, {2, 3, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, 2, 2},
{1, 2, 5, 6, 7, 8, 11, 12});
TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6,
6, 6}, {1, 0, 0}, {2, 1, 3}, {1, 1, 1}, 0, 0, 0, 0, 0, {1,
1, 3}, {3, 3, 3});
TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6,
6, 6}, {0, 0, 0}, {2, 2, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2,
1, 2}, {1, 1, 3, 3});
}
TEST_F(StridedSliceOpTest, TestSlice) { TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3},
{1, 2, 3, 4, 5, 6}); {1, 2, 3, 4, 5, 6});
......
...@@ -88,6 +88,7 @@ MaceSupportedOps = [ ...@@ -88,6 +88,7 @@ MaceSupportedOps = [
'Dequantize', 'Dequantize',
'Eltwise', 'Eltwise',
'FoldedBatchNorm', 'FoldedBatchNorm',
'Fill',
'FullyConnected', 'FullyConnected',
'Gather', 'Gather',
'Identity', 'Identity',
...@@ -101,6 +102,7 @@ MaceSupportedOps = [ ...@@ -101,6 +102,7 @@ MaceSupportedOps = [
'Reshape', 'Reshape',
'ResizeBilinear', 'ResizeBilinear',
'Slice', 'Slice',
'Split',
'Shape', 'Shape',
'Squeeze', 'Squeeze',
'Stack', 'Stack',
...@@ -146,6 +148,7 @@ class MaceKeyword(object): ...@@ -146,6 +148,7 @@ class MaceKeyword(object):
mace_constant_value_str = 'constant_value' mace_constant_value_str = 'constant_value'
mace_dims_str = 'dims' mace_dims_str = 'dims'
mace_axis_str = 'axis' mace_axis_str = 'axis'
mace_num_split_str = 'num_split'
mace_keepdims_str = 'keepdims' mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape' mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed' mace_winograd_filter_transformed = 'is_filter_transformed'
......
...@@ -68,6 +68,7 @@ TFSupportedOps = [ ...@@ -68,6 +68,7 @@ TFSupportedOps = [
'Relu6', 'Relu6',
'Tanh', 'Tanh',
'Sigmoid', 'Sigmoid',
'Fill',
'FusedBatchNorm', 'FusedBatchNorm',
'AvgPool', 'AvgPool',
'MaxPool', 'MaxPool',
...@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Relu6.name: self.convert_activation, TFOpType.Relu6.name: self.convert_activation,
TFOpType.Tanh.name: self.convert_activation, TFOpType.Tanh.name: self.convert_activation,
TFOpType.Sigmoid.name: self.convert_activation, TFOpType.Sigmoid.name: self.convert_activation,
TFOpType.Fill.name: self.convert_fill,
TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm, TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
TFOpType.AvgPool.name: self.convert_pooling, TFOpType.AvgPool.name: self.convert_pooling,
TFOpType.MaxPool.name: self.convert_pooling, TFOpType.MaxPool.name: self.convert_pooling,
...@@ -458,6 +460,14 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -458,6 +460,14 @@ class TensorflowConverter(base_converter.ConverterInterface):
limit_arg.name = MaceKeyword.mace_activation_max_limit_str limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0 limit_arg.f = 6.0
def convert_fill(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Fill.name
value_arg = op.arg.add()
value_arg.name = MaceKeyword.mace_value_str
value_arg.f = tf_op.inputs[1].eval()
def convert_fused_batchnorm(self, tf_op): def convert_fused_batchnorm(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.FoldedBatchNorm.name op.type = MaceOp.FoldedBatchNorm.name
...@@ -763,19 +773,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -763,19 +773,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output_type.extend([mace_pb2.DT_INT32]) op.output_type.extend([mace_pb2.DT_INT32])
def convert_split(self, tf_op): def convert_split(self, tf_op):
# inputs: [dim, input]
axis = tf_op.inputs[0].eval().astype(np.int32) axis = tf_op.inputs[0].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
mace_check(axis == 3, 'Split with %d axis only support' % axis)
input_shape = self.infer_tensor_shape(tf_op.inputs[1]) input_shape = self.infer_tensor_shape(tf_op.inputs[1])
mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0),
"The input's 4th dimension should be a multiple of 4")
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Slice.name op.type = MaceOp.Split.name
del op.input[0] del op.input[0]
axis_arg = op.arg.add() axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis axis_arg.i = axis
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = tf_op.get_attr('num_split')
self._skip_tensor.add(tf_op.inputs[0].name) self._skip_tensor.add(tf_op.inputs[0].name)
...@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface):
"only support concat at " "only support concat at "
"channel dimension") "channel dimension")
arg.i = 3 arg.i = 3
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \ if producer.type == MaceOp.FullyConnected.name and \
......
...@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): ...@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/slice.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl"))
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册