未验证 提交 cfdde0ec 编写于 作者: J Jiabin Yang 提交者: GitHub

【Deepmd Support】add IsInitialized and tanh double grad (#32188)

* add IsInitialized

* rm additional log and add tanh double grad

* rename is_initialized
上级 f946ba61
......@@ -113,6 +113,9 @@ class PD_DLL_DECL Tensor {
/// \brief Cast datatype from one to another
Tensor cast(const DataType& target_type) const;
/// \brief Check Tensor is initialized
bool is_initialized() const;
#ifdef PADDLE_WITH_CUDA
/// \bref Get current stream of Tensor
cudaStream_t stream() const;
......
......@@ -103,15 +103,6 @@ void GpuCopy(T *src, T *dst, PlaceType src_plc, PlaceType dst_plc,
void Tensor::reshape(const std::vector<int64_t> &shape) {
GET_CASTED_TENSOR
auto new_dim = framework::make_ddim(shape);
if (tensor->numel() != framework::product(new_dim)) {
LOG(WARNING) << "Custom Op: Calling reshape to a new shape which is bigger "
"or smaller"
<< "than original shape will not change your tensor's memory "
"Please call"
<< "paddle::Tensor::mutable_data<T>() after to reallocate "
"your tensor's size."
<< std::endl;
}
tensor->Resize(new_dim);
}
......@@ -393,6 +384,15 @@ int64_t Tensor::size() const {
return tensor->numel();
}
bool Tensor::is_initialized() const {
GET_CASTED_TENSOR;
if (tensor->IsInitialized()) {
return true;
} else {
return false;
}
}
#ifdef PADDLE_WITH_CUDA
cudaStream_t Tensor::stream() const {
if (!stream_.IsStreamSet()) {
......
......@@ -220,6 +220,21 @@ void GroupTestDtypeConvert() {
paddle::DataType::FLOAT16);
}
void TestInitilized() {
paddle::Tensor test_tensor(paddle::PlaceType::kCPU);
CHECK(test_tensor.is_initialized() == false);
test_tensor.reshape({1, 1});
test_tensor.mutable_data<float>();
CHECK(test_tensor.is_initialized() == true);
float* tensor_data = test_tensor.data<float>();
for (int i = 0; i < test_tensor.size(); i++) {
tensor_data[i] = 0.5;
}
for (int i = 0; i < test_tensor.size(); i++) {
CHECK(tensor_data[i] == 0.5);
}
}
TEST(CustomTensor, copyTest) {
VLOG(2) << "TestCopy";
GroupTestCopy();
......@@ -233,4 +248,6 @@ TEST(CustomTensor, copyTest) {
GroupTestCast();
VLOG(2) << "TestDtypeConvert";
GroupTestDtypeConvert();
VLOG(2) << "TestInitilized";
TestInitilized();
}
......@@ -782,6 +782,26 @@ class ActivationOpDoubleGrad2 : public framework::OperatorWithKernel {
}
};
template <typename T>
class TanhDoubleGradMaker : public ::paddle::framework::SingleGradOpMaker<T> {
public:
using ::paddle::framework::SingleGradOpMaker<T>::SingleGradOpMaker;
protected:
void Apply(GradOpPtr<T> op) const override {
op->SetType("tanh_grad_grad");
// input1: Out
op->SetInput("Out", this->Input("Out"));
// input2: ddx
op->SetInput("DDX", this->OutputGrad(framework::GradVarName("X")));
op->SetInput("DOut", this->Input(framework::GradVarName("Out")));
op->SetAttrMap(this->Attrs());
// output: ddy
op->SetOutput("DOutNew", this->InputGrad("Out"));
op->SetOutput("DDOut", this->InputGrad(framework::GradVarName("Out")));
}
};
// ReluGrad: dx = dy if y >= 0 else 0
// ReluGradGrad: ddy = ddx if y >= 0 else 0
template <typename T>
......@@ -1041,6 +1061,34 @@ namespace plat = paddle::platform;
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
/* ========================== tanh register ============================= */
REGISTER_OPERATOR(
tanh, ops::ActivationOp, ops::TanhOpMaker, ops::ActivationOpInferVarType,
ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
paddle::framework::OpDesc>,
ops::ActivationGradOpMaker<ops::TanhGradFunctor<float>::FwdDeps(),
paddle::imperative::OpBase>,
std::conditional<ops::CanInplaceAct<ops::TanhGradFunctor<float>>(),
ops::ActFwdInplaceInferer, void>::type);
REGISTER_OPERATOR(tanh_grad, ops::ActivationOpGrad,
ops::ActivationGradOpInplaceInferer,
ops::TanhDoubleGradMaker<paddle::framework::OpDesc>,
ops::TanhDoubleGradMaker<paddle::imperative::OpBase>)
REGISTER_OPERATOR(
tanh_grad_grad,
ops::ActivationOpDoubleGrad<ops::TanhGradFunctor<float>::FwdDeps()>,
ops::ActivationDoubleGradOpInplaceInferer);
REGISTER_ACTIVATION_CPU_KERNEL(tanh, Tanh, TanhFunctor, TanhGradFunctor);
REGISTER_OP_CPU_KERNEL(
tanh_grad_grad, ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
ops::TanhGradGradFunctor<float>>,
ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
ops::TanhGradGradFunctor<double>>,
ops::TanhDoubleGradKernel<plat::CPUDeviceContext,
ops::TanhGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* ========================== relu register ============================= */
REGISTER_OPERATOR(
relu, ops::ActivationOp, ops::ReluOpMaker, ops::ActivationOpInferVarType,
......
......@@ -468,6 +468,19 @@ REGISTER_OP_CUDA_KERNEL(
ops::ReluGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== tanh register ============================ */
REGISTER_ACTIVATION_CUDA_KERNEL(tanh, Tanh, TanhFunctor, TanhGradFunctor);
REGISTER_OP_CUDA_KERNEL(
tanh_grad_grad,
ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::TanhGradGradFunctor<float>>,
ops::TanhDoubleGradKernel<paddle::platform::CUDADeviceContext,
ops::TanhGradGradFunctor<double>>,
ops::TanhDoubleGradKernel<plat::CUDADeviceContext,
ops::TanhGradGradFunctor<plat::float16>>);
/* ========================================================================== */
/* =========================== sqrt register ============================= */
REGISTER_ACTIVATION_CUDA_KERNEL(sqrt, Sqrt, SqrtFunctor, SqrtGradFunctor);
......
......@@ -366,6 +366,36 @@ struct TanhGradFunctor : public BaseActivationFunctor<T> {
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
struct TanhGradGradFunctor : public BaseActivationFunctor<T> {
template <typename Device>
void operator()(const Device& dev, const framework::Tensor* Out,
const framework::Tensor* ddX, const framework::Tensor* dOut,
framework::Tensor* dOutNew, framework::Tensor* ddOut) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddX, "Input", "DDX", "TanhGradGrad"));
auto out = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(Out, "Input", "Out", "TanhGradGrad"));
// tanh grad grad : ddout = (1 - out^2) * ddx, dout = - (dout_old * 2 * out
// * ddx)
if (dOutNew) {
auto dout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOut, "Input", "DOut", "TanhGradGrad"));
auto dout_new = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(dOutNew, "Output", "DOutNew", "SquareGradGrad"));
dout_new.device(*d) =
static_cast<T>(-1) * dout * static_cast<T>(2) * out * ddx;
}
if (ddOut) {
auto ddout = framework::EigenVector<T>::Flatten(
GET_DATA_SAFELY(ddOut, "Output", "DDOut", "SquareGradGrad"));
ddout.device(*d) = (static_cast<T>(1) - out * out) * ddx;
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
// tanhshrink(x) = x - tanh(x)
// where tanh(x) = (exp(x) - exp(-x)) / (exp(x) + exp(-x))
template <typename T>
......@@ -1734,6 +1764,58 @@ inline void ExtractDoubleGradTensorWithInputDOut(
}
}
template <typename DeviceContext, typename Functor>
class TanhDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& ctx) const override {
const framework::Tensor *Out, *ddX, *dOut;
framework::Tensor *dOutNew, *ddOut;
Out = ddX = dOut = nullptr;
dOutNew = ddOut = nullptr;
// extract ddx(input) and out(input)
auto ddx_var = ctx.InputVar("DDX");
auto out_var = ctx.InputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
ddx_var, platform::errors::NotFound(
"Cannot get input Variable ddx, variable name = %s",
ctx.InputName("DDX")));
PADDLE_ENFORCE_NOT_NULL(
out_var, platform::errors::NotFound(
"Cannot get input Variable out, variable name = %s",
ctx.InputName("Out")));
ddX = ctx.Input<framework::Tensor>("DDX");
Out = ctx.Input<framework::Tensor>("Out");
// set output ddout
auto ddout_var = ctx.OutputVar("DDOut");
if (ddout_var) {
ddOut = ctx.Output<framework::Tensor>("DDOut");
}
// extract dOut(intput)
auto dout_var = ctx.InputVar("DOut");
PADDLE_ENFORCE_NOT_NULL(
dout_var, platform::errors::NotFound(
"Cannot get input Variable dout_var, variable name = %s",
ctx.InputName("DOut")));
dOut = ctx.Input<framework::Tensor>("DOut");
// set output dout_new
auto dout_new_var = ctx.OutputVar("DOutNew");
if (dout_new_var) {
dOutNew = ctx.Output<framework::Tensor>("DOutNew");
}
if (dOutNew) dOutNew->mutable_data<T>(Out->dims(), ctx.GetPlace());
if (ddOut) ddOut->mutable_data<T>(Out->dims(), ctx.GetPlace());
auto& place = ctx.template device_context<DeviceContext>();
Functor functor;
functor(place, Out, ddX, dOut, dOutNew, ddOut);
}
};
template <typename DeviceContext, typename Functor>
class SquareDoubleGradKernel
: public framework::OpKernel<typename Functor::ELEMENT_TYPE> {
......@@ -2048,7 +2130,6 @@ struct LogGradGradFunctor : public BaseActivationFunctor<T> {
#define FOR_EACH_ACTIVATION_OP(__macro) \
__macro(sigmoid, Sigmoid, SigmoidFunctor, SigmoidGradFunctor); \
__macro(logsigmoid, LogSigmoid, LogSigmoidFunctor, LogSigmoidGradFunctor); \
__macro(tanh, Tanh, TanhFunctor, TanhGradFunctor); \
__macro(atan, Atan, AtanFunctor, AtanGradFunctor); \
__macro(softshrink, SoftShrink, SoftShrinkFunctor, SoftShrinkGradFunctor); \
__macro(ceil, Ceil, CeilFunctor, ZeroGradFunctor); \
......
......@@ -18,6 +18,7 @@ import unittest
import numpy as np
import paddle.fluid as fluid
import paddle
import paddle.fluid.layers as layers
import paddle.fluid.core as core
import gradient_checker
......@@ -25,6 +26,28 @@ import gradient_checker
from decorator_helper import prog_scope
class TestTanhDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
shape = [2, 3, 7, 9]
eps = 0.0005
dtype = np.float64
x = layers.data('x', shape, False, dtype=dtype)
x.persistable = True
y = paddle.tanh(x)
x_arr = np.random.uniform(-1, 1, shape).astype(dtype)
x_arr[np.abs(x_arr) < 0.005] = 0.002
gradient_checker.double_grad_check(
[x], y, x_init=x_arr, place=place, eps=eps)
def test_grad(self):
places = [fluid.CPUPlace()]
if core.is_compiled_with_cuda():
places.append(fluid.CUDAPlace(0))
for p in places:
self.func(p)
class TestReluDoubleGradCheck(unittest.TestCase):
@prog_scope()
def func(self, place):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册