提交 d1c3fef4 编写于 作者: L liuqi

Rename space_to_batch helper function name for readability.

上级 5030c087
......@@ -24,18 +24,18 @@ class BatchToSpaceNDOp : public Operator<D, T> {
bool Run(StatsFuture *future) override {
const Tensor *batch_tensor = this->Input(INPUT);
Tensor *space_tensor= this->Output(OUTPUT);
Tensor *space_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
BatchToSpaceHelper(batch_tensor, space_tensor, output_shape);
CalculateOutputShape(batch_tensor, space_tensor, output_shape.data());
functor_(space_tensor, output_shape, const_cast<Tensor *>(batch_tensor), future);
return true;
}
private:
inline void BatchToSpaceHelper(const Tensor *input_tensor,
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
std::vector<index_t> &output_shape) {
index_t *output_shape) {
auto crops = OperatorBase::GetRepeatedArgument<int>("crops", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
......
......@@ -12,7 +12,6 @@
namespace mace {
template<DeviceType D, typename T>
class SpaceToBatchNDOp : public Operator<D, T> {
public:
......@@ -24,20 +23,20 @@ class SpaceToBatchNDOp : public Operator<D, T> {
false) {}
bool Run(StatsFuture *future) override {
const Tensor *space_tensor= this->Input(INPUT);
Tensor *batch_tensor= this->Output(OUTPUT);
const Tensor *space_tensor = this->Input(INPUT);
Tensor *batch_tensor = this->Output(OUTPUT);
std::vector<index_t> output_shape(4, 0);
SpaceToBatchHelper(space_tensor, batch_tensor, output_shape);
CalculateOutputShape(space_tensor, batch_tensor, output_shape.data());
functor_(const_cast<Tensor *>(space_tensor), output_shape, batch_tensor, future);
return true;
}
private:
inline void SpaceToBatchHelper(const Tensor *input_tensor,
inline void CalculateOutputShape(const Tensor *input_tensor,
Tensor *output,
std::vector<index_t> &output_shape) {
index_t *output_shape) {
auto paddings = OperatorBase::GetRepeatedArgument<int>("paddings", {0, 0, 0, 0});
auto block_shape = OperatorBase::GetRepeatedArgument<int>("block_shape", {1, 1});
MACE_CHECK(input_tensor->dim_size() == 4, "Input's shape should be 4D");
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册