未验证 提交 bf50784c 编写于 作者: R Ruibiao Chen 提交者: GitHub

New executor static build for fluid kernel (#50670)

* Check structed kernel for new executor static build

* Update code

* Ready for resnet50

* Move transfer_dtype to phi

* Ready for transformer

* Fix CI errors

* Fix layer_norm InferMeta

* Remove layer_norm infermeta fix
上级 819f8939
......@@ -176,33 +176,10 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
new_op_func_node.input_index["X"] = {var_scope_->VarId(var_name)};
new_op_func_node.output_index["Out"] = {var_scope_->VarId(new_var_name)};
if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx);
new_op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
new_op_func_node.kernel_func_(exec_ctx);
} else {
new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
if (!skip_run) {
(*new_op_func_node.phi_kernel_)(&phi_kernel_context);
} else {
FakeInitializeOutputs(new_op_func_node.phi_kernel_,
op_with_kernel->PhiKernelSignature(),
&phi_kernel_context);
}
}
new_op_func_node.dev_ctx_ = dev_ctx;
new_op_func_node.operator_base_ = op;
const phi::Place& place = dev_ctx->GetPlace();
// NOTE(winter-wang): in npu and custom device, D2H kernel is asynchronous.
// need to explicit synchronization.
if ((platform::is_npu_place(place) || platform::is_custom_place(place)) &&
op_type == kMemcpyD2H) {
dev_ctx->Wait();
}
if (platform::is_cpu_place(place)) {
new_op_func_node.type_ = OpFuncType::kCpuSync;
} else if (platform::is_gpu_place(place)) {
......@@ -218,8 +195,35 @@ void DataTranferHelper::RunAndConstructOpFuncNode(
// Memcpy in npu and custom devices is asynchronous
new_op_func_node.type_ = OpFuncType::kGpuAsync;
}
new_op_func_node.dev_ctx_ = dev_ctx;
new_op_func_node.operator_base_ = op;
if (!run_phi_kernel) {
op_with_kernel->ChooseKernel(exec_ctx);
new_op_func_node.kernel_func_ = *op_with_kernel->kernel_func();
new_op_func_node.kernel_func_(exec_ctx);
} else {
new_op_func_node.phi_kernel_ = op_with_kernel->PhiKernel();
if (skip_run) {
FakeInitializeOutputsForFunctionKernel(
*(new_op_func_node.phi_kernel_),
*(op_with_kernel->PhiKernelSignature()),
runtime_context,
*dev_ctx);
} else {
phi::KernelContext phi_kernel_context;
op_with_kernel->BuildPhiKernelContext(
runtime_context, dev_ctx, &phi_kernel_context);
(*new_op_func_node.phi_kernel_)(&phi_kernel_context);
}
}
// NOTE(winter-wang): in npu and custom device, D2H kernel is asynchronous.
// need to explicit synchronization.
if ((platform::is_npu_place(place) || platform::is_custom_place(place)) &&
op_type == kMemcpyD2H) {
dev_ctx->Wait();
}
VLOG(3) << "Run " << op_type << " done.";
new_op_func_nodes->emplace_back(std::move(new_op_func_node));
......
......@@ -42,18 +42,18 @@ class DataTranferHelper {
std::vector<OpFuncNode>* new_op_func_nodes,
bool use_local_scope,
bool is_fetch_v2,
bool skip_run = false);
bool static_build = false);
void RunAndConstructShareNode(const std::string& src_var_name,
const std::string& dst_var_name,
std::vector<OpFuncNode>* op_func_nodes,
bool skip_run = false);
bool static_build = false);
void RunAndConstructOpFuncNode(const std::shared_ptr<OperatorBase>& op,
const std::string& var_name,
const std::string& new_var_name,
std::vector<OpFuncNode>* op_func_nodes,
bool skip_run = false);
bool static_build = false);
private:
platform::Place place_;
......@@ -69,7 +69,7 @@ void ApplyDataTransform(const OpKernelType& expected_kernel_key,
OpFuncNode* op_func_node,
std::vector<OpFuncNode>* op_func_nodes,
bool use_local_scope = true,
bool skip_run = false);
bool static_build = false);
void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
const platform::Place& place,
......@@ -78,7 +78,7 @@ void HandleComplexGradToRealGrad(const OpFuncNode& op_func_node,
VariableScope* var_scope,
std::vector<OpFuncNode>* op_func_nodes,
framework::Scope* local_scope,
bool skip_run = false);
bool static_build = false);
inline bool need_device_transform(const phi::KernelKey& kernel_type_for_var,
const phi::DenseTensor* tensor,
......
......@@ -65,12 +65,16 @@ class AsyncWorkQueue {
std::unique_ptr<WorkQueueGroup> queue_group_;
};
bool BlockCanBeStaticBuilt(const framework::BlockDesc& block);
bool IsCommunicationOp(const std::string& op_name);
bool IsCommunicationOp(const Instruction& instr);
bool IsCpuOp(const Instruction& instr);
bool IsGradOp(const std::string& op_name);
bool IsMemcpyD2H(const Instruction& instr);
bool IsMemcpyH2D(const Instruction& instr);
......@@ -82,23 +86,30 @@ bool IsSupportedHeterPlace(const phi::Place& place);
void AddFetch(const std::vector<std::string>& fetch_names,
framework::BlockDesc* block);
bool BuildOpFuncList(const platform::Place& place,
void BuildOpFuncList(const platform::Place& place,
const framework::BlockDesc& block,
const std::set<std::string>& skip_gc_vars,
std::vector<OpFuncNode>* vec_func_list,
VariableScope* scope,
const ExecutionConfig& execution_config,
bool use_local_scope = true);
bool use_local_scope = true,
bool static_build = false);
void BuildVariableScope(const framework::BlockDesc& block,
const ExecutionConfig& execution_config,
VariableScope* var_scope);
void LogDeviceMemoryStats(const platform::Place& place);
void FakeInitializeOutputsForFunctionKernel(
const phi::Kernel& phi_kernel,
const phi::KernelSignature& kernel_sig,
const RuntimeContext& ctx,
const platform::DeviceContext& dev_ctx);
void FakeInitializeOutputs(phi::Kernel* phi_kernel,
phi::KernelSignature* kernel_sig,
phi::KernelContext* phi_kernel_context);
void FakeInitializeOutputsForStructureKernel(
const framework::OpKernelType& op_kernel_type,
ExecutionContext* execution_context);
void LogDeviceMemoryStats(const platform::Place& place);
} // namespace interpreter
} // namespace framework
......
......@@ -38,6 +38,10 @@ PADDLE_DEFINE_EXPORTED_bool(
new_executor_serial_run,
false,
"Enable serial execution for standalone executor, used for debug.");
PADDLE_DEFINE_EXPORTED_bool(
new_executor_static_build,
false,
"Build the interpreterCore statically without running kernels.");
PADDLE_DEFINE_EXPORTED_bool(new_executor_use_inplace,
false,
"Use inplace in new executor");
......@@ -117,6 +121,9 @@ InterpreterCore::InterpreterCore(const platform::Place& place,
var_scope_(scope) {
VLOG(4) << "InterpreterCore(): " << this << " on " << place_;
static_build_ = FLAGS_new_executor_static_build &&
interpreter::BlockCanBeStaticBuilt(block);
exception_notifier_ = main_thread_blocker_.RegisterEvent(kExceptionCaught);
completion_notifier_ = main_thread_blocker_.RegisterEvent(kTaskCompletion);
......@@ -275,20 +282,21 @@ paddle::framework::FetchList InterpreterCore::Run(
block_, execution_config_, &var_scope_);
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList(
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_,
HasLocalScope());
HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
is_build_ = true;
UpdateSyncOpNum();
if (skip_run) {
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
......@@ -1270,20 +1278,21 @@ void InterpreterCore::Prepare(const std::vector<std::string>& feed_names,
block_, execution_config_, &var_scope_);
FeedInput();
std::vector<paddle::framework::OpFuncNode> op_func_nodes;
auto skip_run = paddle::framework::interpreter::BuildOpFuncList(
paddle::framework::interpreter::BuildOpFuncList(
place_,
block_,
execution_config_.skip_gc_vars,
&op_func_nodes,
&var_scope_,
execution_config_,
HasLocalScope());
HasLocalScope(),
static_build_);
SetFeedVarsInplaceSkip(feed_names);
// convert vec func_list to graph
Convert(&op_func_nodes);
UpdateSyncOpNum();
is_build_ = true;
if (skip_run) {
if (static_build_) {
VLOG(4) << "RUN impl";
RunImpl();
}
......
......@@ -133,6 +133,7 @@ class InterpreterCore {
private:
bool is_build_{false};
bool static_build_{false};
const platform::Place place_;
const BlockDesc& block_; // not owned
......
......@@ -12,13 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include <memory>
#include "paddle/fluid/framework/convert_utils.h"
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/platform/float16.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/unary.h"
#ifdef PADDLE_WITH_MLU
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#endif
......@@ -89,13 +91,6 @@ class CastOp : public framework::OperatorWithKernel {
using framework::OperatorWithKernel::OperatorWithKernel;
protected:
void InferShape(framework::InferShapeContext *context) const override {
OP_INOUT_CHECK(context->HasInput("X"), "Input", "X", "cast");
OP_INOUT_CHECK(context->HasOutput("Out"), "Output", "Out", "cast");
context->SetOutputDim("Out", context->GetInputDim("X"));
context->ShareLoD("X", "Out");
}
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
// CastOp kernel's device type is decided by input tensor place
......@@ -150,13 +145,18 @@ class CastOp : public framework::OperatorWithKernel {
namespace ops = paddle::operators;
using CPU = phi::CPUContext;
DECLARE_INFER_SHAPE_FUNCTOR(cast,
CastInferShapeFunctor,
PD_INFER_META(phi::CastInferMeta));
// cast use phi kernel, so no need to REGISTER_OP_CPU_KERNEL here.
REGISTER_OPERATOR(cast,
ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastCompositeGradOpMaker,
ops::CastOpProtoMaker);
ops::CastOpProtoMaker,
CastInferShapeFunctor);
// [ why register transfer_dtype_op alias with cast_op? ]
// In case of InterpreterCore, if we reuse cast_op, we cannot distinguish
......@@ -165,19 +165,5 @@ REGISTER_OPERATOR(transfer_dtype,
ops::CastOp,
ops::CastOpGradMaker<paddle::framework::OpDesc>,
ops::CastOpGradMaker<paddle::imperative::OpBase>,
ops::CastOpProtoMaker);
REGISTER_OP_CPU_KERNEL(
transfer_dtype,
ops::CastOpKernel<CPU, float>,
ops::CastOpKernel<CPU, double>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int64_t>,
ops::CastOpKernel<CPU, int>,
ops::CastOpKernel<CPU, int16_t>,
ops::CastOpKernel<CPU, bool>,
ops::CastOpKernel<CPU, uint8_t>,
ops::CastOpKernel<CPU, int8_t>,
ops::CastOpKernel<CPU, paddle::platform::float16>,
ops::CastOpKernel<CPU, paddle::platform::bfloat16>,
ops::CastOpKernel<CPU, paddle::platform::complex<float>>,
ops::CastOpKernel<CPU, paddle::platform::complex<double>>);
ops::CastOpProtoMaker,
CastInferShapeFunctor);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/float16.h"
namespace ops = paddle::operators;
namespace plat = paddle::platform;
using CUDA = phi::GPUContext;
// See [ why register transfer_dtype_op alias with cast_op? ] in cast_op.cc
REGISTER_OP_CUDA_KERNEL(transfer_dtype,
ops::CastOpKernel<CUDA, float>,
ops::CastOpKernel<CUDA, double>,
ops::CastOpKernel<CUDA, int>,
ops::CastOpKernel<CUDA, int64_t>,
ops::CastOpKernel<CUDA, int16_t>,
ops::CastOpKernel<CUDA, bool>,
ops::CastOpKernel<CUDA, uint8_t>,
ops::CastOpKernel<CUDA, plat::float16>,
ops::CastOpKernel<CUDA, plat::complex<float>>,
ops::CastOpKernel<CUDA, plat::complex<double>>,
ops::CastOpKernel<CUDA, plat::bfloat16>);
/* Copyright (c) 2016 PaddlePaddle Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#pragma once
#include "paddle/fluid/framework/data_type.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/phi_utils.h"
#include "paddle/phi/common/transform.h"
#include "paddle/phi/kernels/cast_kernel.h"
namespace paddle {
namespace operators {
template <typename InT, typename OutT>
struct CastOpTransformFunctor {
HOSTDEVICE OutT operator()(InT in) const { return static_cast<OutT>(in); }
};
template <typename DeviceContext, typename InT>
struct CastOpFunctor {
const phi::DenseTensor* in_;
phi::DenseTensor* out_;
const DeviceContext& ctx_;
CastOpFunctor(const phi::DenseTensor* in,
phi::DenseTensor* out,
const DeviceContext& ctx)
: in_(in), out_(out), ctx_(ctx) {}
template <typename OutT>
void apply() const {
auto* in_begin = in_->data<InT>();
auto numel = in_->numel();
auto* in_end = in_begin + numel;
auto* out_begin = out_->mutable_data<OutT>(ctx_.GetPlace());
phi::Transform<DeviceContext> trans;
trans(
ctx_, in_begin, in_end, out_begin, CastOpTransformFunctor<InT, OutT>());
}
};
template <typename DeviceContext, typename InT>
class CastOpKernel : public framework::OpKernel<InT> {
public:
void Compute(const framework::ExecutionContext& context) const override {
auto* in = context.Input<phi::DenseTensor>("X");
auto* out = context.Output<phi::DenseTensor>("Out");
auto out_dtype = context.Attr<int>("out_dtype");
auto& dev_ctx = context.device_context<DeviceContext>();
out->mutable_data(dev_ctx.GetPlace(),
static_cast<framework::proto::VarType::Type>(out_dtype));
auto pt_out_dtype = framework::TransToPhiDataType(
static_cast<framework::proto::VarType::Type>(out_dtype));
// call new kernel
phi::CastKernel<InT>(
static_cast<const typename paddle::framework::ConvertToPhiContext<
DeviceContext>::TYPE&>(dev_ctx),
*in,
pt_out_dtype,
out);
}
};
} // namespace operators
} // namespace paddle
......@@ -12,7 +12,6 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License. */
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/mlu/mlu_baseop.h"
#include "paddle/fluid/platform/device/mlu/device_context.h"
......
......@@ -15,7 +15,6 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/platform/device/npu/npu_op_runner.h"
namespace paddle {
......
......@@ -121,29 +121,38 @@ class FeedOp : public framework::OperatorWithKernel {
if (ctx->IsRuntime()) {
framework::Variable* x_var =
PADDLE_GET(framework::Variable*, ctx->GetInputVarPtrs("X")[0]);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
auto& x = x_var->Get<framework::FeedList>();
int col = ctx->Attrs().Get<int>("col");
auto& feed_item = x[col];
if (feed_item.index() == 0) {
const auto& feed_item = CheckAndGetFeedItem(x, col);
const auto& feed_item = CheckAndGetFeedItem(x, col);
if (feed_item.index() == 0) { // DenseTensor
auto& feed_tensor = PADDLE_GET_CONST(phi::DenseTensor, feed_item);
ctx->SetOutputDim("Out", feed_tensor.dims());
} else if (feed_item.index() == 1) {
phi::DenseTensor* out_tensor = out_var->GetMutable<phi::DenseTensor>();
phi::DenseTensorMeta meta = out_tensor->meta();
meta.dims = feed_tensor.dims();
meta.dtype = feed_tensor.dtype();
meta.layout = feed_tensor.layout();
meta.lod = feed_tensor.lod();
out_tensor->set_meta(meta);
} else if (feed_item.index() == 1) { // Strings
auto& feed_str = PADDLE_GET_CONST(framework::Strings, feed_item);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
out_var->GetMutable<framework::Strings>()->resize(feed_str.size());
} else {
} else if (feed_item.index() == 2) { // SparseCooTensor
auto& feed_sparse_tensor =
PADDLE_GET_CONST(phi::SparseCooTensor, feed_item);
framework::Variable* out_var =
PADDLE_GET(framework::Variable*, ctx->GetOutputVarPtrs("Out")[0]);
out_var->GetMutable<phi::SparseCooTensor>()->set_meta(
feed_sparse_tensor.meta());
out_var->GetMutable<phi::SparseCooTensor>()->SetCoalesced(
feed_sparse_tensor.coalesced());
out_var->GetMutable<phi::SparseCooTensor>()->SetIndicesDict(
feed_sparse_tensor.GetIndicesDict());
} else {
PADDLE_THROW(
phi::errors::Unimplemented("Only support DenseTnesor, Strings, and "
"SparseCooTensor for feed op now."));
}
}
}
......@@ -151,7 +160,23 @@ class FeedOp : public framework::OperatorWithKernel {
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
return phi::KernelKey(framework::proto::VarType::FP32, ctx.GetPlace());
const framework::Variable* x_var = ctx.InputVar("X");
auto& x = x_var->Get<framework::FeedList>();
int col = ctx.Attr<int>("col");
auto& feed_item = x[col];
framework::proto::VarType::Type expected_data_type;
if (feed_item.index() == 0) { // DenseTensor
expected_data_type = framework::TransToProtoVarType(
PADDLE_GET_CONST(phi::DenseTensor, feed_item).dtype());
} else if (feed_item.index() == 2) { // SparseCooTensor
expected_data_type = framework::TransToProtoVarType(
PADDLE_GET_CONST(phi::SparseCooTensor, feed_item).dtype());
} else { // Strings
expected_data_type = framework::proto::VarType::FP32;
}
return phi::KernelKey(expected_data_type, ctx.GetPlace());
}
};
......
......@@ -21,7 +21,6 @@ limitations under the License. */
#include "paddle/fluid/framework/data_type_transform.h"
#include "paddle/fluid/framework/tensor_util.h"
#include "paddle/fluid/operators/cast_op.h"
#include "paddle/fluid/operators/reduce_ops/reduce_op_function.h"
#include "paddle/phi/kernels/funcs/math_function.h"
// only can include the headers in paddle/phi/api dirs
......
......@@ -111,7 +111,9 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
ALL_LAYOUT,
phi::CheckFiniteAndUnscaleKernel,
float,
double) {}
double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(update_loss_scaling,
CPU,
......
......@@ -79,31 +79,37 @@ PD_REGISTER_KERNEL(equal_all,
int,
int64_t,
float,
double) {}
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {} \
PD_REGISTER_KERNEL(name##_raw, \
CPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
CPU, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
CPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
......
......@@ -209,7 +209,11 @@ PD_REGISTER_KERNEL(dropout,
phi::DropoutRawKernel,
float,
double,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL(
dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {}
dropout_nd, CPU, ALL_LAYOUT, phi::DropoutNdKernel, float, double) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
}
......@@ -83,4 +83,6 @@ void OneHotRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
one_hot_raw, CPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -358,7 +358,9 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
float,
double,
phi::dtype::float16,
phi::dtype::bfloat16) {}
phi::dtype::bfloat16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(update_loss_scaling,
GPU,
......
......@@ -90,6 +90,7 @@ PD_REGISTER_KERNEL(dropout,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
}
PD_REGISTER_KERNEL(dropout_nd,
......@@ -101,4 +102,5 @@ PD_REGISTER_KERNEL(dropout_nd,
phi::dtype::bfloat16,
phi::dtype::float16) {
kernel->InputAt(1).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::UINT8);
}
......@@ -91,4 +91,6 @@ void OneHotRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
one_hot_raw, GPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -95,26 +95,49 @@ inline void CompareAllKernelImpl(const Context& ctx,
} // namespace phi
#ifdef PADDLE_WITH_XPU_KP
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {}
PD_REGISTER_KERNEL(less_than, KPS, ALL_LAYOUT, phi::LessThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_equal, KPS, ALL_LAYOUT, phi::LessEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(greater_than, KPS, ALL_LAYOUT, phi::GreaterThanKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {}
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {}
greater_equal, KPS, ALL_LAYOUT, phi::GreaterEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(equal, KPS, ALL_LAYOUT, phi::EqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(not_equal, KPS, ALL_LAYOUT, phi::NotEqualKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(
less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {}
less_than_raw, KPS, ALL_LAYOUT, phi::LessThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(
less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {}
less_equal_raw, KPS, ALL_LAYOUT, phi::LessEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(
greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {}
greater_than_raw, KPS, ALL_LAYOUT, phi::GreaterThanRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(
greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {}
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {}
greater_equal_raw, KPS, ALL_LAYOUT, phi::GreaterEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(equal_raw, KPS, ALL_LAYOUT, phi::EqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(
not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {}
not_equal_raw, KPS, ALL_LAYOUT, phi::NotEqualRawKernel, int) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#else
......@@ -126,33 +149,39 @@ PD_REGISTER_KERNEL(equal_all,
int,
int64_t,
float,
double) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {} \
PD_REGISTER_KERNEL(name##_raw, \
KPS, \
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) {}
double) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, \
KPS, \
ALL_LAYOUT, \
phi::func##Kernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
KPS, \
ALL_LAYOUT, \
phi::func##RawKernel, \
bool, \
int16_t, \
int, \
int64_t, \
float, \
double, \
phi::dtype::float16, \
phi::dtype::bfloat16) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_than, LessThan)
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
......
......@@ -30,12 +30,18 @@ void OneHotKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
PD_REGISTER_KERNEL(one_hot, CPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
PD_REGISTER_KERNEL(one_hot, GPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
}
#endif
#ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {}
PD_REGISTER_KERNEL(one_hot, XPU, ALL_LAYOUT, phi::OneHotKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::FLOAT32);
}
#endif
......@@ -47,6 +47,8 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
}
#if defined(PADDLE_WITH_CUDA) || defined(PADDLE_WITH_HIP)
......@@ -65,6 +67,8 @@ PD_REGISTER_KERNEL(shape,
phi::dtype::complex<double>,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
}
#endif
......@@ -80,5 +84,7 @@ PD_REGISTER_KERNEL(shape,
double,
phi::dtype::float16) {
kernel->InputAt(0).SetBackend(phi::Backend::ALL_BACKEND);
kernel->OutputAt(0).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::INT32);
}
#endif
......@@ -285,4 +285,6 @@ PD_REGISTER_KERNEL(check_finite_and_unscale,
ALL_LAYOUT,
phi::CheckFiniteAndUnscaleKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(paddle::experimental::DataType::BOOL);
}
......@@ -90,7 +90,9 @@ DEFINE_XPU_COMPARE_KERNEL(GreaterEqual, xpu::broadcast_greater_equal<XPUType>)
} // namespace phi
PD_REGISTER_KERNEL(
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {}
less_than, XPU, ALL_LAYOUT, phi::LessThanKernel, int, int64_t, float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
PD_REGISTER_KERNEL(less_than_raw,
XPU,
......@@ -98,18 +100,24 @@ PD_REGISTER_KERNEL(less_than_raw,
phi::LessThanRawKernel,
int,
int64_t,
float) {}
float) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL);
}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) {} \
PD_REGISTER_KERNEL(name##_raw, \
XPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
int, \
int64_t, \
float) {}
#define PD_REGISTER_COMPARE_KERNEL(name, func) \
PD_REGISTER_KERNEL( \
name, XPU, ALL_LAYOUT, phi::func##Kernel, int, int64_t, float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
} \
PD_REGISTER_KERNEL(name##_raw, \
XPU, \
ALL_LAYOUT, \
phi::func##RawKernel, \
int, \
int64_t, \
float) { \
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::BOOL); \
}
PD_REGISTER_COMPARE_KERNEL(less_equal, LessEqual)
PD_REGISTER_COMPARE_KERNEL(greater_than, GreaterThan)
......
......@@ -62,4 +62,6 @@ void OneHotRawKernel(const Context& dev_ctx,
} // namespace phi
PD_REGISTER_KERNEL(
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {}
one_hot_raw, XPU, ALL_LAYOUT, phi::OneHotRawKernel, int, int64_t) {
kernel->OutputAt(0).SetDataType(paddle::experimental::DataType::UNDEFINED);
}
......@@ -22,4 +22,7 @@ KernelSignature CastOpArgumentMapping(const ArgumentMappingContext& ctx) {
} // namespace phi
PD_REGISTER_BASE_KERNEL_NAME(transfer_dtype, cast);
PD_REGISTER_ARG_MAPPING_FN(cast, phi::CastOpArgumentMapping);
PD_REGISTER_ARG_MAPPING_FN(transfer_dtype, phi::CastOpArgumentMapping);
......@@ -22,7 +22,7 @@ from paddle.framework import set_flags
paddle.enable_static()
def build_resnet50():
def build_resnet50(use_amp=False):
main_program = paddle.static.Program()
startup_program = paddle.static.Program()
......@@ -36,49 +36,80 @@ def build_resnet50():
loss = paddle.nn.functional.cross_entropy(input=prediction, label=label)
loss = paddle.mean(loss)
adam = paddle.optimizer.Adam(learning_rate=0.001)
if use_amp:
adam = paddle.static.amp.decorate(
optimizer=adam,
init_loss_scaling=1.0,
use_dynamic_loss_scaling=False,
use_pure_fp16=True,
use_fp16_guard=False,
)
adam.minimize(loss)
return main_program, startup_program, loss
build_strategy = paddle.static.BuildStrategy()
build_strategy.enable_addto = True
build_strategy.fuse_elewise_add_act_ops = True
if use_amp:
build_strategy.fuse_bn_act_ops = True
build_strategy.fuse_bn_add_act_ops = True
main_program = paddle.static.CompiledProgram(
main_program, build_strategy=build_strategy
)
class TestAOTChooseKernel(unittest.TestCase):
def test_aot_choose_kernel(self):
if not paddle.fluid.core.is_compiled_with_cuda():
return
return main_program, startup_program, loss, adam
def run(aot_choose_kernel=None):
paddle.seed(2022)
np.random.seed(2022)
main_program, startup_program, loss = build_resnet50()
def run_resnet50(aot_choose_kernel=False, use_amp=False):
paddle.seed(2022)
np.random.seed(2022)
scope = paddle.static.Scope()
exe = paddle.static.Executor()
main_program, startup_program, loss, optimizer = build_resnet50(use_amp)
set_flags({'FLAGS_cudnn_deterministic': 1})
if aot_choose_kernel:
set_flags({'FLAGS_new_executor_static_build': 1})
else:
set_flags({'FLAGS_new_executor_static_build': 0})
place = paddle.CUDAPlace(0)
exe = paddle.static.Executor(place)
scope = paddle.static.Scope()
with paddle.static.scope_guard(scope):
exe.run(startup_program)
set_flags({'FLAGS_cudnn_deterministic': 1})
if aot_choose_kernel:
set_flags({'FLAGS_new_executor_static_build': 1})
for i in range(10):
feed = {
'image': np.random.randint(
0, 256, size=[32, 3, 224, 224]
).astype('float32'),
'label': np.random.randint(0, 1000, size=[32]).astype(
'int64'
),
}
loss_ = exe.run(main_program, feed=feed, fetch_list=[loss])
return loss_
if use_amp:
set_flags({'FLAGS_conv_workspace_size_limit': 1500})
set_flags({'FLAGS_max_inplace_grad_add': 8})
set_flags({'FLAGS_cudnn_batchnorm_spatial_persistent': 1})
loss1 = run(aot_choose_kernel=True)
loss2 = run(aot_choose_kernel=False)
with paddle.static.scope_guard(scope):
exe.run(startup_program)
if use_amp:
optimizer.amp_init(place)
feed_dtype = 'float16' if use_amp else 'float32'
for i in range(1):
feed = {
'image': np.random.randint(
0, 256, size=[32, 3, 224, 224]
).astype(feed_dtype),
'label': np.random.randint(0, 1000, size=[32]).astype('int64'),
}
loss_ = exe.run(main_program, feed=feed, fetch_list=[loss])
return loss_
class TestAOTChooseKernel(unittest.TestCase):
def test_resnet50_aot_choose_kernel(self):
if not paddle.fluid.core.is_compiled_with_cuda():
return
loss1 = run_resnet50(aot_choose_kernel=True)
loss2 = run_resnet50(aot_choose_kernel=False)
self.assertEqual(loss1, loss2)
def test_resnet50_amp_aot_choose_kernel(self):
if not paddle.fluid.core.is_compiled_with_cuda():
return
loss1 = run_resnet50(aot_choose_kernel=True, use_amp=True)
loss2 = run_resnet50(aot_choose_kernel=False, use_amp=True)
self.assertEqual(loss1, loss2)
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册