From 40f6633a935cefaaf3ac92d13f77ccbe18d50dbe Mon Sep 17 00:00:00 2001 From: Unknown Date: Mon, 19 Mar 2018 19:18:25 +0800 Subject: [PATCH] add Register depth_to_space in operator.cc --- mace/core/operator.cc | 2 ++ mace/ops/depth_to_space.cc | 7 +++++++ mace/ops/depth_to_space.h | 2 +- 3 files changed, 10 insertions(+), 1 deletion(-) diff --git a/mace/core/operator.cc b/mace/core/operator.cc index ad3c8e58..710806f8 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -73,6 +73,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); +extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); extern void Register_FusedConv2D(OperatorRegistry *op_registry); @@ -103,6 +104,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Concat(this); ops::Register_Conv2D(this); ops::Register_DepthwiseConv2d(this); + ops::Register_DepthToSpace(this); ops::Register_FoldedBatchNorm(this); ops::Register_FusedConv2D(this); ops::Register_GlobalAvgPooling(this); diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc index e34d5871..7a71e507 100644 --- a/mace/ops/depth_to_space.cc +++ b/mace/ops/depth_to_space.cc @@ -8,11 +8,18 @@ namespace mace { namespace ops { void Register_DepthToSpace(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::OPENCL) .TypeConstraint("T") .Build(), DepthToSpaceOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h index e05152a8..fe0aee92 100644 --- a/mace/ops/depth_to_space.h +++ b/mace/ops/depth_to_space.h @@ -20,7 +20,7 @@ class DepthToSpaceOp : public Operator { DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), functor_(OperatorBase::GetRepeatedArgument("crops", {0, 0, 0, 0}), - OperatorBase::GetRepeatedArgument("block_shape", {1, 1}), + OperatorBase::GetSingleArgument("block_size", 1), true) {} bool Run(StatsFuture *future) override { -- GitLab