attr_test_op.cc 9.6 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29
// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//     http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#include <cstdlib>
#include <iostream>
#include <vector>

#include "paddle/extension.h"

template <typename data_t>
void assign_cpu_kernel(const data_t* x_data,
                       data_t* out_data,
                       int64_t x_numel) {
  for (int i = 0; i < x_numel; ++i) {
    out_data[i] = x_data[i];
  }
}

30 31 32 33 34 35 36 37 38
void CheckAllForwardAttrs(const bool& bool_attr,
                          const int& int_attr,
                          const float& float_attr,
                          const int64_t& int64_attr,
                          const std::string& str_attr,
                          const std::vector<int>& int_vec_attr,
                          const std::vector<float>& float_vec_attr,
                          const std::vector<int64_t>& int64_vec_attr,
                          const std::vector<std::string>& str_vec_attr) {
39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95
  if (bool_attr != true) {
    throw std::runtime_error("bool_attr value error.");
  }
  if (int_attr != 10) {
    throw std::runtime_error("int_attr value error.");
  }
  if (std::abs(float_attr - 3.14) > 1e-6) {
    throw std::runtime_error("float_attr value error.");
  }
  if (int64_attr != 10000000000) {
    throw std::runtime_error("int64_attr value error.");
  }
  if (str_attr != "StrAttr") {
    throw std::runtime_error("str_attr value error.");
  }

  if (int_vec_attr.size() != 3) {
    throw std::runtime_error("int_vec_attr size error.");
  } else {
    for (auto& value : int_vec_attr) {
      if (value != 10) {
        throw std::runtime_error("int_vec_attr value error.");
      }
    }
  }

  if (float_vec_attr.size() != 3) {
    throw std::runtime_error("float_vec_attr size error.");
  } else {
    for (auto& value : float_vec_attr) {
      if (std::abs(value - 3.14) > 1e-6) {
        throw std::runtime_error("float_vec_attr value error.");
      }
    }
  }

  if (int64_vec_attr.size() != 3) {
    throw std::runtime_error("int64_vec_attr size error.");
  } else {
    for (auto& value : int64_vec_attr) {
      if (value != 10000000000) {
        throw std::runtime_error("int64_vec_attr value error.");
      }
    }
  }

  if (str_vec_attr.size() != 3) {
    throw std::runtime_error("str_vec_attr size error.");
  } else {
    for (auto& value : str_vec_attr) {
      if (value != "StrAttr") {
        throw std::runtime_error("str_vec_attr value error.");
      }
    }
  }
}

96 97 98
void CheckAllBackwardAttrs(const int& int_attr,
                           const std::vector<float>& float_vec_attr,
                           const std::vector<std::string>& str_vec_attr) {
99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121
  if (int_attr != 10) {
    throw std::runtime_error("int_attr value error.");
  }

  if (float_vec_attr.size() != 3) {
    throw std::runtime_error("float_vec_attr size error.");
  } else {
    for (auto& value : float_vec_attr) {
      if (std::abs(value - 3.14) > 1e-6) {
        throw std::runtime_error("float_vec_attr value error.");
      }
    }
  }

  if (str_vec_attr.size() != 3) {
    throw std::runtime_error("str_vec_attr size error.");
  } else {
    for (auto& value : str_vec_attr) {
      if (value != "StrAttr") {
        throw std::runtime_error("str_vec_attr value error.");
      }
    }
  }
122 123 124 125 126 127 128 129
}

std::vector<paddle::Tensor> AttrTestForward(
    const paddle::Tensor& x,
    bool bool_attr,
    int int_attr,
    float float_attr,
    int64_t int64_attr,
130 131 132 133 134
    const std::string& str_attr,
    const std::vector<int>& int_vec_attr,
    const std::vector<float>& float_vec_attr,
    const std::vector<int64_t>& int64_vec_attr,
    const std::vector<std::string>& str_vec_attr) {
135
  auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156

  PD_DISPATCH_FLOATING_TYPES(
      x.type(), "assign_cpu_kernel", ([&] {
        assign_cpu_kernel<data_t>(
            x.data<data_t>(), out.mutable_data<data_t>(), x.size());
      }));

  // Check attrs value
  CheckAllForwardAttrs(bool_attr,
                       int_attr,
                       float_attr,
                       int64_attr,
                       str_attr,
                       int_vec_attr,
                       float_vec_attr,
                       int64_vec_attr,
                       str_vec_attr);

  return {out};
}

157 158 159 160 161 162 163 164 165 166 167 168 169
std::vector<std::vector<int64_t>> AttrTestInferShape(
    const std::vector<int64_t>& x_shape,
    bool bool_attr,
    int int_attr,
    float float_attr,
    int64_t int64_attr,
    const std::string& str_attr,
    const std::vector<int>& int_vec_attr,
    const std::vector<float>& float_vec_attr,
    const std::vector<std::string>& str_vec_attr) {
  return {x_shape};
}

170 171 172 173
// The attrs of backward op must be the subset of attrs of forward op
std::vector<paddle::Tensor> AttrTestBackward(
    const paddle::Tensor& grad_out,
    int int_attr,
174 175
    const std::vector<float>& float_vec_attr,
    const std::vector<std::string>& str_vec_attr) {
176
  auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, grad_out.shape());
