提交 9a17b40b 编写于 作者: U Unknown 提交者: liutuo

Add ops/depth_to_space files

上级 69baafe9
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#include "mace/ops/depth_to_space.h"
namespace mace {
namespace ops {
void Register_DepthToSpace(OperatorRegistry *op_registry) {
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::OPENCL)
.TypeConstraint<float>("T")
.Build(),
DepthToSpaceOp<DeviceType::OPENCL, float>);
REGISTER_OPERATOR(op_registry, OpKeyBuilder("DepthToSpace")
.Device(DeviceType::OPENCL)
.TypeConstraint<half>("T")
.Build(),
DepthToSpaceOp<DeviceType::OPENCL, half>);
}
} // namespace ops
} // namespace mace
//
// Copyright (c) 2017 XiaoMi All rights reserved.
//
#ifndef MACE_OPS_DEPTH_TO_SPACE_H_
#define MACE_OPS_DEPTH_TO_SPACE_H_
#include <memory>
#include <vector>
#include "mace/core/operator.h"
#include "mace/kernels/depth_to_space.h"
namespace mace {
namespace ops {
template <DeviceType D, typename T>
class DepthToSpaceOp : public Operator<D, T> {
public:
DepthToSpaceOp(const OperatorDef &op_def, Workspace *ws)
: Operator<D, T>(op_def, ws),
functor_(OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0}),
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}),
true) {}
bool Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
CalculateOutputShape(batch_tensor, space_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor),
future);
return true;
}
private:
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape =
OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
MACE_CHECK(block_shape.size() == 2, "Block's shape should be 1D");
MACE_CHECK(crops.size() == 4, "Crops' shape should be 2D");
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1,
"block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t cropped_input_size =
input_tensor->dim(block_dim + 1) * block_shape_value -
crops[block_dim * 2] - crops[block_dim * 2 + 1];
MACE_CHECK(cropped_input_size >= 0, "cropped size must be non-negative");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = cropped_input_size;
}
output_shape[0] = input_tensor->dim(0) / block_shape_product;
output_shape[3] = input_tensor->dim(3);
}
private:
kernels::DepthToSpaceOpFunctor<D, T> functor_;
protected:
OP_INPUT_TAGS(INPUT);
OP_OUTPUT_TAGS(OUTPUT);
};
} // namespace ops
} // namespace mace
#endif // MACE_OPS_DEPTH_TO_SPACE_H_
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册