未验证 提交 a5e73f9e 编写于 作者: Y Yu Yang 提交者: GitHub

Support many data types of several operators (#5731)

* Support many data types of several operators

* SeqConv only support float/double

* Revert adagrad
上级 f2ca07e8
...@@ -98,7 +98,6 @@ $y = \max(x, 0)$ ...@@ -98,7 +98,6 @@ $y = \max(x, 0)$
} }
}; };
template <typename AttrType>
class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
LeakyReluOpMaker(framework::OpProto *proto, LeakyReluOpMaker(framework::OpProto *proto,
...@@ -106,8 +105,7 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -106,8 +105,7 @@ class LeakyReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of LeakyRelu operator"); AddInput("X", "Input of LeakyRelu operator");
AddOutput("Y", "Output of LeakyRelu operator"); AddOutput("Y", "Output of LeakyRelu operator");
AddAttr<AttrType>("alpha", "The small negative slope") AddAttr<float>("alpha", "The small negative slope").SetDefault(0.02f);
.SetDefault(static_cast<AttrType>(0.02f));
AddComment(R"DOC( AddComment(R"DOC(
LeakyRelu Activation Operator. LeakyRelu Activation Operator.
...@@ -117,7 +115,6 @@ $y = \max(x, \alpha * x)$ ...@@ -117,7 +115,6 @@ $y = \max(x, \alpha * x)$
} }
}; };
template <typename AttrType>
class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftShrinkOpMaker(framework::OpProto *proto, SoftShrinkOpMaker(framework::OpProto *proto,
...@@ -125,8 +122,7 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -125,8 +122,7 @@ class SoftShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Softshrink operator"); AddInput("X", "Input of Softshrink operator");
AddOutput("Y", "Output of Softshrink operator"); AddOutput("Y", "Output of Softshrink operator");
AddAttr<AttrType>("lambda", "non-negative offset") AddAttr<float>("lambda", "non-negative offset").SetDefault(0.5f);
.SetDefault(static_cast<AttrType>(0.5f));
AddComment(R"DOC( AddComment(R"DOC(
Softshrink Activation Operator. Softshrink Activation Operator.
...@@ -173,7 +169,6 @@ $$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$ ...@@ -173,7 +169,6 @@ $$y = x - \frac{e^{x} - e^{-x}}{e^{x} + e^{-x}}$$
} }
}; };
template <typename AttrType>
class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardShrinkOpMaker(framework::OpProto *proto, HardShrinkOpMaker(framework::OpProto *proto,
...@@ -181,8 +176,8 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -181,8 +176,8 @@ class HardShrinkOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardShrink operator"); AddInput("X", "Input of HardShrink operator");
AddOutput("Y", "Output of HardShrink operator"); AddOutput("Y", "Output of HardShrink operator");
AddAttr<AttrType>("threshold", "The value of threshold for HardShrink") AddAttr<float>("threshold", "The value of threshold for HardShrink")
.SetDefault(static_cast<AttrType>(0.5)); .SetDefault(0.5f);
AddComment(R"DOC( AddComment(R"DOC(
HardShrink Activation Operator. HardShrink Activation Operator.
...@@ -308,17 +303,16 @@ $$y = \frac{x}{1 + |x|}$$ ...@@ -308,17 +303,16 @@ $$y = \frac{x}{1 + |x|}$$
} }
}; };
template <typename AttrType>
class BReluOpMaker : public framework::OpProtoAndCheckerMaker { class BReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) BReluOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of BRelu operator"); AddInput("X", "Input of BRelu operator");
AddOutput("Y", "Output of BRelu operator"); AddOutput("Y", "Output of BRelu operator");
AddAttr<AttrType>("t_min", "The min marginal value of BRelu") AddAttr<float>("t_min", "The min marginal value of BRelu")
.SetDefault(static_cast<AttrType>(0)); .SetDefault(static_cast<float>(0));
AddAttr<AttrType>("t_max", "The max marginal value of BRelu") AddAttr<float>("t_max", "The max marginal value of BRelu")
.SetDefault(static_cast<AttrType>(24)); .SetDefault(static_cast<float>(24));
AddComment(R"DOC( AddComment(R"DOC(
BRelu Activation Operator. BRelu Activation Operator.
...@@ -328,7 +322,6 @@ $y = \max(\min(x, t_{min}), t_{max})$ ...@@ -328,7 +322,6 @@ $y = \max(\min(x, t_{min}), t_{max})$
} }
}; };
template <typename AttrType>
class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
SoftReluOpMaker(framework::OpProto *proto, SoftReluOpMaker(framework::OpProto *proto,
...@@ -336,8 +329,8 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -336,8 +329,8 @@ class SoftReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of SoftRelu operator"); AddInput("X", "Input of SoftRelu operator");
AddOutput("Y", "Output of SoftRelu operator"); AddOutput("Y", "Output of SoftRelu operator");
AddAttr<AttrType>("threshold", "The threshold value of SoftRelu") AddAttr<float>("threshold", "The threshold value of SoftRelu")
.SetDefault(static_cast<AttrType>(40)); .SetDefault(40.0f);
AddComment(R"DOC( AddComment(R"DOC(
SoftRelu Activation Operator. SoftRelu Activation Operator.
...@@ -347,15 +340,13 @@ $y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$ ...@@ -347,15 +340,13 @@ $y = \ln(1 + \exp(\max(\min(x, threshold), threshold))$
} }
}; };
template <typename AttrType>
class ELUOpMaker : public framework::OpProtoAndCheckerMaker { class ELUOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) ELUOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ELU operator"); AddInput("X", "Input of ELU operator");
AddOutput("Y", "Output of ELU operator"); AddOutput("Y", "Output of ELU operator");
AddAttr<AttrType>("alpha", "The alpha value of ELU") AddAttr<float>("alpha", "The alpha value of ELU").SetDefault(1.0f);
.SetDefault(static_cast<AttrType>(1.0f));
AddComment(R"DOC( AddComment(R"DOC(
ELU Activation Operator. ELU Activation Operator.
...@@ -368,15 +359,14 @@ $y = \max(0, x) + \min(0, \alpha * (e^x - 1))$ ...@@ -368,15 +359,14 @@ $y = \max(0, x) + \min(0, \alpha * (e^x - 1))$
} }
}; };
template <typename AttrType>
class Relu6OpMaker : public framework::OpProtoAndCheckerMaker { class Relu6OpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) Relu6OpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Relu6 operator"); AddInput("X", "Input of Relu6 operator");
AddOutput("Y", "Output of Relu6 operator"); AddOutput("Y", "Output of Relu6 operator");
AddAttr<AttrType>("threshold", "The threshold value of Relu6") AddAttr<float>("threshold", "The threshold value of Relu6")
.SetDefault(static_cast<AttrType>(6)); .SetDefault(6.0f);
AddComment(R"DOC( AddComment(R"DOC(
Relu6 Activation Operator. Relu6 Activation Operator.
...@@ -386,15 +376,13 @@ $y = \min(\max(0, x), 6)$ ...@@ -386,15 +376,13 @@ $y = \min(\max(0, x), 6)$
} }
}; };
template <typename AttrType>
class PowOpMaker : public framework::OpProtoAndCheckerMaker { class PowOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) PowOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of Pow operator"); AddInput("X", "Input of Pow operator");
AddOutput("Y", "Output of Pow operator"); AddOutput("Y", "Output of Pow operator");
AddAttr<AttrType>("factor", "The exponential factor of Pow") AddAttr<float>("factor", "The exponential factor of Pow").SetDefault(1.0f);
.SetDefault(static_cast<AttrType>(1));
AddComment(R"DOC( AddComment(R"DOC(
Pow Activation Operator. Pow Activation Operator.
...@@ -404,17 +392,16 @@ $y = x^{factor}$ ...@@ -404,17 +392,16 @@ $y = x^{factor}$
} }
}; };
template <typename AttrType>
class STanhOpMaker : public framework::OpProtoAndCheckerMaker { class STanhOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker) STanhOpMaker(framework::OpProto *proto, framework::OpAttrChecker *op_checker)
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of STanh operator"); AddInput("X", "Input of STanh operator");
AddOutput("Y", "Output of STanh operator"); AddOutput("Y", "Output of STanh operator");
AddAttr<AttrType>("scale_a", "The scale parameter of a for the input") AddAttr<float>("scale_a", "The scale parameter of a for the input")
.SetDefault(static_cast<AttrType>(2 / 3)); .SetDefault(2.0f / 3.0f);
AddAttr<AttrType>("scale_b", "The scale parameter of b for the input") AddAttr<float>("scale_b", "The scale parameter of b for the input")
.SetDefault(static_cast<AttrType>(1.7159)); .SetDefault(1.7159f);
AddComment(R"DOC( AddComment(R"DOC(
STanh Activation Operator. STanh Activation Operator.
...@@ -424,7 +411,6 @@ $$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$ ...@@ -424,7 +411,6 @@ $$y = b * \frac{e^{a * x} - e^{-a * x}}{e^{a * x} + e^{-a * x}}$$
} }
}; };
template <typename AttrType>
class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
ThresholdedReluOpMaker(framework::OpProto *proto, ThresholdedReluOpMaker(framework::OpProto *proto,
...@@ -432,8 +418,8 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -432,8 +418,8 @@ class ThresholdedReluOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of ThresholdedRelu operator"); AddInput("X", "Input of ThresholdedRelu operator");
AddOutput("Y", "Output of ThresholdedRelu operator"); AddOutput("Y", "Output of ThresholdedRelu operator");
AddAttr<AttrType>("threshold", "The threshold location of activation") AddAttr<float>("threshold", "The threshold location of activation")
.SetDefault(static_cast<AttrType>(1.0)); .SetDefault(1.0f);
AddComment(R"DOC( AddComment(R"DOC(
ThresholdedRelu Activation Operator. ThresholdedRelu Activation Operator.
...@@ -448,7 +434,6 @@ $$ ...@@ -448,7 +434,6 @@ $$
} }
}; };
template <typename AttrType>
class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
HardSigmoidOpMaker(framework::OpProto *proto, HardSigmoidOpMaker(framework::OpProto *proto,
...@@ -456,10 +441,10 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -456,10 +441,10 @@ class HardSigmoidOpMaker : public framework::OpProtoAndCheckerMaker {
: OpProtoAndCheckerMaker(proto, op_checker) { : OpProtoAndCheckerMaker(proto, op_checker) {
AddInput("X", "Input of HardSigmoid operator"); AddInput("X", "Input of HardSigmoid operator");
AddOutput("Y", "Output of HardSigmoid operator"); AddOutput("Y", "Output of HardSigmoid operator");
AddAttr<AttrType>("slope", "Slope for linear approximation of sigmoid") AddAttr<float>("slope", "Slope for linear approximation of sigmoid")
.SetDefault(static_cast<AttrType>(0.2)); .SetDefault(0.2f);
AddAttr<AttrType>("offset", "Offset for linear approximation of sigmoid") AddAttr<float>("offset", "Offset for linear approximation of sigmoid")
.SetDefault(static_cast<AttrType>(0.5)); .SetDefault(0.5f);
AddComment(R"DOC( AddComment(R"DOC(
HardSigmoid Activation Operator. HardSigmoid Activation Operator.
...@@ -499,7 +484,7 @@ REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad, ...@@ -499,7 +484,7 @@ REGISTER_OP(tanh, ops::ActivationOp, ops::TanhOpMaker, tanh_grad,
REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker, REGISTER_OP(tanh_shrink, ops::ActivationOp, ops::TanhShrinkOpMaker,
tanh_shrink_grad, ops::ActivationOpGrad); tanh_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker<float>, REGISTER_OP(softshrink, ops::ActivationOp, ops::SoftShrinkOpMaker,
softshrink_grad, ops::ActivationOpGrad); softshrink_grad, ops::ActivationOpGrad);
REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad, REGISTER_OP(sqrt, ops::ActivationOp, ops::SqrtOpMaker, sqrt_grad,
...@@ -523,35 +508,34 @@ REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad, ...@@ -523,35 +508,34 @@ REGISTER_OP(softplus, ops::ActivationOp, ops::SoftplusOpMaker, softplus_grad,
REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad, REGISTER_OP(softsign, ops::ActivationOp, ops::SoftsignOpMaker, softsign_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker<float>, brelu_grad, REGISTER_OP(brelu, ops::ActivationOp, ops::BReluOpMaker, brelu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker<float>, REGISTER_OP(leaky_relu, ops::ActivationOp, ops::LeakyReluOpMaker,
leaky_relu_grad, ops::ActivationOpGrad); leaky_relu_grad, ops::ActivationOpGrad);
REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker<float>, REGISTER_OP(soft_relu, ops::ActivationOp, ops::SoftReluOpMaker, soft_relu_grad,
soft_relu_grad, ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker<float>, elu_grad, REGISTER_OP(elu, ops::ActivationOp, ops::ELUOpMaker, elu_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker<float>, relu6_grad, REGISTER_OP(relu6, ops::ActivationOp, ops::Relu6OpMaker, relu6_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker<float>, pow_grad, REGISTER_OP(pow, ops::ActivationOp, ops::PowOpMaker, pow_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker<float>, stanh_grad, REGISTER_OP(stanh, ops::ActivationOp, ops::STanhOpMaker, stanh_grad,
ops::ActivationOpGrad); ops::ActivationOpGrad);
REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker<float>, REGISTER_OP(hard_shrink, ops::ActivationOp, ops::HardShrinkOpMaker,
hard_shrink_grad, ops::ActivationOpGrad); hard_shrink_grad, ops::ActivationOpGrad);
REGISTER_OP(thresholded_relu, ops::ActivationOp, REGISTER_OP(thresholded_relu, ops::ActivationOp, ops::ThresholdedReluOpMaker,
ops::ThresholdedReluOpMaker<float>, thresholded_relu_grad, thresholded_relu_grad, ops::ActivationOpGrad);
ops::ActivationOpGrad);
REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker<float>, REGISTER_OP(hard_sigmoid, ops::ActivationOp, ops::HardSigmoidOpMaker,
hard_sigmoid_grad, ops::ActivationOpGrad); hard_sigmoid_grad, ops::ActivationOpGrad);
#define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \ #define REGISTER_ACTIVATION_CPU_KERNEL(act_type, functor, grad_functor) \
......
...@@ -109,4 +109,5 @@ paramOut = param + paramUpdate$$ ...@@ -109,4 +109,5 @@ paramOut = param + paramUpdate$$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adadelta, ops::AdadeltaOp, ops::AdadeltaOpMaker);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>); adadelta, ops::AdadeltaOpKernel<paddle::platform::CPUPlace, float>,
ops::AdadeltaOpKernel<paddle::platform::CPUPlace, double>);
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>); adadelta, ops::AdadeltaOpKernel<paddle::platform::GPUPlace, float>,
ops::AdadeltaOpKernel<paddle::platform::GPUPlace, double>);
...@@ -33,8 +33,8 @@ class AdadeltaOpKernel : public framework::OpKernel<T> { ...@@ -33,8 +33,8 @@ class AdadeltaOpKernel : public framework::OpKernel<T> {
avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace()); avg_squared_grad_out_tensor->mutable_data<T>(ctx.GetPlace());
avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace()); avg_squared_update_out_tensor->mutable_data<T>(ctx.GetPlace());
float rho = ctx.Attr<float>("rho"); T rho = static_cast<T>(ctx.Attr<float>("rho"));
float epsilon = ctx.Attr<float>("epsilon"); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten( auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param")); *ctx.Input<framework::Tensor>("Param"));
......
...@@ -14,8 +14,8 @@ ...@@ -14,8 +14,8 @@
#define EIGEN_USE_GPU #define EIGEN_USE_GPU
#include "paddle/operators/adagrad_op.h" #include "paddle/operators/adagrad_op.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/operators/math/math_function.h" #include "paddle/operators/math/math_function.h"
#include "paddle/operators/math/selected_rows_functor.h"
#include "paddle/platform/cuda_helper.h" #include "paddle/platform/cuda_helper.h"
namespace paddle { namespace paddle {
...@@ -134,8 +134,8 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> { ...@@ -134,8 +134,8 @@ struct SparseAdagradFunctor<platform::GPUPlace, T> {
T, 256><<<grid2, threads, 0, T, 256><<<grid2, threads, 0,
reinterpret_cast<const platform::CUDADeviceContext&>(context) reinterpret_cast<const platform::CUDADeviceContext&>(context)
.stream()>>>(grad_merge_data, grad_merge->rows().data(), .stream()>>>(grad_merge_data, grad_merge->rows().data(),
lr, param_data, lr, param_data, moment_data, grad_width,
moment_data, grad_width, epsilon); epsilon);
} }
}; };
......
...@@ -127,4 +127,5 @@ paramOut = param - learningRate * moment_1/ ($\sqrt{(moment_2)} + \epsilon)$$ ...@@ -127,4 +127,5 @@ paramOut = param - learningRate * moment_1/ ($\sqrt{(moment_2)} + \epsilon)$$
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adam, ops::AdamOp, ops::AdamOpMaker);
REGISTER_OP_CPU_KERNEL(adam, REGISTER_OP_CPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::CPUPlace, float>); ops::AdamOpKernel<paddle::platform::CPUPlace, float>,
ops::AdamOpKernel<paddle::platform::CPUPlace, double>);
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adam, REGISTER_OP_GPU_KERNEL(adam,
ops::AdamOpKernel<paddle::platform::GPUPlace, float>); ops::AdamOpKernel<paddle::platform::GPUPlace, float>,
ops::AdamOpKernel<paddle::platform::GPUPlace, double>);
...@@ -31,9 +31,9 @@ class AdamOpKernel : public framework::OpKernel<T> { ...@@ -31,9 +31,9 @@ class AdamOpKernel : public framework::OpKernel<T> {
moment1_out_tensor->mutable_data<T>(ctx.GetPlace()); moment1_out_tensor->mutable_data<T>(ctx.GetPlace());
moment2_out_tensor->mutable_data<T>(ctx.GetPlace()); moment2_out_tensor->mutable_data<T>(ctx.GetPlace());
float beta1 = ctx.Attr<float>("beta1"); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
float beta2 = ctx.Attr<float>("beta2"); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
float epsilon = ctx.Attr<float>("epsilon"); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten( auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param")); *ctx.Input<framework::Tensor>("Param"));
......
...@@ -126,4 +126,5 @@ division by 0 error. ...@@ -126,4 +126,5 @@ division by 0 error.
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker); REGISTER_OP_WITHOUT_GRADIENT(adamax, ops::AdamaxOp, ops::AdamaxOpMaker);
REGISTER_OP_CPU_KERNEL(adamax, REGISTER_OP_CPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>); ops::AdamaxOpKernel<paddle::platform::CPUPlace, float>,
ops::AdamaxOpKernel<paddle::platform::CPUPlace, double>);
...@@ -17,4 +17,5 @@ ...@@ -17,4 +17,5 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL(adamax, REGISTER_OP_GPU_KERNEL(adamax,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>); ops::AdamaxOpKernel<paddle::platform::GPUPlace, float>,
ops::AdamaxOpKernel<paddle::platform::GPUPlace, double>);
...@@ -31,9 +31,9 @@ class AdamaxOpKernel : public framework::OpKernel<T> { ...@@ -31,9 +31,9 @@ class AdamaxOpKernel : public framework::OpKernel<T> {
moment_out_tensor->mutable_data<T>(ctx.GetPlace()); moment_out_tensor->mutable_data<T>(ctx.GetPlace());
inf_norm_out_tensor->mutable_data<T>(ctx.GetPlace()); inf_norm_out_tensor->mutable_data<T>(ctx.GetPlace());
float beta1 = ctx.Attr<float>("beta1"); T beta1 = static_cast<T>(ctx.Attr<float>("beta1"));
float beta2 = ctx.Attr<float>("beta2"); T beta2 = static_cast<T>(ctx.Attr<float>("beta2"));
float epsilon = ctx.Attr<float>("epsilon"); T epsilon = static_cast<T>(ctx.Attr<float>("epsilon"));
auto param = framework::EigenVector<T>::Flatten( auto param = framework::EigenVector<T>::Flatten(
*ctx.Input<framework::Tensor>("Param")); *ctx.Input<framework::Tensor>("Param"));
......
...@@ -179,7 +179,9 @@ REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker, ...@@ -179,7 +179,9 @@ REGISTER_OP(sequence_conv, ops::SequenceConvOp, ops::SequenceConvOpMaker,
sequence_conv_grad, ops::SequenceConvGradOp); sequence_conv_grad, ops::SequenceConvGradOp);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>); sequence_conv, ops::SequenceConvKernel<paddle::platform::CPUPlace, float>,
ops::SequenceConvKernel<paddle::platform::CPUPlace, double>);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
sequence_conv_grad, sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>); ops::SequenceConvGradKernel<paddle::platform::CPUPlace, float>,
ops::SequenceConvGradKernel<paddle::platform::CPUPlace, double>);
...@@ -16,7 +16,9 @@ ...@@ -16,7 +16,9 @@
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>); sequence_conv, ops::SequenceConvKernel<paddle::platform::GPUPlace, float>,
ops::SequenceConvKernel<paddle::platform::GPUPlace, double>);
REGISTER_OP_GPU_KERNEL( REGISTER_OP_GPU_KERNEL(
sequence_conv_grad, sequence_conv_grad,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>); ops::SequenceConvGradKernel<paddle::platform::GPUPlace, float>,
ops::SequenceConvGradKernel<paddle::platform::GPUPlace, double>);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册