未验证 提交 d9a41fc4 编写于 作者: H hong 提交者: GitHub

Change bn muable data to phi (#40748)

* move mutable_data to context alloc

* move mutable_data to context alloc

* remvoe duplicate code
上级 71b813f0
...@@ -36,8 +36,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context, ...@@ -36,8 +36,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
in_dims_vec[3] = input->dims()[2]; in_dims_vec[3] = input->dims()[2];
in_dims_vec[4] = input->dims()[3]; in_dims_vec[4] = input->dims()[3];
transformed_input->Resize(make_ddim(in_dims_vec)); transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace()); context.template Alloc<T>(transformed_input);
} else if (dim == 2) { } else if (dim == 2) {
// input // input
transformed_input->Resize(input->dims()); transformed_input->Resize(input->dims());
...@@ -47,7 +46,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context, ...@@ -47,7 +46,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
in_dims_vec[2] = input->dims()[1]; in_dims_vec[2] = input->dims()[1];
in_dims_vec[3] = input->dims()[2]; in_dims_vec[3] = input->dims()[2];
transformed_input->Resize(make_ddim(in_dims_vec)); transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace()); context.template Alloc<T>(transformed_input);
} else if (dim == 1) { } else if (dim == 1) {
transformed_input->Resize(input->dims()); transformed_input->Resize(input->dims());
...@@ -55,7 +54,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context, ...@@ -55,7 +54,7 @@ inline void ResizeToChannelFirst(const DeviceContext& context,
in_dims_vec[1] = input->dims()[2]; in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1]; in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec)); transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace()); context.template Alloc<T>(transformed_input);
} }
} }
...@@ -74,7 +73,7 @@ inline void ResizeToChannelLast(const DeviceContext& context, ...@@ -74,7 +73,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
in_dims_vec[3] = input->dims()[4]; in_dims_vec[3] = input->dims()[4];
in_dims_vec[4] = input->dims()[1]; in_dims_vec[4] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec)); transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace()); context.template Alloc<T>(transformed_input);
} else if (dim == 2) { } else if (dim == 2) {
// input // input
...@@ -85,7 +84,7 @@ inline void ResizeToChannelLast(const DeviceContext& context, ...@@ -85,7 +84,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
in_dims_vec[2] = input->dims()[3]; in_dims_vec[2] = input->dims()[3];
in_dims_vec[3] = input->dims()[1]; in_dims_vec[3] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec)); transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace()); context.template Alloc<T>(transformed_input);
} else if (dim == 1) { } else if (dim == 1) {
transformed_input->Resize(input->dims()); transformed_input->Resize(input->dims());
...@@ -93,7 +92,7 @@ inline void ResizeToChannelLast(const DeviceContext& context, ...@@ -93,7 +92,7 @@ inline void ResizeToChannelLast(const DeviceContext& context,
in_dims_vec[1] = input->dims()[2]; in_dims_vec[1] = input->dims()[2];
in_dims_vec[2] = input->dims()[1]; in_dims_vec[2] = input->dims()[1];
transformed_input->Resize(make_ddim(in_dims_vec)); transformed_input->Resize(make_ddim(in_dims_vec));
transformed_input->mutable_data<T>(context.GetPlace()); context.template Alloc<T>(transformed_input);
} }
} }
......
...@@ -359,8 +359,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -359,8 +359,8 @@ void BatchNormGradRawKernel(const Context &ctx,
} }
if (d_scale && d_bias) { if (d_scale && d_bias) {
d_scale->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); ctx.template Alloc<BatchNormParamType<T>>(d_scale);
d_bias->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); ctx.template Alloc<BatchNormParamType<T>>(d_bias);
} }
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
...@@ -569,8 +569,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -569,8 +569,8 @@ void BatchNormGradRawKernel(const Context &ctx,
/*activationDesc=*/nullptr, /*activationDesc=*/nullptr,
/*sizeInBytes=*/&workspace_size)); /*sizeInBytes=*/&workspace_size));
workspace_ptr = workspace_tensor.mutable_data( workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
ctx.GetPlace(), transformed_x.type(), workspace_size); workspace_ptr = ctx.template Alloc<T>(&workspace_tensor);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationBackwardEx( paddle::platform::dynload::cudnnBatchNormalizationBackwardEx(
...@@ -594,12 +594,9 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -594,12 +594,9 @@ void BatchNormGradRawKernel(const Context &ctx,
/*dBnScaleBiasDesc=*/bn_param_desc_, /*dBnScaleBiasDesc=*/bn_param_desc_,
/*bnScaleData=*/scale.template data<BatchNormParamType<T>>(), /*bnScaleData=*/scale.template data<BatchNormParamType<T>>(),
/*bnBiasData=*/nullptr, /*bnBiasData=*/nullptr,
/*dBnScaleData=*/d_scale /*dBnScaleData=*/ctx.template Alloc<BatchNormParamType<T>>(
->template mutable_data<BatchNormParamType<T>>( d_scale),
ctx.GetPlace()), /*dBnBiasData=*/ctx.template Alloc<BatchNormParamType<T>>(d_bias),
/*dBnBiasData=*/d_bias
->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
/*epsilon=*/epsilon, /*epsilon=*/epsilon,
/*savedMean=*/saved_mean_data, /*savedMean=*/saved_mean_data,
/*savedInvVariance=*/saved_var_data, /*savedInvVariance=*/saved_var_data,
...@@ -626,10 +623,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -626,10 +623,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D, H * W * D,
epsilon, epsilon,
transformed_d_x.template data<T>(), transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(d_bias));
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
} else { } else {
BNBackward<T, BNBackward<T,
block, block,
...@@ -644,10 +639,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -644,10 +639,8 @@ void BatchNormGradRawKernel(const Context &ctx,
H * W * D, H * W * D,
epsilon, epsilon,
transformed_d_x.template data<T>(), transformed_d_x.template data<T>(),
d_scale->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(d_bias));
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()));
} }
// TODO(wangran16): wait for MIOpen to improve the performance of BN // TODO(wangran16): wait for MIOpen to improve the performance of BN
...@@ -682,10 +675,8 @@ void BatchNormGradRawKernel(const Context &ctx, ...@@ -682,10 +675,8 @@ void BatchNormGradRawKernel(const Context &ctx,
ctx.template Alloc<T>(&transformed_d_x), ctx.template Alloc<T>(&transformed_d_x),
bn_param_desc_, bn_param_desc_,
scale.template data<BatchNormParamType<T>>(), scale.template data<BatchNormParamType<T>>(),
d_scale->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(d_scale),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(d_bias),
d_bias->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, epsilon,
saved_mean_data, saved_mean_data,
saved_var_data)); saved_var_data));
......
...@@ -439,11 +439,11 @@ void BatchNormKernel(const Context &ctx, ...@@ -439,11 +439,11 @@ void BatchNormKernel(const Context &ctx,
// Run training mode. // Run training mode.
// obtain running mean and running inv var, and there is no need // obtain running mean and running inv var, and there is no need
// to initialize them. // to initialize them.
mean_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); ctx.template Alloc<BatchNormParamType<T>>(mean_out);
variance_out->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); ctx.template Alloc<BatchNormParamType<T>>(variance_out);
saved_mean->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); ctx.template Alloc<BatchNormParamType<T>>(saved_mean);
saved_variance->mutable_data<BatchNormParamType<T>>(ctx.GetPlace()); ctx.template Alloc<BatchNormParamType<T>>(saved_variance);
if ((N * H * W * D) == 1) { if ((N * H * W * D) == 1) {
// Only 1 element in normalization dimension, // Only 1 element in normalization dimension,
...@@ -497,10 +497,10 @@ void BatchNormKernel(const Context &ctx, ...@@ -497,10 +497,10 @@ void BatchNormKernel(const Context &ctx,
/*xDesc=*/data_desc_, /*xDesc=*/data_desc_,
/*sizeInBytes=*/&reserve_space_size)); /*sizeInBytes=*/&reserve_space_size));
reserve_space_ptr = reserve_space->mutable_data( reserve_space->Resize({static_cast<int64_t>(reserve_space_size)});
ctx.GetPlace(), transformed_x.type(), reserve_space_size); reserve_space_ptr = ctx.template Alloc<T>(reserve_space);
workspace_ptr = workspace_tensor.mutable_data( workspace_tensor.Resize({static_cast<int64_t>(workspace_size)});
ctx.GetPlace(), transformed_x.type(), workspace_size); workspace_ptr = ctx.template Alloc<T>(&workspace_tensor);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx( paddle::platform::dynload::cudnnBatchNormalizationForwardTrainingEx(
handle, handle,
...@@ -518,15 +518,11 @@ void BatchNormKernel(const Context &ctx, ...@@ -518,15 +518,11 @@ void BatchNormKernel(const Context &ctx,
scale.template data<BatchNormParamType<T>>(), scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(), bias.template data<BatchNormParamType<T>>(),
this_factor, this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(mean_out),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(variance_out),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(saved_variance),
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
nullptr, nullptr,
workspace_ptr, workspace_ptr,
workspace_size, workspace_size,
...@@ -621,15 +617,11 @@ void BatchNormKernel(const Context &ctx, ...@@ -621,15 +617,11 @@ void BatchNormKernel(const Context &ctx,
scale.template data<BatchNormParamType<T>>(), scale.template data<BatchNormParamType<T>>(),
bias.template data<BatchNormParamType<T>>(), bias.template data<BatchNormParamType<T>>(),
this_factor, this_factor,
mean_out->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(mean_out),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(variance_out),
variance_out->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace()),
epsilon, epsilon,
saved_mean->template mutable_data<BatchNormParamType<T>>( ctx.template Alloc<BatchNormParamType<T>>(saved_mean),
ctx.GetPlace()), ctx.template Alloc<BatchNormParamType<T>>(saved_variance)));
saved_variance->template mutable_data<BatchNormParamType<T>>(
ctx.GetPlace())));
#endif #endif
} }
} }
......
...@@ -71,15 +71,15 @@ void ConvCudnnGradGradKernel( ...@@ -71,15 +71,15 @@ void ConvCudnnGradGradKernel(
auto dW = filter_grad; auto dW = filter_grad;
auto dX = input_grad; auto dX = input_grad;
if (ddO) { if (ddO) {
ddO->mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(ddO);
phi::funcs::SetConstant<Context, T> set_zero; phi::funcs::SetConstant<Context, T> set_zero;
set_zero(ctx, ddO, static_cast<T>(0)); set_zero(ctx, ddO, static_cast<T>(0));
} }
if (dW) { if (dW) {
dW->mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(dW);
} }
if (dX) { if (dX) {
dX->mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(dX);
} }
// const T* x = X->data<T>(); // const T* x = X->data<T>();
...@@ -131,7 +131,7 @@ void ConvCudnnGradGradKernel( ...@@ -131,7 +131,7 @@ void ConvCudnnGradGradKernel(
} }
if (dX) { if (dX) {
ResizeToChannelFirst<Context, T>(ctx, dX, &transformed_dX_channel); ResizeToChannelFirst<Context, T>(ctx, dX, &transformed_dX_channel);
transformed_dX_channel.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_dX_channel);
} }
} else { } else {
...@@ -186,13 +186,13 @@ void ConvCudnnGradGradKernel( ...@@ -186,13 +186,13 @@ void ConvCudnnGradGradKernel(
transformed_ddX.Resize(new_input_shape); transformed_ddX.Resize(new_input_shape);
transformed_dX.Resize(new_input_shape); transformed_dX.Resize(new_input_shape);
transformed_X.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_X);
if (ddX) { if (ddX) {
transformed_ddX.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_ddX);
} }
if (dX) { if (dX) {
transformed_dX.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_dX);
} }
// pad for input // pad for input
......
...@@ -58,10 +58,10 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -58,10 +58,10 @@ void ConvCudnnGradKernel(const Context& ctx,
DenseTensor* input_grad, DenseTensor* input_grad,
DenseTensor* filter_grad) { DenseTensor* filter_grad) {
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(input_grad);
} }
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(filter_grad);
} }
std::vector<int> dilations = dilations_t; std::vector<int> dilations = dilations_t;
...@@ -204,12 +204,12 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -204,12 +204,12 @@ void ConvCudnnGradKernel(const Context& ctx,
} }
DDim new_input_shape(make_ddim(new_input_shape_vec)); DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape); transformed_input.Resize(new_input_shape);
transformed_input.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_input);
transformed_input_grad.Resize(new_input_shape); transformed_input_grad.Resize(new_input_shape);
if (input_grad) { if (input_grad) {
transformed_input_grad.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_input_grad);
} }
// pad for input // pad for input
const int rank = transformed_input_channel.dims().size(); const int rank = transformed_input_channel.dims().size();
...@@ -427,7 +427,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -427,7 +427,7 @@ void ConvCudnnGradKernel(const Context& ctx,
if (use_addto) { if (use_addto) {
DenseTensor temp_tensor(transformed_input_grad.type()); DenseTensor temp_tensor(transformed_input_grad.type());
temp_tensor.Resize(transformed_input_grad.dims()); temp_tensor.Resize(transformed_input_grad.dims());
T* temp_tensor_data = temp_tensor.mutable_data<T>(ctx.GetPlace()); T* temp_tensor_data = ctx.template Alloc<T>(&temp_tensor);
workspace_handle.RunFunc( workspace_handle.RunFunc(
[&](void* cudnn_workspace_ptr) { [&](void* cudnn_workspace_ptr) {
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -513,7 +513,7 @@ void ConvCudnnGradKernel(const Context& ctx, ...@@ -513,7 +513,7 @@ void ConvCudnnGradKernel(const Context& ctx,
axes[i] = i; axes[i] = i;
} }
transformed_input_grad_channel.mutable_data(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_input_grad_channel);
if (transformed_input_channel.dims().size() == 4) { if (transformed_input_channel.dims().size() == 4) {
paddle::operators::RemovePaddingSlice<Context, T, 4>( paddle::operators::RemovePaddingSlice<Context, T, 4>(
ctx, ctx,
......
...@@ -54,7 +54,7 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -54,7 +54,7 @@ void ConvCudnnKernel(const Context& ctx,
int workspace_size_MB, int workspace_size_MB,
bool exhaustive_search_t, bool exhaustive_search_t,
DenseTensor* output) { DenseTensor* output) {
output->mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(output);
std::vector<int> paddings = paddings_t; std::vector<int> paddings = paddings_t;
std::vector<int> dilations = dilations_t; std::vector<int> dilations = dilations_t;
...@@ -170,7 +170,7 @@ void ConvCudnnKernel(const Context& ctx, ...@@ -170,7 +170,7 @@ void ConvCudnnKernel(const Context& ctx,
} }
DDim new_input_shape(make_ddim(new_input_shape_vec)); DDim new_input_shape(make_ddim(new_input_shape_vec));
transformed_input.Resize(new_input_shape); transformed_input.Resize(new_input_shape);
transformed_input.mutable_data<T>(ctx.GetPlace()); ctx.template Alloc<T>(&transformed_input);
const int rank = transformed_input_channel.dims().size(); const int rank = transformed_input_channel.dims().size();
T pad_value(0.0); T pad_value(0.0);
......
...@@ -129,7 +129,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -129,7 +129,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
DenseTensor col_matrix; DenseTensor col_matrix;
if (is_expand) { if (is_expand) {
col.Resize(col_shape); col.Resize(col_shape);
col.mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
...@@ -143,7 +143,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -143,7 +143,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
if (dX && ddW_in) { if (dX && ddW_in) {
Tensor ddW; Tensor ddW;
ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape); ddW.ShareDataWith(*ddW_in).Resize(filter_matrix_shape);
dX->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(dX);
DenseTensor transformed_dX(dX->type()); DenseTensor transformed_dX(dX->type());
...@@ -201,7 +201,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -201,7 +201,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
// oH, oW) // oH, oW)
// dw convolution double grad: im2col(vol2col) + gemm // dw convolution double grad: im2col(vol2col) + gemm
if (dW && ddX) { if (dW && ddX) {
dW->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(dW);
set_zero(dev_ctx, dW, static_cast<T>(0)); set_zero(dev_ctx, dW, static_cast<T>(0));
DenseTensor dW_arr = *dW; DenseTensor dW_arr = *dW;
dW_arr.Resize(filter_matrix_shape); dW_arr.Resize(filter_matrix_shape);
...@@ -244,7 +244,7 @@ void ConvGradGradKernel(const Context& dev_ctx, ...@@ -244,7 +244,7 @@ void ConvGradGradKernel(const Context& dev_ctx,
// w/ddw(Cout, Cin, kh, kw) // w/ddw(Cout, Cin, kh, kw)
// ddy convolution double grad: im2col(vol2col) + gemm // ddy convolution double grad: im2col(vol2col) + gemm
if (ddY) { if (ddY) {
ddY->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(ddY);
DenseTensor transformed_ddY(ddY->type()); DenseTensor transformed_ddY(ddY->type());
if (channel_last) { if (channel_last) {
......
...@@ -128,7 +128,7 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -128,7 +128,7 @@ void ConvGradKernel(const Context& dev_ctx,
DenseTensor col_matrix; DenseTensor col_matrix;
if (is_expand) { if (is_expand) {
col.Resize(col_shape); col.Resize(col_shape);
col.mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
...@@ -137,7 +137,7 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -137,7 +137,7 @@ void ConvGradKernel(const Context& dev_ctx,
auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx); auto blas = phi::funcs::GetBlas<Context, T>(dev_ctx);
if (input_grad) { if (input_grad) {
input_grad->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(input_grad);
DenseTensor transformed_input_grad(input_grad->type()); DenseTensor transformed_input_grad(input_grad->type());
if (channel_last) { if (channel_last) {
ResizeToChannelFirst<Context, T>( ResizeToChannelFirst<Context, T>(
...@@ -203,7 +203,7 @@ void ConvGradKernel(const Context& dev_ctx, ...@@ -203,7 +203,7 @@ void ConvGradKernel(const Context& dev_ctx,
} }
if (filter_grad) { if (filter_grad) {
filter_grad->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(filter_grad);
Tensor filter_grad_ = *filter_grad; Tensor filter_grad_ = *filter_grad;
filter_grad_.Resize(filter_matrix_shape); filter_grad_.Resize(filter_matrix_shape);
set_zero(dev_ctx, filter_grad, static_cast<T>(0)); set_zero(dev_ctx, filter_grad, static_cast<T>(0));
......
...@@ -44,7 +44,7 @@ void ConvKernel(const Context& dev_ctx, ...@@ -44,7 +44,7 @@ void ConvKernel(const Context& dev_ctx,
// The filter will be reshaped in the calculations, // The filter will be reshaped in the calculations,
// so here use an assignment operation, // so here use an assignment operation,
// that avoids modifying the variable in the Scope. // that avoids modifying the variable in the Scope.
output->mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(output);
const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC"); const bool channel_last = (data_format == "NHWC" || data_format == "NDHWC");
...@@ -115,7 +115,7 @@ void ConvKernel(const Context& dev_ctx, ...@@ -115,7 +115,7 @@ void ConvKernel(const Context& dev_ctx,
if (is_expand) { if (is_expand) {
// col = context.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx); // col = context.AllocateTmpTensor<T, DeviceContext>(col_shape, dev_ctx);
col.Resize(col_shape); col.Resize(col_shape);
col.mutable_data<T>(dev_ctx.GetPlace()); dev_ctx.template Alloc<T>(&col);
col_matrix.ShareDataWith(col); col_matrix.ShareDataWith(col);
col_matrix.Resize(col_matrix_shape); col_matrix.Resize(col_matrix_shape);
} }
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册