未验证 提交 f1366d58 编写于 作者: C Chen Weihang 提交者: GitHub

replace contextt to context (#38619)

上级 a1275c8b
......@@ -18,7 +18,7 @@ limitations under the License. */
namespace pten {
template <typename T, typename ContextT>
void Conj(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
} // namespace pten
......@@ -18,8 +18,8 @@ limitations under the License. */
namespace pten {
template <typename ContextT>
void Copy(const ContextT& dev_ctx,
template <typename Context>
void Copy(const Context& dev_ctx,
const DenseTensor& src,
bool blocking,
DenseTensor* dst);
......
......@@ -25,8 +25,8 @@ limitations under the License. */
namespace pten {
// NOTE(chenweihang): blocking is useless in cpu kernel
template <typename ContextT>
void Copy(const ContextT& dev_ctx,
template <typename Context>
void Copy(const Context& dev_ctx,
const DenseTensor& src,
bool blocking,
DenseTensor* dst) {
......
......@@ -22,8 +22,8 @@
namespace pten {
template <typename T, typename ContextT>
void Dot(const ContextT& dev_ctx,
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
......
......@@ -18,8 +18,8 @@
namespace pten {
template <typename T, typename ContextT>
void Dot(const ContextT& dev_ctx,
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out);
......
......@@ -21,8 +21,8 @@
namespace pten {
template <typename T, typename ContextT>
void Flatten(const ContextT& dev_ctx,
template <typename T, typename Context>
void Flatten(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
......@@ -35,14 +35,14 @@ void Flatten(const ContextT& dev_ctx,
// TODO(yuanrisheng): this kernel is for training and xshape is a Intermediate
// Output Tensor,
// is there a more flexible way to deal with this case?
template <typename T, typename ContextT>
void FlattenWithXShape(const ContextT& dev_ctx,
template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out,
DenseTensor* xshape) {
Flatten<T, ContextT>(dev_ctx, x, start_axis, stop_axis, out);
Flatten<T, Context>(dev_ctx, x, start_axis, stop_axis, out);
funcs::SetXShape(x, xshape);
}
......
......@@ -18,15 +18,15 @@ limitations under the License. */
namespace pten {
template <typename T, typename ContextT>
void Flatten(const ContextT& dev_ctx,
template <typename T, typename Context>
void Flatten(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
DenseTensor* out);
template <typename T, typename ContextT>
void FlattenWithXShape(const ContextT& dev_ctx,
template <typename T, typename Context>
void FlattenWithXShape(const Context& dev_ctx,
const DenseTensor& x,
int start_axis,
int stop_axis,
......
......@@ -20,13 +20,13 @@
namespace pten {
template <typename T, typename ContextT>
void Full(const ContextT& dev_ctx,
template <typename T, typename Context>
void Full(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out);
template <typename T, typename ContextT>
void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out);
template <typename T, typename Context>
void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out);
} // namespace pten
......@@ -24,8 +24,8 @@ limitations under the License. */
namespace pten {
template <typename ContextT>
void Copy(const ContextT& dev_ctx,
template <typename Context>
void Copy(const Context& dev_ctx,
const DenseTensor& src,
bool blocking,
DenseTensor* dst) {
......
......@@ -24,8 +24,8 @@
namespace pten {
template <typename T, typename ContextT>
void Dot(const ContextT& dev_ctx,
template <typename T, typename Context>
void Dot(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
DenseTensor* out) {
......
......@@ -20,13 +20,13 @@
namespace pten {
template <typename T, typename ContextT>
void Conj(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out) {
template <typename T, typename Context>
void Conj(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
auto numel = x.numel();
auto* x_data = x.data<T>();
auto* out_data = out->mutable_data<T>();
paddle::platform::ForRange<ContextT> for_range(dev_ctx, numel);
paddle::platform::ForRange<Context> for_range(dev_ctx, numel);
paddle::operators::math::ConjFunctor<T> functor(x_data, numel, out_data);
for_range(functor);
}
......
......@@ -24,24 +24,24 @@ limitations under the License. */
namespace pten {
template <typename DeviceContext, typename T, typename VType>
void fill_(const DeviceContext& context, DenseTensor* tensor, VType val) {
template <typename Context, typename T, typename VType>
void FullValue(const Context& dev_ctx, DenseTensor* tensor, VType val) {
tensor->mutable_data<T>();
auto t = pten::EigenVector<T>::Flatten(*tensor);
t.device(*context.eigen_device()) = t.constant(static_cast<T>(val));
t.device(*dev_ctx.eigen_device()) = t.constant(static_cast<T>(val));
}
template <typename T, typename ContextT>
void Full(const ContextT& dev_ctx,
template <typename T, typename Context>
void Full(const Context& dev_ctx,
const ScalarArray& shape,
const Scalar& val,
DenseTensor* out) {
out->Resize(paddle::framework::make_ddim(shape.GetData()));
fill_<ContextT, T>(dev_ctx, out, val.to<T>());
FullValue<Context, T>(dev_ctx, out, val.to<T>());
}
template <typename T, typename ContextT>
void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out) {
template <typename T, typename Context>
void FullLike(const Context& dev_ctx, const Scalar& val, DenseTensor* out) {
auto value = val.to<float>();
using CommonType = typename std::common_type<
float,
......@@ -66,7 +66,7 @@ void FullLike(const ContextT& dev_ctx, const Scalar& val, DenseTensor* out) {
static_cast<CommonType>(std::numeric_limits<T>::lowest()),
static_cast<CommonType>(std::numeric_limits<T>::max()),
static_cast<float>(value)));
fill_<ContextT, T>(dev_ctx, out, value);
FullValue<Context, T>(dev_ctx, out, value);
}
} // namespace pten
......@@ -23,8 +23,8 @@ limitations under the License. */
namespace pten {
template <typename T, typename ContextT>
void Scale(const ContextT& dev_ctx,
template <typename T, typename Context>
void Scale(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
......
......@@ -22,8 +22,8 @@ limitations under the License. */
namespace pten {
template <typename T, typename ContextT>
void Sign(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out) {
template <typename T, typename Context>
void Sign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out) {
out->mutable_data<T>();
auto eigen_out = pten::EigenVector<T>::Flatten(*out);
auto eigen_x = pten::EigenVector<T>::Flatten(x);
......
......@@ -21,8 +21,8 @@
namespace pten {
template <typename ContextT>
void Reshape(const ContextT& dev_ctx,
template <typename Context>
void Reshape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out) {
......@@ -36,8 +36,8 @@ void Reshape(const ContextT& dev_ctx,
out->ResetLoD(x.lod());
}
template <typename ContextT>
void ReshapeWithXShape(const ContextT& dev_ctx,
template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
......
......@@ -19,14 +19,14 @@ limitations under the License. */
namespace pten {
template <typename ContextT>
void Reshape(const ContextT& dev_ctx,
template <typename Context>
void Reshape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* out);
template <typename ContextT>
void ReshapeWithXShape(const ContextT& dev_ctx,
template <typename Context>
void ReshapeWithXShape(const Context& dev_ctx,
const DenseTensor& x,
const ScalarArray& shape,
DenseTensor* xshape,
......
......@@ -19,8 +19,8 @@ limitations under the License. */
namespace pten {
template <typename T, typename ContextT>
void Scale(const ContextT& dev_ctx,
template <typename T, typename Context>
void Scale(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
......
......@@ -18,7 +18,7 @@ limitations under the License. */
namespace pten {
template <typename T, typename ContextT>
void Sign(const ContextT& dev_ctx, const DenseTensor& x, DenseTensor* out);
template <typename T, typename Context>
void Sign(const Context& dev_ctx, const DenseTensor& x, DenseTensor* out);
} // namespace pten
......@@ -24,8 +24,8 @@ limitations under the License. */
namespace pten {
template <typename ContextT>
void Copy(const ContextT& dev_ctx,
template <typename Context>
void Copy(const Context& dev_ctx,
const DenseTensor& src,
bool blocking,
DenseTensor* dst) {
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册