未验证 提交 a001f263 编写于 作者: C Chen Weihang 提交者: GitHub

[Phi] Migrate load kernel (#45891)

* migrate load kernel

* remove load op

* fix test failed
上级 3bad26ec
...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,10 +12,10 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/load_op.h"
#include <string> #include <string>
#include "paddle/fluid/framework/op_registry.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -65,12 +65,3 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker { ...@@ -65,12 +65,3 @@ class LoadOpProtoMaker : public framework::OpProtoAndCheckerMaker {
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker); REGISTER_OPERATOR(load, ops::LoadOp, ops::LoadOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
load,
ops::LoadOpKernel<phi::CPUContext, float>,
ops::LoadOpKernel<phi::CPUContext, double>,
ops::LoadOpKernel<phi::CPUContext, paddle::platform::bfloat16>,
ops::LoadOpKernel<phi::CPUContext, int>,
ops::LoadOpKernel<phi::CPUContext, int8_t>,
ops::LoadOpKernel<phi::CPUContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/load_op.h"
namespace ops = paddle::operators;
REGISTER_OP_CUDA_KERNEL(load,
ops::LoadOpKernel<phi::GPUContext, float>,
ops::LoadOpKernel<phi::GPUContext, double>,
ops::LoadOpKernel<phi::GPUContext, int>,
ops::LoadOpKernel<phi::GPUContext, int8_t>,
ops::LoadOpKernel<phi::GPUContext, int64_t>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LoadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
platform::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
auto out_var_name = ctx.OutputNames("Out").data();
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var,
platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.", out_var_name));
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<phi::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Load operator only supports loading LoDTensor and SelectedRows "
"variable, %s has wrong type",
out_var_name));
}
}
void LoadLodTensor(std::istream &fin,
const platform::Place &place,
framework::Variable *var,
const framework::ExecutionContext &ctx) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
auto seek = ctx.Attr<int64_t>("seek");
if (seek != -1) {
PADDLE_ENFORCE_GE(seek,
0,
platform::errors::InvalidArgument(
"seek witn tensor must great than or equal to 0"));
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
paddle::framework::DeserializeFromStream(
fin, tensor, dev_ctx, seek, shape);
} else {
paddle::framework::DeserializeFromStream(fin, tensor, dev_ctx);
}
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto in_dtype = framework::TransToProtoVarType(tensor->dtype());
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(
in_kernel_type, out_kernel_type, *tensor, &fp16_tensor);
// reset output tensor
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin,
const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<phi::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
}
};
} // namespace operators
} // namespace paddle
...@@ -12,7 +12,113 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. ...@@ -12,7 +12,113 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#include "paddle/fluid/operators/load_op.h" #include <fstream>
#include <string>
#include <vector>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/device_context.h"
#include "paddle/fluid/platform/profiler/event_tracing.h"
namespace paddle {
namespace operators {
template <typename DeviceContext, typename T>
class LoadOpKernel : public framework::OpKernel<T> {
public:
void Compute(const framework::ExecutionContext &ctx) const override {
auto place = ctx.GetPlace();
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
auto filename = ctx.Attr<std::string>("file_path");
std::ifstream fin(filename, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
platform::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
filename));
auto out_var_name = ctx.OutputNames("Out").data();
auto *out_var = ctx.OutputVar("Out");
PADDLE_ENFORCE_NOT_NULL(
out_var,
platform::errors::InvalidArgument(
"The variable %s to be loaded cannot be found.", out_var_name));
if (out_var->IsType<framework::LoDTensor>()) {
LoadLodTensor(fin, place, out_var, ctx);
} else if (out_var->IsType<phi::SelectedRows>()) {
LoadSelectedRows(fin, place, out_var);
} else {
PADDLE_THROW(platform::errors::InvalidArgument(
"Load operator only supports loading LoDTensor and SelectedRows "
"variable, %s has wrong type",
out_var_name));
}
}
void LoadLodTensor(std::istream &fin,
const platform::Place &place,
framework::Variable *var,
const framework::ExecutionContext &ctx) const {
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
auto *tensor = var->GetMutable<framework::LoDTensor>();
auto seek = ctx.Attr<int64_t>("seek");
if (seek != -1) {
PADDLE_ENFORCE_GE(seek,
0,
platform::errors::InvalidArgument(
"seek witn tensor must great than or equal to 0"));
auto shape = ctx.Attr<std::vector<int64_t>>("shape");
paddle::framework::DeserializeFromStream(
fin, tensor, dev_ctx, seek, shape);
} else {
paddle::framework::DeserializeFromStream(fin, tensor, dev_ctx);
}
auto load_as_fp16 = ctx.Attr<bool>("load_as_fp16");
auto in_dtype = framework::TransToProtoVarType(tensor->dtype());
auto out_dtype = load_as_fp16 ? framework::proto::VarType::FP16 : in_dtype;
if (in_dtype != out_dtype) {
// convert to float16 tensor
auto in_kernel_type = framework::OpKernelType(in_dtype, place);
auto out_kernel_type = framework::OpKernelType(out_dtype, place);
framework::LoDTensor fp16_tensor;
// copy LoD info to the new tensor
fp16_tensor.set_lod(tensor->lod());
framework::TransDataType(
in_kernel_type, out_kernel_type, *tensor, &fp16_tensor);
// reset output tensor
var->Clear();
tensor = var->GetMutable<framework::LoDTensor>();
tensor->set_lod(fp16_tensor.lod());
tensor->ShareDataWith(fp16_tensor);
}
}
void LoadSelectedRows(std::istream &fin,
const platform::Place &place,
framework::Variable *var) const {
auto *selectedRows = var->GetMutable<phi::SelectedRows>();
// get device context from pool
platform::DeviceContextPool &pool = platform::DeviceContextPool::Instance();
auto &dev_ctx = *pool.Get(place);
framework::DeserializeFromStream(fin, selectedRows, dev_ctx);
selectedRows->SyncIndex();
}
};
} // namespace operators
} // namespace paddle
namespace ops = paddle::operators; namespace ops = paddle::operators;
......
/* Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#ifdef PADDLE_WITH_XPU
#include "paddle/fluid/operators/load_op.h"
namespace ops = paddle::operators;
REGISTER_OP_XPU_KERNEL(
load,
ops::LoadOpKernel<paddle::platform::XPUDeviceContext, float>,
ops::LoadOpKernel<paddle::platform::XPUDeviceContext, double>,
ops::LoadOpKernel<paddle::platform::XPUDeviceContext, int>,
ops::LoadOpKernel<paddle::platform::XPUDeviceContext, int8_t>,
ops::LoadOpKernel<paddle::platform::XPUDeviceContext, int64_t>);
#endif // PADDLE_WITH_XPU
...@@ -21,7 +21,9 @@ USE_OP_ITSELF(save); ...@@ -21,7 +21,9 @@ USE_OP_ITSELF(save);
PD_DECLARE_KERNEL(save, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(save, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(save_sr, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(save_sr, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(cast, CPU, ALL_LAYOUT); PD_DECLARE_KERNEL(cast, CPU, ALL_LAYOUT);
USE_CPU_ONLY_OP(load); USE_OP_ITSELF(load);
PD_DECLARE_KERNEL(load, CPU, ALL_LAYOUT);
PD_DECLARE_KERNEL(load_sr, CPU, ALL_LAYOUT);
TEST(SaveLoadOp, CPU) { TEST(SaveLoadOp, CPU) {
paddle::framework::Scope scope; paddle::framework::Scope scope;
......
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/load_kernel.h"
#include <fstream>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/serialization.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace phi {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
DenseTensor* out) {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
phi::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
file_path));
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"The variable to be loaded cannot be found."));
if (seek != -1) {
PADDLE_ENFORCE_GE(seek,
0,
phi::errors::InvalidArgument(
"seek witn tensor must great than or equal to 0"));
DeserializeFromStream(fin, out, dev_ctx, seek, shape);
} else {
DeserializeFromStream(fin, out, dev_ctx);
}
auto in_dtype = out->dtype();
auto out_dtype = load_as_fp16 ? DataType::FLOAT16 : in_dtype;
if (in_dtype != out_dtype) {
CastKernel<T>(dev_ctx, *out, out_dtype, out);
}
}
} // namespace phi
PD_REGISTER_KERNEL(load, CPU, ALL_LAYOUT, phi::LoadKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(load, GPU, ALL_LAYOUT, phi::LoadKernel, float) {}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(load, XPU, ALL_LAYOUT, phi::LoadKernel, float) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/dense_tensor.h"
namespace phi {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
DenseTensor* out);
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/kernels/selected_rows/load_kernel.h"
#include <fstream>
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/serialization.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
SelectedRows* out) {
// FIXME(yuyang18): We save variable to local file now, but we should change
// it to save an output stream.
std::ifstream fin(file_path, std::ios::binary);
PADDLE_ENFORCE_EQ(static_cast<bool>(fin),
true,
phi::errors::Unavailable(
"Load operator fail to open file %s, please check "
"whether the model file is complete or damaged.",
file_path));
PADDLE_ENFORCE_NOT_NULL(out,
phi::errors::InvalidArgument(
"The variable to be loaded cannot be found."));
DeserializeFromStream(fin, out, dev_ctx);
}
} // namespace sr
} // namespace phi
PD_REGISTER_KERNEL(load_sr, CPU, ALL_LAYOUT, phi::sr::LoadKernel, float) {}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(load_sr, GPU, ALL_LAYOUT, phi::sr::LoadKernel, float) {}
#endif
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#pragma once
#include <string>
#include "paddle/phi/core/selected_rows.h"
namespace phi {
namespace sr {
template <typename T, typename Context>
void LoadKernel(const Context& dev_ctx,
const std::string& file_path,
int64_t seek,
const std::vector<int64_t>& shape,
bool load_as_fp16,
SelectedRows* out);
} // namespace sr
} // namespace phi
// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "paddle/phi/core/compat/op_utils.h"
namespace phi {
KernelSignature LoadOpArgumentMapping(const ArgumentMappingContext& ctx) {
if (ctx.IsDenseTensorOutput("Out")) {
return KernelSignature(
"load", {}, {"file_path", "seek", "shape", "load_as_fp16"}, {"Out"});
} else if (ctx.IsSelectedRowsOutput("Out")) {
return KernelSignature(
"load_sr", {}, {"file_path", "seek", "shape", "load_as_fp16"}, {"Out"});
} else {
return KernelSignature("unregistered", {}, {}, {});
}
}
} // namespace phi
PD_REGISTER_ARG_MAPPING_FN(load, phi::LoadOpArgumentMapping);
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册