提交 d1c3fef4 编写于 作者: L liuqi

Rename space_to_batch helper function name for readability.

上级 5030c087
...@@ -24,18 +24,18 @@ class BatchToSpaceNDOp : public Operator<D, T> { ...@@ -24,18 +24,18 @@ class BatchToSpaceNDOp : public Operator<D, T> {
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT); const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor= this->Output(OUTPUT); Tensor *space_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0); std::vector<index_t> output_shape(4, 0);
BatchToSpaceHelper(batch_tensor, space_tensor, output_shape); CalculateOutputShape(batch_tensor, space_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future); functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future);
return true; return true;
} }
private: private:
inline void BatchToSpaceHelper(const Tensor *input_tensor, inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output, Tensor *output,
std::vector<index_t> &output_shape) { index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0}); auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}); auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
......
...@@ -12,7 +12,6 @@ ...@@ -12,7 +12,6 @@
namespace mace { namespace mace {
template<DeviceType D, typename T> template<DeviceType D, typename T>
class SpaceToBatchNDOp : public Operator<D, T> { class SpaceToBatchNDOp : public Operator<D, T> {
public: public:
...@@ -24,20 +23,20 @@ class SpaceToBatchNDOp : public Operator<D, T> { ...@@ -24,20 +23,20 @@ class SpaceToBatchNDOp : public Operator<D, T> {
false) {} false) {}
bool Run(StatsFuture *future) override { bool Run(StatsFuture *future) override {
const Tensor *space_tensor= this->Input(INPUT); const Tensor *space_tensor = this->Input(INPUT);
Tensor *batch_tensor= this->Output(OUTPUT); Tensor *batch_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0); std::vector<index_t> output_shape(4, 0);
SpaceToBatchHelper(space_tensor, batch_tensor, output_shape); CalculateOutputShape(space_tensor, batch_tensor, output_shape.data());
functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future); functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
return true; return true;
} }
private: private:
inline void SpaceToBatchHelper(const Tensor *input_tensor, inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output, Tensor *output,
std::vector<index_t> &output_shape) { index_t *output_shape) {
auto paddings = OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0}); auto paddings = OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1}); auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D"); MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册