static_prim_api.cc 3.7 KB
Newer Older
J
Jiabin Yang 已提交
1 2 3 4 5 6 7 8 9 10 11 12 13 14
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

15
#include <string.h>
J
Jiabin Yang 已提交
16 17 18 19 20 21 22 23 24 25 26 27 28
#include <memory>
#include <sstream>
#include <string>
#include <unordered_set>
#include <vector>

#include "paddle/fluid/framework/block_desc.h"
#include "paddle/fluid/framework/op_desc.h"
#include "paddle/fluid/framework/op_proto_maker.h"
#include "paddle/fluid/framework/operator.h"
#include "paddle/fluid/framework/program_desc.h"

#include "paddle/fluid/framework/convert_utils.h"
29 30
#include "paddle/fluid/prim/api/manual_prim/prim_manual_api.h"
#include "paddle/fluid/prim/api/manual_prim/utils/utils.h"
J
Jiabin Yang 已提交
31 32 33
#include "paddle/fluid/prim/utils/static/composite_grad_desc_maker.h"
#include "paddle/fluid/prim/utils/static/desc_tensor.h"
#include "paddle/phi/api/include/tensor.h"
34 35 36
#include "paddle/phi/common/data_type.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/core/enforce.h"
J
Jiabin Yang 已提交
37 38 39 40
namespace paddle {
namespace prim {

template <>
41
Tensor reshape<DescTensor>(const Tensor& x, const IntArray& shape) {
42 43
  framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
  framework::OpDesc* op = block->AppendOp();
44 45
  // TODO(cxxly): Fix test_resnet_prim_cinn error when SetType("reshape2")
  op->SetType("reshape");
46 47
  op->SetInput("X",
               {std::static_pointer_cast<prim::DescTensor>(x.impl())->Name()});
48 49
  // Tensor out = empty<DescTensor>({}, x.dtype(), paddle::Place());
  auto out = empty<DescTensor>({}, x.dtype(), paddle::Place());
50 51
  op->SetOutput(
      "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
52
  op->SetAttr("shape", unsafe_vector_cast<int64_t, int>(shape.GetData()));
53 54
  op->CheckAttrs();
  op->InferVarType(block);
X
xiaoguoguo626807 已提交
55
  op->InferShape(*block);
56 57 58
  return out;
}

59
template <>
X
xiaoguoguo626807 已提交
60 61 62 63
Tensor full<DescTensor>(const IntArray& shape,
                        const Scalar& value,
                        DataType dtype,
                        const Place& place) {
64 65 66 67 68 69 70
  // Grad infershape
  Tensor out = empty<DescTensor>({}, dtype, place);
  framework::BlockDesc* block = StaticCompositeContext::Instance().GetBlock();
  framework::OpDesc* op = block->AppendOp();
  op->SetType("fill_constant");
  op->SetAttr("shape", shape.GetData());
  PADDLE_ENFORCE_EQ(
X
xiaoguoguo626807 已提交
71 72
      ((dtype == DataType::FLOAT32) || (dtype == DataType::FLOAT64) ||
       (dtype == DataType::FLOAT16)),
73 74 75 76
      true,
      phi::errors::InvalidArgument(
          "We only support float32/float16 for full, but we got data type: %s",
          phi::DataTypeToString(dtype)));
77 78 79 80 81 82 83 84 85 86
  if (dtype == phi::DataType::FLOAT32) {
    op->SetAttr("value", value.to<float>());
  } else if (dtype == phi::DataType::FLOAT64) {
    op->SetAttr("str_value", std::to_string(value.to<double>()));
  } else if (dtype == phi::DataType::FLOAT16) {
    op->SetAttr("str_value", std::to_string(value.to<float>()));
  } else {
    PADDLE_THROW(phi::errors::Unimplemented(
        "We only support float64/float32/float16 for full"));
  }
87 88 89 90 91 92 93 94
  op->SetAttr("dtype", paddle::framework::TransToProtoVarType(dtype));
  op->SetOutput(
      "Out", {std::static_pointer_cast<prim::DescTensor>(out.impl())->Name()});
  op->CheckAttrs();
  op->InferVarType(block);
  op->InferShape(*block);
  return out;
}
95

J
Jiabin Yang 已提交
96 97
}  // namespace prim
}  // namespace paddle