提交 e811c865 编写于 作者: S sweetsky0901

for epsilon dataType

上级 8a7c309d
......@@ -16,6 +16,7 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename AttrType>
class NormOpMaker : public framework::OpProtoAndCheckerMaker {
public:
NormOpMaker(OpProto* proto, OpAttrChecker* op_checker)
......@@ -28,9 +29,9 @@ class NormOpMaker : public framework::OpProtoAndCheckerMaker {
AddInput("Scale",
"(Tensor) The input tensor of norm operator. "
"The format of input tensor is C * 1.");
AddAttr<float>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
AddAttr<AttrType>("epsilon",
"(float, default 1e-10) Constant "
"for numerical stability.")
.SetDefault(1.0e-10f);
AddOutput("Out",
"(Tensor) The output tensor of norm operator."
......@@ -100,7 +101,8 @@ class NormOpGrad : public framework::OperatorWithKernel {
} // namespace paddle
namespace ops = paddle::operators;
REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker, norm_grad, ops::NormOpGrad);
REGISTER_OP(norm, ops::NormOp, ops::NormOpMaker<float>, norm_grad,
ops::NormOpGrad);
REGISTER_OP_CPU_KERNEL(
norm, ops::NormKernel<paddle::platform::CPUDeviceContext, float>,
ops::NormKernel<paddle::platform::CPUDeviceContext, double>);
......
......@@ -26,14 +26,14 @@ template <typename T, int MajorType = Eigen::RowMajor,
typename IndexType = Eigen::DenseIndex>
using EigenMatrix = framework::EigenMatrix<T, MajorType, IndexType>;
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, typename AttrType = T>
class NormKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
const framework::Tensor* in_x = context.Input<framework::Tensor>("X");
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
auto* out = context.Output<framework::Tensor>("Out");
T epsilon = context.Attr<T>("epsilon");
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
out->mutable_data<T>(context.GetPlace());
int batch_size = in_x->dims()[0];
int channels = in_x->dims()[1];
......@@ -82,7 +82,7 @@ class NormKernel : public framework::OpKernel<T> {
}
}
};
template <typename DeviceContext, typename T>
template <typename DeviceContext, typename T, typename AttrType = T>
class NormGradKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext& context) const override {
......@@ -90,7 +90,7 @@ class NormGradKernel : public framework::OpKernel<T> {
const framework::Tensor* scale = context.Input<framework::Tensor>("Scale");
const framework::Tensor* out_grad =
context.Input<framework::Tensor>(framework::GradVarName("Out"));
T epsilon = context.Attr<T>("epsilon");
auto epsilon = static_cast<T>(context.Attr<AttrType>("epsilon"));
framework::Tensor* in_x_grad =
context.Output<framework::Tensor>(framework::GradVarName("X"));
in_x_grad->mutable_data<T>(context.GetPlace());
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册