未验证 提交 9f7b027d 编写于 作者: Z Zeng Jinle 提交者: GitHub

fix activation grad op desc maker (#16715)

test=develop
上级 9bd44b94
...@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> { ...@@ -233,6 +233,12 @@ struct OpInfoFiller<T, kNoNeedBufferVarsInference> {
} }
}; };
// A fake OpInfoFiller of void
template <>
struct OpInfoFiller<void, kUnknown> {
void operator()(const char* op_type, OpInfo* info) const {}
};
} // namespace details } // namespace details
} // namespace framework } // namespace framework
......
abs
acos
asin
atan
attention_lstm attention_lstm
brelu
conv_shift conv_shift
cos
cos_sim cos_sim
dequantize dequantize
elu
fc fc
flatten flatten
fsp fsp
...@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu ...@@ -21,13 +14,8 @@ fusion_seqconv_eltadd_relu
fusion_seqexpand_concat_fc fusion_seqexpand_concat_fc
fusion_seqpool_concat fusion_seqpool_concat
fusion_squared_mat_sub fusion_squared_mat_sub
gelu
gru gru
hard_shrink
hierarchical_sigmoid hierarchical_sigmoid
leaky_relu
log
logsigmoid
lrn lrn
lstm_unit lstm_unit
lstmp lstmp
...@@ -38,7 +26,6 @@ modified_huber_loss ...@@ -38,7 +26,6 @@ modified_huber_loss
nce nce
pool2d pool2d
pool3d pool3d
pow
prelu prelu
quantize quantize
rank_loss rank_loss
...@@ -50,20 +37,10 @@ reduce_sum ...@@ -50,20 +37,10 @@ reduce_sum
requantize requantize
reshape reshape
rnn_memory_helper rnn_memory_helper
round
sequence_softmax sequence_softmax
sin
softplus
softshrink
softsign
spp spp
square
squeeze squeeze
stanh
swish
tanh_shrink
tensor_array_to_tensor tensor_array_to_tensor
thresholded_relu
transpose transpose
unpool unpool
unsqueeze unsqueeze
...@@ -12,6 +12,9 @@ ...@@ -12,6 +12,9 @@
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#include <memory>
#include <string>
#include <vector>
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include "paddle/fluid/platform/cudnn_desc.h" #include "paddle/fluid/platform/cudnn_desc.h"
...@@ -82,6 +85,8 @@ template <typename T> ...@@ -82,6 +85,8 @@ template <typename T>
struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnReluGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx) explicit CudnnReluGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_RELU) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> { ...@@ -94,6 +99,8 @@ struct CudnnRelu6GradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx) explicit CudnnRelu6GradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) { : CudnnActivationGradFunctor<T>(ctx, 6.0, CUDNN_ACTIVATION_CLIPPED_RELU) {
} }
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -105,6 +112,8 @@ template <typename T> ...@@ -105,6 +112,8 @@ template <typename T>
struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnSigmoidGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx) explicit CudnnSigmoidGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_SIGMOID) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename T> template <typename T>
...@@ -116,6 +125,8 @@ template <typename T> ...@@ -116,6 +125,8 @@ template <typename T>
struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> { struct CudnnTanhGradFunctor : public CudnnActivationGradFunctor<T> {
explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx) explicit CudnnTanhGradFunctor(const CUDADeviceContext& ctx)
: CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {} : CudnnActivationGradFunctor<T>(ctx, 0.0, CUDNN_ACTIVATION_TANH) {}
static constexpr ActBwdOpFwdDeps FwdDeps() { return kDepOut; }
}; };
template <typename Functor> template <typename Functor>
...@@ -140,10 +151,13 @@ class CudnnActivationGradKernel ...@@ -140,10 +151,13 @@ class CudnnActivationGradKernel
public: public:
using T = typename Functor::ELEMENT_TYPE; using T = typename Functor::ELEMENT_TYPE;
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
static_assert(Functor::FwdDeps() == kDepOut, "Forward deps must be Out.");
const framework::Tensor *X, *Out, *dOut; const framework::Tensor *X, *Out, *dOut;
X = Out = dOut = nullptr; X = Out = dOut = nullptr;
framework::Tensor* dX = nullptr; framework::Tensor* dX = nullptr;
ExtractActivationGradTensor(context, &X, &Out, &dOut, &dX); ExtractActivationGradTensor<Functor::FwdDeps()>(context, &X, &Out, &dOut,
&dX);
dX->mutable_data<T>(context.GetPlace()); dX->mutable_data<T>(context.GetPlace());
auto& dev_ctx = context.template device_context<CUDADeviceContext>(); auto& dev_ctx = context.template device_context<CUDADeviceContext>();
Functor functor(dev_ctx); Functor functor(dev_ctx);
......
...@@ -15,7 +15,9 @@ limitations under the License. */ ...@@ -15,7 +15,9 @@ limitations under the License. */
#include "paddle/fluid/operators/activation_op.h" #include "paddle/fluid/operators/activation_op.h"
#include <memory> #include <memory>
#include <string> #include <string>
#include <type_traits>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h" #include "paddle/fluid/operators/mkldnn/mkldnn_activation_op.h"
#include "paddle/fluid/platform/port.h" #include "paddle/fluid/platform/port.h"
#ifdef PADDLE_WITH_CUDA #ifdef PADDLE_WITH_CUDA
...@@ -27,6 +29,25 @@ namespace operators { ...@@ -27,6 +29,25 @@ namespace operators {
using paddle::framework::Tensor; using paddle::framework::Tensor;
template <typename GradFunctor>
static constexpr bool CanInplaceAct() {
return GradFunctor::FwdDeps() == kDepOut || GradFunctor::FwdDeps() == kNoDeps;
}
std::unique_ptr<std::unordered_set<std::string>> GetInplaceOpSet() {
std::unique_ptr<std::unordered_set<std::string>> ret(
new std::unordered_set<std::string>());
#define INSERT_INTO_INPLACE_OP_SET(op_type, __omitted, fwd_functor, \
bwd_functor) \
if (CanInplaceAct<bwd_functor<float>>()) { \
ret->insert(#op_type); \
}
FOR_EACH_ACTIVATION_OP(INSERT_INTO_INPLACE_OP_SET);
#undef INSERT_INTO_INPLACE_OP_SET
return ret;
}
#define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \ #define REGISTER_ACTIVATION_OP_MAKER(OP_NAME, OP_COMMENT) \
class OP_NAME##OpMaker \ class OP_NAME##OpMaker \
: public ::paddle::framework::OpProtoAndCheckerMaker { \ : public ::paddle::framework::OpProtoAndCheckerMaker { \
...@@ -50,26 +71,32 @@ using paddle::framework::Tensor; ...@@ -50,26 +71,32 @@ using paddle::framework::Tensor;
} \ } \
} }
#define REGISTER_ACTIVATION_OP_GRAD_MAKER(OP_NAME, KERNEL_TYPE) \ template <ActBwdOpFwdDeps kDepValue>
class OP_NAME##GradMaker \ class ActivationGradOpDescMaker : public framework::SingleGradOpDescMaker {
: public ::paddle::framework::SingleGradOpDescMaker { \ public:
public: \ using framework::SingleGradOpDescMaker::SingleGradOpDescMaker;
using ::paddle::framework::SingleGradOpDescMaker::SingleGradOpDescMaker; \
\ protected:
protected: \ std::unique_ptr<framework::OpDesc> Apply() const override {
std::unique_ptr<::paddle::framework::OpDesc> Apply() const override { \ std::unique_ptr<framework::OpDesc> op(new framework::OpDesc());
auto* op = new ::paddle::framework::OpDesc(); \ op->SetType(ForwardOpType() + "_grad");
op->SetType(#KERNEL_TYPE "_grad"); \ op->SetInput(framework::GradVarName("Out"), OutputGrad("Out"));
op->SetInput("Out", Output("Out")); \ op->SetOutput(framework::GradVarName("X"), InputGrad("X"));
op->SetInput(::paddle::framework::GradVarName("Out"), \ op->SetAttrMap(Attrs());
OutputGrad("Out")); \
\ if (static_cast<int>(kDepValue) &
op->SetAttrMap(Attrs()); \ static_cast<int>(ActBwdOpFwdDeps::kDepX)) {
\ op->SetInput("X", Input("X"));
op->SetOutput(::paddle::framework::GradVarName("X"), InputGrad("X")); \ }
return std::unique_ptr<::paddle::framework::OpDesc>(op); \
} \ if (static_cast<int>(kDepValue) &
static_cast<int>(ActBwdOpFwdDeps::kDepOut)) {
op->SetInput("Out", Output("Out"));
}
return op;
} }
};
framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx, framework::OpKernelType GetKernelType(const framework::ExecutionContext& ctx,
const framework::OperatorWithKernel& oper, const framework::OperatorWithKernel& oper,
...@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel { ...@@ -129,14 +156,15 @@ class ActivationOpGrad : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override { void InferShape(framework::InferShapeContext* ctx) const override {
ctx->ShareDim("Out", framework::GradVarName("X")); auto out_grad_name = framework::GradVarName("Out");
ctx->ShareLoD("Out", framework::GradVarName("X")); ctx->ShareDim(out_grad_name, framework::GradVarName("X"));
ctx->ShareLoD(out_grad_name, framework::GradVarName("X"));
} }
protected: protected:
framework::OpKernelType GetExpectedKernelType( framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
return GetKernelType(ctx, *this, "Out"); return GetKernelType(ctx, *this, framework::GradVarName("Out"));
} }
}; };
...@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc); ...@@ -558,79 +586,27 @@ REGISTER_ACTIVATION_OP_MAKER(Log, LogDoc);
REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc); REGISTER_ACTIVATION_OP_MAKER(Square, SquareDoc);
REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc); REGISTER_ACTIVATION_OP_MAKER(Softplus, SoftplusDoc);
REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc); REGISTER_ACTIVATION_OP_MAKER(Softsign, SoftsignDoc);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sigmoid, sigmoid);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu, relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Gelu, gelu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Exp, exp);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Tanh, tanh);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Ceil, ceil);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Floor, floor);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Sqrt, sqrt);
REGISTER_ACTIVATION_OP_GRAD_MAKER(SoftRelu, soft_relu);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Relu6, relu6);
REGISTER_ACTIVATION_OP_GRAD_MAKER(Reciprocal, reciprocal);
REGISTER_ACTIVATION_OP_GRAD_MAKER(HardSigmoid, hard_sigmoid);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
#define FOR_EACH_INPLACE_OP_FUNCTOR(__macro) \ #define REGISTER_ACTIVATION_OP(KERNEL_TYPE, OP_NAME, functor, grad_functor) \
__macro(Sigmoid, sigmoid); \ REGISTER_OPERATOR( \
__macro(Relu, relu); \ KERNEL_TYPE, ops::ActivationOp, ops::OP_NAME##OpMaker, \
__macro(Exp, exp); \ ops::ActivationOpInferVarType, \
__macro(Tanh, tanh); \ ops::ActivationGradOpDescMaker<ops::grad_functor<float>::FwdDeps()>, \
__macro(Ceil, ceil); \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
__macro(Floor, floor); \ ::paddle::framework::SingleOpInplaceInToOut, \
__macro(Sqrt, sqrt); \ void>::type); \
__macro(SoftRelu, soft_relu); \ REGISTER_OPERATOR( \
__macro(Relu6, relu6); \ KERNEL_TYPE##_grad, ops::ActivationOpGrad, \
__macro(Reciprocal, reciprocal); \ std::conditional<ops::CanInplaceAct<ops::grad_functor<float>>(), \
__macro(HardSigmoid, hard_sigmoid); ::paddle::framework::SingleOpInplaceInToOut, \
void>::type)
#define FOR_EACH_OP_FUNCTOR(__macro) \
__macro(LogSigmoid, logsigmoid); \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, op_name, functor, \
__macro(SoftShrink, softshrink); \ grad_functor) \
__macro(Abs, abs); \
__macro(Cos, cos); \
__macro(Acos, acos); \
__macro(Sin, sin); \
__macro(Asin, asin); \
__macro(Atan, atan); \
__macro(Round, round); \
__macro(Log, log); \
__macro(Square, square); \
__macro(Gelu, gelu); \
__macro(BRelu, brelu); \
__macro(Pow, pow); \
__macro(STanh, stanh); \
__macro(Softplus, softplus); \
__macro(Softsign, softsign); \
__macro(LeakyRelu, leaky_relu); \
__macro(TanhShrink, tanh_shrink); \
__macro(ELU, elu); \
__macro(HardShrink, hard_shrink); \
__macro(Swish, swish); \
__macro(ThresholdedRelu, thresholded_relu);
#define REGISTER_INPLACE_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::operators::OP_NAME##GradMaker, \
::paddle::framework::SingleOpInplaceInToOut); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad, \
::paddle::framework::SingleOpInplaceInToOut)
#define REGISTER_ACTIVATION_OP(OP_NAME, KERNEL_TYPE) \
REGISTER_OPERATOR(KERNEL_TYPE, ::paddle::operators::ActivationOp, \
::paddle::operators::OP_NAME##OpMaker, \
::paddle::operators::ActivationOpInferVarType, \
::paddle::framework::DefaultGradOpDescMaker<true>); \
REGISTER_OPERATOR(KERNEL_TYPE##_grad, ::paddle::operators::ActivationOpGrad)
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
REGISTER_OP_CPU_KERNEL( \ REGISTER_OP_CPU_KERNEL( \
act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \ act_type, ops::ActivationKernel<paddle::platform::CPUDeviceContext, \
ops::functor<float>>, \ ops::functor<float>>, \
...@@ -643,6 +619,5 @@ namespace ops = paddle::operators; ...@@ -643,6 +619,5 @@ namespace ops = paddle::operators;
ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \ ops::ActivationGradKernel<paddle::platform::CPUDeviceContext, \
ops::grad_functor<double>>); ops::grad_functor<double>>);
FOR_EACH_OP_FUNCTOR(REGISTER_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_OP);
FOR_EACH_INPLACE_OP_FUNCTOR(REGISTER_INPLACE_ACTIVATION_OP); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CPU_KERNEL);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CPU_KERNEL);
...@@ -15,7 +15,8 @@ limitations under the License. */ ...@@ -15,7 +15,8 @@ limitations under the License. */
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
#define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CUDA_KERNEL(act_type, op_name, functor, \
grad_functor) \
REGISTER_OP_CUDA_KERNEL( \ REGISTER_OP_CUDA_KERNEL( \
act_type, \ act_type, \
ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \ ops::ActivationKernel<plat::CUDADeviceContext, ops::functor<float>>, \
...@@ -30,4 +31,4 @@ namespace plat = paddle::platform; ...@@ -30,4 +31,4 @@ namespace plat = paddle::platform;
ops::ActivationGradKernel<plat::CUDADeviceContext, \ ops::ActivationGradKernel<plat::CUDADeviceContext, \
ops::grad_functor<plat::float16>>); ops::grad_functor<plat::float16>>);
FOR_EACH_KERNEL_FUNCTOR(REGISTER_ACTIVATION_CUDA_KERNEL); FOR_EACH_ACTIVATION_OP(REGISTER_ACTIVATION_CUDA_KERNEL);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册