提交 b36205e2 编写于 作者: L liaogang

Refine compute code in operators

上级 bfaea910
...@@ -28,6 +28,7 @@ struct EigenDim { ...@@ -28,6 +28,7 @@ struct EigenDim {
static Type From(const DDim& dims) { static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)"); PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
Type ret; Type ret;
#pragma unroll
for (int d = 0; d < arity(dims); d++) { for (int d = 0; d < arity(dims); d++) {
ret[d] = dims[d]; ret[d] = dims[d];
} }
......
...@@ -28,10 +28,13 @@ public: ...@@ -28,10 +28,13 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( auto X = EigenVector<T>::Flatten(*input0);
*(context.GetEigenDevice<Place>())) = auto Y = EigenVector<T>::Flatten(*input1);
framework::EigenVector<T>::Flatten(*input0) + auto Z = EigenVector<T>::Flatten(*output);
framework::EigenVector<T>::Flatten(*input1);
auto place = *context.GetEigenDevice<Place>();
Z.device(place) = X + Y;
} }
}; };
......
...@@ -27,8 +27,11 @@ public: ...@@ -27,8 +27,11 @@ public:
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = auto X = EigenVector<T>::Flatten(*input);
EigenVector<T>::Flatten(*input).mean(); auto y = EigenScalar<T>::From(*output);
auto place = *context.GetEigenDevice<Place>();
y.device(place) = X.mean();
} }
}; };
......
...@@ -26,13 +26,18 @@ public: ...@@ -26,13 +26,18 @@ public:
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = { Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}}; {Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) = auto X = EigenMatrix<T>::From(*input0);
EigenMatrix<T>::From(*context.Input<Tensor>("X")) auto Y = EigenMatrix<T>::From(*input1);
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")), auto Z = EigenMatrix<T>::From(*output);
dim_pair); auto place = *context.GetEigenDevice<Place>();
Z.device(place) = X.contract(Y, dim_pair);
} }
}; };
} // namespace operators } // namespace operators
......
...@@ -29,8 +29,12 @@ public: ...@@ -29,8 +29,12 @@ public:
param_out->mutable_data<T>(ctx.GetPlace()); param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) = auto p = EigenVector<T>::Flatten(*param);
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad); auto g = EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out);
auto place = *ctx.GetEigenDevice<Place>();
o.device(place) = p - lr * g;
} }
}; };
......
...@@ -27,9 +27,11 @@ public: ...@@ -27,9 +27,11 @@ public:
auto output = context.Output<Tensor>(0); auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace()); output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device( auto X = EigenVector<T>::Flatten(*input);
*(context.GetEigenDevice<Place>())) = auto Y = EigenVector<T>::Flatten(*output);
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp()); auto place = *context.GetEigenDevice<Place>();
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
} }
}; };
} // namespace operators } // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册