未验证 提交 f63ff90b 编写于 作者: D dzhwinter 提交者: GitHub

Fix/fp64 (#10346)

* "fix double type error"

* "fix ci"

* "softmax fp64"

* "fix momentum"

* "fix ci"
上级 1ae086ed
...@@ -17,6 +17,8 @@ limitations under the License. */ ...@@ -17,6 +17,8 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
using Tensor = framework::Tensor;
class MomentumOp : public framework::OperatorWithKernel { class MomentumOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
...@@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel { ...@@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim); ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim); ctx->SetOutputDim("VelocityOut", param_dim);
} }
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
}; };
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker { class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
......
...@@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel { ...@@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel {
} }
}; };
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker { class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker) ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
...@@ -47,9 +46,9 @@ Scale operator ...@@ -47,9 +46,9 @@ Scale operator
$$Out = scale*X$$ $$Out = scale*X$$
)DOC"); )DOC");
AddAttr<AttrType>("scale", AddAttr<float>("scale",
"(float, default 1.0)" "(float, default 1.0)"
"The scaling factor of the scale operator.") "The scaling factor of the scale operator.")
.SetDefault(1.0); .SetDefault(1.0);
} }
}; };
...@@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker { ...@@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>, REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker);
ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>, scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>, ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
...@@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker, ...@@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>); paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad); REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>); softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
softmax_grad, softmax_grad,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>); ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
...@@ -19,6 +19,8 @@ namespace ops = paddle::operators; ...@@ -19,6 +19,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform; namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL( REGISTER_OP_CUDA_KERNEL(
softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>, softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxKernel<plat::CUDADeviceContext, double>,
ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>); ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(softmax_grad, REGISTER_OP_CUDA_KERNEL(
ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>); softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>);
...@@ -75,4 +75,5 @@ namespace ops = paddle::operators; ...@@ -75,4 +75,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker, REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
paddle::framework::EmptyGradOpMaker); paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(top_k, REGISTER_OP_CPU_KERNEL(top_k,
ops::TopkKernel<paddle::platform::CPUPlace, float>); ops::TopkKernel<paddle::platform::CPUPlace, float>,
ops::TopkKernel<paddle::platform::CPUPlace, double>);
...@@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> { ...@@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>); REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册