提交 b36205e2 编写于 作者: L liaogang

Refine compute code in operators

上级 bfaea910
......@@ -28,6 +28,7 @@ struct EigenDim {
static Type From(const DDim& dims) {
PADDLE_ENFORCE(arity(dims) == D, "D must match arity(DDim)");
Type ret;
#pragma unroll
for (int d = 0; d < arity(dims); d++) {
ret[d] = dims[d];
}
......
......@@ -28,10 +28,13 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
framework::EigenVector<T>::Flatten(*input0) +
framework::EigenVector<T>::Flatten(*input1);
auto X = EigenVector<T>::Flatten(*input0);
auto Y = EigenVector<T>::Flatten(*input1);
auto Z = EigenVector<T>::Flatten(*output);
auto place = *context.GetEigenDevice<Place>();
Z.device(place) = X + Y;
}
};
......
......@@ -27,8 +27,11 @@ public:
output->mutable_data<T>(context.GetPlace());
EigenScalar<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*input).mean();
auto X = EigenVector<T>::Flatten(*input);
auto y = EigenScalar<T>::From(*output);
auto place = *context.GetEigenDevice<Place>();
y.device(place) = X.mean();
}
};
......
......@@ -26,13 +26,18 @@ public:
Eigen::array<Eigen::IndexPair<Eigen::DenseIndex>, 1> dim_pair = {
{Eigen::IndexPair<Eigen::DenseIndex>(1, 0)}};
auto input0 = context.Input<Tensor>("X");
auto input1 = context.Input<Tensor>("Y");
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenMatrix<T>::From(*output).device(*(context.GetEigenDevice<Place>())) =
EigenMatrix<T>::From(*context.Input<Tensor>("X"))
.contract(EigenMatrix<T>::From(*context.Input<Tensor>("Y")),
dim_pair);
auto X = EigenMatrix<T>::From(*input0);
auto Y = EigenMatrix<T>::From(*input1);
auto Z = EigenMatrix<T>::From(*output);
auto place = *context.GetEigenDevice<Place>();
Z.device(place) = X.contract(Y, dim_pair);
}
};
} // namespace operators
......
......@@ -29,8 +29,12 @@ public:
param_out->mutable_data<T>(ctx.GetPlace());
EigenVector<T>::Flatten(*param_out).device(*(ctx.GetEigenDevice<Place>())) =
EigenVector<T>::Flatten(*param) - lr * EigenVector<T>::Flatten(*grad);
auto p = EigenVector<T>::Flatten(*param);
auto g = EigenVector<T>::Flatten(*grad);
auto o = EigenVector<T>::Flatten(*param_out);
auto place = *ctx.GetEigenDevice<Place>();
o.device(place) = p - lr * g;
}
};
......
......@@ -27,9 +27,11 @@ public:
auto output = context.Output<Tensor>(0);
output->mutable_data<T>(context.GetPlace());
EigenVector<T>::Flatten(*output).device(
*(context.GetEigenDevice<Place>())) =
1.0 / (1.0 + (-1.0 * EigenVector<T>::Flatten(*input)).exp());
auto X = EigenVector<T>::Flatten(*input);
auto Y = EigenVector<T>::Flatten(*output);
auto place = *context.GetEigenDevice<Place>();
Y.device(place) = 1.0 / (1.0 + (-1.0 * X).exp());
}
};
} // namespace operators
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册