未验证 提交 657c69bc 编写于 作者: W Wilber 提交者: GitHub

Fix fused cuda op's mutable data [3] (#45564)

上级 9034ca70
......@@ -23,17 +23,22 @@ limitations under the License. */
namespace paddle {
namespace operators {
template <typename DeviceContext>
static void MutableMultiTypeData(
std::vector<paddle::framework::LoDTensor*>* var,
const std::vector<int>& data_type,
const DeviceContext& dev_ctx,
const platform::Place& place) {
for (size_t i = 0; i < var->size(); i++) {
if (data_type[i] == framework::proto::VarType::FP32) {
(*var)[i]->mutable_data<float>(place);
dev_ctx.template Alloc<float>((*var)[i],
(*var)[i]->numel() * sizeof(float));
} else if (data_type[i] == framework::proto::VarType::FP16) {
(*var)[i]->mutable_data<paddle::platform::float16>(place);
dev_ctx.template Alloc<paddle::platform::float16>(
(*var)[i], (*var)[i]->numel() * sizeof(paddle::platform::float16));
} else if (data_type[i] == framework::proto::VarType::FP64) {
(*var)[i]->mutable_data<double>(place);
dev_ctx.template Alloc<double>((*var)[i],
(*var)[i]->numel() * sizeof(double));
}
}
}
......@@ -52,8 +57,9 @@ class FusionGroupKernel : public framework::OpKernel<T> {
size_t num_outs = outs.size();
auto place = ctx.GetPlace();
auto& dev_ctx = ctx.template device_context<DeviceContext>();
MutableMultiTypeData(&outs, outs_dtype, place);
MutableMultiTypeData(&outs, outs_dtype, dev_ctx, place);
std::string func_name = ctx.Attr<std::string>("func_name");
platform::DeviceCode* dev_code =
......
......@@ -30,7 +30,8 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
void Compute(const framework::ExecutionContext& ctx) const override {
auto ins = ctx.MultiInput<framework::Tensor>("X");
auto* out = ctx.Output<framework::Tensor>("Out");
out->mutable_data<T>(ctx.GetPlace());
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
auto odims = out->dims();
std::vector<int> trans_axis = ctx.Attr<std::vector<int>>("trans_axis");
......@@ -52,7 +53,6 @@ class TransposeFlattenConcatFusionKernel : public framework::OpKernel<T> {
platform::dynload::cudnnCreateTensorDescriptor(&out_desc));
cudnnDataType_t cudnn_dtype = CudnnDataType<T>::type;
auto& dev_ctx = ctx.template device_context<phi::GPUContext>();
auto handle = dev_ctx.cudnn_handle();
T* odata = out->data<T>();
......
......@@ -44,7 +44,8 @@ class SkipLayerNormKernel : public framework::OpKernel<T> {
auto *out = context.Output<framework::Tensor>("Out");
out->Resize(X->dims());
auto *output_d = out->mutable_data<T>(context.GetPlace());
auto &dev_ctx = context.template device_context<phi::GPUContext>();
auto *output_d = dev_ctx.Alloc<T>(out, out->numel() * sizeof(T));
size_t num = 1;
for (size_t i = 0; i < X->dims().size(); i++) {
......
......@@ -81,7 +81,7 @@ class YoloBoxHeadKernel : public framework::OpKernel<T> {
const int grid_size_y = h;
const int anchors_num = anchors.size() / 2;
const T* input_data = x->data<T>();
T* output_data = out->mutable_data<T>(context.GetPlace());
T* output_data = device_ctx.Alloc<T>(out, out->numel() * sizeof(T));
auto stream = device_ctx.stream();
const int volume = x_dims[1] * h * w;
dim3 block(16, 16, 4);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册