177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200

  PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] {
                               assign_cpu_kernel<data_t>(
                                   grad_out.data<data_t>(),
                                   grad_x.mutable_data<data_t>(),
                                   grad_out.size());
                             }));

  CheckAllBackwardAttrs(int_attr, float_vec_attr, str_vec_attr);

  return {grad_x};
}

std::vector<paddle::Tensor> ConstAttrTestForward(
    const paddle::Tensor& x,
    const bool& bool_attr,
    const int& int_attr,
    const float& float_attr,
    const int64_t& int64_attr,
    const std::string& str_attr,
    const std::vector<int>& int_vec_attr,
    const std::vector<float>& float_vec_attr,
    const std::vector<int64_t>& int64_vec_attr,
    const std::vector<std::string>& str_vec_attr) {
201
  auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222

  PD_DISPATCH_FLOATING_TYPES(
      x.type(), "assign_cpu_kernel", ([&] {
        assign_cpu_kernel<data_t>(
            x.data<data_t>(), out.mutable_data<data_t>(), x.size());
      }));

  // Check attrs value
  CheckAllForwardAttrs(bool_attr,
                       int_attr,
                       float_attr,
                       int64_attr,
                       str_attr,
                       int_vec_attr,
                       float_vec_attr,
                       int64_vec_attr,
                       str_vec_attr);

  return {out};
}

223 224 225 226 227 228 229 230 231 232 233 234 235
std::vector<std::vector<int64_t>> ConstAttrTestInferShape(
    const std::vector<int64_t>& x_shape,
    const bool& bool_attr,
    const int& int_attr,
    const float& float_attr,
    const int64_t& int64_attr,
    const std::string& str_attr,
    const std::vector<int>& int_vec_attr,
    const std::vector<float>& float_vec_attr,
    const std::vector<std::string>& str_vec_attr) {
  return {x_shape};
}

236 237 238 239 240 241
// The attrs of backward op must be the subset of attrs of forward op
std::vector<paddle::Tensor> ConstAttrTestBackward(
    const paddle::Tensor& grad_out,
    const int& int_attr,
    const std::vector<float>& float_vec_attr,
    const std::vector<std::string>& str_vec_attr) {
242
  auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, grad_out.shape());
243 244 245 246 247 248 249 250 251

  PD_DISPATCH_FLOATING_TYPES(grad_out.type(), "assign_cpu_kernel", ([&] {
                               assign_cpu_kernel<data_t>(
                                   grad_out.data<data_t>(),
                                   grad_x.mutable_data<data_t>(),
                                   grad_out.size());
                             }));

  CheckAllBackwardAttrs(int_attr, float_vec_attr, str_vec_attr);
252 253 254 255

  return {grad_x};
}

256
PD_BUILD_OP(attr_test)
257 258 259 260 261 262 263 264 265 266 267
    .Inputs({"X"})
    .Outputs({"Out"})
    .Attrs({"bool_attr: bool",
            "int_attr: int",
            "float_attr: float",
            "int64_attr: int64_t",
            "str_attr: std::string",
            "int_vec_attr: std::vector<int>",
            "float_vec_attr: std::vector<float>",
            "int64_vec_attr: std::vector<int64_t>",
            "str_vec_attr: std::vector<std::string>"})
268 269
    .SetKernelFn(PD_KERNEL(AttrTestForward))
    .SetInferShapeFn(PD_INFER_SHAPE(AttrTestInferShape));
270 271

PD_BUILD_GRAD_OP(attr_test)
272 273 274 275 276 277
    .Inputs({paddle::Grad("Out")})
    .Outputs({paddle::Grad("X")})
    .Attrs({"int_attr: int",
            "float_vec_attr: std::vector<float>",
            "str_vec_attr: std::vector<std::string>"})
    .SetKernelFn(PD_KERNEL(AttrTestBackward));
278 279 280 281 282 283 284 285 286 287 288 289 290

PD_BUILD_OP(const_attr_test)
    .Inputs({"X"})
    .Outputs({"Out"})
    .Attrs({"bool_attr: bool",
            "int_attr: int",
            "float_attr: float",
            "int64_attr: int64_t",
            "str_attr: std::string",
            "int_vec_attr: std::vector<int>",
            "float_vec_attr: std::vector<float>",
            "int64_vec_attr: std::vector<int64_t>",
            "str_vec_attr: std::vector<std::string>"})
291 292
    .SetKernelFn(PD_KERNEL(ConstAttrTestForward))
    .SetInferShapeFn(PD_INFER_SHAPE(ConstAttrTestInferShape));
293 294 295 296 297 298 299

PD_BUILD_GRAD_OP(const_attr_test)
    .Inputs({paddle::Grad("Out")})
    .Outputs({paddle::Grad("X")})
    .Attrs({"int_attr: int",
            "float_vec_attr: std::vector<float>",
            "str_vec_attr: std::vector<std::string>"})
300
    .SetKernelFn(PD_KERNEL(ConstAttrTestBackward));