未验证 提交 f63ff90b 编写于 作者: D dzhwinter 提交者: GitHub

Fix/fp64 (#10346)

* "fix double type error"

* "fix ci"

* "softmax fp64"

* "fix momentum"

* "fix ci"
上级 1ae086ed
......@@ -17,6 +17,8 @@ limitations under the License. */
namespace paddle {
namespace operators {
using Tensor = framework::Tensor;
class MomentumOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
......@@ -50,6 +52,12 @@ class MomentumOp : public framework::OperatorWithKernel {
ctx->SetOutputDim("ParamOut", param_dim);
ctx->SetOutputDim("VelocityOut", param_dim);
}
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto input_data_type =
framework::ToDataType(ctx.Input<Tensor>("Param")->type());
return framework::OpKernelType(input_data_type, ctx.GetPlace());
}
};
class MomentumOpMaker : public framework::OpProtoAndCheckerMaker {
......
......@@ -35,7 +35,6 @@ class ScaleOp : public framework::OperatorWithKernel {
}
};
template <typename AttrType>
class ScaleOpMaker : public framework::OpProtoAndCheckerMaker {
public:
ScaleOpMaker(OpProto *proto, OpAttrChecker *op_checker)
......@@ -47,7 +46,7 @@ Scale operator
$$Out = scale*X$$
)DOC");
AddAttr<AttrType>("scale",
AddAttr<float>("scale",
"(float, default 1.0)"
"The scaling factor of the scale operator.")
.SetDefault(1.0);
......@@ -73,8 +72,7 @@ class ScaleGradMaker : public framework::SingleGradOpDescMaker {
namespace ops = paddle::operators;
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker<float>,
ops::ScaleGradMaker);
REGISTER_OPERATOR(scale, ops::ScaleOp, ops::ScaleOpMaker, ops::ScaleGradMaker);
REGISTER_OP_CPU_KERNEL(
scale, ops::ScaleKernel<paddle::platform::CPUDeviceContext, float>,
ops::ScaleKernel<paddle::platform::CPUDeviceContext, double>,
......
......@@ -164,7 +164,9 @@ REGISTER_OPERATOR(softmax, ops::SoftmaxOp, ops::SoftmaxOpMaker,
paddle::framework::DefaultGradOpDescMaker<true>);
REGISTER_OPERATOR(softmax_grad, ops::SoftmaxOpGrad);
REGISTER_OP_CPU_KERNEL(
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>);
softmax, ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxKernel<paddle::platform::CPUDeviceContext, double>);
REGISTER_OP_CPU_KERNEL(
softmax_grad,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>);
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, float>,
ops::SoftmaxGradKernel<paddle::platform::CPUDeviceContext, double>);
......@@ -19,6 +19,8 @@ namespace ops = paddle::operators;
namespace plat = paddle::platform;
REGISTER_OP_CUDA_KERNEL(
softmax, ops::SoftmaxKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxKernel<plat::CUDADeviceContext, double>,
ops::SoftmaxKernel<plat::CUDADeviceContext, plat::float16>);
REGISTER_OP_CUDA_KERNEL(softmax_grad,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>);
REGISTER_OP_CUDA_KERNEL(
softmax_grad, ops::SoftmaxGradKernel<plat::CUDADeviceContext, float>,
ops::SoftmaxGradKernel<plat::CUDADeviceContext, double>);
......@@ -75,4 +75,5 @@ namespace ops = paddle::operators;
REGISTER_OPERATOR(top_k, ops::TopkOp, ops::TopkOpMaker,
paddle::framework::EmptyGradOpMaker);
REGISTER_OP_CPU_KERNEL(top_k,
ops::TopkKernel<paddle::platform::CPUPlace, float>);
ops::TopkKernel<paddle::platform::CPUPlace, float>,
ops::TopkKernel<paddle::platform::CPUPlace, double>);
......@@ -318,4 +318,5 @@ class TopkOpCUDAKernel : public framework::OpKernel<T> {
} // namespace operators
} // namespace paddle
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>);
REGISTER_OP_CUDA_KERNEL(top_k, paddle::operators::TopkOpCUDAKernel<float>,
paddle::operators::TopkOpCUDAKernel<double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册