From 3f4aca618f4ef0b87881c4a7ae62af2490e35780 Mon Sep 17 00:00:00 2001 From: chengduozh Date: Fri, 30 Nov 2018 21:03:28 +0800 Subject: [PATCH] code refine test=develop --- paddle/fluid/operators/cudnn_lstm_op.cc | 26 ++++++------- paddle/fluid/operators/cudnn_lstm_op.cu.cc | 17 ++++---- paddle/fluid/operators/cudnn_lstm_op.h | 45 ---------------------- 3 files changed, 20 insertions(+), 68 deletions(-) delete mode 100644 paddle/fluid/operators/cudnn_lstm_op.h diff --git a/paddle/fluid/operators/cudnn_lstm_op.cc b/paddle/fluid/operators/cudnn_lstm_op.cc index 86632fc9f..e63d57be5 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cc @@ -12,12 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cudnn_lstm_op.h" #include - -#ifdef PADDLE_WITH_CUDA -#include "paddle/fluid/platform/cudnn_helper.h" -#endif +#include "paddle/fluid/framework/op_registry.h" namespace paddle { namespace operators { @@ -201,18 +197,22 @@ class CudnnLSTMGradOp : public framework::OperatorWithKernel { } }; +template +class NotImpleKernel : public framework::OpKernel { + public: + void Compute(const framework::ExecutionContext& ctx) const override { + PADDLE_THROW( + "CPU is not support for this kernel now. Will be add in the future"); + } +}; + } // namespace operators } // namespace paddle namespace ops = paddle::operators; REGISTER_OPERATOR(cudnn_lstm, ops::CudnnLSTMOp, ops::CudnnLSTMOpMaker, paddle::framework::DefaultGradOpDescMaker); -REGISTER_OPERATOR(lstm_cudnn_grad, ops::CudnnLSTMGradOp); - -REGISTER_OP_CPU_KERNEL( - cudnn_lstm, - ops::CudnnLSTMKernel); +REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); -REGISTER_OP_CPU_KERNEL( - lstm_cudnn_grad, - ops::CudnnLSTMGradKernel); +REGISTER_OP_CPU_KERNEL(cudnn_lstm, ops::NotImpleKernel); +REGISTER_OP_CPU_KERNEL(cudnn_lstm_grad, ops::NotImpleKernel); diff --git a/paddle/fluid/operators/cudnn_lstm_op.cu.cc b/paddle/fluid/operators/cudnn_lstm_op.cu.cc index 811975a9f..cad62de75 100644 --- a/paddle/fluid/operators/cudnn_lstm_op.cu.cc +++ b/paddle/fluid/operators/cudnn_lstm_op.cu.cc @@ -12,7 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. */ -#include "paddle/fluid/operators/cudnn_lstm_op.h" +#include "paddle/fluid/framework/op_registry.h" +#include "paddle/fluid/operators/math/math_function.h" #include "paddle/fluid/platform/cudnn_helper.h" namespace paddle { @@ -246,7 +247,7 @@ struct CudnnRNNCache { } }; -template +template class CudnnLSTMGPUKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { @@ -343,7 +344,7 @@ class CudnnLSTMGPUKernel : public framework::OpKernel { } }; -template +template class CudnnLSTMGPUGradKernel : public framework::OpKernel { public: void Compute(const framework::ExecutionContext &ctx) const override { @@ -380,7 +381,7 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { auto init_c_dims = init_c->dims(); in_grad->mutable_data(ctx.GetPlace()); weight_grad->mutable_data(ctx.GetPlace()); - math::SetConstant zero; + math::SetConstant zero; zero(dev_ctx, in_grad, static_cast(0.0)); zero(dev_ctx, weight_grad, static_cast(0.0)); @@ -486,9 +487,5 @@ class CudnnLSTMGPUGradKernel : public framework::OpKernel { } // namespace paddle namespace ops = paddle::operators; -REGISTER_OP_CUDA_KERNEL( - cudnn_lstm, - ops::CudnnLSTMGPUKernel); -REGISTER_OP_CUDA_KERNEL( - cudnn_lstm_grad, - ops::CudnnLSTMGPUGradKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm, ops::CudnnLSTMGPUKernel); +REGISTER_OP_CUDA_KERNEL(cudnn_lstm_grad, ops::CudnnLSTMGPUGradKernel); diff --git a/paddle/fluid/operators/cudnn_lstm_op.h b/paddle/fluid/operators/cudnn_lstm_op.h deleted file mode 100644 index fc329cc23..000000000 --- a/paddle/fluid/operators/cudnn_lstm_op.h +++ /dev/null @@ -1,45 +0,0 @@ -/* Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - -http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. */ - -#pragma once -#include -#include "paddle/fluid/framework/op_registry.h" -#include "paddle/fluid/operators/math/blas.h" -#include "paddle/fluid/operators/math/detail/activation_functions.h" -#include "paddle/fluid/operators/math/lstm_compute.h" -#include "paddle/fluid/operators/math/sequence2batch.h" - -namespace paddle { -namespace operators { - -using LoDTensor = framework::LoDTensor; -using Tensor = framework::Tensor; - -template -class CudnnLSTMKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override { - PADDLE_THROW( - "CPU is not support for this kernel now. Will be add in the future"); - } -}; - -template -class CudnnLSTMGradKernel : public framework::OpKernel { - public: - void Compute(const framework::ExecutionContext& ctx) const override {} -}; - -} // namespace operators -} // namespace paddle -- GitLab