diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 6c70970e15f0d63ebe2134c6bc8163339ba30e75..f13ec53905aa3d5b55b865c3514f36211c06a549 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -17,6 +17,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + class MomentumOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("VelocityOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 1e938638c9182972a2ae2436166ff0aa49efd4be..7dcf33c989c3bcd905da8017ee36ec8ce8032911 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel { } }; -template class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { public: ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -47,9 +46,9 @@ Scale operator $$Out = scale*X$$ )DOC"); - AddAttr("scale", - "(float, default 1.0)" - "The scaling factor of the scale operator.") + AddAttr("scale", + "(float, default 1.0)" + "The scaling factor of the scale operator.") .SetDefault(1.0); } }; @@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; -REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, - ops::ScaleGradMaker); +REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker); REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 2741ba95bcfc1db3d74e0fb8c3f6fddf7d5a2caa..aa7b192e327704c02a26c86cc208ebe8a5cd7ba5 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL( - softmax, ops::SoftmaxKernel); + softmax, ops::SoftmaxKernel, + ops::SoftmaxKernel); REGISTER_OP_CPU_KERNEL( softmax_grad, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 0c1f7cef7ab7b66358d80f6f0670e0d07536128c..5fb4f011d9b47cebc4a23bcce47eada825263343 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -19,6 +19,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( softmax, ops::SoftmaxKernel, + ops::SoftmaxKernel, ops::SoftmaxKernel); -REGISTER_OP_CUDA_KERNEL(softmax_grad, - ops::SoftmaxGradKernel); +REGISTER_OP_CUDA_KERNEL( + softmax_grad, ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 2e4e8caed5327f4ca9038c376de2ec831354917e..942a5de3f90f20eabe691924a570b61509eccf76 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -75,4 +75,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(top_k, - ops::TopkKernel); + ops::TopkKernel, + ops::TopkKernel); diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index d7f4d383ce0d9e1ff42fc12c96aaf0ceb532e5db..2ea9fd1d29936357bbb5f6819c813e4bd1ebe44c 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel);