未验证 提交 f1366d58 编写于 作者: C Chen Weihang 提交者: GitHub

replace contextt to context (#38619)

上级 a1275c8b
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Conj(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out); void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
} // namespace pten } // namespace pten
...@@ -18,8 +18,8 @@ limitations under the License. */ ...@@ -18,8 +18,8 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename ContextT> template <typename Context>
void Copy(const ContextT& dev_ctx, void Copy(const Context& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
bool blocking, bool blocking,
DenseTensor* dst); DenseTensor* dst);
......
...@@ -25,8 +25,8 @@ limitations under the License. */ ...@@ -25,8 +25,8 @@ limitations under the License. */
namespace pten { namespace pten {
// NOTE(chenweihang): blocking is useless in cpu kernel // NOTE(chenweihang): blocking is useless in cpu kernel
template <typename ContextT> template <typename Context>
void Copy(const ContextT& dev_ctx, void Copy(const Context& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
bool blocking, bool blocking,
DenseTensor* dst) { DenseTensor* dst) {
......
...@@ -22,8 +22,8 @@ ...@@ -22,8 +22,8 @@
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Dot(const ContextT& dev_ctx, void Dot(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
......
...@@ -18,8 +18,8 @@ ...@@ -18,8 +18,8 @@
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Dot(const ContextT& dev_ctx, void Dot(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out); DenseTensor* out);
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Flatten(const ContextT& dev_ctx, void Flatten(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
...@@ -35,14 +35,14 @@ void Flatten(const ContextT& dev_ctx, ...@@ -35,14 +35,14 @@ void Flatten(const ContextT& dev_ctx,
// TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate // TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate
// Output Tensor, // Output Tensor,
// is there a more flexible way to deal with this case? // is there a more flexible way to deal with this case?
template <typename T, typename ContextT> template <typename T, typename Context>
void FlattenWithXShape(const ContextT& dev_ctx, void FlattenWithXShape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out, DenseTensor* out,
DenseTensor* xshape) { DenseTensor* xshape) {
Flatten<T, ContextT>(dev_ctx, x, start_axis, stop_axis, out); Flatten<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
funcs::SetXShape(x, xshape); funcs::SetXShape(x, xshape);
} }
......
...@@ -18,15 +18,15 @@ limitations under the License. */ ...@@ -18,15 +18,15 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Flatten(const ContextT& dev_ctx, void Flatten(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
DenseTensor* out); DenseTensor* out);
template <typename T, typename ContextT> template <typename T, typename Context>
void FlattenWithXShape(const ContextT& dev_ctx, void FlattenWithXShape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
int start_axis, int start_axis,
int stop_axis, int stop_axis,
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Full(const ContextT& dev_ctx, void Full(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out); DenseTensor* out);
template <typename T, typename ContextT> template <typename T, typename Context>
void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out); void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out);
} // namespace pten } // namespace pten
...@@ -24,8 +24,8 @@ limitations under the License. */ ...@@ -24,8 +24,8 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename ContextT> template <typename Context>
void Copy(const ContextT& dev_ctx, void Copy(const Context& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
bool blocking, bool blocking,
DenseTensor* dst) { DenseTensor* dst) {
......
...@@ -24,8 +24,8 @@ ...@@ -24,8 +24,8 @@
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Dot(const ContextT& dev_ctx, void Dot(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const DenseTensor& y, const DenseTensor& y,
DenseTensor* out) { DenseTensor* out) {
......
...@@ -20,13 +20,13 @@ ...@@ -20,13 +20,13 @@
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Conj(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out) { void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
auto numel = x.numel(); auto numel = x.numel();
auto* x_data = x.data<T>(); auto* x_data = x.data<T>();
auto* out_data = out->mutable_data<T>(); auto* out_data = out->mutable_data<T>();
paddle::platform::ForRange<ContextT> for_range(dev_ctx, numel); paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data); paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor); for_range(functor);
} }
......
...@@ -24,24 +24,24 @@ limitations under the License. */ ...@@ -24,24 +24,24 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename DeviceContext, typename T, typename VType> template <typename Context, typename T, typename VType>
void fill_(const DeviceContext& context, DenseTensor* tensor, VType val) { void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
tensor->mutable_data<T>(); tensor->mutable_data<T>();
auto t = pten::EigenVector<T>::Flatten(*tensor); auto t = pten::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(val)); t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(val));
} }
template <typename T, typename ContextT> template <typename T, typename Context>
void Full(const ContextT& dev_ctx, void Full(const Context& dev_ctx,
const ScalarArray& shape, const ScalarArray& shape,
const Scalar& val, const Scalar& val,
DenseTensor* out) { DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData())); out->Resize(paddle::framework::make_ddim(shape.GetData()));
fill_<ContextT, T>(dev_ctx, out, val.to<T>()); FullValue<Context, T>(dev_ctx, out, val.to<T>());
} }
template <typename T, typename ContextT> template <typename T, typename Context>
void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out) { void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) {
auto value = val.to<float>(); auto value = val.to<float>();
using CommonType = typename std::common_type< using CommonType = typename std::common_type<
float, float,
...@@ -66,7 +66,7 @@ void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out) { ...@@ -66,7 +66,7 @@ void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out) {
static_cast<CommonType>(std::numeric_limits<T>::lowest()), static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()), static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value))); static_cast<float>(value)));
fill_<ContextT, T>(dev_ctx, out, value); FullValue<Context, T>(dev_ctx, out, value);
} }
} // namespace pten } // namespace pten
...@@ -23,8 +23,8 @@ limitations under the License. */ ...@@ -23,8 +23,8 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Scale(const ContextT& dev_ctx, void Scale(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const Scalar& scale, const Scalar& scale,
float bias, float bias,
......
...@@ -22,8 +22,8 @@ limitations under the License. */ ...@@ -22,8 +22,8 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Sign(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out) { void Sign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
out->mutable_data<T>(); out->mutable_data<T>();
auto eigen_out = pten::EigenVector<T>::Flatten(*out); auto eigen_out = pten::EigenVector<T>::Flatten(*out);
auto eigen_x = pten::EigenVector<T>::Flatten(x); auto eigen_x = pten::EigenVector<T>::Flatten(x);
......
...@@ -21,8 +21,8 @@ ...@@ -21,8 +21,8 @@
namespace pten { namespace pten {
template <typename ContextT> template <typename Context>
void Reshape(const ContextT& dev_ctx, void Reshape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out) { DenseTensor* out) {
...@@ -36,8 +36,8 @@ void Reshape(const ContextT& dev_ctx, ...@@ -36,8 +36,8 @@ void Reshape(const ContextT& dev_ctx,
out->ResetLoD(x.lod()); out->ResetLoD(x.lod());
} }
template <typename ContextT> template <typename Context>
void ReshapeWithXShape(const ContextT& dev_ctx, void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
......
...@@ -19,14 +19,14 @@ limitations under the License. */ ...@@ -19,14 +19,14 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename ContextT> template <typename Context>
void Reshape(const ContextT& dev_ctx, void Reshape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* out); DenseTensor* out);
template <typename ContextT> template <typename Context>
void ReshapeWithXShape(const ContextT& dev_ctx, void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const ScalarArray& shape, const ScalarArray& shape,
DenseTensor* xshape, DenseTensor* xshape,
......
...@@ -19,8 +19,8 @@ limitations under the License. */ ...@@ -19,8 +19,8 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Scale(const ContextT& dev_ctx, void Scale(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const Scalar& scale, const Scalar& scale,
float bias, float bias,
......
...@@ -18,7 +18,7 @@ limitations under the License. */ ...@@ -18,7 +18,7 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename T, typename ContextT> template <typename T, typename Context>
void Sign(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out); void Sign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
} // namespace pten } // namespace pten
...@@ -24,8 +24,8 @@ limitations under the License. */ ...@@ -24,8 +24,8 @@ limitations under the License. */
namespace pten { namespace pten {
template <typename ContextT> template <typename Context>
void Copy(const ContextT& dev_ctx, void Copy(const Context& dev_ctx,
const DenseTensor& src, const DenseTensor& src,
bool blocking, bool blocking,
DenseTensor* dst) { DenseTensor* dst) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册