“46ed40a8aef7eb57f5740206aff6faf1aad61446”上不存在“mobile/src/operators/kernel/compare_kernel.h”
提交 70a41900 编写于 作者: 李寅

Merge branch 'support_basiclstm_cpu' into 'master'

Support basiclstm cpu

See merge request !733
...@@ -82,6 +82,8 @@ in one deployment file. ...@@ -82,6 +82,8 @@ in one deployment file.
- The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU. - The running device, one of [cpu, gpu, dsp, cpu_gpu]. cpu_gpu contains CPU and GPU model definition so you can run the model on both CPU and GPU.
* - data_type * - data_type
- [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP. - [optional] The data type used for specified runtime. [fp16_fp32, fp32_fp32] for GPU, default is fp16_fp32, [fp32] for CPU and [uint8] for DSP.
* - input_data_types
- [optional] The input data type for specific op(eg. gather), which can be [int32, float32], default to float32.
* - limit_opencl_kernel_time * - limit_opencl_kernel_time
- [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0. - [optional] Whether splitting the OpenCL kernel within 1 ms to keep UI responsiveness, default is 0.
* - obfuscate * - obfuscate
......
...@@ -12,7 +12,7 @@ Operator lists ...@@ -12,7 +12,7 @@ Operator lists
"BIAS_ADD","Y","" "BIAS_ADD","Y",""
"CAST","Y","Only CPU and TensorFlow model is supported." "CAST","Y","Only CPU and TensorFlow model is supported."
"CHANNEL_SHUFFLE","Y","" "CHANNEL_SHUFFLE","Y",""
"CONCATENATION","Y","Only support channel axis concatenation." "CONCATENATION","Y","For GPU only support channel axis concatenation."
"CONV_2D","Y","Fusion with BN and activation layer is supported." "CONV_2D","Y","Fusion with BN and activation layer is supported."
"CROP","Y","Only Caffe's crop layer is supported (in GPU, offset on channel-dim should be dividable by 4)." "CROP","Y","Only Caffe's crop layer is supported (in GPU, offset on channel-dim should be dividable by 4)."
"DECONV_2D","Y","Supports Caffe's Deconvolution and TensorFlow's tf.layers.conv2d_transpose." "DECONV_2D","Y","Supports Caffe's Deconvolution and TensorFlow's tf.layers.conv2d_transpose."
...@@ -20,7 +20,7 @@ Operator lists ...@@ -20,7 +20,7 @@ Operator lists
"DEPTH_TO_SPACE","Y","" "DEPTH_TO_SPACE","Y",""
"DEQUANTIZE","Y","Model quantization will be supported later." "DEQUANTIZE","Y","Model quantization will be supported later."
"ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL" "ELEMENT_WISE","Y","ADD/MUL/DIV/MIN/MAX/NEG/ABS/SQR_DIFF/POW/RSQRT/EQUAL"
"EMBEDDING_LOOKUP","Y","Only support channel axis concatenation." "EMBEDDING_LOOKUP","Y",""
"FULLY_CONNECTED","Y","" "FULLY_CONNECTED","Y",""
"GROUP_CONV_2D","","Caffe model with group count = channel count is supported." "GROUP_CONV_2D","","Caffe model with group count = channel count is supported."
"IDENTITY","Y","Only TensorFlow model is supported." "IDENTITY","Y","Only TensorFlow model is supported."
...@@ -44,7 +44,7 @@ Operator lists ...@@ -44,7 +44,7 @@ Operator lists
"SHAPE","Y","Only CPU and TensorFlow is supported." "SHAPE","Y","Only CPU and TensorFlow is supported."
"STACK","Y","Only CPU and TensorFlow is supported." "STACK","Y","Only CPU and TensorFlow is supported."
"STRIDEDSLICE","Y","Only CPU and TensorFlow is supported." "STRIDEDSLICE","Y","Only CPU and TensorFlow is supported."
"SLICE","Y","In TensorFlow, this op is equivalent to SPLIT; Only support channel axis slice." "SPLIT","Y","In Caffe, this op is equivalent to SLICE; For GPU only support channel axis slice."
"SOFTMAX","Y","" "SOFTMAX","Y",""
"SPACE_TO_BATCH_ND", "Y","" "SPACE_TO_BATCH_ND", "Y",""
"SPACE_TO_DEPTH","Y","" "SPACE_TO_DEPTH","Y",""
......
...@@ -70,7 +70,7 @@ MaceStatus OpenCLAllocator::New(size_t nbytes, void **result) const { ...@@ -70,7 +70,7 @@ MaceStatus OpenCLAllocator::New(size_t nbytes, void **result) const {
MaceStatus OpenCLAllocator::NewImage(const std::vector<size_t> &image_shape, MaceStatus OpenCLAllocator::NewImage(const std::vector<size_t> &image_shape,
const DataType dt, const DataType dt,
void **result) const { void **result) const {
MACE_CHECK(image_shape.size() == 2) << "Image shape's size must equal 2"; MACE_CHECK(image_shape.size() == 2, "Image shape's size must equal 2");
VLOG(3) << "Allocate OpenCL image: " << image_shape[0] << ", " VLOG(3) << "Allocate OpenCL image: " << image_shape[0] << ", "
<< image_shape[1]; << image_shape[1];
...@@ -134,7 +134,7 @@ void *OpenCLAllocator::Map(void *buffer, size_t offset, size_t nbytes) const { ...@@ -134,7 +134,7 @@ void *OpenCLAllocator::Map(void *buffer, size_t offset, size_t nbytes) const {
void *OpenCLAllocator::MapImage(void *buffer, void *OpenCLAllocator::MapImage(void *buffer,
const std::vector<size_t> &image_shape, const std::vector<size_t> &image_shape,
std::vector<size_t> *mapped_image_pitch) const { std::vector<size_t> *mapped_image_pitch) const {
MACE_CHECK(image_shape.size() == 2) << "Just support map 2d image"; MACE_CHECK(image_shape.size() == 2, "Just support map 2d image");
auto cl_image = static_cast<cl::Image2D *>(buffer); auto cl_image = static_cast<cl::Image2D *>(buffer);
std::array<size_t, 3> origin = {0, 0, 0}; std::array<size_t, 3> origin = {0, 0, 0};
std::array<size_t, 3> region = {image_shape[0], image_shape[1], 1}; std::array<size_t, 3> region = {image_shape[0], image_shape[1], 1};
......
...@@ -39,7 +39,7 @@ std::string DataTypeToString(const DataType dt) { ...@@ -39,7 +39,7 @@ std::string DataTypeToString(const DataType dt) {
#endif #endif
{DT_UINT8, "DT_UINT8"}, {DT_UINT8, "DT_UINT8"},
{DT_INT32, "DT_UINT32"}}; {DT_INT32, "DT_UINT32"}};
MACE_CHECK(dt != DT_INVALID) << "Not support Invalid data type"; MACE_CHECK(dt != DT_INVALID, "Not support Invalid data type");
return dtype_string_map[dt]; return dtype_string_map[dt];
} }
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_KERNELS_FILL_H_
#define MACE_KERNELS_FILL_H_
#include <algorithm>
#include <functional>
#include <vector>
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/public/mace.h"
namespace mace {
namespace kernels {
template <DeviceType D, class T>
struct FillFunctor;
template <>
struct FillFunctor<DeviceType::CPU, float> {
FillFunctor() {}
MaceStatus operator()(const Tensor *shape,
const Tensor *value,
Tensor *output,
StatsFuture *future) {
MACE_UNUSED(future);
MACE_CHECK(shape->dim_size() == 1, "Shape must be 1-D");
const index_t num_dims = shape->dim(0);
Tensor::MappingGuard shape_guard(shape);
const int32_t *shape_data = shape->data<int32_t>();
std::vector<index_t> output_shape;
for (index_t i = 0; i < num_dims; ++i) {
MACE_CHECK(shape_data[i] > 0, "Shape must be non-negative: ",
shape_data[i]);
output_shape.push_back(shape_data[i]);
}
Tensor::MappingGuard value_guard(value);
const float *value_data = value->data<float>();
MACE_RETURN_IF_ERROR(output->Resize(output_shape));
Tensor::MappingGuard output_guard(output);
float *output_data = output->mutable_data<float>();
std::fill(output_data, output_data + output->size(), *value_data);
return MACE_SUCCESS;
}
};
} // namespace kernels
} // namespace mace
#endif // MACE_KERNELS_FILL_H_
#include <common.h> #include <common.h>
__kernel void slice(KERNEL_ERROR_PARAMS __kernel void split(KERNEL_ERROR_PARAMS
GLOBAL_WORK_GROUP_SIZE_DIM3 GLOBAL_WORK_GROUP_SIZE_DIM3
__read_only image2d_t input, __read_only image2d_t input,
__private const int chan_blk_offset, __private const int chan_blk_offset,
......
...@@ -58,7 +58,7 @@ namespace kernels { ...@@ -58,7 +58,7 @@ namespace kernels {
if (runtime->IsOutOfRangeCheckEnabled()) { \ if (runtime->IsOutOfRangeCheckEnabled()) { \
(kernel_error)->Map(nullptr); \ (kernel_error)->Map(nullptr); \
char *kerror_code = (kernel_error)->mutable_data<char>(); \ char *kerror_code = (kernel_error)->mutable_data<char>(); \
MACE_CHECK(*kerror_code == 0) << "Kernel error code: " << *kerror_code;\ MACE_CHECK(*kerror_code == 0, "Kernel error code: ", *kerror_code);\
(kernel_error)->UnMap(); \ (kernel_error)->UnMap(); \
} }
......
...@@ -12,7 +12,7 @@ ...@@ -12,7 +12,7 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/kernels/slice.h" #include "mace/kernels/split.h"
#include "mace/core/runtime/opencl/opencl_runtime.h" #include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/kernels/opencl/helper.h" #include "mace/kernels/opencl/helper.h"
#include "mace/utils/tuner.h" #include "mace/utils/tuner.h"
...@@ -21,7 +21,7 @@ namespace mace { ...@@ -21,7 +21,7 @@ namespace mace {
namespace kernels { namespace kernels {
template <typename T> template <typename T>
MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( MaceStatus SplitFunctor<DeviceType::GPU, T>::operator()(
const Tensor *input, const Tensor *input,
const std::vector<Tensor *> &output_list, const std::vector<Tensor *> &output_list,
StatsFuture *future) { StatsFuture *future) {
...@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( ...@@ -29,7 +29,7 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
const size_t outputs_count = output_list.size(); const size_t outputs_count = output_list.size();
const index_t output_channels = input_channels / outputs_count; const index_t output_channels = input_channels / outputs_count;
MACE_CHECK(output_channels % 4 == 0) MACE_CHECK(output_channels % 4 == 0)
<< "output channels of slice op must be divisible by 4"; << "output channels of split op must be divisible by 4";
std::vector<index_t> output_shape( std::vector<index_t> output_shape(
{input->dim(0), input->dim(1), input->dim(2), output_channels}); {input->dim(0), input->dim(1), input->dim(2), output_channels});
...@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( ...@@ -46,12 +46,12 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
std::set<std::string> built_options; std::set<std::string> built_options;
OUT_OF_RANGE_CONFIG(kernel_error_); OUT_OF_RANGE_CONFIG(kernel_error_);
NON_UNIFORM_WG_CONFIG; NON_UNIFORM_WG_CONFIG;
std::string kernel_name = MACE_OBFUSCATE_SYMBOL("slice"); std::string kernel_name = MACE_OBFUSCATE_SYMBOL("split");
built_options.emplace("-Dslice=" + kernel_name); built_options.emplace("-Dsplit=" + kernel_name);
built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value)); built_options.emplace("-DDATA_TYPE=" + DtToCLDt(DataTypeToEnum<T>::value));
built_options.emplace("-DCMD_DATA_TYPE=" + built_options.emplace("-DCMD_DATA_TYPE=" +
DtToCLCMDDt(DataTypeToEnum<T>::value)); DtToCLCMDDt(DataTypeToEnum<T>::value));
MACE_RETURN_IF_ERROR(runtime->BuildKernel("slice", MACE_RETURN_IF_ERROR(runtime->BuildKernel("split",
kernel_name, kernel_name,
built_options, built_options,
&kernel_)); &kernel_));
...@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()( ...@@ -116,8 +116,8 @@ MaceStatus SliceFunctor<DeviceType::GPU, T>::operator()(
return MACE_SUCCESS; return MACE_SUCCESS;
} }
template struct SliceFunctor<DeviceType::GPU, float>; template struct SplitFunctor<DeviceType::GPU, float>;
template struct SliceFunctor<DeviceType::GPU, half>; template struct SplitFunctor<DeviceType::GPU, half>;
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
...@@ -12,8 +12,8 @@ ...@@ -12,8 +12,8 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_KERNELS_SLICE_H_ #ifndef MACE_KERNELS_SPLIT_H_
#define MACE_KERNELS_SLICE_H_ #define MACE_KERNELS_SPLIT_H_
#include <memory> #include <memory>
#include <functional> #include <functional>
...@@ -31,15 +31,15 @@ ...@@ -31,15 +31,15 @@
namespace mace { namespace mace {
namespace kernels { namespace kernels {
struct SliceFunctorBase { struct SplitFunctorBase {
explicit SliceFunctorBase(const int32_t axis) : axis_(axis) {} explicit SplitFunctorBase(const int32_t axis) : axis_(axis) {}
int32_t axis_; int32_t axis_;
}; };
template<DeviceType D, typename T> template<DeviceType D, typename T>
struct SliceFunctor : SliceFunctorBase { struct SplitFunctor : SplitFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {}
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list, const std::vector<Tensor *> &output_list,
...@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase { ...@@ -89,8 +89,8 @@ struct SliceFunctor : SliceFunctorBase {
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
template<typename T> template<typename T>
struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase { struct SplitFunctor<DeviceType::GPU, T> : SplitFunctorBase {
explicit SliceFunctor(const int32_t axis) : SliceFunctorBase(axis) {} explicit SplitFunctor(const int32_t axis) : SplitFunctorBase(axis) {}
MaceStatus operator()(const Tensor *input, MaceStatus operator()(const Tensor *input,
const std::vector<Tensor *> &output_list, const std::vector<Tensor *> &output_list,
...@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase { ...@@ -104,4 +104,4 @@ struct SliceFunctor<DeviceType::GPU, T> : SliceFunctorBase {
} // namespace kernels } // namespace kernels
} // namespace mace } // namespace mace
#endif // MACE_KERNELS_SLICE_H_ #endif // MACE_KERNELS_SPLIT_H_
...@@ -169,7 +169,6 @@ struct StridedSliceFunctor { ...@@ -169,7 +169,6 @@ struct StridedSliceFunctor {
i += strides_data[0]) { i += strides_data[0]) {
*output_data++ = input_data[i]; *output_data++ = input_data[i];
} }
} else if (input->dim_size() == 2) { } else if (input->dim_size() == 2) {
for (index_t i = real_begin_indices[0]; for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0] strides_data[0] > 0 ? i < real_end_indices[0]
...@@ -179,7 +178,25 @@ struct StridedSliceFunctor { ...@@ -179,7 +178,25 @@ struct StridedSliceFunctor {
strides_data[1] > 0 ? j < real_end_indices[1] strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1]; : j > real_end_indices[1];
j += strides_data[1]) { j += strides_data[1]) {
*output_data++ = input_data[i * dim_stride[0] + j]; *output_data++ = input_data[i * input->dim(1) + j];
}
}
} else if (input->dim_size() == 3) {
for (index_t i = real_begin_indices[0];
strides_data[0] > 0 ? i < real_end_indices[0]
: i > real_end_indices[0];
i += strides_data[0]) {
for (index_t j = real_begin_indices[1];
strides_data[1] > 0 ? j < real_end_indices[1]
: j > real_end_indices[1];
j += strides_data[1]) {
for (index_t k = real_begin_indices[2];
strides_data[2] > 0 ? k < real_end_indices[2]
: k > real_end_indices[2];
k += strides_data[2]) {
*output_data++ =
input_data[(i * input->dim(1) + j) * input->dim(2) + k];
}
} }
} }
} else { } else {
......
...@@ -55,10 +55,10 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) { ...@@ -55,10 +55,10 @@ TEST_F(ConcatOpTest, CPUSimpleHorizon) {
const float *output_ptr = output->data<float>(); const float *output_ptr = output->data<float>();
for (auto f : input0) { for (auto f : input0) {
ASSERT_EQ(f, *output_ptr++); EXPECT_EQ(f, *output_ptr++);
} }
for (auto f : input1) { for (auto f : input1) {
ASSERT_EQ(f, *output_ptr++); EXPECT_EQ(f, *output_ptr++);
} }
} }
...@@ -93,10 +93,10 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) { ...@@ -93,10 +93,10 @@ TEST_F(ConcatOpTest, CPUSimpleVertical) {
const float *output_ptr = output->data<float>(); const float *output_ptr = output->data<float>();
for (int i = 0; i < 4; ++i) { for (int i = 0; i < 4; ++i) {
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
ASSERT_EQ(input0[i * 4 + j], *output_ptr++); EXPECT_EQ(input0[i * 4 + j], *output_ptr++);
} }
for (int j = 0; j < 4; ++j) { for (int j = 0; j < 4; ++j) {
ASSERT_EQ(input1[i * 4 + j], *output_ptr++); EXPECT_EQ(input1[i * 4 + j], *output_ptr++);
} }
} }
} }
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/ops/fill.h"
namespace mace {
namespace ops {
void Register_Fill(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Fill")
.Device(DeviceType::CPU)
.TypeConstraint<float>("T")
.Build(),
FillOp<DeviceType::CPU, float>);
}
} // namespace ops
} // namespace mace
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef MACE_OPS_FILL_H_
#define MACE_OPS_FILL_H_
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/fill.h"
namespace mace {
namespace ops {
template <DeviceType D, class T>
class FillOp : public Operator<D, T> {
public:
FillOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws),
functor_() {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *shape = this->Input(SHAPE);
const Tensor *value = this->Input(VALUE);
Tensor *output = this->Output(OUTPUT);
return functor_(shape, value, output, future);
}
private:
kernels::FillFunctor<D, T> functor_;
MACE_OP_INPUT_TAGS(SHAPE, VALUE);
MACE_OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_FILL_H_
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
class FillTest : public OpsTestBase {};
namespace {
void TestFill(const std::vector<int32_t> &shape,
const float &value) {
// Construct graph
OpsTestNet net;
OpDefBuilder("Fill", "FillTest")
.Input("Shape")
.Input("Value")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Add input data
net.AddInputFromArray<DeviceType::CPU, int32_t>(
"Shape",
{static_cast<index_t>(shape.size())},
shape);
net.AddInputFromArray<DeviceType::CPU, float>("Value", {}, {value});
// Run
net.RunOp();
auto output = net.GetTensor("Output");
for (index_t i = 0; i < output->dim_size(); ++i) {
EXPECT_EQ(output->dim(i), shape[i]);
}
const float *output_ptr = output->data<float>();
const index_t size = output->size();
for (index_t i = 0; i < size; ++i) {
EXPECT_EQ(output_ptr[i], value);
}
}
} // namespace
TEST_F(FillTest, Simple) {
TestFill({3, 2, 1}, 5.0f);
TestFill({1, 3}, -1.0f);
}
} // namespace test
} // namespace ops
} // namespace mace
...@@ -46,7 +46,7 @@ void TestIdentity(const std::vector<index_t> &shape) { ...@@ -46,7 +46,7 @@ void TestIdentity(const std::vector<index_t> &shape) {
const float *output_ptr = output->data<float>(); const float *output_ptr = output->data<float>();
const int size = output->size(); const int size = output->size();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]); EXPECT_EQ(input_ptr[i], output_ptr[i]);
} }
} }
} // namespace } // namespace
......
...@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry); ...@@ -34,6 +34,7 @@ extern void Register_DepthToSpace(OperatorRegistryBase *op_registry);
extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistryBase *op_registry);
extern void Register_Dequantize(OperatorRegistryBase *op_registry); extern void Register_Dequantize(OperatorRegistryBase *op_registry);
extern void Register_Eltwise(OperatorRegistryBase *op_registry); extern void Register_Eltwise(OperatorRegistryBase *op_registry);
extern void Register_Fill(OperatorRegistryBase *op_registry);
extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistryBase *op_registry);
extern void Register_FullyConnected(OperatorRegistryBase *op_registry); extern void Register_FullyConnected(OperatorRegistryBase *op_registry);
extern void Register_Gather(OperatorRegistryBase *op_registry); extern void Register_Gather(OperatorRegistryBase *op_registry);
...@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry); ...@@ -48,7 +49,7 @@ extern void Register_ReduceMean(OperatorRegistryBase *op_registry);
extern void Register_Reshape(OperatorRegistryBase *op_registry); extern void Register_Reshape(OperatorRegistryBase *op_registry);
extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry); extern void Register_ResizeBilinear(OperatorRegistryBase *op_registry);
extern void Register_Shape(OperatorRegistryBase *op_registry); extern void Register_Shape(OperatorRegistryBase *op_registry);
extern void Register_Slice(OperatorRegistryBase *op_registry); extern void Register_Split(OperatorRegistryBase *op_registry);
extern void Register_Softmax(OperatorRegistryBase *op_registry); extern void Register_Softmax(OperatorRegistryBase *op_registry);
extern void Register_Stack(OperatorRegistryBase *op_registry); extern void Register_Stack(OperatorRegistryBase *op_registry);
extern void Register_StridedSlice(OperatorRegistryBase *op_registry); extern void Register_StridedSlice(OperatorRegistryBase *op_registry);
...@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -84,6 +85,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_DepthwiseConv2d(this); ops::Register_DepthwiseConv2d(this);
ops::Register_Dequantize(this); ops::Register_Dequantize(this);
ops::Register_Eltwise(this); ops::Register_Eltwise(this);
ops::Register_Fill(this);
ops::Register_FoldedBatchNorm(this); ops::Register_FoldedBatchNorm(this);
ops::Register_FullyConnected(this); ops::Register_FullyConnected(this);
ops::Register_Gather(this); ops::Register_Gather(this);
...@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() { ...@@ -98,7 +100,7 @@ OperatorRegistry::OperatorRegistry() : OperatorRegistryBase() {
ops::Register_Reshape(this); ops::Register_Reshape(this);
ops::Register_ResizeBilinear(this); ops::Register_ResizeBilinear(this);
ops::Register_Shape(this); ops::Register_Shape(this);
ops::Register_Slice(this); ops::Register_Split(this);
ops::Register_Softmax(this); ops::Register_Softmax(this);
ops::Register_Stack(this); ops::Register_Stack(this);
ops::Register_StridedSlice(this); ops::Register_StridedSlice(this);
......
...@@ -42,12 +42,12 @@ class ReshapeOp : public Operator<D, T> { ...@@ -42,12 +42,12 @@ class ReshapeOp : public Operator<D, T> {
for (int i = 0; i < num_dims; ++i) { for (int i = 0; i < num_dims; ++i) {
if (shape_data[i] == -1) { if (shape_data[i] == -1) {
MACE_CHECK(unknown_idx == -1) << "Only one input size may be -1"; MACE_CHECK(unknown_idx == -1, "Only one input size may be -1");
unknown_idx = i; unknown_idx = i;
out_shape.push_back(1); out_shape.push_back(1);
} else { } else {
MACE_CHECK(shape_data[i] >= 0) << "Shape must be non-negative: " MACE_CHECK(shape_data[i] >= 0, "Shape must be non-negative: ",
<< shape_data[i]; shape_data[i]);
out_shape.push_back(shape_data[i]); out_shape.push_back(shape_data[i]);
product *= shape_data[i]; product *= shape_data[i];
} }
......
...@@ -53,7 +53,7 @@ void TestReshape(const std::vector<index_t> &org_shape, ...@@ -53,7 +53,7 @@ void TestReshape(const std::vector<index_t> &org_shape,
const float *output_ptr = output->data<float>(); const float *output_ptr = output->data<float>();
const int size = output->size(); const int size = output->size();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]); EXPECT_EQ(input_ptr[i], output_ptr[i]);
} }
} }
} // namespace } // namespace
......
...@@ -12,30 +12,30 @@ ...@@ -12,30 +12,30 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include "mace/ops/slice.h" #include "mace/ops/split.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
void Register_Slice(OperatorRegistryBase *op_registry) { void Register_Split(OperatorRegistryBase *op_registry) {
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::CPU) .Device(DeviceType::CPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
SliceOp<DeviceType::CPU, float>); SplitOp<DeviceType::CPU, float>);
#ifdef MACE_ENABLE_OPENCL #ifdef MACE_ENABLE_OPENCL
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::GPU) .Device(DeviceType::GPU)
.TypeConstraint<float>("T") .TypeConstraint<float>("T")
.Build(), .Build(),
SliceOp<DeviceType::GPU, float>); SplitOp<DeviceType::GPU, float>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Slice") MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("Split")
.Device(DeviceType::GPU) .Device(DeviceType::GPU)
.TypeConstraint<half>("T") .TypeConstraint<half>("T")
.Build(), .Build(),
SliceOp<DeviceType::GPU, half>); SplitOp<DeviceType::GPU, half>);
#endif // MACE_ENABLE_OPENCL #endif // MACE_ENABLE_OPENCL
} }
......
...@@ -12,21 +12,21 @@ ...@@ -12,21 +12,21 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#ifndef MACE_OPS_SLICE_H_ #ifndef MACE_OPS_SPLIT_H_
#define MACE_OPS_SLICE_H_ #define MACE_OPS_SPLIT_H_
#include <vector> #include <vector>
#include "mace/core/operator.h" #include "mace/core/operator.h"
#include "mace/kernels/slice.h" #include "mace/kernels/split.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
template <DeviceType D, typename T> template <DeviceType D, typename T>
class SliceOp : public Operator<D, T> { class SplitOp : public Operator<D, T> {
public: public:
SliceOp(const OperatorDef &op_def, Workspace *ws) SplitOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws), : Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {} functor_(OperatorBase::GetOptionalArg<int>("axis", 3)) {}
...@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> { ...@@ -35,15 +35,15 @@ class SliceOp : public Operator<D, T> {
<< "There must be at least two outputs for slicing"; << "There must be at least two outputs for slicing";
const Tensor *input = this->Input(INPUT); const Tensor *input = this->Input(INPUT);
const std::vector<Tensor *> output_list = this->Outputs(); const std::vector<Tensor *> output_list = this->Outputs();
const int32_t slice_axis = OperatorBase::GetOptionalArg<int>("axis", 3); const int32_t split_axis = OperatorBase::GetOptionalArg<int>("axis", 3);
MACE_CHECK((input->dim(slice_axis) % this->OutputSize()) == 0) MACE_CHECK((input->dim(split_axis) % this->OutputSize()) == 0)
<< "Outputs do not split input equally."; << "Outputs do not split input equally.";
return functor_(input, output_list, future); return functor_(input, output_list, future);
} }
private: private:
kernels::SliceFunctor<D, T> functor_; kernels::SplitFunctor<D, T> functor_;
private: private:
MACE_OP_INPUT_TAGS(INPUT); MACE_OP_INPUT_TAGS(INPUT);
...@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> { ...@@ -52,4 +52,4 @@ class SliceOp : public Operator<D, T> {
} // namespace ops } // namespace ops
} // namespace mace } // namespace mace
#endif // MACE_OPS_SLICE_H_ #endif // MACE_OPS_SPLIT_H_
...@@ -22,7 +22,7 @@ namespace test { ...@@ -22,7 +22,7 @@ namespace test {
namespace { namespace {
template<DeviceType D, typename T> template<DeviceType D, typename T>
void BMSliceHelper(int iters, void BMSplitHelper(int iters,
const std::vector<index_t> &input_shape, const std::vector<index_t> &input_shape,
const index_t num_outputs) { const index_t num_outputs) {
mace::testing::StopTiming(); mace::testing::StopTiming();
...@@ -42,7 +42,7 @@ void BMSliceHelper(int iters, ...@@ -42,7 +42,7 @@ void BMSliceHelper(int iters,
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest"); auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("InputImage"); builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i)); builder = builder.Output(MakeString("OutputImage", i));
...@@ -51,7 +51,7 @@ void BMSliceHelper(int iters, ...@@ -51,7 +51,7 @@ void BMSliceHelper(int iters,
.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) .AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
auto builder = OpDefBuilder("Slice", "SliceTest"); auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("Input"); builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i)); builder = builder.Output(MakeString("Output", i));
...@@ -73,28 +73,28 @@ void BMSliceHelper(int iters, ...@@ -73,28 +73,28 @@ void BMSliceHelper(int iters,
} }
} // namespace } // namespace
#define MACE_BM_SLICE_MACRO(N, H, W, C, NO, TYPE, DEVICE) \ #define MACE_BM_SPLIT_MACRO(N, H, W, C, NO, TYPE, DEVICE) \
static void \ static void \
MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \ MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE( \
int iters) { \ int iters) { \
const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \ const int64_t tot = static_cast<int64_t>(iters) * N * H * W * C; \
mace::testing::MaccProcessed(tot); \ mace::testing::MaccProcessed(tot); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \ mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
BMSliceHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \ BMSplitHelper<DEVICE, TYPE>(iters, {N, H, W, C}, NO); \
} \ } \
MACE_BENCHMARK( \ MACE_BENCHMARK( \
MACE_BM_SLICE_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE) MACE_BM_SPLIT_##N##_##H##_##W##_##C##_##NO##_##TYPE##_##DEVICE)
#define MACE_BM_SLICE(N, H, W, C, NO) \ #define MACE_BM_SPLIT(N, H, W, C, NO) \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, CPU); \ MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, CPU); \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, float, GPU); \ MACE_BM_SPLIT_MACRO(N, H, W, C, NO, float, GPU); \
MACE_BM_SLICE_MACRO(N, H, W, C, NO, half, GPU); MACE_BM_SPLIT_MACRO(N, H, W, C, NO, half, GPU);
MACE_BM_SLICE(1, 32, 32, 32, 2); MACE_BM_SPLIT(1, 32, 32, 32, 2);
MACE_BM_SLICE(1, 32, 32, 128, 2); MACE_BM_SPLIT(1, 32, 32, 128, 2);
MACE_BM_SLICE(1, 32, 32, 256, 2); MACE_BM_SPLIT(1, 32, 32, 256, 2);
MACE_BM_SLICE(1, 128, 128, 32, 2); MACE_BM_SPLIT(1, 128, 128, 32, 2);
MACE_BM_SLICE(1, 128, 128, 128, 2); MACE_BM_SPLIT(1, 128, 128, 128, 2);
} // namespace test } // namespace test
} // namespace ops } // namespace ops
......
...@@ -17,13 +17,13 @@ ...@@ -17,13 +17,13 @@
#include "gmock/gmock.h" #include "gmock/gmock.h"
#include "mace/ops/ops_test_util.h" #include "mace/ops/ops_test_util.h"
#include "mace/ops/slice.h" #include "mace/ops/split.h"
namespace mace { namespace mace {
namespace ops { namespace ops {
namespace test { namespace test {
class SliceOpTest : public OpsTestBase {}; class SplitOpTest : public OpsTestBase {};
namespace { namespace {
template <DeviceType D, typename T> template <DeviceType D, typename T>
...@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) { ...@@ -53,7 +53,7 @@ void RandomTest(const int num_outputs, const int axis) {
BufferToImage<D, T>(&net, "Input", "InputImage", BufferToImage<D, T>(&net, "Input", "InputImage",
kernels::BufferType::IN_OUT_CHANNEL); kernels::BufferType::IN_OUT_CHANNEL);
auto builder = OpDefBuilder("Slice", "SliceTest"); auto builder = OpDefBuilder("Split", "SplitTest");
builder.Input("InputImage"); builder.Input("InputImage");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("OutputImage", i)); builder = builder.Output(MakeString("OutputImage", i));
...@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) { ...@@ -61,7 +61,7 @@ void RandomTest(const int num_outputs, const int axis) {
builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value)) builder.AddIntArg("T", static_cast<int>(DataTypeToEnum<T>::value))
.Finalize(net.NewOperatorDef()); .Finalize(net.NewOperatorDef());
} else { } else {
auto builder = OpDefBuilder("Slice", "SliceTest").AddIntArg("axis", axis); auto builder = OpDefBuilder("Split", "SplitTest").AddIntArg("axis", axis);
builder.Input("Input"); builder.Input("Input");
for (int i = 0; i < num_outputs; ++i) { for (int i = 0; i < num_outputs; ++i) {
builder = builder.Output(MakeString("Output", i)); builder = builder.Output(MakeString("Output", i));
...@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) { ...@@ -111,25 +111,25 @@ void RandomTest(const int num_outputs, const int axis) {
} }
} // namespace } // namespace
TEST_F(SliceOpTest, CPU) { TEST_F(SplitOpTest, CPU) {
RandomTest<DeviceType::CPU, float>(2, 3); RandomTest<DeviceType::CPU, float>(2, 3);
RandomTest<DeviceType::CPU, float>(4, 3); RandomTest<DeviceType::CPU, float>(4, 3);
RandomTest<DeviceType::CPU, float>(11, 3); RandomTest<DeviceType::CPU, float>(11, 3);
} }
TEST_F(SliceOpTest, CPUAxis1) { TEST_F(SplitOpTest, CPUAxis1) {
RandomTest<DeviceType::CPU, float>(2, 1); RandomTest<DeviceType::CPU, float>(2, 1);
RandomTest<DeviceType::CPU, float>(4, 1); RandomTest<DeviceType::CPU, float>(4, 1);
RandomTest<DeviceType::CPU, float>(11, 1); RandomTest<DeviceType::CPU, float>(11, 1);
} }
TEST_F(SliceOpTest, OPENCLFloat) { TEST_F(SplitOpTest, OPENCLFloat) {
RandomTest<DeviceType::GPU, float>(2, 3); RandomTest<DeviceType::GPU, float>(2, 3);
RandomTest<DeviceType::GPU, float>(4, 3); RandomTest<DeviceType::GPU, float>(4, 3);
RandomTest<DeviceType::GPU, float>(11, 3); RandomTest<DeviceType::GPU, float>(11, 3);
} }
TEST_F(SliceOpTest, OPENCLHalf) { TEST_F(SplitOpTest, OPENCLHalf) {
RandomTest<DeviceType::GPU, half>(2, 3); RandomTest<DeviceType::GPU, half>(2, 3);
RandomTest<DeviceType::GPU, half>(4, 3); RandomTest<DeviceType::GPU, half>(4, 3);
RandomTest<DeviceType::GPU, half>(11, 3); RandomTest<DeviceType::GPU, half>(11, 3);
......
...@@ -49,7 +49,7 @@ void TestSqueeze(const std::vector<index_t> &org_shape, ...@@ -49,7 +49,7 @@ void TestSqueeze(const std::vector<index_t> &org_shape,
const float *output_ptr = output->data<float>(); const float *output_ptr = output->data<float>();
const int size = output->size(); const int size = output->size();
for (int i = 0; i < size; ++i) { for (int i = 0; i < size; ++i) {
ASSERT_EQ(input_ptr[i], output_ptr[i]); EXPECT_EQ(input_ptr[i], output_ptr[i]);
} }
} }
} // namespace } // namespace
......
...@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) { ...@@ -146,6 +146,18 @@ TEST_F(StridedSliceOpTest, TestStridedSliceRank2) {
0, 3, {}, {6}); 0, 3, {}, {6});
} }
TEST_F(StridedSliceOpTest, TestStridedSliceRank3) {
TestStridedSlice({2, 3, 2}, {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12},
{0, 0, 0}, {2, 3, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2, 2, 2},
{1, 2, 5, 6, 7, 8, 11, 12});
TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6,
6, 6}, {1, 0, 0}, {2, 1, 3}, {1, 1, 1}, 0, 0, 0, 0, 0, {1,
1, 3}, {3, 3, 3});
TestStridedSlice({3, 2, 3}, {1, 1, 1, 2, 2, 2, 3, 3, 3, 4, 4, 4, 5, 5, 5, 6,
6, 6}, {0, 0, 0}, {2, 2, 2}, {1, 2, 1}, 0, 0, 0, 0, 0, {2,
1, 2}, {1, 1, 3, 3});
}
TEST_F(StridedSliceOpTest, TestSlice) { TEST_F(StridedSliceOpTest, TestSlice) {
TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3}, TestSlice({2, 3}, {1, 2, 3, 4, 5, 6}, {0, 0}, {2, 3}, {2, 3},
{1, 2, 3, 4, 5, 6}); {1, 2, 3, 4, 5, 6});
......
...@@ -88,6 +88,7 @@ MaceSupportedOps = [ ...@@ -88,6 +88,7 @@ MaceSupportedOps = [
'Dequantize', 'Dequantize',
'Eltwise', 'Eltwise',
'FoldedBatchNorm', 'FoldedBatchNorm',
'Fill',
'FullyConnected', 'FullyConnected',
'Gather', 'Gather',
'Identity', 'Identity',
...@@ -101,6 +102,7 @@ MaceSupportedOps = [ ...@@ -101,6 +102,7 @@ MaceSupportedOps = [
'Reshape', 'Reshape',
'ResizeBilinear', 'ResizeBilinear',
'Slice', 'Slice',
'Split',
'Shape', 'Shape',
'Squeeze', 'Squeeze',
'Stack', 'Stack',
...@@ -146,6 +148,7 @@ class MaceKeyword(object): ...@@ -146,6 +148,7 @@ class MaceKeyword(object):
mace_constant_value_str = 'constant_value' mace_constant_value_str = 'constant_value'
mace_dims_str = 'dims' mace_dims_str = 'dims'
mace_axis_str = 'axis' mace_axis_str = 'axis'
mace_num_split_str = 'num_split'
mace_keepdims_str = 'keepdims' mace_keepdims_str = 'keepdims'
mace_shape_str = 'shape' mace_shape_str = 'shape'
mace_winograd_filter_transformed = 'is_filter_transformed' mace_winograd_filter_transformed = 'is_filter_transformed'
......
...@@ -68,6 +68,7 @@ TFSupportedOps = [ ...@@ -68,6 +68,7 @@ TFSupportedOps = [
'Relu6', 'Relu6',
'Tanh', 'Tanh',
'Sigmoid', 'Sigmoid',
'Fill',
'FusedBatchNorm', 'FusedBatchNorm',
'AvgPool', 'AvgPool',
'MaxPool', 'MaxPool',
...@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -165,6 +166,7 @@ class TensorflowConverter(base_converter.ConverterInterface):
TFOpType.Relu6.name: self.convert_activation, TFOpType.Relu6.name: self.convert_activation,
TFOpType.Tanh.name: self.convert_activation, TFOpType.Tanh.name: self.convert_activation,
TFOpType.Sigmoid.name: self.convert_activation, TFOpType.Sigmoid.name: self.convert_activation,
TFOpType.Fill.name: self.convert_fill,
TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm, TFOpType.FusedBatchNorm.name: self.convert_fused_batchnorm,
TFOpType.AvgPool.name: self.convert_pooling, TFOpType.AvgPool.name: self.convert_pooling,
TFOpType.MaxPool.name: self.convert_pooling, TFOpType.MaxPool.name: self.convert_pooling,
...@@ -458,6 +460,10 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -458,6 +460,10 @@ class TensorflowConverter(base_converter.ConverterInterface):
limit_arg.name = MaceKeyword.mace_activation_max_limit_str limit_arg.name = MaceKeyword.mace_activation_max_limit_str
limit_arg.f = 6.0 limit_arg.f = 6.0
def convert_fill(self, tf_op):
op = self.convert_general_op(tf_op)
op.type = MaceOp.Fill.name
def convert_fused_batchnorm(self, tf_op): def convert_fused_batchnorm(self, tf_op):
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.FoldedBatchNorm.name op.type = MaceOp.FoldedBatchNorm.name
...@@ -763,19 +769,19 @@ class TensorflowConverter(base_converter.ConverterInterface): ...@@ -763,19 +769,19 @@ class TensorflowConverter(base_converter.ConverterInterface):
op.output_type.extend([mace_pb2.DT_INT32]) op.output_type.extend([mace_pb2.DT_INT32])
def convert_split(self, tf_op): def convert_split(self, tf_op):
# inputs: [dim, input]
axis = tf_op.inputs[0].eval().astype(np.int32) axis = tf_op.inputs[0].eval().astype(np.int32)
axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis axis = len(op.output_shape[0].dims) + axis if axis < 0 else axis
mace_check(axis == 3, 'Split with %d axis only support' % axis)
input_shape = self.infer_tensor_shape(tf_op.inputs[1]) input_shape = self.infer_tensor_shape(tf_op.inputs[1])
mace_check(len(input_shape) == 4 and (input_shape[3] % 4 == 0),
"The input's 4th dimension should be a multiple of 4")
op = self.convert_general_op(tf_op) op = self.convert_general_op(tf_op)
op.type = MaceOp.Slice.name op.type = MaceOp.Split.name
del op.input[0] del op.input[0]
axis_arg = op.arg.add() axis_arg = op.arg.add()
axis_arg.name = MaceKeyword.mace_axis_str axis_arg.name = MaceKeyword.mace_axis_str
axis_arg.i = axis axis_arg.i = axis
num_split_arg = op.arg.add()
num_split_arg.name = MaceKeyword.mace_num_split_str
num_split_arg.i = tf_op.get_attr('num_split')
self._skip_tensor.add(tf_op.inputs[0].name) self._skip_tensor.add(tf_op.inputs[0].name)
...@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface): ...@@ -812,6 +812,7 @@ class Transformer(base_converter.ConverterInterface):
"only support concat at " "only support concat at "
"channel dimension") "channel dimension")
arg.i = 3 arg.i = 3
producer = self._producer[op.input[0]] producer = self._producer[op.input[0]]
input_shape = producer.output_shape[0].dims input_shape = producer.output_shape[0].dims
if producer.type == MaceOp.FullyConnected.name and \ if producer.type == MaceOp.FullyConnected.name and \
......
...@@ -342,7 +342,7 @@ void MaceRunFunc(const int in_out_size) { ...@@ -342,7 +342,7 @@ void MaceRunFunc(const int in_out_size) {
MaceEngine engine(device); MaceEngine engine(device);
MaceStatus status = engine.Init(net_def.get(), input_names, output_names, MaceStatus status = engine.Init(net_def.get(), input_names, output_names,
reinterpret_cast<unsigned char *>(data.data())); reinterpret_cast<unsigned char *>(data.data()));
ASSERT_EQ(status, MaceStatus::MACE_SUCCESS); EXPECT_EQ(status, MaceStatus::MACE_SUCCESS);
std::map<std::string, mace::MaceTensor> inputs; std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs; std::map<std::string, mace::MaceTensor> outputs;
......
...@@ -336,7 +336,7 @@ void MaceRun(const int in_out_size, ...@@ -336,7 +336,7 @@ void MaceRun(const int in_out_size,
MaceEngine engine(device); MaceEngine engine(device);
MaceStatus status = engine.Init(net_def.get(), input_names, output_names, MaceStatus status = engine.Init(net_def.get(), input_names, output_names,
reinterpret_cast<unsigned char *>(data.data())); reinterpret_cast<unsigned char *>(data.data()));
ASSERT_EQ(status, MaceStatus::MACE_SUCCESS); EXPECT_EQ(status, MaceStatus::MACE_SUCCESS);
std::map<std::string, mace::MaceTensor> inputs; std::map<std::string, mace::MaceTensor> inputs;
std::map<std::string, mace::MaceTensor> outputs; std::map<std::string, mace::MaceTensor> outputs;
......
...@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx): ...@@ -42,7 +42,7 @@ def _opencl_encrypt_kernel_impl(repository_ctx):
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/pooling.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/reduce_mean.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/resize_bilinear.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/slice.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/split.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/softmax.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/space_to_batch.cl"))
unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl")) unused_var = repository_ctx.path(Label("//:mace/kernels/opencl/cl/winograd_transform.cl"))
......
...@@ -130,6 +130,16 @@ class RuntimeType(object): ...@@ -130,6 +130,16 @@ class RuntimeType(object):
cpu_gpu = 'cpu+gpu' cpu_gpu = 'cpu+gpu'
InputDataTypeStrs = [
"int32",
"float32",
]
InputDataType = Enum('InputDataType',
[(ele, ele) for ele in InputDataTypeStrs],
type=str)
CPUDataTypeStrs = [ CPUDataTypeStrs = [
"fp32", "fp32",
] ]
...@@ -183,6 +193,7 @@ class YAMLKeyword(object): ...@@ -183,6 +193,7 @@ class YAMLKeyword(object):
output_shapes = 'output_shapes' output_shapes = 'output_shapes'
runtime = 'runtime' runtime = 'runtime'
data_type = 'data_type' data_type = 'data_type'
input_data_types = 'input_data_types'
limit_opencl_kernel_time = 'limit_opencl_kernel_time' limit_opencl_kernel_time = 'limit_opencl_kernel_time'
nnlib_graph_mode = 'nnlib_graph_mode' nnlib_graph_mode = 'nnlib_graph_mode'
obfuscate = 'obfuscate' obfuscate = 'obfuscate'
...@@ -447,6 +458,18 @@ def format_model_config(flags): ...@@ -447,6 +458,18 @@ def format_model_config(flags):
if not isinstance(value, list): if not isinstance(value, list):
subgraph[key] = [value] subgraph[key] = [value]
input_data_types = subgraph.get(YAMLKeyword.input_data_types, "")
if input_data_types:
if not isinstance(input_data_types, list):
subgraph[YAMLKeyword.input_data_types] = [input_data_types]
for input_data_type in input_data_types:
mace_check(input_data_type in InputDataTypeStrs,
ModuleName.YAML_CONFIG,
"'input_data_types' must be in "
+ str(InputDataTypeStrs))
else:
subgraph[YAMLKeyword.input_data_types] = []
validation_threshold = subgraph.get( validation_threshold = subgraph.get(
YAMLKeyword.validation_threshold, {}) YAMLKeyword.validation_threshold, {})
if not isinstance(validation_threshold, dict): if not isinstance(validation_threshold, dict):
...@@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config, ...@@ -1025,7 +1048,8 @@ def tuning(library_name, model_name, model_config,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
sh_commands.tuning_run( sh_commands.tuning_run(
abi=target_abi, abi=target_abi,
...@@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1170,7 +1194,8 @@ def run_specific_target(flags, configs, target_abi,
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
...@@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi, ...@@ -1236,6 +1261,7 @@ def run_specific_target(flags, configs, target_abi,
output_shapes=subgraphs[0][YAMLKeyword.output_shapes], output_shapes=subgraphs[0][YAMLKeyword.output_shapes],
model_output_dir=model_output_dir, model_output_dir=model_output_dir,
phone_data_dir=PHONE_DATA_DIR, phone_data_dir=PHONE_DATA_DIR,
input_data_types=subgraphs[0][YAMLKeyword.input_data_types], # noqa
caffe_env=flags.caffe_env, caffe_env=flags.caffe_env,
validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa validation_threshold=subgraphs[0][YAMLKeyword.validation_threshold][device_type]) # noqa
if flags.report and flags.round > 0: if flags.report and flags.round > 0:
...@@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num): ...@@ -1478,7 +1504,8 @@ def bm_specific_target(flags, configs, target_abi, target_soc, serial_num):
subgraphs[0][YAMLKeyword.input_tensors], subgraphs[0][YAMLKeyword.input_tensors],
subgraphs[0][YAMLKeyword.input_shapes], subgraphs[0][YAMLKeyword.input_shapes],
subgraphs[0][YAMLKeyword.validation_inputs_data], subgraphs[0][YAMLKeyword.validation_inputs_data],
input_ranges=subgraphs[0][YAMLKeyword.input_ranges]) input_ranges=subgraphs[0][YAMLKeyword.input_ranges],
input_data_types=subgraphs[0][YAMLKeyword.input_data_types])
runtime_list = [] runtime_list = []
if target_abi == ABIType.host: if target_abi == ABIType.host:
runtime_list.extend([RuntimeType.cpu]) runtime_list.extend([RuntimeType.cpu])
......
...@@ -27,30 +27,37 @@ import common ...@@ -27,30 +27,37 @@ import common
# --input_ranges -1,1 # --input_ranges -1,1
def generate_data(name, shape, input_file, tensor_range): def generate_data(name, shape, input_file, tensor_range, input_data_type):
np.random.seed() np.random.seed()
data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \ data = np.random.random(shape) * (tensor_range[1] - tensor_range[0]) \
+ tensor_range[0] + tensor_range[0]
input_file_name = common.formatted_file_name(input_file, name) input_file_name = common.formatted_file_name(input_file, name)
print 'Generate input file: ', input_file_name print 'Generate input file: ', input_file_name
data.astype(np.float32).tofile(input_file_name) if input_data_type == 'float32':
np_data_type = np.float32
elif input_data_type == 'int32':
np_data_type = np.int32
data.astype(np_data_type).tofile(input_file_name)
def generate_input_data(input_file, input_node, input_shape, input_ranges): def generate_input_data(input_file, input_node, input_shape, input_ranges,
input_data_type):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shapes = [shape for shape in input_shape.split(':')] input_shapes = [shape for shape in input_shape.split(':')]
if input_ranges: if input_ranges:
input_ranges = [r for r in input_ranges.split(':')] input_ranges = [r for r in input_ranges.split(':')]
else: else:
input_ranges = None input_ranges = [[-1, 1]] * len(input_names)
assert len(input_names) == len(input_shapes) if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
assert len(input_names) == len(input_shapes) == len(input_ranges) == len(input_data_types) # noqa
for i in range(len(input_names)): for i in range(len(input_names)):
shape = [int(x) for x in input_shapes[i].split(',')] shape = [int(x) for x in input_shapes[i].split(',')]
if input_ranges: generate_data(input_names[i], shape, input_file, input_ranges[i],
input_range = [float(x) for x in input_ranges[i].split(',')] input_data_types[i])
else:
input_range = [-1, 1]
generate_data(input_names[i], shape, input_file, input_range)
print "Generate input file done." print "Generate input file done."
...@@ -66,6 +73,8 @@ def parse_args(): ...@@ -66,6 +73,8 @@ def parse_args():
"--input_shape", type=str, default="1,64,64,3", help="input shape.") "--input_shape", type=str, default="1,64,64,3", help="input shape.")
parser.add_argument( parser.add_argument(
"--input_ranges", type=str, default="-1,1", help="input range.") "--input_ranges", type=str, default="-1,1", help="input range.")
parser.add_argument(
"--input_data_type", type=str, default="", help="input range.")
return parser.parse_known_args() return parser.parse_known_args()
...@@ -73,4 +82,4 @@ def parse_args(): ...@@ -73,4 +82,4 @@ def parse_args():
if __name__ == '__main__': if __name__ == '__main__':
FLAGS, unparsed = parse_args() FLAGS, unparsed = parse_args()
generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape, generate_input_data(FLAGS.input_file, FLAGS.input_node, FLAGS.input_shape,
FLAGS.input_ranges) FLAGS.input_ranges, FLAGS.input_data_type)
...@@ -536,6 +536,7 @@ def gen_random_input(model_output_dir, ...@@ -536,6 +536,7 @@ def gen_random_input(model_output_dir,
input_shapes, input_shapes,
input_files, input_files,
input_ranges, input_ranges,
input_data_types,
input_file_name="model_input"): input_file_name="model_input"):
for input_name in input_nodes: for input_name in input_nodes:
formatted_name = common.formatted_file_name( formatted_name = common.formatted_file_name(
...@@ -545,10 +546,12 @@ def gen_random_input(model_output_dir, ...@@ -545,10 +546,12 @@ def gen_random_input(model_output_dir,
input_nodes_str = ",".join(input_nodes) input_nodes_str = ",".join(input_nodes)
input_shapes_str = ":".join(input_shapes) input_shapes_str = ":".join(input_shapes)
input_ranges_str = ":".join(input_ranges) input_ranges_str = ":".join(input_ranges)
input_data_types_str = ",".join(input_data_types)
generate_input_data("%s/%s" % (model_output_dir, input_file_name), generate_input_data("%s/%s" % (model_output_dir, input_file_name),
input_nodes_str, input_nodes_str,
input_shapes_str, input_shapes_str,
input_ranges_str) input_ranges_str,
input_data_types_str)
input_file_list = [] input_file_list = []
if isinstance(input_files, list): if isinstance(input_files, list):
...@@ -800,6 +803,7 @@ def validate_model(abi, ...@@ -800,6 +803,7 @@ def validate_model(abi,
output_shapes, output_shapes,
model_output_dir, model_output_dir,
phone_data_dir, phone_data_dir,
input_data_types,
caffe_env, caffe_env,
input_file_name="model_input", input_file_name="model_input",
output_file_name="model_out", output_file_name="model_out",
...@@ -821,7 +825,7 @@ def validate_model(abi, ...@@ -821,7 +825,7 @@ def validate_model(abi,
"%s/%s" % (model_output_dir, output_file_name), device_type, "%s/%s" % (model_output_dir, output_file_name), device_type,
":".join(input_shapes), ":".join(output_shapes), ":".join(input_shapes), ":".join(output_shapes),
",".join(input_nodes), ",".join(output_nodes), ",".join(input_nodes), ",".join(output_nodes),
validation_threshold) validation_threshold, ",".join(input_data_types))
elif platform == "caffe": elif platform == "caffe":
image_name = "mace-caffe:latest" image_name = "mace-caffe:latest"
container_name = "mace_caffe_validator" container_name = "mace_caffe_validator"
......
...@@ -40,10 +40,12 @@ import common ...@@ -40,10 +40,12 @@ import common
VALIDATION_MODULE = 'VALIDATION' VALIDATION_MODULE = 'VALIDATION'
def load_data(file): def load_data(file, data_type='float32'):
if os.path.isfile(file): if os.path.isfile(file):
if data_type == 'float32':
return np.fromfile(file=file, dtype=np.float32) return np.fromfile(file=file, dtype=np.float32)
else: elif data_type == 'int32':
return np.fromfile(file=file, dtype=np.int32)
return np.empty([0]) return np.empty([0])
...@@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name): ...@@ -78,7 +80,7 @@ def normalize_tf_tensor_name(name):
def validate_tf_model(platform, device_type, model_file, input_file, def validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold): output_names, validation_threshold, input_data_types):
import tensorflow as tf import tensorflow as tf
if not os.path.isfile(model_file): if not os.path.isfile(model_file):
common.MaceLogger.error( common.MaceLogger.error(
...@@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file, ...@@ -98,7 +100,8 @@ def validate_tf_model(platform, device_type, model_file, input_file,
input_dict = {} input_dict = {}
for i in range(len(input_names)): for i in range(len(input_names)):
input_value = load_data( input_value = load_data(
common.formatted_file_name(input_file, input_names[i])) common.formatted_file_name(input_file, input_names[i]),
input_data_types[i])
input_value = input_value.reshape(input_shapes[i]) input_value = input_value.reshape(input_shapes[i])
input_node = graph.get_tensor_by_name( input_node = graph.get_tensor_by_name(
normalize_tf_tensor_name(input_names[i])) normalize_tf_tensor_name(input_names[i]))
...@@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file, ...@@ -168,18 +171,23 @@ def validate_caffe_model(platform, device_type, model_file, input_file,
def validate(platform, model_file, weight_file, input_file, mace_out_file, def validate(platform, model_file, weight_file, input_file, mace_out_file,
device_type, input_shape, output_shape, input_node, output_node, device_type, input_shape, output_shape, input_node, output_node,
validation_threshold): validation_threshold, input_data_type):
input_names = [name for name in input_node.split(',')] input_names = [name for name in input_node.split(',')]
input_shape_strs = [shape for shape in input_shape.split(':')] input_shape_strs = [shape for shape in input_shape.split(':')]
input_shapes = [[int(x) for x in shape.split(',')] input_shapes = [[int(x) for x in shape.split(',')]
for shape in input_shape_strs] for shape in input_shape_strs]
if input_data_type:
input_data_types = [data_type
for data_type in input_data_type.split(',')]
else:
input_data_types = ['float32'] * len(input_names)
output_names = [name for name in output_node.split(',')] output_names = [name for name in output_node.split(',')]
assert len(input_names) == len(input_shapes) assert len(input_names) == len(input_shapes)
if platform == 'tensorflow': if platform == 'tensorflow':
validate_tf_model(platform, device_type, model_file, input_file, validate_tf_model(platform, device_type, model_file, input_file,
mace_out_file, input_names, input_shapes, mace_out_file, input_names, input_shapes,
output_names, validation_threshold) output_names, validation_threshold, input_data_types)
elif platform == 'caffe': elif platform == 'caffe':
output_shape_strs = [shape for shape in output_shape.split(':')] output_shape_strs = [shape for shape in output_shape.split(':')]
output_shapes = [[int(x) for x in shape.split(',')] output_shapes = [[int(x) for x in shape.split(',')]
...@@ -220,6 +228,11 @@ def parse_args(): ...@@ -220,6 +228,11 @@ def parse_args():
"--output_shape", type=str, default="1,64,64,2", help="output shape.") "--output_shape", type=str, default="1,64,64,2", help="output shape.")
parser.add_argument( parser.add_argument(
"--input_node", type=str, default="input_node", help="input node") "--input_node", type=str, default="input_node", help="input node")
parser.add_argument(
"--input_data_type",
type=str,
default="",
help="input data type")
parser.add_argument( parser.add_argument(
"--output_node", type=str, default="output_node", help="output node") "--output_node", type=str, default="output_node", help="output node")
parser.add_argument( parser.add_argument(
...@@ -241,4 +254,5 @@ if __name__ == '__main__': ...@@ -241,4 +254,5 @@ if __name__ == '__main__':
FLAGS.output_shape, FLAGS.output_shape,
FLAGS.input_node, FLAGS.input_node,
FLAGS.output_node, FLAGS.output_node,
FLAGS.validation_threshold) FLAGS.validation_threshold,
FLAGS.input_data_type)
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册