未验证 提交 c472d105 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP] InferDtypeFn supports attrs (#56141)

* [CustomOP] InferDtypeFn supports attrs

* Update paddle/fluid/framework/custom_operator.cc

* update cmake list

* fix cpu device

* change unittest time
上级 a1e2c63c
...@@ -669,6 +669,7 @@ static void RunInferDtypeFunc( ...@@ -669,6 +669,7 @@ static void RunInferDtypeFunc(
const paddle::InferDtypeFunc& func, const paddle::InferDtypeFunc& func,
const std::vector<std::string>& inputs, const std::vector<std::string>& inputs,
const std::vector<std::string>& outputs, const std::vector<std::string>& outputs,
const std::vector<std::string>& attrs,
const std::unordered_map<std::string, std::string>& inplace_map, const std::unordered_map<std::string, std::string>& inplace_map,
const std::unordered_map<std::string, std::string>& inplace_reverse_map) { const std::unordered_map<std::string, std::string>& inplace_reverse_map) {
std::vector<DataType> input_dtypes; std::vector<DataType> input_dtypes;
...@@ -711,8 +712,51 @@ static void RunInferDtypeFunc( ...@@ -711,8 +712,51 @@ static void RunInferDtypeFunc(
} }
} }
std::vector<paddle::any> custom_attrs;
for (auto& attr_str : attrs) {
auto attr_name_and_type = paddle::ParseAttrStr(attr_str);
auto attr_name = attr_name_and_type[0];
auto attr_type_str = attr_name_and_type[1];
if (attr_type_str == "bool") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(bool, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "int") {
custom_attrs.emplace_back(PADDLE_GET_CONST(int, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "float") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(float, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "int64_t") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(int64_t, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "std::string") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(std::string, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "std::vector<int>") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(std::vector<int>, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "std::vector<float>") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(std::vector<float>, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "std::vector<int64_t>") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(std::vector<int64_t>, ctx->GetAttr(attr_name)));
} else if (attr_type_str == "std::vector<std::string>") {
custom_attrs.emplace_back(
PADDLE_GET_CONST(std::vector<std::string>, ctx->GetAttr(attr_name)));
} else {
PADDLE_THROW(phi::errors::Unimplemented(
"Unsupported `%s` type value as custom attribute now. "
"Supported data types include `bool`, `int`, `float`, "
"`int64_t`, `std::string`, `std::vector<int>`, "
"`std::vector<float>`, `std::vector<int64_t>`, "
"`std::vector<std::string>`, Please check whether the attribute data "
"type and data type string are matched.",
attr_type_str));
}
}
VLOG(3) << "Custom Operator: InferDtype - infer output dtype."; VLOG(3) << "Custom Operator: InferDtype - infer output dtype.";
auto output_dtypes = func(input_dtypes, vec_input_dtypes); auto output_dtypes = func(input_dtypes, vec_input_dtypes, custom_attrs);
if (inplace_map.empty()) { if (inplace_map.empty()) {
PADDLE_ENFORCE_EQ(outputs.size(), PADDLE_ENFORCE_EQ(outputs.size(),
output_dtypes.size(), output_dtypes.size(),
...@@ -1016,6 +1060,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1016,6 +1060,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
} else { } else {
info.infer_var_type_ = [op_inputs, info.infer_var_type_ = [op_inputs,
op_outputs, op_outputs,
op_attrs,
op_inplace_map, op_inplace_map,
op_inplace_reverse_map, op_inplace_reverse_map,
infer_dtype_func](InferVarTypeContext* ctx) { infer_dtype_func](InferVarTypeContext* ctx) {
...@@ -1023,6 +1068,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1023,6 +1068,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
infer_dtype_func, infer_dtype_func,
op_inputs, op_inputs,
op_outputs, op_outputs,
op_attrs,
op_inplace_map, op_inplace_map,
op_inplace_reverse_map); op_inplace_reverse_map);
}; };
...@@ -1051,6 +1097,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1051,6 +1097,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
OpMetaInfoHelper::GetInplaceReverseMap(cur_grad_op); OpMetaInfoHelper::GetInplaceReverseMap(cur_grad_op);
auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op); auto& grad_kernel_fn = OpMetaInfoHelper::GetKernelFn(cur_grad_op);
auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op); auto& grad_infer_shape_fn = OpMetaInfoHelper::GetInferShapeFn(cur_grad_op);
auto& grad_infer_dtype_fn = OpMetaInfoHelper::GetInferDtypeFn(cur_grad_op);
VLOG(3) << "Custom Operator: backward, op name: " << grad_op_name; VLOG(3) << "Custom Operator: backward, op name: " << grad_op_name;
VLOG(3) << "Custom Operator: backward, op inputs: " VLOG(3) << "Custom Operator: backward, op inputs: "
...@@ -1182,6 +1229,25 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos, ...@@ -1182,6 +1229,25 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
}; };
} }
// Grad InferDtype
if (grad_infer_dtype_fn != nullptr) {
grad_info.infer_var_type_ =
[grad_op_inputs,
grad_op_outputs,
grad_op_attrs,
grad_op_inplace_map,
grad_op_inplace_reverse_map,
grad_infer_dtype_fn](InferVarTypeContext* ctx) {
RunInferDtypeFunc(ctx,
grad_infer_dtype_fn,
grad_op_inputs,
grad_op_outputs,
grad_op_attrs,
grad_op_inplace_map,
grad_op_inplace_reverse_map);
};
}
// Kernel func // Kernel func
RegisterOperatorKernel(grad_op_name, RegisterOperatorKernel(grad_op_name,
grad_kernel_fn, grad_kernel_fn,
......
...@@ -643,35 +643,71 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -643,35 +643,71 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
// Record Op Infer dtype core function // Record Op Infer dtype core function
using InferDtypeFunc = std::vector<DataType> (*)( using InferDtypeFunc = std::vector<DataType> (*)(
const std::vector<DataType>& input_dtypes, const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes); const std::vector<std::vector<DataType>>& vec_input_dtypes,
const std::vector<paddle::any>& attrs);
#define PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(input_type) \ #define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPE(input_type) \
template <typename... Tail> \ template <typename... Tail> \
struct InferDtypeCallHelper<input_type, Tail...> { \ struct InferDtypeCallHelper<input_type, Tail...> { \
template <int in_idx, int vec_in_idx, typename... PreviousArgs> \ template <int in_idx, \
int vec_in_idx, \
int attr_idx, \
typename... PreviousArgs> \
static Return InferDtype( \ static Return InferDtype( \
const std::vector<DataType>& input_dtypes, \ const std::vector<DataType>& input_dtypes, \
const std::vector<std::vector<DataType>>& vec_input_dtypes, \ const std::vector<std::vector<DataType>>& vec_input_dtypes, \
const std::vector<paddle::any>& attrs, \
const PreviousArgs&... pargs) { \ const PreviousArgs&... pargs) { \
input_type arg = input_dtypes[in_idx]; \ input_type arg = input_dtypes[in_idx]; \
return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1, \ return InferDtypeCallHelper<Tail...>:: \
vec_in_idx>( \ template InferDtype<in_idx + 1, vec_in_idx, attr_idx>( \
input_dtypes, vec_input_dtypes, pargs..., arg); \ input_dtypes, vec_input_dtypes, attrs, pargs..., arg); \
} \ } \
} }
#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(input_type) \ #define PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(input_type) \
template <typename... Tail> \ template <typename... Tail> \
struct InferDtypeCallHelper<input_type, Tail...> { \ struct InferDtypeCallHelper<input_type, Tail...> { \
template <int in_idx, int vec_in_idx, typename... PreviousArgs> \ template <int in_idx, \
int vec_in_idx, \
int attr_idx, \
typename... PreviousArgs> \
static Return InferDtype( \ static Return InferDtype( \
const std::vector<DataType>& input_dtypes, \ const std::vector<DataType>& input_dtypes, \
const std::vector<std::vector<DataType>>& vec_input_dtypes, \ const std::vector<std::vector<DataType>>& vec_input_dtypes, \
const std::vector<paddle::any>& attrs, \
const PreviousArgs&... pargs) { \ const PreviousArgs&... pargs) { \
input_type arg = vec_input_dtypes[vec_in_idx]; \ input_type arg = vec_input_dtypes[vec_in_idx]; \
return InferDtypeCallHelper<Tail...>:: \ return InferDtypeCallHelper<Tail...>:: \
template InferDtype<in_idx, vec_in_idx + 1>( \ template InferDtype<in_idx, vec_in_idx + 1, attr_idx>( \
input_dtypes, vec_input_dtypes, pargs..., arg); \ input_dtypes, vec_input_dtypes, attrs, pargs..., arg); \
} \
}
#define PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(attr_type) \
template <typename... Tail> \
struct InferDtypeCallHelper<attr_type, Tail...> { \
template <int in_idx, \
int vec_in_idx, \
int attr_idx, \
typename... PreviousArgs> \
static Return InferDtype( \
const std::vector<DataType>& input_dtypes, \
const std::vector<std::vector<DataType>>& vec_input_dtypes, \
const std::vector<paddle::any>& attrs, \
const PreviousArgs&... pargs) { \
try { \
attr_type arg = paddle::any_cast<attr_type>(attrs[attr_idx]); \
return InferDtypeCallHelper<Tail...>:: \
template InferDtype<in_idx, vec_in_idx, attr_idx + 1>( \
input_dtypes, vec_input_dtypes, attrs, pargs..., arg); \
} catch (paddle::bad_any_cast&) { \
PD_THROW( \
"Attribute cast error in custom operator InferDtypeFn. " \
"Expected " #attr_type \
" value. InferDtypeFn's attribute list must be exactly same as " \
"Forward KernelFn's attribute list"); \
} \
} \ } \
} }
...@@ -682,35 +718,39 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)> ...@@ -682,35 +718,39 @@ template <typename Return, typename... Args, Return (*impl_fn)(Args...)>
struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> { struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
static Return InferDtype( static Return InferDtype(
const std::vector<DataType>& input_dtypes, const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes) { const std::vector<std::vector<DataType>>& vec_input_dtypes,
return InferDtypeCallHelper<Args..., TypeTag<int>>::template InferDtype<0, const std::vector<paddle::any>& attrs) {
0>( return InferDtypeCallHelper<Args..., TypeTag<int>>::
input_dtypes, vec_input_dtypes); template InferDtype<0, 0, 0>(input_dtypes, vec_input_dtypes, attrs);
} }
private: private:
template <typename... RemainingArgs> template <typename... RemainingArgs>
struct InferDtypeCallHelper; struct InferDtypeCallHelper;
PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(const DataType&); PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPE(const DataType&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector<DataType>&); PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector<DataType>&);
template <typename... Tail> template <typename... Tail>
struct InferDtypeCallHelper<const paddle::optional<DataType>&, Tail...> { struct InferDtypeCallHelper<const paddle::optional<DataType>&, Tail...> {
template <int in_idx, int vec_in_idx, typename... PreviousArgs> template <int in_idx,
int vec_in_idx,
int attr_idx,
typename... PreviousArgs>
static Return InferDtype( static Return InferDtype(
const std::vector<DataType>& input_dtypes, const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes, const std::vector<std::vector<DataType>>& vec_input_dtypes,
const std::vector<paddle::any>& attrs,
const PreviousArgs&... pargs) { const PreviousArgs&... pargs) {
const DataType& arg = input_dtypes[in_idx]; const DataType& arg = input_dtypes[in_idx];
if (arg == DataType::UNDEFINED) { if (arg == DataType::UNDEFINED) {
return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1, return InferDtypeCallHelper<Tail...>::
vec_in_idx>( template InferDtype<in_idx + 1, vec_in_idx, attr_idx>(
input_dtypes, vec_input_dtypes, pargs..., paddle::none); input_dtypes, vec_input_dtypes, attrs, pargs..., paddle::none);
} else { } else {
return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1, return InferDtypeCallHelper<Tail...>::
vec_in_idx>( template InferDtype<in_idx + 1, vec_in_idx, attr_idx>(
input_dtypes, vec_input_dtypes, pargs..., arg); input_dtypes, vec_input_dtypes, attrs, pargs..., arg);
} }
} }
}; };
...@@ -718,36 +758,65 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> { ...@@ -718,36 +758,65 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
template <typename... Tail> template <typename... Tail>
struct InferDtypeCallHelper<const paddle::optional<std::vector<DataType>>&, struct InferDtypeCallHelper<const paddle::optional<std::vector<DataType>>&,
Tail...> { Tail...> {
template <int in_idx, int vec_in_idx, typename... PreviousArgs> template <int in_idx,
int vec_in_idx,
int attr_idx,
typename... PreviousArgs>
static Return InferDtype( static Return InferDtype(
const std::vector<DataType>& input_dtypes, const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes, const std::vector<std::vector<DataType>>& vec_input_dtypes,
const std::vector<paddle::any>& attrs,
const PreviousArgs&... pargs) { const PreviousArgs&... pargs) {
const std::vector<DataType>& arg = vec_input_dtypes[vec_in_idx]; const std::vector<DataType>& arg = vec_input_dtypes[vec_in_idx];
if (arg.empty()) { if (arg.empty()) {
return InferDtypeCallHelper<Tail...>:: return InferDtypeCallHelper<Tail...>::
template InferDtype<in_idx, vec_in_idx + 1>( template InferDtype<in_idx, vec_in_idx + 1, attr_idx>(
input_dtypes, vec_input_dtypes, pargs..., paddle::none); input_dtypes, vec_input_dtypes, attrs, pargs..., paddle::none);
} else { } else {
return InferDtypeCallHelper<Tail...>:: return InferDtypeCallHelper<Tail...>::
template InferDtype<in_idx, vec_in_idx + 1>( template InferDtype<in_idx, vec_in_idx + 1, attr_idx>(
input_dtypes, vec_input_dtypes, pargs..., arg); input_dtypes, vec_input_dtypes, attrs, pargs..., arg);
} }
} }
}; };
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released // NOTE(HongyuJia): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future // interface, and will be deprecated in the future
PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(DataType); PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPE(DataType);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(std::vector<DataType>); PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(std::vector<DataType>);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(bool);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(int);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(float);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(int64_t);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::string&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector<int>&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector<float>&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector<int64_t>&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const std::vector<std::string>&);
// NOTE(HongyuJia): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const bool&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const int&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const float&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(const int64_t&);
// NOTE(HongyuJia): Used to be compatible with the 2.1 released
// interface, but not recommended
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::string);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::vector<int>);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::vector<float>);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_ATTR(std::vector<std::string>);
// end: base template // end: base template
template <typename T> template <typename T>
struct InferDtypeCallHelper<TypeTag<T>> { struct InferDtypeCallHelper<TypeTag<T>> {
template <int in_idx, int vec_in_idx> template <int in_idx, int vec_in_idx, int attr_idx>
static Return InferDtype( static Return InferDtype(
const std::vector<DataType>& input_dtypes, const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes, const std::vector<std::vector<DataType>>& vec_input_dtypes,
const std::vector<paddle::any>& attrs,
const Args&... args) { const Args&... args) {
return impl_fn(args...); return impl_fn(args...);
} }
......
...@@ -506,13 +506,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) { ...@@ -506,13 +506,6 @@ OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferShapeFn(InferShapeFunc func) {
} }
OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) { OpMetaInfoBuilder& OpMetaInfoBuilder::SetInferDtypeFn(InferDtypeFunc func) {
PADDLE_ENFORCE_EQ(
index_,
0UL,
phi::errors::Unimplemented(
"Currently, the InferDtypeFn setting of Grad Op is not supported, "
"And backward Tensor `X@GRAD` will use the dtype of forward Tensor "
"`X` by default."));
info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func)); info_ptr_->SetInferDtypeFn(std::forward<InferDtypeFunc>(func));
return *this; return *this;
} }
......
...@@ -39,6 +39,8 @@ if(WITH_TESTING) ...@@ -39,6 +39,8 @@ if(WITH_TESTING)
py_test(test_dispatch_jit SRCS test_dispatch_jit.py) py_test(test_dispatch_jit SRCS test_dispatch_jit.py)
py_test(test_multi_out_jit SRCS test_multi_out_jit.py) py_test(test_multi_out_jit SRCS test_multi_out_jit.py)
py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py) py_test(test_custom_attrs_jit SRCS test_custom_attrs_jit.py)
py_test(test_custom_cast_op_jit SRCS test_custom_cast_op_jit.py)
set_tests_properties(test_custom_cast_op_jit PROPERTIES TIMEOUT 180)
py_test(test_custom_concat SRCS test_custom_concat.py) py_test(test_custom_concat SRCS test_custom_concat.py)
set_tests_properties( set_tests_properties(
test_custom_concat PROPERTIES ENVIRONMENT test_custom_concat PROPERTIES ENVIRONMENT
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
paddle::DataType ConvertDtype(const std::string& data_type) {
if (data_type == "float16") {
return paddle::DataType::FLOAT16;
} else if (data_type == "float32") {
return paddle::DataType::FLOAT32;
} else if (data_type == "float64") {
return paddle::DataType::FLOAT64;
} else {
PD_THROW("DataType Not Supported.");
}
}
std::vector<paddle::Tensor> CastForward(const paddle::Tensor& x,
const std::string& data_type) {
return {paddle::experimental::cast(x, ConvertDtype(data_type))};
}
std::vector<paddle::DataType> CastForwardInferDtype(
const paddle::DataType& input_dtype, const std::string& data_type) {
return {ConvertDtype(data_type)};
}
std::vector<paddle::Tensor> CastBackward(const paddle::Tensor& grad_out,
const std::string& data_type) {
return {paddle::experimental::cast(grad_out, ConvertDtype(data_type))};
}
std::vector<paddle::DataType> CastBackwardInferDtype(
const paddle::DataType& grad_out_dtype, const std::string& data_type) {
return {ConvertDtype(data_type)};
}
PD_BUILD_OP(custom_cast)
.Inputs({"X"})
.Attrs({"data_type: std::string"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(CastForward))
.SetInferDtypeFn(PD_INFER_DTYPE(CastForwardInferDtype));
PD_BUILD_GRAD_OP(custom_cast)
.Inputs({paddle::Grad("Out")})
.Attrs({"data_type: std::string"})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(CastBackward))
.SetInferDtypeFn(PD_INFER_DTYPE(CastBackwardInferDtype));
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from utils import (
extra_cc_args,
extra_nvcc_args,
paddle_includes,
paddle_libraries,
)
import paddle
from paddle import static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_cast_module_jit\\custom_cast_module_jit.pyd'.format(
get_build_directory()
)
if os.name == 'nt' and os.path.isfile(file):
cmd = f'del {file}'
run_cmd(cmd, True)
custom_module = load(
name='custom_cast_module_jit',
sources=['custom_cast_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_library_paths=paddle_libraries,
extra_cxx_cflags=extra_cc_args, # test for cc flags
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
verbose=True,
)
def custom_cast_dynamic(device, dtype, np_x):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype="float32")
x.stop_gradient = False
out = custom_module.custom_cast(x, dtype)
out.stop_gradient = False
out.backward()
assert str(out.dtype).split(".")[-1] == dtype
assert str(x.grad.dtype).split(".")[-1] == dtype
def custom_cast_static(device, dtype, np_x):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name='X', shape=[None, 8], dtype="float32")
x.stop_gradient = False
out = custom_module.custom_cast(x, dtype)
static.append_backward(out)
exe = static.Executor()
exe.run(static.default_startup_program())
# in static graph mode, x data has been covered by out
out_v, x_grad_v = exe.run(
static.default_main_program(),
feed={'X': np_x},
fetch_list=[out.name, x.name + "@GRAD"],
)
assert x_grad_v[0].dtype == dtype
assert out_v[0].dtype == dtype
paddle.disable_static()
return out_v
class TestCustomCastOp(unittest.TestCase):
def setUp(self):
self.dtypes = ['float32', 'float64']
def test_static(self):
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype("float32")
custom_cast_static('cpu', dtype, x)
def test_dynamic(self):
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype("float32")
custom_cast_dynamic('cpu', dtype, x)
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册