未验证 提交 9f7b027d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix activation grad op desc maker (#16715)

test=develop
上级 9bd44b94
......@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
}
};
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
void operator()(const char* op_type, OpInfo* info) const {}
};
} // namespace details
} // namespace framework
......
abs
acos
asin
atan
attention_lstm
brelu
conv_shift
cos
cos_sim
dequantize
elu
fc
flatten
fsp
......@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu
fusion_seqexpand_concat_fc
fusion_seqpool_concat
fusion_squared_mat_sub
gelu
gru
hard_shrink
hierarchical_sigmoid
leaky_relu
log
logsigmoid
lrn
lstm_unit
lstmp
......@@ -38,7 +26,6 @@ modified_huber_loss
nce
pool2d
pool3d
pow
prelu
quantize
rank_loss
......@@ -50,20 +37,10 @@ reduce_sum
requantize
reshape
rnn_memory_helper
round
sequence_softmax
sin
softplus
softshrink
softsign
spp
square
squeeze
stanh
swish
tanh_shrink
tensor_array_to_tensor
thresholded_relu
transpose
unpool
unsqueeze
......@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and
// limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h"
......@@ -82,6 +85,8 @@ template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {
}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -105,6 +112,8 @@ template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename T>
......@@ -116,6 +125,8 @@ template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
};
template <typename Functor>
......@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
public:
using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");
const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr;
framework::Tensor* dX = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX);
ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx);
......
......@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h"
#include <memory>
#include <string>
#include <type_traits>
#include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA
......@@ -27,6 +29,25 @@ namespace operators {
using paddle::framework::Tensor;
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
std::unique_ptr<std::unordered_set<std::string>> ret(
new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
bwd_functor) \
if (CanInplaceAct<bwd_functor<float>>()) { \
ret->insert(#op_type); \
}
FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
return ret;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \
......@@ -50,26 +71,32 @@ using paddle::framework::Tensor;
} \
}
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \
class OP_NAME##GradMaker \
: public ::paddle::framework::SingleGradOpDescMaker { \
public: \
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\
protected: \
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \
auto* op = new ::paddle::framework::OpDesc(); \
op->SetType(#KERNEL_TYPE "_grad"); \
op->SetInput("Out", Output("Out")); \
op->SetInput(::paddle::framework::GradVarName("Out"), \
OutputGrad("Out")); \
\
op->SetAttrMap(Attrs()); \
\
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \
template <ActBwdOpFwdDeps kDepValue>
class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
public:
using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
protected:
std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
op->SetType(ForwardOpType() + "_grad");
op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetAttrMap(Attrs());
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
op->SetInput("X", Input("X"));
}
if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out"));
}
return op;
}
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper,
......@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("Out", framework::GradVarName("X"));
ctx->ShareLoD("Out", framework::GradVarName("X"));
auto out_grad_name = framework::GradVarName("Out");
ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
}
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out");
return GetKernelType(ctx, *this, framework::GradVarName("Out"));
}
};
......@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \
__macro(Sigmoid, sigmoid); \
__macro(Relu, relu); \
__macro(Exp, exp); \
__macro(Tanh, tanh); \
__macro(Ceil, ceil); \
__macro(Floor, floor); \
__macro(Sqrt, sqrt); \
__macro(SoftRelu, soft_relu); \
__macro(Relu6, relu6); \
__macro(Reciprocal, reciprocal); \
__macro(HardSigmoid, hard_sigmoid);
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \
__macro(SoftShrink, softshrink); \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Acos, acos); \
__macro(Sin, sin); \
__macro(Asin, asin); \
__macro(Atan, atan); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(Gelu, gelu); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
REGISTER_OPERATOR( \
KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
ops::ActivationOpInferVarType, \
ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type); \
REGISTER_OPERATOR( \
KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \
......@@ -643,6 +619,5 @@ namespace ops = paddle::operators;
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
......@@ -15,7 +15,8 @@ limitations under the License. */
namespace ops = paddle::operators;
namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \
act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
......@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL);
FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册