提交 e811c865 编写于 作者: S sweetsky0901

for epsilon dataType

上级 8a7c309d
...@@ -16,6 +16,7 @@ limitations under the License. */ ...@@ -16,6 +16,7 @@ limitations under the License. */
namespace paddle { namespace paddle {
namespace operators { namespace operators {
template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker { class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public: public:
NormOpMaker(OpProto* proto, OpAttrChecker* op_checker) NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
...@@ -28,9 +29,9 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker { ...@@ -28,9 +29,9 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Scale", AddInput("Scale",
"(Tensor) The input tensor of norm operator. " "(Tensor) The input tensor of norm operator. "
"The format of input tensor is C * 1."); "The format of input tensor is C * 1.");
AddAttr<float>("epsilon", AddAttr<AttrType>("epsilon",
"(float, default 1e-10) Constant " "(float, default 1e-10) Constant "
"for numerical stability.") "for numerical stability.")
.SetDefault(1.0e-10f); .SetDefault(1.0e-10f);
AddOutput("Out", AddOutput("Out",
"(Tensor) The output tensor of norm operator." "(Tensor) The output tensor of norm operator."
...@@ -100,7 +101,8 @@ class NormOpGrad : public framework::OperatorWithKernel { ...@@ -100,7 +101,8 @@ class NormOpGrad : public framework::OperatorWithKernel {
} // namespace paddle } // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad); REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker<float>, norm_grad,
ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL( REGISTER_OP_CPU_KERNEL(
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>, norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormKernel<paddle::platform::CPUDeviceContext, double>); ops::NormKernel<paddle::platform::CPUDeviceContext, double>);
......
...@@ -26,14 +26,14 @@ template <typename T, int MajorType = Eigen::RowMajor, ...@@ -26,14 +26,14 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex> typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>; using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T, typename AttrType = T>
class NormKernel : public framework::OpKernel<T> { class NormKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X"); const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale"); const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
auto* out = context.Output<framework::Tensor>("Out"); auto* out = context.Output<framework::Tensor>("Out");
T epsilon = context.Attr<T>("epsilon"); auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
out->mutable_data<T>(context.GetPlace()); out->mutable_data<T>(context.GetPlace());
int batch_size = in_x->dims()[0]; int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1]; int channels = in_x->dims()[1];
...@@ -82,7 +82,7 @@ class NormKernel : public framework::OpKernel<T> { ...@@ -82,7 +82,7 @@ class NormKernel : public framework::OpKernel<T> {
} }
} }
}; };
template <typename DeviceContext, typename T> template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradKernel : public framework::OpKernel<T> { class NormGradKernel : public framework::OpKernel<T> {
public: public:
void Compute(const framework::ExecutionContext& context) const override { void Compute(const framework::ExecutionContext& context) const override {
...@@ -90,7 +90,7 @@ class NormGradKernel : public framework::OpKernel<T> { ...@@ -90,7 +90,7 @@ class NormGradKernel : public framework::OpKernel<T> {
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale"); const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
const framework::Tensor* out_grad = const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out")); context.Input<framework::Tensor>(framework::GradVarName("Out"));
T epsilon = context.Attr<T>("epsilon"); auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
framework::Tensor* in_x_grad = framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X")); context.Output<framework::Tensor>(framework::GradVarName("X"));
in_x_grad->mutable_data<T>(context.GetPlace()); in_x_grad->mutable_data<T>(context.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册