未验证 提交 cb81befa 编写于 作者: H huangjiyi 提交者: GitHub

register fluid kerenls to phi [part 6.5] (#52882)

* update

* fix bug

* update

* fix bug
上级 bc91012f
...@@ -2483,21 +2483,32 @@ Scope* OperatorWithKernel::PrepareData( ...@@ -2483,21 +2483,32 @@ Scope* OperatorWithKernel::PrepareData(
} }
std::unique_ptr<phi::KernelKey> new_expected_kernel_key = nullptr; std::unique_ptr<phi::KernelKey> new_expected_kernel_key = nullptr;
if (run_phi_kernel_ && in_def != nullptr && if (run_phi_kernel_) {
in_def->backend != phi::Backend::ALL_BACKEND) { if (phi_kernel_->GetKernelRegisteredType() ==
auto tensor_backend = phi::TransToPhiBackend(tensor_in->place()); phi::KernelRegisteredType::STRUCTURE) {
if ((in_def->backend != tensor_backend && if (!backends_are_same_class(kernel_type_for_var.backend(),
!(in_def->backend == phi::Backend::GPUDNN && expected_kernel_key.backend())) {
tensor_backend == phi::Backend::GPU) && new_expected_kernel_key =
!(in_def->backend == phi::Backend::KPS && std::make_unique<phi::KernelKey>(expected_kernel_key.backend(),
tensor_backend == phi::Backend::XPU) && expected_kernel_key.layout(),
!(in_def->backend == phi::Backend::ONEDNN && expected_kernel_key.dtype());
tensor_backend == phi::Backend::CPU)) || }
tensor_in->place().GetType() == AllocationType::GPUPINNED) { } else if (in_def != nullptr &&
new_expected_kernel_key = in_def->backend != phi::Backend::ALL_BACKEND) {
std::make_unique<phi::KernelKey>(in_def->backend, auto tensor_backend = phi::TransToPhiBackend(tensor_in->place());
expected_kernel_key.layout(), if ((in_def->backend != tensor_backend &&
expected_kernel_key.dtype()); !(in_def->backend == phi::Backend::GPUDNN &&
tensor_backend == phi::Backend::GPU) &&
!(in_def->backend == phi::Backend::KPS &&
tensor_backend == phi::Backend::XPU) &&
!(in_def->backend == phi::Backend::ONEDNN &&
tensor_backend == phi::Backend::CPU)) ||
tensor_in->place().GetType() == AllocationType::GPUPINNED) {
new_expected_kernel_key =
std::make_unique<phi::KernelKey>(in_def->backend,
expected_kernel_key.layout(),
expected_kernel_key.dtype());
}
} }
} }
......
...@@ -73,6 +73,8 @@ class MpAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> { ...@@ -73,6 +73,8 @@ class MpAllReduceSumOpGradMaker : public framework::SingleGradOpMaker<T> {
DECLARE_INPLACE_OP_INFERER(MpAllReduceSumInplaceInferer, {"X", "Out"}); DECLARE_INPLACE_OP_INFERER(MpAllReduceSumInplaceInferer, {"X", "Out"});
DEFINE_C_ALLREDUCE_CPU_KERNEL(MpAllReduceSum, kRedSum);
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -86,9 +88,12 @@ REGISTER_OPERATOR(mp_allreduce_sum, ...@@ -86,9 +88,12 @@ REGISTER_OPERATOR(mp_allreduce_sum,
ops::MpAllReduceSumOpMaker, ops::MpAllReduceSumOpMaker,
ops::MpAllReduceSumInplaceInferer); ops::MpAllReduceSumInplaceInferer);
REGISTER_OP_CPU_KERNEL(mp_allreduce_sum, PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
ops::CAllReduceOpCPUKernel<ops::kRedSum, float>, CPU,
ops::CAllReduceOpCPUKernel<ops::kRedSum, double>, ALL_LAYOUT,
ops::CAllReduceOpCPUKernel<ops::kRedSum, int>, ops::MpAllReduceSumCPUKernel,
ops::CAllReduceOpCPUKernel<ops::kRedSum, int64_t>, float,
ops::CAllReduceOpCPUKernel<ops::kRedSum, plat::float16>) double,
int,
int64_t,
plat::float16) {}
...@@ -15,16 +15,24 @@ ...@@ -15,16 +15,24 @@
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/operators/collective/c_allreduce_op.h" #include "paddle/fluid/operators/collective/c_allreduce_op.h"
namespace paddle {
namespace operators {
DEFINE_C_ALLREDUCE_CUDA_KERNEL(MpAllReduceSum, kRedSum)
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
PD_REGISTER_STRUCT_KERNEL(mp_allreduce_sum,
REGISTER_OP_CUDA_KERNEL( GPU,
mp_allreduce_sum, ALL_LAYOUT,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, float>, ops::MpAllReduceSumCUDAKernel,
float,
double,
int,
int64_t,
#if NCCL_VERSION_CODE >= 21000 #if NCCL_VERSION_CODE >= 21000
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::bfloat16>, plat::bfloat16,
#endif #endif
ops::CAllReduceOpCUDAKernel<ops::kRedSum, double>, plat::float16) {
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int>, }
ops::CAllReduceOpCUDAKernel<ops::kRedSum, int64_t>,
ops::CAllReduceOpCUDAKernel<ops::kRedSum, plat::float16>)
...@@ -114,6 +114,5 @@ REGISTER_OPERATOR( ...@@ -114,6 +114,5 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(iou_similarity, PD_REGISTER_STRUCT_KERNEL(
ops::IOUSimilarityKernel<phi::CPUContext, float>, iou_similarity, CPU, ALL_LAYOUT, ops::IOUSimilarityKernel, float, double) {}
ops::IOUSimilarityKernel<phi::CPUContext, double>);
...@@ -15,6 +15,5 @@ limitations under the License. */ ...@@ -15,6 +15,5 @@ limitations under the License. */
#include "paddle/fluid/operators/detection/iou_similarity_op.h" #include "paddle/fluid/operators/detection/iou_similarity_op.h"
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(iou_similarity, PD_REGISTER_STRUCT_KERNEL(
ops::IOUSimilarityKernel<phi::GPUContext, float>, iou_similarity, GPU, ALL_LAYOUT, ops::IOUSimilarityKernel, float, double) {}
ops::IOUSimilarityKernel<phi::GPUContext, double>);
...@@ -105,7 +105,7 @@ struct IOUSimilarityFunctor { ...@@ -105,7 +105,7 @@ struct IOUSimilarityFunctor {
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class IOUSimilarityKernel : public framework::OpKernel<T> { class IOUSimilarityKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -160,7 +160,7 @@ void GetMaxScoreIndexWithLocalityAware( ...@@ -160,7 +160,7 @@ void GetMaxScoreIndexWithLocalityAware(
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class LocalityAwareNMSKernel : public framework::OpKernel<T> { class LocalityAwareNMSKernel : public framework::OpKernel<T> {
public: public:
void LocalityAwareNMSFast(phi::DenseTensor* bbox, void LocalityAwareNMSFast(phi::DenseTensor* bbox,
...@@ -520,6 +520,9 @@ REGISTER_OPERATOR( ...@@ -520,6 +520,9 @@ REGISTER_OPERATOR(
ops::LocalityAwareNMSOpMaker, ops::LocalityAwareNMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(locality_aware_nms, PD_REGISTER_STRUCT_KERNEL(locality_aware_nms,
ops::LocalityAwareNMSKernel<float>, CPU,
ops::LocalityAwareNMSKernel<double>); ALL_LAYOUT,
ops::LocalityAwareNMSKernel,
float,
double) {}
...@@ -49,7 +49,7 @@ inline MiningType GetMiningType(std::string str) { ...@@ -49,7 +49,7 @@ inline MiningType GetMiningType(std::string str) {
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class MineHardExamplesKernel : public framework::OpKernel<T> { class MineHardExamplesKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -403,6 +403,9 @@ REGISTER_OPERATOR( ...@@ -403,6 +403,9 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(mine_hard_examples, PD_REGISTER_STRUCT_KERNEL(mine_hard_examples,
ops::MineHardExamplesKernel<phi::CPUContext, float>, CPU,
ops::MineHardExamplesKernel<phi::CPUContext, double>); ALL_LAYOUT,
ops::MineHardExamplesKernel,
float,
double) {}
...@@ -143,7 +143,7 @@ void SliceOneClass(const platform::DeviceContext& ctx, ...@@ -143,7 +143,7 @@ void SliceOneClass(const platform::DeviceContext& ctx,
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class MultiClassNMSKernel : public framework::OpKernel<T> { class MultiClassNMSKernel : public framework::OpKernel<T> {
public: public:
void NMSFast(const phi::DenseTensor& bbox, void NMSFast(const phi::DenseTensor& bbox,
...@@ -629,6 +629,9 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker { ...@@ -629,6 +629,9 @@ class MultiClassNMS3OpMaker : public MultiClassNMS2OpMaker {
} }
}; };
template <typename T, typename DeviceContext>
class MultiClassNMS2Kernel : public MultiClassNMSKernel<T, DeviceContext> {};
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
...@@ -643,18 +646,21 @@ REGISTER_OPERATOR( ...@@ -643,18 +646,21 @@ REGISTER_OPERATOR(
ops::MultiClassNMSOpMaker, ops::MultiClassNMSOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms, PD_REGISTER_STRUCT_KERNEL(
ops::MultiClassNMSKernel<float>, multiclass_nms, CPU, ALL_LAYOUT, ops::MultiClassNMSKernel, float, double) {}
ops::MultiClassNMSKernel<double>);
REGISTER_OPERATOR( REGISTER_OPERATOR(
multiclass_nms2, multiclass_nms2,
ops::MultiClassNMS2Op, ops::MultiClassNMS2Op,
ops::MultiClassNMS2OpMaker, ops::MultiClassNMS2OpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(multiclass_nms2, PD_REGISTER_STRUCT_KERNEL(multiclass_nms2,
ops::MultiClassNMSKernel<float>, CPU,
ops::MultiClassNMSKernel<double>); ALL_LAYOUT,
ops::MultiClassNMS2Kernel,
float,
double) {}
REGISTER_OPERATOR( REGISTER_OPERATOR(
multiclass_nms3, multiclass_nms3,
......
...@@ -270,7 +270,7 @@ __global__ void broadcast_batch_head_number(const T *src, ...@@ -270,7 +270,7 @@ __global__ void broadcast_batch_head_number(const T *src,
} }
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext &context) const override { void Compute(const framework::ExecutionContext &context) const override {
...@@ -423,12 +423,15 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> { ...@@ -423,12 +423,15 @@ class MultiHeadMatMulV2Kernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
namespace plat = paddle::platform;
#if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000 #if defined(PADDLE_WITH_CUDA) && CUDA_VERSION >= 10000
REGISTER_OP_CUDA_KERNEL( PD_REGISTER_STRUCT_KERNEL(multihead_matmul,
multihead_matmul, GPU,
ops::MultiHeadMatMulV2Kernel<phi::GPUContext, paddle::platform::float16>, ALL_LAYOUT,
ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>); ops::MultiHeadMatMulV2Kernel,
float,
plat::float16) {}
#else #else
REGISTER_OP_CUDA_KERNEL(multihead_matmul, PD_REGISTER_STRUCT_KERNEL(
ops::MultiHeadMatMulV2Kernel<phi::GPUContext, float>); multihead_matmul, GPU, ALL_LAYOUT, ops::MultiHeadMatMulV2Kernel, float) {}
#endif #endif
...@@ -240,7 +240,7 @@ void MatchMatrixTensorOpMaker::Make() { ...@@ -240,7 +240,7 @@ void MatchMatrixTensorOpMaker::Make() {
)DOC"); )DOC");
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> { class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -321,7 +321,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> { ...@@ -321,7 +321,7 @@ class CPUMatchMatrixTensorOPKernel : public framework::OpKernel<T> {
} }
}; };
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> { class CPUMatchMatrixTensorOPGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -458,10 +458,13 @@ REGISTER_OPERATOR( ...@@ -458,10 +458,13 @@ REGISTER_OPERATOR(
ops::MatchMatrixTensorGradOpMaker<paddle::imperative::OpBase>); ops::MatchMatrixTensorGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad); REGISTER_OPERATOR(match_matrix_tensor_grad, ops::MatchMatrixTensorOpGrad);
REGISTER_OP_CPU_KERNEL( PD_REGISTER_STRUCT_KERNEL(match_matrix_tensor,
match_matrix_tensor, CPU,
ops::CPUMatchMatrixTensorOPKernel<phi::CPUContext, float>); ALL_LAYOUT,
ops::CPUMatchMatrixTensorOPKernel,
REGISTER_OP_CPU_KERNEL( float) {}
match_matrix_tensor_grad, PD_REGISTER_STRUCT_KERNEL(match_matrix_tensor_grad,
ops::CPUMatchMatrixTensorOPGradKernel<phi::CPUContext, float>); CPU,
ALL_LAYOUT,
ops::CPUMatchMatrixTensorOPGradKernel,
float) {}
...@@ -107,7 +107,6 @@ REGISTER_OPERATOR( ...@@ -107,7 +107,6 @@ REGISTER_OPERATOR(
ops::MeanIoUOpMaker, ops::MeanIoUOpMaker,
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>); paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>);
REGISTER_OP_CPU_KERNEL(mean_iou,
ops::MeanIoUKernel<int>, PD_REGISTER_STRUCT_KERNEL(
ops::MeanIoUKernel<int32_t>, mean_iou, CPU, ALL_LAYOUT, ops::MeanIoUKernel, int, int64_t) {}
ops::MeanIoUKernel<int64_t>);
...@@ -88,7 +88,7 @@ __global__ void ComputeIoUCUDAKernel( ...@@ -88,7 +88,7 @@ __global__ void ComputeIoUCUDAKernel(
} }
} }
template <typename T> template <typename T, typename DeviceContext>
class MeanIoUCUDAOpKernel : public framework::OpKernel<T> { class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
...@@ -166,7 +166,5 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel<T> { ...@@ -166,7 +166,5 @@ class MeanIoUCUDAOpKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(mean_iou, PD_REGISTER_STRUCT_KERNEL(
ops::MeanIoUCUDAOpKernel<int>, mean_iou, GPU, ALL_LAYOUT, ops::MeanIoUCUDAOpKernel, int, int64_t) {}
ops::MeanIoUCUDAOpKernel<int64_t>,
ops::MeanIoUCUDAOpKernel<int32_t>);
...@@ -27,7 +27,7 @@ template <typename T, ...@@ -27,7 +27,7 @@ template <typename T,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>; using EigenTensor = framework::EigenTensor<T, D, MajorType, IndexType>;
template <typename T> template <typename T, typename DeviceContext>
class MeanIoUKernel : public framework::OpKernel<T> { class MeanIoUKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -155,6 +155,8 @@ REGISTER_OPERATOR(minus, ...@@ -155,6 +155,8 @@ REGISTER_OPERATOR(minus,
ops::MinusOpMaker, ops::MinusOpMaker,
ops::MinusGradDescMaker, ops::MinusGradDescMaker,
ops::MinusGradMaker); ops::MinusGradMaker);
REGISTER_OP_CPU_KERNEL(minus, ops::MinusKernel<phi::CPUContext, float>); PD_REGISTER_STRUCT_KERNEL(minus, CPU, ALL_LAYOUT, ops::MinusKernel, float) {}
REGISTER_OP_CUDA_KERNEL(minus, ops::MinusKernel<phi::GPUContext, float>); #if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_STRUCT_KERNEL(minus, GPU, ALL_LAYOUT, ops::MinusKernel, float) {}
#endif
...@@ -20,7 +20,7 @@ limitations under the License. */ ...@@ -20,7 +20,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class MinusKernel : public framework::OpKernel<T> { class MinusKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
......
...@@ -233,6 +233,6 @@ REGISTER_OPERATOR( ...@@ -233,6 +233,6 @@ REGISTER_OPERATOR(
paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>, paddle::framework::EmptyGradOpMaker<paddle::framework::OpDesc>,
paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>, paddle::framework::EmptyGradOpMaker<paddle::imperative::OpBase>,
ops::LarsMomentumOpVarTypeInference); ops::LarsMomentumOpVarTypeInference);
REGISTER_OP_CPU_KERNEL(lars_momentum,
ops::LarsMomentumOpKernel<float>, PD_REGISTER_STRUCT_KERNEL(
ops::LarsMomentumOpKernel<double>); lars_momentum, CPU, ALL_LAYOUT, ops::LarsMomentumOpKernel, float, double) {}
...@@ -474,7 +474,7 @@ inline void SeparatedLarsMomentumOpCUDAKernel(const phi::GPUContext& cuda_ctx, ...@@ -474,7 +474,7 @@ inline void SeparatedLarsMomentumOpCUDAKernel(const phi::GPUContext& cuda_ctx,
is_amp); is_amp);
} }
template <typename DeviceContext, typename T> template <typename T, typename DeviceContext>
class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
using MT = MultiPrecisionType<T>; using MT = MultiPrecisionType<T>;
...@@ -679,8 +679,11 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> { ...@@ -679,8 +679,11 @@ class LarsMomentumOpCUDAKernel : public framework::OpKernel<T> {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL( namespace plat = paddle::platform;
lars_momentum, PD_REGISTER_STRUCT_KERNEL(lars_momentum,
ops::LarsMomentumOpCUDAKernel<phi::GPUContext, float>, GPU,
ops::LarsMomentumOpCUDAKernel<phi::GPUContext, double>, ALL_LAYOUT,
ops::LarsMomentumOpCUDAKernel<phi::GPUContext, paddle::platform::float16>); ops::LarsMomentumOpCUDAKernel,
float,
double,
plat::float16) {}
...@@ -19,7 +19,7 @@ limitations under the License. */ ...@@ -19,7 +19,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename T> template <typename T, typename DeviceContext>
class LarsMomentumOpKernel : public framework::OpKernel<T> { class LarsMomentumOpKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& ctx) const override { void Compute(const framework::ExecutionContext& ctx) const override {
......
...@@ -10,7 +10,6 @@ register_unity_group( ...@@ -10,7 +10,6 @@ register_unity_group(
lars_momentum_op.cc lars_momentum_op.cc
proximal_adagrad_op.cc proximal_adagrad_op.cc
adam_op.cc adam_op.cc
dgc_momentum_op.cc
proximal_gd_op.cc proximal_gd_op.cc
decayed_adagrad_op.cc decayed_adagrad_op.cc
adadelta_op.cc adadelta_op.cc
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册