未验证 提交 1053b1d5 编写于 作者: C Chen Weihang 提交者: GitHub

replace last contextT (#38971)

上级 88966b28
......@@ -43,8 +43,8 @@ struct ScaleFunctor {
}
};
template <typename T, typename ContextT>
void ScaleKernel(const ContextT& dev_ctx,
template <typename T, typename Context>
void ScaleKernel(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
......
......@@ -67,8 +67,8 @@ void SumKernel(const Context& dev_ctx,
DataType out_dtype,
DenseTensor* out);
template <typename T, typename ContextT>
DenseTensor Add(const ContextT& dev_ctx,
template <typename T, typename Context>
DenseTensor Add(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
......@@ -77,12 +77,12 @@ DenseTensor Add(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
AddKernel<T, ContextT>(dev_ctx, x, y, axis, &dense_out);
AddKernel<T, Context>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Subtract(const ContextT& dev_ctx,
template <typename T, typename Context>
DenseTensor Subtract(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
......@@ -91,12 +91,12 @@ DenseTensor Subtract(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
SubtractKernel<T, ContextT>(dev_ctx, x, y, axis, &dense_out);
SubtractKernel<T, Context>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Divide(const ContextT& dev_ctx,
template <typename T, typename Context>
DenseTensor Divide(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
......@@ -105,12 +105,12 @@ DenseTensor Divide(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
DivideKernel<T, ContextT>(dev_ctx, x, y, axis, &dense_out);
DivideKernel<T, Context>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
template <typename T, typename ContextT>
DenseTensor Multiply(const ContextT& dev_ctx,
template <typename T, typename Context>
DenseTensor Multiply(const Context& dev_ctx,
const DenseTensor& x,
const DenseTensor& y,
int axis) {
......@@ -119,7 +119,7 @@ DenseTensor Multiply(const ContextT& dev_ctx,
pten::make_intrusive<paddle::experimental::SharedStorage>(
dev_ctx.GetPlace()),
std::move(out_meta));
MultiplyKernel<T, ContextT>(dev_ctx, x, y, axis, &dense_out);
MultiplyKernel<T, Context>(dev_ctx, x, y, axis, &dense_out);
return dense_out;
}
......
......@@ -28,15 +28,15 @@ void ScaleKernel(const Context& dev_ctx,
bool bias_after_scale,
DenseTensor* out);
template <typename T, typename ContextT>
DenseTensor Scale(const ContextT& dev_ctx,
template <typename T, typename Context>
DenseTensor Scale(const Context& dev_ctx,
const DenseTensor& x,
const Scalar& scale,
float bias,
bool bias_after_scale) {
auto out_meta = UnchangedInferMeta(x.meta());
auto dense_out = pten::Empty<T, ContextT>(dev_ctx, std::move(out_meta));
ScaleKernel<T, ContextT>(
auto dense_out = pten::Empty<T, Context>(dev_ctx, std::move(out_meta));
ScaleKernel<T, Context>(
dev_ctx, x, scale, bias, bias_after_scale, &dense_out);
return dense_out;
}
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册