未验证 提交 b74e00e1 编写于 作者: H HongyuJia 提交者: GitHub

[CustomOP Optional] CustomOP supports optional Tensor (#51923)

* [CustomOP Optional] CustomOP supports optional Tensor

* fix test_custom_concat, add pytest to CMakeLists
上级 af2fa429
......@@ -191,6 +191,8 @@ RunCustomOpNode::operator()(paddle::small_vector<std::vector<paddle::Tensor>,
}
for (auto it : fwd_ins) {
// NOTE(HongyuJia): returned tensor maybe un-defined tensor when inputs
// optional<Tensor>
VLOG(7) << "Insert fwd_ins to grad_inputs: " << it.first;
tmp_ins[it.first] = RunCustomOpNode::Recover(&(it.second));
}
......
......@@ -76,6 +76,11 @@ inline static bool IsDuplicableVar(const std::string& var_name) {
return var_name.rfind(suffix) != std::string::npos;
}
inline static bool IsOptionalVar(const std::string& var_name) {
std::string suffix = kOptionalSuffix;
return var_name.rfind(suffix) != std::string::npos;
}
inline static std::string NoGrad(const std::string& var_name,
bool is_double_grad = false) {
std::string suffix = kGradVarSuffix;
......@@ -141,57 +146,79 @@ static void RunKernelFunc(
paddle::CustomOpKernelContext kernel_ctx;
for (auto& in_name : inputs) {
VLOG(3) << "Custom Operator: input name - " << in_name;
if (detail::IsDuplicableVar(in_name)) {
// return const std::vector<const phi::DenseTensor*>
auto vec_x = ctx.MultiInput<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NE(vec_x.empty(),
true,
platform::errors::NotFound(
"Input vector<tensor> (%s) is empty.", in_name));
if (detail::IsDuplicableVar(in_name)) { // inputs vector<Tensor>
std::vector<paddle::Tensor> custom_vec_in;
for (size_t i = 0; i < vec_x.size(); ++i) {
auto* x = vec_x[i];
PADDLE_ENFORCE_NOT_NULL(
x,
platform::errors::NotFound(
"The %d-th tensor in input vector<tensor> (%s) is nullptr.",
i,
in_name));
PADDLE_ENFORCE_EQ(x->IsInitialized(),
if (ctx.HasInputs(in_name)) { // general inputs
// return const std::vector<const phi::DenseTensor*>
auto vec_x = ctx.MultiInput<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NE(vec_x.empty(),
true,
platform::errors::InvalidArgument(
"The %d-th tensor in input vector<tensor> (%s) "
"is not initialized.",
i,
in_name));
paddle::Tensor custom_t;
custom_t.set_impl(std::make_shared<phi::DenseTensor>(*x));
custom_vec_in.emplace_back(custom_t);
platform::errors::NotFound(
"Input vector<tensor> (%s) is empty.", in_name));
for (size_t i = 0; i < vec_x.size(); ++i) {
auto* x = vec_x[i];
PADDLE_ENFORCE_NOT_NULL(
x,
platform::errors::NotFound(
"The %d-th tensor in input vector<tensor> (%s) is nullptr.",
i,
in_name));
PADDLE_ENFORCE_EQ(x->IsInitialized(),
true,
platform::errors::InvalidArgument(
"The %d-th tensor in input vector<tensor> (%s) "
"is not initialized.",
i,
in_name));
paddle::Tensor custom_t;
custom_t.set_impl(std::make_shared<phi::DenseTensor>(*x));
custom_vec_in.emplace_back(custom_t);
}
} else { // optional inputs, `custom_vec_in` is empty
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's KernelFunc cannot "
"find input parameter `%s`",
in_name));
VLOG(3) << "Custom Operator: KernelFunc's vector input " << in_name
<< " is optional dtype with None input";
}
kernel_ctx.EmplaceBackInputs(std::move(custom_vec_in));
} else {
auto* x = ctx.Input<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NOT_NULL(
x,
platform::errors::NotFound("Input tensor (%s) is nullptr.", in_name));
PADDLE_ENFORCE_EQ(x->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Input tensor (%s) is not initialized.", in_name));
paddle::Tensor custom_in;
custom_in.set_impl(std::make_shared<phi::DenseTensor>(*x));
} else { // inputs Tensor
if (ctx.HasInput(in_name)) { // general inputs
auto* x = ctx.Input<phi::DenseTensor>(in_name);
PADDLE_ENFORCE_NOT_NULL(x,
platform::errors::NotFound(
"Input tensor (%s) is nullptr.", in_name));
PADDLE_ENFORCE_EQ(
x->IsInitialized(),
true,
platform::errors::InvalidArgument(
"Input tensor (%s) is not initialized.", in_name));
paddle::Tensor custom_in;
custom_in.set_impl(std::make_shared<phi::DenseTensor>(*x));
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
if (custom_in.is_gpu_pinned()) {
VLOG(3) << "Custom Operator: custom input is gpu pinned tensor";
auto gpu_place = phi::GPUPlace(platform::GetCurrentDeviceId());
auto custom_gpu_in = custom_in.copy_to(gpu_place, true);
kernel_ctx.EmplaceBackInput(std::move(custom_gpu_in));
} else {
kernel_ctx.EmplaceBackInput(std::move(custom_in));
}
if (custom_in.is_gpu_pinned()) {
VLOG(3) << "Custom Operator: custom input is gpu pinned tensor";
auto gpu_place = phi::GPUPlace(platform::GetCurrentDeviceId());
auto custom_gpu_in = custom_in.copy_to(gpu_place, true);
kernel_ctx.EmplaceBackInput(std::move(custom_gpu_in));
} else {
kernel_ctx.EmplaceBackInput(std::move(custom_in));
}
#else
kernel_ctx.EmplaceBackInput(std::move(custom_in));
kernel_ctx.EmplaceBackInput(std::move(custom_in));
#endif
} else { // optional inputs
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's KernelFunc cannot "
"find input parameter `%s`",
in_name));
VLOG(3) << "Custom Operator: KernelFunc's input " << in_name
<< " is optional dtype with None input";
kernel_ctx.EmplaceBackInput(std::move(paddle::Tensor()));
}
}
}
......@@ -337,21 +364,41 @@ static void RunInferShapeFunc(framework::InferShapeContext* ctx,
VLOG(3) << "Custom Operator: InferShape - get input ddim.";
for (auto& in_name : inputs) {
if (detail::IsDuplicableVar(in_name)) {
OP_INOUT_CHECK(ctx->HasInputs(in_name), "Input", in_name, "Custom");
auto vec_ddim = ctx->GetInputsDim(in_name);
std::vector<std::vector<int64_t>> vec_shape;
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(),
vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return phi::vectorize(ddim);
});
if (ctx->HasInputs(in_name)) { // general inputs
auto vec_ddim = ctx->GetInputsDim(in_name);
vec_shape.reserve(vec_ddim.size());
std::transform(vec_ddim.begin(),
vec_ddim.end(),
std::back_inserter(vec_shape),
[&](const DDim& ddim) -> std::vector<int64_t> {
return phi::vectorize(ddim);
});
} else { // optional inputs, `vec_shape` is empty
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferShapeFunc "
"cannot find input parameter `%s`",
in_name));
VLOG(3) << "Custom Operator: InferShapeFunc's vector input " << in_name
<< " is optional dtype with None input";
}
vec_input_shapes.emplace_back(vec_shape);
} else {
OP_INOUT_CHECK(ctx->HasInput(in_name), "Input", in_name, "Custom");
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(phi::vectorize(ddim));
if (ctx->HasInput(in_name)) { // general inputs
auto ddim = ctx->GetInputDim(in_name);
input_shapes.emplace_back(phi::vectorize(ddim));
} else { // optional inputs
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferShapeFunc "
"cannot find input parameter `%s`",
in_name));
input_shapes.emplace_back(std::vector<int64_t>());
VLOG(3) << "Custom Operator: InferShapeFunc's input " << in_name
<< " is optional dtype with None input";
}
}
}
......@@ -468,11 +515,13 @@ class CustomOpMaker : public OpProtoAndCheckerMaker {
void Make() override {
for (auto& in_name : inputs_) {
auto input_var_builder =
AddInput(in_name, "The input " + in_name + "of Custom operator.");
if (detail::IsDuplicableVar(in_name)) {
AddInput(in_name, "The input " + in_name + "of Custom operator.")
.AsDuplicable();
} else {
AddInput(in_name, "The input " + in_name + "of Custom operator.");
input_var_builder.AsDuplicable();
}
if (detail::IsOptionalVar(in_name)) {
input_var_builder.AsDispensable();
}
}
for (auto& out_name : outputs_) {
......@@ -893,16 +942,37 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
for (auto& in_name : op_inputs) {
if (detail::IsDuplicableVar(in_name)) {
std::vector<DataType> vec_custom_dtype;
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
if (ctx->HasInput(in_name)) { // general inputs
for (size_t i = 0; i < ctx->InputSize(in_name); ++i) {
auto dtype = ctx->GetInputDataType(in_name, i);
vec_custom_dtype.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
}
} else { // optional inputs, `vec_custom_dtype` is empty
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferDtypeFn "
"cannot find input parameter `%s`",
in_name));
VLOG(3) << "Custom Operator: InferDtypeFn's vector input "
<< in_name << " is optional dtype with None input";
}
vec_input_dtypes.emplace_back(vec_custom_dtype);
} else {
auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
if (ctx->HasInput(in_name)) { // general inputs
auto dtype = ctx->GetInputDataType(in_name);
input_dtypes.emplace_back(
paddle::framework::TransToPhiDataType(dtype));
} else { // optional inputs
PADDLE_ENFORCE(
detail::IsOptionalVar(in_name),
phi::errors::NotFound("Your custom operator's InferDtypeFn "
"cannot find input parameter `%s`",
in_name));
input_dtypes.emplace_back(DataType::UNDEFINED);
VLOG(3) << "Custom Operator: InferDtypeFn's input " << in_name
<< " is optional dtype with None input";
}
}
}
......@@ -1047,7 +1117,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
"If a custom grad operator contains only one input and "
"only one output, the input shape will be directly set "
"to the output shape. Otherwise, Please set the forward "
"input as the grad operator's input or set the "
"input as the grad operator's input or set the "
"InferShapeFn of custom grad operator by "
".SetInferShapeFn(PD_INFER_SHAPE(...))"));
ctx->ShareDim(grad_op_inputs[0], out_name);
......
......@@ -1060,6 +1060,8 @@ PYBIND11_MODULE(libpaddle, m) {
if (PyList_Check(obj) || PyTuple_Check(obj)) {
self.EmplaceBackInputs(
std::move(CastPyArg2VectorOfTensor(obj, 1)));
} else if (obj == Py_None) { // check optional Tensor
self.EmplaceBackInput(std::move(paddle::Tensor()));
} else {
self.EmplaceBackInput(std::move(CastPyArg2Tensor(obj, 1)));
}
......
......@@ -24,6 +24,8 @@ limitations under the License. */
#include "paddle/phi/api/include/dll_decl.h"
#include "paddle/phi/api/include/tensor.h"
#include "paddle/utils/any.h"
#include "paddle/utils/none.h"
#include "paddle/utils/optional.h"
/**
* Op Meta Info Related Define.
......@@ -57,6 +59,7 @@ using Tensor = paddle::Tensor;
constexpr char kGradTensorSuffix[] = "@GRAD";
constexpr char kTensorVectorSuffix[] = "@VECTOR";
constexpr char kDoubleGradNewOutSuffix[] = "@NEW";
constexpr char kOptionalSuffix[] = "@OPTIONAL";
// Used for Construct Grad Tensor name
inline std::string Grad(const std::string& t_name) {
......@@ -85,6 +88,15 @@ inline std::string New(const std::string& t_name) {
return result;
}
// Used for Construct paddle::optional name
inline std::string Optional(const std::string& t_name) {
std::string result;
result.reserve(t_name.size() + 9U);
result += t_name;
result += kOptionalSuffix;
return result;
}
PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst);
////////////////////// Kernel Context ////////////////////////
......@@ -197,6 +209,25 @@ struct KernelFuncImpl<Return (*)(Args...), impl_fn> {
}
};
template <typename... Tail>
struct ComputeCallHelper<const paddle::optional<paddle::Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
static void Compute(CustomOpKernelContext* ctx, PreviousArgs&... pargs) {
auto& range = ctx->InputRangeAt(in_idx);
auto& arg = ctx->InputAt(range.first);
if (!arg.is_initialized()) {
ComputeCallHelper<Tail...>::
template Compute<in_idx + 1, attr_idx, out_idx>(
ctx, pargs..., paddle::none);
} else {
ComputeCallHelper<
Tail...>::template Compute<in_idx + 1, attr_idx, out_idx>(ctx,
pargs...,
arg);
}
}
};
template <typename... Tail>
struct ComputeCallHelper<const std::vector<Tensor>&, Tail...> {
template <int in_idx, int attr_idx, int out_idx, typename... PreviousArgs>
......@@ -430,6 +461,31 @@ struct InferShapeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPES(
const std::vector<std::vector<int64_t>>&);
template <typename... Tail>
struct InferShapeCallHelper<const paddle::optional<std::vector<int64_t>>&,
Tail...> {
template <int in_idx,
int vec_in_idx,
int attr_idx,
typename... PreviousArgs>
static Return InferShape(
const std::vector<std::vector<int64_t>>& input_shapes,
const std::vector<std::vector<std::vector<int64_t>>>& vec_input_shapes,
const std::vector<paddle::any>& attrs,
const PreviousArgs&... pargs) {
const std::vector<int64_t>& arg = input_shapes[in_idx];
if (arg.empty()) {
return InferShapeCallHelper<Tail...>::
template InferShape<in_idx + 1, vec_in_idx, attr_idx>(
input_shapes, vec_input_shapes, attrs, pargs..., paddle::none);
} else {
return InferShapeCallHelper<Tail...>::
template InferShape<in_idx + 1, vec_in_idx, attr_idx>(
input_shapes, vec_input_shapes, attrs, pargs..., arg);
}
}
};
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_InferShapeCallHelper_FOR_SHAPE(std::vector<int64_t>);
......@@ -536,6 +592,27 @@ struct InferDtypeFuncImpl<Return (*)(Args...), impl_fn> {
PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(const DataType&);
PD_SPECIALIZE_InferDtypeCallHelper_FOR_DTYPES(const std::vector<DataType>&);
template <typename... Tail>
struct InferDtypeCallHelper<const paddle::optional<paddle::DataType>&,
Tail...> {
template <int in_idx, int vec_in_idx, typename... PreviousArgs>
static Return InferDtype(
const std::vector<DataType>& input_dtypes,
const std::vector<std::vector<DataType>>& vec_input_dtypes,
const PreviousArgs&... pargs) {
const DataType& arg = input_dtypes[in_idx];
if (arg == DataType::UNDEFINED) {
return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1,
vec_in_idx>(
input_dtypes, vec_input_dtypes, pargs..., paddle::none);
} else {
return InferDtypeCallHelper<Tail...>::template InferDtype<in_idx + 1,
vec_in_idx>(
input_dtypes, vec_input_dtypes, pargs..., arg);
}
}
};
// NOTE(chenweihang): Used to be compatible with the 2.0.1 released
// interface, and will be deprecated in the future
PD_SPECIALIZE_InferDtypeCallHelper_TO_DTYPE(DataType);
......
......@@ -51,6 +51,7 @@ py_test(test_custom_linear SRCS test_custom_linear.py)
py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py)
py_test(test_custom_tanh_double_grad SRCS test_custom_tanh_double_grad.py)
py_test(test_custom_inplace SRCS test_custom_inplace.py)
py_test(test_custom_optional SRCS test_custom_optional.py)
# other tests
py_test(test_sysconfig SRCS test_sysconfig.py)
......
// Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WIdata_tHOUdata_t WARRANdata_tIES OR CONDIdata_tIONS OF ANY KIND, either
// express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <iostream>
#include <vector>
#include "paddle/extension.h"
template <typename data_t>
void add_forward_kernel(const data_t* x_data,
const data_t* y_data,
data_t* out_data,
int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
out_data[i] = x_data[i] + y_data[i];
}
}
template <typename data_t>
void add_backward_kernel(data_t* x_grad_data,
const data_t* out_grad_data,
int64_t numel) {
for (size_t i = 0; i < numel; ++i) {
x_grad_data[i] += out_grad_data[i];
}
}
/*
if (y) {
out = x + y;
} else {
out = x + x;
}
*/
std::vector<paddle::Tensor> AddForward(
const paddle::Tensor& x,
const paddle::optional<paddle::Tensor>& y) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor out = paddle::empty(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
x.type(), "AddForward", ([&] {
if (y) {
add_forward_kernel<data_t>(x.data<data_t>(),
y->data<data_t>(),
out.data<data_t>(),
x.size());
} else {
add_forward_kernel<data_t>(
x.data<data_t>(), x.data<data_t>(), out.data<data_t>(), x.size());
}
}));
return {out};
}
std::vector<paddle::DataType> AddInferDtype(
const paddle::DataType& x_dtype,
const paddle::optional<paddle::DataType>& y_dtype) {
if (y_dtype) {
std::cout << "DEBUG AddInferDtype" << *y_dtype << std::endl;
return {*y_dtype};
}
return {x_dtype};
}
std::vector<std::vector<int64_t>> AddInferShape(
const std::vector<int64_t>& x_shape,
const paddle::optional<std::vector<int64_t>>& y_shape) {
if (y_shape) {
return {*y_shape};
}
return {x_shape};
}
/*
if (y) {
x_grad = out_grad;
} else {
x_grad = out_grad + out_grad;
}
*/
std::vector<paddle::Tensor> AddBackward(
const paddle::Tensor& x,
const paddle::optional<paddle::Tensor>& y,
const paddle::Tensor& out_grad) { // NOLINT
PD_CHECK(x.place() == paddle::PlaceType::kCPU, "x must be a CPU Tensor.");
paddle::Tensor x_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
paddle::Tensor y_grad = paddle::zeros(x.shape(), x.dtype(), x.place());
PD_DISPATCH_FLOATING_TYPES(
out_grad.type(), "AddBackward", ([&] {
add_backward_kernel<data_t>(
x_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
if (y) {
add_backward_kernel<data_t>(
y_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
} else {
add_backward_kernel<data_t>(
x_grad.data<data_t>(), out_grad.data<data_t>(), out_grad.size());
}
}));
return {x_grad};
}
PD_BUILD_OP(custom_add)
.Inputs({"X", paddle::Optional("Y")})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(AddForward))
.SetInferShapeFn(PD_INFER_SHAPE(AddInferShape))
.SetInferDtypeFn(PD_INFER_DTYPE(AddInferDtype));
PD_BUILD_GRAD_OP(custom_add)
.Inputs({"X", paddle::Optional("Y"), paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(AddBackward));
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
from utils import extra_cc_args, extra_nvcc_args, paddle_includes
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import get_build_directory, load
from paddle.utils.cpp_extension.extension_utils import run_cmd
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_optional\\custom_optional.pyd'.format(get_build_directory())
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
# Compile and load custom op Just-In-Time.
custom_optional = load(
name='custom_optional',
sources=['custom_optional.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cflags
extra_cuda_cflags=extra_nvcc_args, # test for cflags
verbose=True,
)
def optional_dynamic_add(phi_func, device, dtype, np_x, np_y):
paddle.set_device(device)
x = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
if np_y is not None:
y = paddle.to_tensor(np_y, dtype=dtype, stop_gradient=False)
else:
y = x
if phi_func:
out = custom_optional.custom_add(x, y if np_y is not None else None)
else:
out = paddle.add(x, y)
out.backward()
return x.numpy(), out.numpy(), x.grad.numpy()
def optional_static_add(phi_func, device, dtype, np_x, np_y):
paddle.enable_static()
paddle.set_device(device)
with static.scope_guard(static.Scope()):
with static.program_guard(static.Program()):
x = static.data(name="x", shape=[None, np_x.shape[1]], dtype=dtype)
x.stop_gradient = False
if np_y is not None:
y = static.data(
name="y", shape=[None, np_x.shape[1]], dtype=dtype
)
y.stop_gradient = False
feed_dict = {
"x": np_x.astype(dtype),
"y": np_y.astype(dtype),
}
else:
y = x
feed_dict = {
"x": np_x.astype(dtype),
}
if phi_func:
out = custom_optional.custom_add(
x, y if np_y is not None else None
)
else:
out = paddle.add(x, y)
mean_out = paddle.mean(out)
static.append_backward(mean_out)
exe = static.Executor()
exe.run(static.default_startup_program())
x_v, out_v, x_grad_v = exe.run(
static.default_main_program(),
feed=feed_dict,
fetch_list=[
x.name,
out.name,
x.name + "@GRAD",
],
)
paddle.disable_static()
return x_v, out_v, x_grad_v
class TestCustomOptionalJit(unittest.TestCase):
def setUp(self):
self.dtypes = ['float32', 'float64']
self.devices = ['cpu']
self.np_x = np.random.random((3, 2)).astype("float32")
self.np_y = np.random.random((3, 2)).astype("float32")
def check_output(self, out, pd_out, name):
np.testing.assert_array_equal(
out,
pd_out,
err_msg='custom op {}: {},\n paddle api {}: {}'.format(
name, out, name, pd_out
),
)
def check_output_allclose(self, out, pd_out, name):
np.testing.assert_allclose(
out,
pd_out,
rtol=5e-5,
atol=1e-2,
err_msg='custom op {}: {},\n paddle api {}: {}'.format(
name, out, name, pd_out
),
)
def test_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_static_add(
False,
device,
dtype,
self.np_x,
self.np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_static_add(
True,
device,
dtype,
self.np_x,
self.np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
def test_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_dynamic_add(
False,
device,
dtype,
self.np_x,
self.np_y,
)
(phi_x, phi_out, phi_x_grad,) = optional_dynamic_add(
True,
device,
dtype,
self.np_x,
self.np_y,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
def test_optional_static_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_static_add(
False,
device,
dtype,
self.np_x,
None,
)
(phi_x, phi_out, phi_x_grad,) = optional_static_add(
True,
device,
dtype,
self.np_x,
None,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
def test_optional_dynamic_add(self):
for device in self.devices:
for dtype in self.dtypes:
(pd_x, pd_out, pd_x_grad,) = optional_dynamic_add(
False,
device,
dtype,
self.np_x,
None,
)
(phi_x, phi_out, phi_x_grad,) = optional_dynamic_add(
True,
device,
dtype,
self.np_x,
None,
)
self.check_output(phi_x, pd_x, "x")
self.check_output(phi_out, pd_out, "out")
self.check_output(phi_x_grad, pd_x_grad, "x_grad")
if __name__ == "__main__":
unittest.main()
......@@ -1057,7 +1057,6 @@ def _custom_api_content(op_name):
def {op_name}({inputs}):
# prepare inputs and outputs
ins = {ins}
attrs = {attrs}
outs = {{}}
out_names = {out_names}
......@@ -1075,6 +1074,11 @@ def _custom_api_content(op_name):
ctx.add_outputs(outs[out_name])
core.eager._run_custom_op(ctx, "{op_name}", True)
else:
ins = {{}}
for key, value in dict({ins}).items():
# handle optional inputs
if value is not None:
ins[key] = value
helper = LayerHelper("{op_name}", **locals())
for out_name in out_names:
outs[out_name] = helper.create_variable(dtype='float32')
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册