diff --git a/mace/core/operator.cc b/mace/core/operator.cc index 10e74a8e6a5cf6449d98fdad573c510dc5603652..3b6b789dc7e8cfe2baeb3879d0a64a609baca005 100644 --- a/mace/core/operator.cc +++ b/mace/core/operator.cc @@ -74,6 +74,7 @@ extern void Register_BufferToImage(OperatorRegistry *op_registry); extern void Register_ChannelShuffle(OperatorRegistry *op_registry); extern void Register_Concat(OperatorRegistry *op_registry); extern void Register_Conv2D(OperatorRegistry *op_registry); +extern void Register_DepthToSpace(OperatorRegistry *op_registry); extern void Register_DepthwiseConv2d(OperatorRegistry *op_registry); extern void Register_Eltwise(OperatorRegistry *op_registry); extern void Register_FoldedBatchNorm(OperatorRegistry *op_registry); @@ -109,6 +110,7 @@ OperatorRegistry::OperatorRegistry() { ops::Register_Conv2D(this); ops::Register_DepthwiseConv2d(this); ops::Register_Eltwise(this); + ops::Register_DepthToSpace(this); ops::Register_FoldedBatchNorm(this); ops::Register_FullyConnected(this); ops::Register_FusedConv2D(this); diff --git a/mace/ops/depth_to_space.cc b/mace/ops/depth_to_space.cc index e34d587158b25f1e1913d349d5858f8c2226fd66..7a71e507987f4879f1213487ff112f2e7406888d 100644 --- a/mace/ops/depth_to_space.cc +++ b/mace/ops/depth_to_space.cc @@ -8,11 +8,18 @@ namespace mace { namespace ops { void Register_DepthToSpace(OperatorRegistry *op_registry) { + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") + .Device(DeviceType::CPU) + .TypeConstraint("T") + .Build(), + DepthToSpaceOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::OPENCL) .TypeConstraint("T") .Build(), DepthToSpaceOp); + REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace") .Device(DeviceType::OPENCL) .TypeConstraint("T") diff --git a/mace/ops/depth_to_space.h b/mace/ops/depth_to_space.h index e05152a81a38a4a088f47db4116474ed73a53a1b..fe0aee92e9ae491f4e1c4049483aef330f3fe469 100644 --- a/mace/ops/depth_to_space.h +++ b/mace/ops/depth_to_space.h @@ -20,7 +20,7 @@ class DepthToSpaceOp : public Operator { DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws) : Operator(op_def, ws), functor_(OperatorBase::GetRepeatedArgument("crops", {0, 0, 0, 0}), - OperatorBase::GetRepeatedArgument("block_shape", {1, 1}), + OperatorBase::GetSingleArgument("block_size", 1), true) {} bool Run(StatsFuture *future) override {