pad2d_op.cc 2.5 KB
Newer Older
Y
Yan Chunwei 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include "lite/operators/pad2d_op.h"
#include "lite/core/op_lite.h"
#include "lite/core/op_registry.h"

namespace paddle {
namespace lite {
namespace operators {

bool Pad2dOpLite::CheckShape() const {
  CHECK_GT_OR_FALSE(param_.X->dims().size(), 1UL);
  CHECK_OR_FALSE(param_.Out);
  CHECK(param_.mode == "constant" || param_.mode == "reflect" ||
        param_.mode == "edge")
      << "Invilid mode.";
  CHECK_EQ(param_.paddings.size(), 4UL);
  return true;
}

33
bool Pad2dOpLite::InferShapeImpl() const {
Y
Yan Chunwei 已提交
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
  // nchw
  auto x_dims = param_.X->dims();
  int out_h = x_dims[2] + param_.paddings[0] + param_.paddings[1];
  int out_w = x_dims[3] + param_.paddings[2] + param_.paddings[3];
  param_.Out->Resize(lite::DDim({x_dims[0], x_dims[1], out_h, out_w}));
  return true;
}

// TODO(Superjomn) replace framework::OpDesc with a lite one.
bool Pad2dOpLite::AttachImpl(const cpp::OpDesc &op_desc, lite::Scope *scope) {
  param_.X = scope->FindVar(op_desc.Input("X").front())->GetMutable<Tensor>();
  param_.Out =
      scope->FindVar(op_desc.Output("Out").front())->GetMutable<Tensor>();
  param_.mode = op_desc.GetAttr<std::string>("mode");
  param_.pad_value = op_desc.GetAttr<float>("pad_value");
49 50 51 52 53 54 55 56 57 58 59 60 61 62
  if (op_desc.HasAttr("variable_padding") &&
      op_desc.GetAttr<bool>("variable_paddings")) {
    auto Paddings =
        scope->FindVar(op_desc.Input("Paddings").front())->GetMutable<Tensor>();
    auto ptr = Paddings->data<int>();
    if (Paddings->dims().size() < 4) {
      printf("Paddings size must be four: %d \n",
             static_cast<int>(Paddings->dims().size()));
      return false;
    }
    param_.paddings = {ptr[0], ptr[1], ptr[2], ptr[3]};
  } else {
    param_.paddings = op_desc.GetAttr<std::vector<int>>("paddings");
  }
Y
Yan Chunwei 已提交
63 64 65 66 67 68 69 70 71
  param_.data_format = op_desc.GetAttr<std::string>("data_format");
  return true;
}

}  // namespace operators
}  // namespace lite
}  // namespace paddle

REGISTER_LITE_OP(pad2d, paddle::lite::operators::Pad2dOpLite);