未验证 提交 0daa5c97 编写于 作者: Z Zeng Jinle 提交者: GitHub

Make leaky relu inplacable (#19676)

* make leaky relu inplacable, test=develop

* force add unittests to pass coverage, test=develop
上级 c308c88d
......@@ -250,6 +250,8 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
VLOG(10) << "Start to apply buffer_shared_inplace_pass";
graph = inplace_pass->Apply(graph);
VLOG(10) << "buffer_shared_inplace_pass Applied";
LOG(INFO) << "Inplace strategy is enabled, when "
"build_strategy.enable_inplace = True";
}
/**
......@@ -278,6 +280,9 @@ ir::Graph *ParallelExecutorPrivate::ApplyMemoryOptimizePass(ir::Graph *graph) {
VLOG(10) << "Start to apply buffer_shared_cross_op_memory_reuse_pass";
graph = cross_op_memory_reuse_pass->Apply(graph);
VLOG(10) << "buffer_shared_cross_op_memory_reuse_pass Applied";
LOG(INFO) << "Cross op memory reuse strategy is enabled, when "
"build_strategy.memory_optimize = True or garbage collection "
"strategy is disabled, which is not recommended";
}
if (!is_gc_enabled) {
......
......@@ -116,6 +116,11 @@ cc_test(strided_memcpy_test SRCS strided_memcpy_test.cc DEPS tensor memory)
cc_test(save_load_op_test SRCS save_load_op_test.cc DEPS save_op load_op)
cc_test(save_load_combine_op_test SRCS save_load_combine_op_test.cc DEPS save_combine_op load_combine_op)
nv_test(dropout_op_test SRCS dropout_op_test.cc DEPS dropout_op tensor)
if (WITH_GPU)
nv_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc test_leaky_relu_grad_grad_functor.cu DEPS tensor device_context eigen3)
else()
cc_test(test_leaky_relu_grad_grad_functor SRCS test_leaky_relu_grad_grad_functor.cc DEPS tensor device_context eigen3)
endif()
if (WITH_PYTHON)
cc_library(py_func_op SRCS py_func_op.cc DEPS op_registry python pybind)
......
......@@ -716,8 +716,8 @@ class LeakyReluDoubleGradMaker
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override {
auto* op = new ::paddle::framework::OpDesc();
op->SetType("leaky_relu_grad_grad");
// input1: X
op->SetInput("X", Input("X"));
// input1: Out
op->SetInput("Out", Input("Out"));
// X@GRAD@GRAD: ddx
op->SetInput("DDX", OutputGrad(framework::GradVarName("X")));
op->SetAttrMap(Attrs());
......
......@@ -463,8 +463,8 @@ struct HardShrinkFunctor : public BaseActivationFunctor<T> {
}
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
out.device(d) = x * (temp1 + temp2);
}
};
......@@ -480,8 +480,8 @@ struct HardShrinkGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>().eval();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>().eval();
auto temp1 = (x < static_cast<T>(threshold * -1)).template cast<T>();
auto temp2 = (x > static_cast<T>(threshold)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
......@@ -500,8 +500,8 @@ struct SoftShrinkFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out>
void operator()(Device d, X x, Out out) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
auto temp1 = (x > lambdaT).template cast<T>();
auto temp2 = (x < -lambdaT).template cast<T>();
out.device(d) = temp1 * (x - lambdaT) + temp2 * (x + lambdaT);
}
};
......@@ -516,8 +516,8 @@ struct SoftShrinkGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto lambdaT = static_cast<T>(lambda);
auto temp1 = (x > lambdaT).template cast<T>().eval();
auto temp2 = (x < -lambdaT).template cast<T>().eval();
auto temp1 = (x > lambdaT).template cast<T>();
auto temp2 = (x < -lambdaT).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
......@@ -1043,7 +1043,7 @@ struct SoftReluGradFunctor : public BaseActivationFunctor<T> {
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto tmp = static_cast<T>(threshold);
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>().eval();
auto temp = ((out > -tmp) * (out < tmp)).template cast<T>();
dx.device(d) = dout * (static_cast<T>(1) - (-out).exp()) * temp;
}
......@@ -1072,13 +1072,13 @@ struct LeakyReluGradFunctor : public BaseActivationFunctor<T> {
template <typename Device, typename X, typename Out, typename dOut,
typename dX>
void operator()(Device d, X x, Out out, dOut dout, dX dx) const {
auto temp1 = static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval();
auto temp2 = (x >= static_cast<T>(0)).template cast<T>().eval();
auto temp1 =
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>();
auto temp2 = (out >= static_cast<T>(0)).template cast<T>();
dx.device(d) = dout * (temp1 + temp2).template cast<T>();
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -1413,19 +1413,19 @@ struct LeakyReluGradGradFunctor : public BaseActivationFunctor<T> {
const framework::Tensor* Out, const framework::Tensor* ddX,
framework::Tensor* ddOut, framework::Tensor* dOut,
framework::Tensor* dX) const {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto x = framework::EigenVector<T>::Flatten(detail::Ref(X));
if (ddOut) {
auto* d = dev.eigen_device();
auto ddx = framework::EigenVector<T>::Flatten(detail::Ref(ddX));
auto out = framework::EigenVector<T>::Flatten(detail::Ref(Out));
auto ddout = framework::EigenVector<T>::Flatten(detail::Ref(ddOut));
ddout.device(*d) = ddx *
((x >= static_cast<T>(0)).template cast<T>().eval() +
static_cast<T>(alpha) *
(x < static_cast<T>(0)).template cast<T>().eval())
.template cast<T>();
ddout.device(*d) =
ddx *
((out >= static_cast<T>(0)).template cast<T>() +
static_cast<T>(alpha) * (out < static_cast<T>(0)).template cast<T>())
.template cast<T>();
}
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepX; }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h"
namespace paddle {
namespace operators {
TEST(leaky_relu_grad_grad, test_cpu) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CPUPlace(), 0.02));
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/fluid/operators/test_leaky_relu_grad_grad_functor.h"
namespace paddle {
namespace operators {
TEST(leaky_relu_grad_grad, test_gpu) {
ASSERT_TRUE(
TestLeakyReluGradGradMain<float>({32, 64}, platform::CUDAPlace(0), 0.15));
}
} // namespace operators
} // namespace paddle
// Copyright (c) 2019 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <algorithm>
#include <random>
#include "gtest/gtest.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/for_range.h"
namespace paddle {
namespace operators {
template <typename T>
static void InitRandom(framework::Tensor *tensor,
const platform::Place &place) {
framework::Tensor cpu_tensor;
auto *cpu_ptr =
cpu_tensor.mutable_data<T>(tensor->dims(), platform::CPUPlace());
int64_t numel = cpu_tensor.numel();
std::mt19937 engine;
std::uniform_real_distribution<T> dist(static_cast<T>(-2.0),
static_cast<T>(2.0));
for (int64_t i = 0; i < numel; ++i) {
cpu_ptr[i] = dist(engine);
}
framework::TensorCopySync(cpu_tensor, place, tensor);
}
template <typename T>
struct LeakyReluGradGradEachElementFunctor {
LeakyReluGradGradEachElementFunctor(const T *ddx, const T *out, T alpha,
T *ddout)
: ddx_(ddx), out_(out), alpha_(alpha), ddout_(ddout) {}
HOSTDEVICE void operator()(int idx) {
if (out_[idx] >= 0) {
ddout_[idx] = ddx_[idx];
} else {
ddout_[idx] = ddx_[idx] * alpha_;
}
}
const T *ddx_;
const T *out_;
T alpha_;
T *ddout_;
};
template <typename T>
static bool TestLeakyReluGradGradMain(const framework::DDim &dim,
const platform::Place &place,
float alpha) {
LeakyReluGradGradFunctor<T> functor;
functor.alpha = alpha;
auto &dev_ctx = *platform::DeviceContextPool::Instance().Get(place);
framework::Tensor *x = nullptr;
framework::Tensor *dout = nullptr;
framework::Tensor *dx = nullptr;
framework::Tensor out;
out.Resize(dim);
InitRandom<T>(&out, place);
framework::Tensor ddx;
ddx.Resize(dim);
InitRandom<T>(&ddx, place);
framework::Tensor ddout;
ddout.Resize(dim);
InitRandom<T>(&ddout, place);
framework::Tensor ddout_actual;
ddout_actual.mutable_data<T>(dim, place);
LeakyReluGradGradEachElementFunctor<T> actual_functor(
ddx.data<T>(), out.data<T>(), static_cast<T>(alpha),
ddout_actual.data<T>());
int64_t limit = out.numel();
#ifdef __NVCC__
if (platform::is_gpu_place(place)) {
auto &cuda_dev_ctx = dynamic_cast<platform::CUDADeviceContext &>(dev_ctx);
functor(cuda_dev_ctx, x, &out, &ddx, &ddout, dout, dx);
platform::ForRange<platform::CUDADeviceContext> for_range(cuda_dev_ctx,
limit);
for_range(actual_functor);
} else {
#endif
auto &cpu_dev_ctx = dynamic_cast<platform::CPUDeviceContext &>(dev_ctx);
functor(cpu_dev_ctx, x, &out, &ddx, &ddout, dout, dx);
platform::ForRange<platform::CPUDeviceContext> for_range(cpu_dev_ctx,
limit);
for_range(actual_functor);
#ifdef __NVCC__
}
#endif
dev_ctx.Wait();
framework::Tensor ddout_cpu, ddout_actual_cpu;
framework::TensorCopySync(ddout, platform::CPUPlace(), &ddout_cpu);
framework::TensorCopySync(ddout_actual, platform::CPUPlace(),
&ddout_actual_cpu);
bool is_equal = std::equal(ddout_cpu.data<T>(), ddout_cpu.data<T>() + limit,
ddout_actual_cpu.data<T>());
return is_equal;
}
} // namespace operators
} // namespace paddle
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册