From f63ff90b03b444ff7562bf72fca6877ad7b068a2 Mon Sep 17 00:00:00 2001 From: dzhwinter Date: Thu, 3 May 2018 11:13:30 +0800 Subject: [PATCH] Fix/fp64 (#10346) * "fix double type error" * "fix ci" * "softmax fp64" * "fix momentum" * "fix ci" --- paddle/fluid/operators/momentum_op.cc | 8 ++++++++ paddle/fluid/operators/scale_op.cc | 10 ++++------ paddle/fluid/operators/softmax_op.cc | 6 ++++-- paddle/fluid/operators/softmax_op.cu.cc | 6 ++++-- paddle/fluid/operators/top_k_op.cc | 3 ++- paddle/fluid/operators/top_k_op.cu | 3 ++- 6 files changed, 24 insertions(+), 12 deletions(-) diff --git a/paddle/fluid/operators/momentum_op.cc b/paddle/fluid/operators/momentum_op.cc index 6c70970e15..f13ec53905 100644 --- a/paddle/fluid/operators/momentum_op.cc +++ b/paddle/fluid/operators/momentum_op.cc @@ -17,6 +17,8 @@ limitations under the License. */ namespace paddle { namespace operators { +using Tensor = framework::Tensor; + class MomentumOp : public framework::OperatorWithKernel { public: using framework::OperatorWithKernel::OperatorWithKernel; @@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel { ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("VelocityOut", param_dim); } + framework::OpKernelType GetExpectedKernelType( + const framework::ExecutionContext &ctx) const override { + auto input_data_type = + framework::ToDataType(ctx.Input("Param")->type()); + return framework::OpKernelType(input_data_type, ctx.GetPlace()); + } }; class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { diff --git a/paddle/fluid/operators/scale_op.cc b/paddle/fluid/operators/scale_op.cc index 1e938638c9..7dcf33c989 100644 --- a/paddle/fluid/operators/scale_op.cc +++ b/paddle/fluid/operators/scale_op.cc @@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel { } }; -template class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { public: ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker) @@ -47,9 +46,9 @@ Scale operator $$Out = scale*X$$ )DOC"); - AddAttr("scale", - "(float, default 1.0)" - "The scaling factor of the scale operator.") + AddAttr("scale", + "(float, default 1.0)" + "The scaling factor of the scale operator.") .SetDefault(1.0); } }; @@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { namespace ops = paddle::operators; -REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, - ops::ScaleGradMaker); +REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker); REGISTER_OP_CPU_KERNEL( scale, ops::ScaleKernel, ops::ScaleKernel, diff --git a/paddle/fluid/operators/softmax_op.cc b/paddle/fluid/operators/softmax_op.cc index 2741ba95bc..aa7b192e32 100644 --- a/paddle/fluid/operators/softmax_op.cc +++ b/paddle/fluid/operators/softmax_op.cc @@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, paddle::framework::DefaultGradOpDescMaker); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OP_CPU_KERNEL( - softmax, ops::SoftmaxKernel); + softmax, ops::SoftmaxKernel, + ops::SoftmaxKernel); REGISTER_OP_CPU_KERNEL( softmax_grad, - ops::SoftmaxGradKernel); + ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/softmax_op.cu.cc b/paddle/fluid/operators/softmax_op.cu.cc index 0c1f7cef7a..5fb4f011d9 100644 --- a/paddle/fluid/operators/softmax_op.cu.cc +++ b/paddle/fluid/operators/softmax_op.cu.cc @@ -19,6 +19,8 @@ namespace ops = paddle::operators; namespace plat = paddle::platform; REGISTER_OP_CUDA_KERNEL( softmax, ops::SoftmaxKernel, + ops::SoftmaxKernel, ops::SoftmaxKernel); -REGISTER_OP_CUDA_KERNEL(softmax_grad, - ops::SoftmaxGradKernel); +REGISTER_OP_CUDA_KERNEL( + softmax_grad, ops::SoftmaxGradKernel, + ops::SoftmaxGradKernel); diff --git a/paddle/fluid/operators/top_k_op.cc b/paddle/fluid/operators/top_k_op.cc index 2e4e8caed5..942a5de3f9 100644 --- a/paddle/fluid/operators/top_k_op.cc +++ b/paddle/fluid/operators/top_k_op.cc @@ -75,4 +75,5 @@ namespace ops = paddle::operators; REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker, paddle::framework::EmptyGradOpMaker); REGISTER_OP_CPU_KERNEL(top_k, - ops::TopkKernel); + ops::TopkKernel, + ops::TopkKernel); diff --git a/paddle/fluid/operators/top_k_op.cu b/paddle/fluid/operators/top_k_op.cu index d7f4d383ce..2ea9fd1d29 100644 --- a/paddle/fluid/operators/top_k_op.cu +++ b/paddle/fluid/operators/top_k_op.cu @@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel { } // namespace operators } // namespace paddle -REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel); +REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel, + paddle::operators::TopkOpCUDAKernel); -- GitLab