提交 12139efe 编写于 作者: Y yejianwu

update memory optimize for expand_dims, add reverse benchmark

上级 f3de19de
......@@ -19,6 +19,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/kernel.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
......@@ -27,18 +28,13 @@
namespace mace {
namespace kernels {
struct ExpandDimsBase {
explicit ExpandDimsBase(int axis) : axis_(axis) {}
int axis_;
};
template <DeviceType D, typename T>
struct ExpandDimsFunctor;
template <typename T>
struct ExpandDimsFunctor<DeviceType::CPU, T> : ExpandDimsBase {
explicit ExpandDimsFunctor(int axis) : ExpandDimsBase(axis) {}
struct ExpandDimsFunctor<DeviceType::CPU, T> : OpKernel {
explicit ExpandDimsFunctor(OpKernelContext *context, int axis)
: OpKernel(context), axis_(axis) {}
MaceStatus operator()(const Tensor *input,
Tensor *output,
......@@ -64,6 +60,8 @@ struct ExpandDimsFunctor<DeviceType::CPU, T> : ExpandDimsBase {
return MACE_SUCCESS;
}
int axis_;
};
} // namespace kernels
......
......@@ -20,6 +20,7 @@
#include "mace/core/future.h"
#include "mace/core/tensor.h"
#include "mace/kernels/kernel.h"
#ifdef MACE_ENABLE_OPENCL
#include "mace/core/runtime/opencl/cl2_header.h"
......@@ -32,7 +33,8 @@ template <DeviceType D, typename T>
struct ReverseFunctor;
template <typename T>
struct ReverseFunctor<DeviceType::CPU, T> {
struct ReverseFunctor<DeviceType::CPU, T> : OpKernel {
explicit ReverseFunctor(OpKernelContext *context) : OpKernel(context) {}
MaceStatus operator()(const Tensor *input,
const Tensor *axis,
Tensor *output,
......
......@@ -28,6 +28,11 @@ void Register_ExpandDims(OperatorRegistryBase *op_registry) {
.TypeConstraint<int32_t>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, int32_t>);
MACE_REGISTER_OPERATOR(op_registry, OpKeyBuilder("ExpandDims")
.Device(DeviceType::CPU)
.TypeConstraint<uint8_t>("T")
.Build(),
ExpandDimsOp<DeviceType::CPU, uint8_t>);
}
} // namespace ops
......
......@@ -26,9 +26,9 @@ namespace ops {
template <DeviceType D, typename T>
class ExpandDimsOp : public Operator<D, T> {
public:
ExpandDimsOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetOptionalArg<int>("axis", 0)) {}
ExpandDimsOp(const OperatorDef &op_def, OpKernelContext *context)
: Operator<D, T>(op_def, context),
functor_(context, OperatorBase::GetOptionalArg<int>("axis", 0)) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
......@@ -26,8 +26,8 @@ namespace ops {
template <DeviceType D, class T>
class ReverseOp : public Operator<D, T> {
public:
ReverseOp(const OperatorDef &operator_def, Workspace *ws)
: Operator<D, T>(operator_def, ws) {}
ReverseOp(const OperatorDef &operator_def, OpKernelContext *context)
: Operator<D, T>(operator_def, context), functor_(context) {}
MaceStatus Run(StatsFuture *future) override {
const Tensor *input = this->Input(INPUT);
......
// Copyright 2018 Xiaomi, Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "mace/core/operator.h"
#include "mace/core/runtime/opencl/opencl_runtime.h"
#include "mace/core/testing/test_benchmark.h"
#include "mace/ops/ops_test_util.h"
namespace mace {
namespace ops {
namespace test {
namespace {
template <DeviceType D, typename T>
void Reverse(int iters, int batch, int channels, int height, int width) {
mace::testing::StopTiming();
OpsTestNet net;
net.AddRandomInput<D, T>("Input", {batch, channels, height, width});
net.AddRandomInput<D, int32_t>("Axis", {1});
OpDefBuilder("Reverse", "ReverseOpTest")
.Input("Input")
.Input("Axis")
.Output("Output")
.Finalize(net.NewOperatorDef());
// Warm-up
for (int i = 0; i < 5; ++i) {
net.RunOp(D);
}
net.Sync();
mace::testing::StartTiming();
while (iters--) {
net.RunOp(D);
}
net.Sync();
}
} // namespace
#define MACE_BM_REVERSE_MACRO(N, C, H, W, TYPE, DEVICE) \
static void MACE_BM_REVERSE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE( \
int iters) { \
const int64_t macc = \
static_cast<int64_t>(iters) * N * C * H * W; \
const int64_t tot = static_cast<int64_t>(iters) * N * C * H * W; \
mace::testing::MaccProcessed(macc); \
mace::testing::BytesProcessed(tot *(sizeof(TYPE))); \
Reverse<DEVICE, TYPE>(iters, N, C, H, W); \
} \
MACE_BENCHMARK(MACE_BM_REVERSE_##N##_##C##_##H##_##W##_##TYPE##_##DEVICE)
#define MACE_BM_REVERSE(N, C, H, W) \
MACE_BM_REVERSE_MACRO(N, C, H, W, float, CPU);
MACE_BM_REVERSE(1, 1, 99, 256);
MACE_BM_REVERSE(1, 30, 99, 256);
MACE_BM_REVERSE(1, 50, 99, 256);
} // namespace test
} // namespace ops
} // namespace mace
......@@ -124,7 +124,7 @@ class MemoryOptimizer(object):
@staticmethod
def is_memory_reuse_op(op):
return op.type == 'Reshape' or op.type == 'Identity' \
or op.type == 'Squeeze'
or op.type == 'Squeeze' or op.type == 'ExpandDims'
def optimize(self):
for op in self.net_def.op:
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册