未验证 提交 76d5483a 编写于 作者: C Chen Weihang 提交者: GitHub

[CustomOp]Add new method for custom double grad (#41538) (#41781)

* add new method for custom double grad

* add tanh double grad unittest

* change year

* revert tensor init method
上级 5450e42c
......@@ -67,9 +67,17 @@ inline static bool IsDuplicableVar(const std::string& var_name) {
return var_name.rfind(suffix) != std::string::npos;
}
inline static std::string NoGrad(const std::string& var_name) {
inline static std::string NoGrad(const std::string& var_name,
bool is_double_grad = false) {
std::string suffix = kGradVarSuffix;
return var_name.substr(0, var_name.size() - kGradVarSuffixSize);
std::string new_out_suffix = kDoubleGradNewOutSuffix;
std::string tmp_var_name(var_name);
if (is_double_grad &&
(tmp_var_name.rfind(new_out_suffix) != std::string::npos)) {
tmp_var_name = tmp_var_name.substr(
0, tmp_var_name.size() - /*kDoubleGradNewOutSuffix length*/ 4);
}
return tmp_var_name.substr(0, tmp_var_name.size() - kGradVarSuffixSize);
}
inline static bool IsGradVar(const std::string& var_name, bool is_double_grad) {
......@@ -533,11 +541,12 @@ class CustomGradOpMaker<OpDesc> : public SingleGradOpMaker<OpDesc> {
for (auto& out_name : outputs_) {
VLOG(3) << "Custom Operator: GradOpDescMaker - output: " << out_name;
if (detail::IsDuplicableVar(out_name)) {
grad_op->SetOutput(out_name,
this->InputGrad(detail::NoGrad(out_name),
/*drop_empty_grad=*/false));
grad_op->SetOutput(
out_name, this->InputGrad(detail::NoGrad(out_name, is_double_grad_),
/*drop_empty_grad=*/false));
} else {
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(
out_name, is_double_grad_)));
}
}
grad_op->SetAttrMap(this->Attrs());
......@@ -600,7 +609,8 @@ class CustomGradOpMaker<imperative::OpBase>
}
for (auto& out_name : outputs_) {
VLOG(3) << "Custom Operator: GradOpBaseMaker - output: " << out_name;
grad_op->SetOutput(out_name, this->InputGrad(detail::NoGrad(out_name)));
grad_op->SetOutput(
out_name, this->InputGrad(detail::NoGrad(out_name, is_double_grad_)));
}
grad_op->SetAttrMap(this->Attrs());
}
......@@ -885,8 +895,8 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
// Grad InferShape
if (grad_infer_shape_fn == nullptr) {
grad_info.infer_shape_ = [grad_op_inputs,
grad_op_outputs](InferShapeContext* ctx) {
grad_info.infer_shape_ = [grad_op_inputs, grad_op_outputs,
is_double_grad](InferShapeContext* ctx) {
// 1. if forward input exists, gradient's shape is same with forward
// input
// default
......@@ -897,7 +907,7 @@ void RegisterOperatorWithMetaInfo(const std::vector<OpMetaInfo>& op_meta_infos,
// [Suitable for the situation that forward input is not used as
// backward input]
for (auto& out_name : grad_op_outputs) {
auto fwd_name = detail::NoGrad(out_name);
auto fwd_name = detail::NoGrad(out_name, is_double_grad);
if (detail::IsDuplicableVar(fwd_name)) {
// Duplicable forward var must as backward input
ctx->ShareDim(fwd_name, out_name);
......
......@@ -58,6 +58,7 @@ using Tensor = paddle::experimental::Tensor;
constexpr char kGradTensorSuffix[] = "@GRAD";
constexpr char kTensorVectorSuffix[] = "@VECTOR";
constexpr char kDoubleGradNewOutSuffix[] = "@NEW";
// Used for Construct Grad Tensor name
inline std::string Grad(const std::string& t_name) {
......@@ -77,6 +78,15 @@ inline std::string Vec(const std::string& t_name) {
return result;
}
// Used for Construct double grad output name
inline std::string New(const std::string& t_name) {
std::string result;
result.reserve(t_name.size() + 4U);
result += t_name;
result += kDoubleGradNewOutSuffix;
return result;
}
PADDLE_API void AssignTensorImpl(const Tensor& src, Tensor* dst);
////////////////////// Kernel Context ////////////////////////
......
......@@ -23,6 +23,7 @@ py_test(test_custom_concat SRCS test_custom_concat.py)
py_test(test_custom_conj SRCS test_custom_conj.py)
py_test(test_custom_linear SRCS test_custom_linear.py)
py_test(test_custom_simple_slice SRCS test_custom_simple_slice.py)
py_test(test_custom_tanh_double_grad SRCS test_custom_tanh_double_grad.py)
# other tests
py_test(test_sysconfig SRCS test_sysconfig.py)
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include <cmath>
#include <iostream>
#include <vector>
#include "paddle/extension.h"
#define CHECK_CPU_INPUT(x) \
PD_CHECK(x.place() == paddle::PlaceType::kCPU, #x " must be a CPU Tensor.")
template <typename data_t>
void tanh_cpu_forward_kernel(const data_t* x_data,
data_t* out_data,
int64_t x_numel) {
PD_CHECK(x_data != nullptr, "x_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
for (int64_t i = 0; i < x_numel; ++i) {
out_data[i] = std::tanh(x_data[i]);
}
}
template <typename data_t>
void tanh_cpu_backward_kernel(const data_t* grad_out_data,
const data_t* out_data,
data_t* grad_x_data,
int64_t out_numel) {
PD_CHECK(grad_out_data != nullptr, "grad_out_data is nullptr.");
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
PD_CHECK(grad_x_data != nullptr, "grad_x_data is nullptr.");
for (int64_t i = 0; i < out_numel; ++i) {
grad_x_data[i] =
grad_out_data[i] * (static_cast<data_t>(1) - out_data[i] * out_data[i]);
}
}
template <typename data_t>
void tanh_cpu_double_backward_kernel(const data_t* out_data,
const data_t* ddx_data,
const data_t* dout_data,
data_t* dout_new_data,
data_t* ddout_data,
int64_t ddout_numel) {
PD_CHECK(out_data != nullptr, "out_data is nullptr.");
PD_CHECK(ddx_data != nullptr, "ddx_data is nullptr.");
PD_CHECK(dout_data != nullptr, "dout_data is nullptr.");
PD_CHECK(dout_new_data != nullptr, "dout_new_data is nullptr.");
PD_CHECK(ddout_data != nullptr, "ddout_data is nullptr.");
for (int64_t i = 0; i < ddout_numel; ++i) {
dout_new_data[i] = static_cast<data_t>(-1) * dout_data[i] *
static_cast<data_t>(2) * out_data[i] * ddx_data[i];
ddout_data[i] =
ddx_data[i] * (static_cast<data_t>(1) - out_data[i] * out_data[i]);
}
}
std::vector<paddle::Tensor> TanhForward(const paddle::Tensor& x) {
CHECK_CPU_INPUT(x);
auto out = paddle::Tensor(paddle::PlaceType::kCPU, x.shape());
PD_DISPATCH_FLOATING_TYPES(
x.dtype(), "tanh_cpu_forward", ([&] {
tanh_cpu_forward_kernel<data_t>(
x.data<data_t>(), out.mutable_data<data_t>(x.place()), x.size());
}));
return {out};
}
std::vector<paddle::Tensor> TanhBackward(const paddle::Tensor& out,
const paddle::Tensor& grad_out) {
CHECK_CPU_INPUT(out);
auto grad_x = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());
PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_backward", ([&] {
tanh_cpu_backward_kernel<data_t>(
grad_out.data<data_t>(),
out.data<data_t>(),
grad_x.mutable_data<data_t>(out.place()),
out.size());
}));
return {grad_x};
}
std::vector<paddle::Tensor> TanhDoubleBackward(const paddle::Tensor& out,
const paddle::Tensor& ddx,
const paddle::Tensor& dout) {
CHECK_CPU_INPUT(out);
CHECK_CPU_INPUT(ddx);
CHECK_CPU_INPUT(dout);
auto dout_new = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());
auto ddout = paddle::Tensor(paddle::PlaceType::kCPU, out.shape());
PD_DISPATCH_FLOATING_TYPES(out.dtype(), "tanh_cpu_double_backward", ([&] {
tanh_cpu_double_backward_kernel<data_t>(
out.data<data_t>(),
ddx.data<data_t>(),
dout.data<data_t>(),
dout_new.mutable_data<data_t>(out.place()),
ddout.mutable_data<data_t>(out.place()),
ddout.size());
}));
return {dout_new, ddout};
}
std::vector<std::vector<int64_t>> TanhBackwardInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& dout_shape) {
return {out_shape};
}
std::vector<std::vector<int64_t>> TanhDoubleBackwardInferShape(
const std::vector<int64_t>& out_shape,
const std::vector<int64_t>& ddx_shape,
const std::vector<int64_t>& dout_shape) {
return {dout_shape, dout_shape};
}
PD_BUILD_OP(custom_tanh)
.Inputs({"X"})
.Outputs({"Out"})
.SetKernelFn(PD_KERNEL(TanhForward));
PD_BUILD_GRAD_OP(custom_tanh)
.Inputs({"Out", paddle::Grad("Out")})
.Outputs({paddle::Grad("X")})
.SetKernelFn(PD_KERNEL(TanhBackward))
.SetInferShapeFn(PD_INFER_SHAPE(TanhBackwardInferShape));
PD_BUILD_DOUBLE_GRAD_OP(custom_tanh)
.Inputs({"Out", paddle::Grad(paddle::Grad("X")), paddle::Grad("Out")})
.Outputs({paddle::New(paddle::Grad("Out")),
paddle::Grad(paddle::Grad("Out"))})
.SetKernelFn(PD_KERNEL(TanhDoubleBackward))
.SetInferShapeFn(PD_INFER_SHAPE(TanhDoubleBackwardInferShape));
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import os
import unittest
import numpy as np
import paddle
import paddle.static as static
from paddle.utils.cpp_extension import load, get_build_directory
from paddle.utils.cpp_extension.extension_utils import run_cmd
from utils import paddle_includes, extra_cc_args, extra_nvcc_args
from paddle.fluid.framework import _test_eager_guard
# Because Windows don't use docker, the shared lib already exists in the
# cache dir, it will not be compiled again unless the shared lib is removed.
file = '{}\\custom_tanh\\custom_tanh.pyd'.format(get_build_directory())
if os.name == 'nt' and os.path.isfile(file):
cmd = 'del {}'.format(file)
run_cmd(cmd, True)
custom_ops = load(
name='custom_tanh_jit',
sources=['custom_tanh_op.cc'],
extra_include_paths=paddle_includes, # add for Coverage CI
extra_cxx_cflags=extra_cc_args, # test for cc flags
extra_cuda_cflags=extra_nvcc_args, # test for nvcc flags
verbose=True)
def custom_tanh_double_grad_dynamic(func, device, dtype, np_x):
paddle.set_device(device)
t = paddle.to_tensor(np_x, dtype=dtype, stop_gradient=False)
out = func(t)
out.stop_gradient = False
dx = paddle.grad(
outputs=[out], inputs=[t], create_graph=True, retain_graph=True)
dx[0].backward()
assert out.grad is not None
assert dx[0].grad is not None
return dx[0].numpy(), dx[0].grad.numpy(), out.grad.numpy()
class TestCustomTanhDoubleGradJit(unittest.TestCase):
def setUp(self):
paddle.set_device('cpu')
self.dtypes = ['float32', 'float64']
self.devices = ['cpu']
def test_func_double_grad_dynamic(self):
for device in self.devices:
for dtype in self.dtypes:
x = np.random.uniform(-1, 1, [4, 8]).astype(dtype)
out, dx_grad, dout = custom_tanh_double_grad_dynamic(
custom_ops.custom_tanh, device, dtype, x)
pd_out, pd_dx_grad, pd_dout = custom_tanh_double_grad_dynamic(
paddle.tanh, device, dtype, x)
self.assertTrue(
np.allclose(out, pd_out),
"custom op out: {},\n paddle api out: {}".format(out,
pd_out))
self.assertTrue(
np.allclose(dx_grad, pd_dx_grad),
"custom op dx grad: {},\n paddle api dx grad: {}".format(
dx_grad, pd_dx_grad))
self.assertTrue(
np.allclose(dout, pd_dout),
"custom op out grad: {},\n paddle api out grad: {}".format(
dout, pd_dout))
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册