提交 d1c3fef4 编写于 作者: L liuqi

Rename space_to_batch helper function name for readability.

上级 5030c087
......@@ -24,18 +24,18 @@ class BatchToSpaceNDOp : public Operator<D, T> {
bool Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor= this->Output(OUTPUT);
Tensor *space_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
BatchToSpaceHelper(batch_tensor, space_tensor, output_shape);
CalculateOutputShape(batch_tensor, space_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future);
return true;
}
private:
inline void BatchToSpaceHelper(const Tensor *input_tensor,
Tensor *output,
std::vector<index_t> &output_shape) {
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
......
......@@ -12,7 +12,6 @@
namespace mace {
template<DeviceType D, typename T>
class SpaceToBatchNDOp : public Operator<D, T> {
public:
......@@ -24,20 +23,20 @@ class SpaceToBatchNDOp : public Operator<D, T> {
false) {}
bool Run(StatsFuture *future) override {
const Tensor *space_tensor= this->Input(INPUT);
Tensor *batch_tensor= this->Output(OUTPUT);
const Tensor *space_tensor = this->Input(INPUT);
Tensor *batch_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
SpaceToBatchHelper(space_tensor, batch_tensor, output_shape);
CalculateOutputShape(space_tensor, batch_tensor, output_shape.data());
functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
return true;
}
private:
inline void SpaceToBatchHelper(const Tensor *input_tensor,
Tensor *output,
std::vector<index_t> &output_shape) {
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
index_t *output_shape) {
auto paddings = OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
......@@ -47,15 +46,15 @@ class SpaceToBatchNDOp : public Operator<D, T> {
const index_t block_dims = block_shape.size();
index_t block_shape_product = 1;
for (uint32_t block_dim = 0; block_dim < block_dims; ++block_dim) {
MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t padded_input_size = input_tensor->dim(block_dim + 1)
+ paddings[block_dim * 2]
+ paddings[block_dim * 2 + 1];
MACE_CHECK(padded_input_size % block_shape_value == 0,
"padded input ", padded_input_size, " is not divisible by block_shape");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = padded_input_size / block_shape_value;
MACE_CHECK(block_shape[block_dim] > 1, "block_shape's value should be great to 1");
const index_t block_shape_value = block_shape[block_dim];
const index_t padded_input_size = input_tensor->dim(block_dim + 1)
+ paddings[block_dim * 2]
+ paddings[block_dim * 2 + 1];
MACE_CHECK(padded_input_size % block_shape_value == 0,
"padded input ", padded_input_size, " is not divisible by block_shape");
block_shape_product *= block_shape_value;
output_shape[block_dim + 1] = padded_input_size / block_shape_value;
}
output_shape[0] = input_tensor->dim(0) * block_shape_product;
output_shape[3] = input_tensor->dim(3);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册