static_prim_api.cc 5.2 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <string.h>
J
Jiabin Yang 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"

#include "paddle/fluid/framework/convert_utils.h"
29 30
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
J
Jiabin Yang 已提交
31 32 33
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
34 35 36
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
J
Jiabin Yang 已提交
37 38 39 40
namespace paddle {
namespace prim {

template <>
41
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
42 43
  framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
  framework::OpDesc* op = block->AppendOp();
44 45
  // TODO(cxxly): Fix test_resnet_prim_cinn error when SetType("reshape2")
  op->SetType("reshape");
46 47
  op->SetInput("X",
               {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
48 49
  // Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
  auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
50 51
  op->SetOutput(
      "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
52
  op->SetAttr("shape", unsafe_vector_cast<int64_t, int>(shape.GetData()));
53 54
  op->CheckAttrs();
  op->InferVarType(block);
X
xiaoguoguo626807 已提交
55
  op->InferShape(*block);
56 57 58
  return out;
}

59
template <>
X
xiaoguoguo626807 已提交
60 61 62 63
Tensor full<DescTensor>(const IntArray& shape,
                        const Scalar& value,
                        DataType dtype,
                        const Place& place) {
64 65 66 67 68 69
  // Grad infershape
  Tensor out = empty<DescTensor>({}, dtype, place);
  framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
  framework::OpDesc* op = block->AppendOp();
  op->SetType("fill_constant");
  op->SetAttr("shape", shape.GetData());
70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111
  switch (dtype) {
    case phi::DataType::FLOAT16:
      op->SetAttr("str_value", std::to_string(value.to<float>()));
      break;
    case phi::DataType::FLOAT32:
      op->SetAttr("value", value.to<float>());
      break;
    case phi::DataType::FLOAT64:
      op->SetAttr("str_value", std::to_string(value.to<double>()));
      break;
    case phi::DataType::BOOL:
      op->SetAttr("str_value", std::to_string(value.to<bool>()));
      break;
    case phi::DataType::INT8:
      op->SetAttr("str_value", std::to_string(value.to<int8_t>()));
      break;
    case phi::DataType::UINT8:
      op->SetAttr("str_value", std::to_string(value.to<uint8_t>()));
      break;
    case phi::DataType::INT16:
      op->SetAttr("str_value", std::to_string(value.to<int16_t>()));
      break;
    case phi::DataType::UINT16:
      op->SetAttr("str_value", std::to_string(value.to<uint16_t>()));
      break;
    case phi::DataType::INT32:
      op->SetAttr("str_value", std::to_string(value.to<int32_t>()));
      break;
    case phi::DataType::UINT32:
      op->SetAttr("str_value", std::to_string(value.to<uint32_t>()));
      break;
    case phi::DataType::INT64:
      op->SetAttr("str_value", std::to_string(value.to<int64_t>()));
      break;
    case phi::DataType::UINT64:
      op->SetAttr("str_value", std::to_string(value.to<uint64_t>()));
      break;
    default:
      PADDLE_THROW(phi::errors::Unimplemented(
          "We support "
          "bool/float16/float32/float64/int8/int16/int32/int64/uint8/uint16/"
          "uint32/uint64 for full, but we got data type: %s",
112
          phi::DataTypeToString(dtype)));
113
  }
114

115 116 117 118 119 120 121 122
  op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
  op->SetOutput(
      "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
  op->CheckAttrs();
  op->InferVarType(block);
  op->InferShape(*block);
  return out;
}
123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139
template <>
Tensor cast<DescTensor>(const Tensor& x, DataType dtype) {
  Tensor out = empty<DescTensor>({}, DataType::FLOAT32, paddle::Place());
  framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
  framework::OpDesc* op = block->AppendOp();
  op->SetType("cast");
  op->SetInput("X",
               {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
  op->SetOutput(
      "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
  op->SetAttr("in_dtype", static_cast<int>(x.dtype()));
  op->SetAttr("out_dtype", static_cast<int>(dtype));
  op->CheckAttrs();
  op->InferVarType(block);
  op->InferShape(*block);
  return out;
}
J
Jiabin Yang 已提交
140 141
}  // namespace prim
}  // namespace paddle