未验证 提交 733eca85 编写于 作者: S Sonder 提交者: GitHub

remove ops from OpsWithFluidKernelNeedMoveToPhi set (#54007)

* remove ops from OpsWithFluidKernelNeedMoveToPhi set

* open static build flag

* OpsWithFluidKernelNeedMoveToPhi

* open new_executor_static_build

* add infermate for cudnn_lstm

* fix

* update

* fix

* update

* update

* update

* fix pow2 decay

* fix pow2 decay

* recover analysis_predictor.cc

* fix pow2 decay

* fix cudnn lstm

* add output register info for svd

* fix pow2_decay_with_linear_warmup_kernel

* recover test lstm cudnn

* recover svg register codes

* fix register info

* fix reduce sum register info

* add output info for adadelta

* add output info for adadelta

* add output info for adamax

* fix complex abs register info

* add register info for cudnn_lstm_grad

* recover

* fix lstm cudnn

* fix

* fix xpu output registe info

* remove std::cout

* add backend

* remove output info in pow2_decay_with_linear_warmup_kernel

* add judgment in TensorShouldBeFakeInitialized

* recover power_

* close new_executor_static_build

* fix set_value_xpu
上级 59dd97af
...@@ -40,20 +40,6 @@ std::set<std::string> OpsCanSkipedFakeAllocInStaticBuild = { ...@@ -40,20 +40,6 @@ std::set<std::string> OpsCanSkipedFakeAllocInStaticBuild = {
"fetch_v2", "fetch_v2",
"nop"}; "nop"};
// Cannot static analysis these Ops' output dtype or backend because their
// kernels have not moved to PHI yet.
std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = {
"cudnn_lstm",
"dequantize",
"distributed_fused_lamb",
"fused_batch_norm_act",
"fused_batch_norm_act_grad",
"fusion_group",
"pow2_decay_with_linear_warmup",
"sequence_mask",
"sequence_pool",
"stft"};
std::set<std::string> StaticBuildBlackList = { std::set<std::string> StaticBuildBlackList = {
"batch_norm" /*: to handle reserve_space output*/, "batch_norm" /*: to handle reserve_space output*/,
"cinn_instruction_run" /*: to handle subgraph infermeta*/, "cinn_instruction_run" /*: to handle subgraph infermeta*/,
...@@ -95,8 +81,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) { ...@@ -95,8 +81,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
bool has_fluid_kernel = OperatorWithKernel::AllOpKernels().count(op_type); bool has_fluid_kernel = OperatorWithKernel::AllOpKernels().count(op_type);
bool has_structured_kernel = bool has_structured_kernel =
phi::KernelFactory::Instance().HasStructuredKernel(op_type); phi::KernelFactory::Instance().HasStructuredKernel(op_type);
bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel) && bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel);
OpsWithFluidKernelNeedMoveToPhi.count(op_type);
KernelCode kernel_code = KernelCode kernel_code =
(in_black_list << 7) + (is_operator_base << 6) + (is_custom_op << 5) + (in_black_list << 7) + (is_operator_base << 6) + (is_custom_op << 5) +
...@@ -162,6 +147,11 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op, ...@@ -162,6 +147,11 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op,
return false; return false;
} }
if (op_type == "distributed_fused_lamb" && parameter_name == "ParamOut") {
VLOG(2) << "Skip fake initialization for: " << parameter_name;
return false;
}
if (op_type == "fake_quantize_range_abs_max") { if (op_type == "fake_quantize_range_abs_max") {
if (op.Attr<bool>("is_test") && if (op.Attr<bool>("is_test") &&
(parameter_name == "OutScale" || parameter_name == "OutScales")) { (parameter_name == "OutScale" || parameter_name == "OutScales")) {
...@@ -467,11 +457,14 @@ void FakeInitializeOutputsForFunctionKernel( ...@@ -467,11 +457,14 @@ void FakeInitializeOutputsForFunctionKernel(
} else if (op_type == "layer_norm") { } else if (op_type == "layer_norm") {
dtype = InferMPDType(runtime_ctx, "X"); dtype = InferMPDType(runtime_ctx, "X");
} else if (op_type == "reduce_sum") { } else if (op_type == "reduce_sum") {
phi::DataType in_dtype = GetInputDType(runtime_ctx, "X");
int dtype_attr = op.Attr<int>("out_dtype"); int dtype_attr = op.Attr<int>("out_dtype");
if (dtype_attr != -1) { if (dtype_attr != -1) {
dtype = phi::TransToPhiDataType(dtype_attr); dtype = phi::TransToPhiDataType(dtype_attr);
if (dtype == DataType::UNDEFINED) {
dtype = in_dtype;
}
} else { } else {
phi::DataType in_dtype = GetInputDType(runtime_ctx, "X");
dtype = dtype =
(in_dtype == DataType::BOOL || in_dtype == DataType::INT32) (in_dtype == DataType::BOOL || in_dtype == DataType::INT32)
? DataType::INT64 ? DataType::INT64
...@@ -489,7 +482,6 @@ void FakeInitializeOutputsForFunctionKernel( ...@@ -489,7 +482,6 @@ void FakeInitializeOutputsForFunctionKernel(
// analyze layout // analyze layout
phi::DataLayout layout = tensor_arg_def.layout; phi::DataLayout layout = tensor_arg_def.layout;
FakeInitializeTensorBase(dev_ctx, place, dtype, layout, out_tensor); FakeInitializeTensorBase(dev_ctx, place, dtype, layout, out_tensor);
} }
} }
......
...@@ -44,7 +44,6 @@ void TensorCopyImpl(const TENSOR& src, ...@@ -44,7 +44,6 @@ void TensorCopyImpl(const TENSOR& src,
TensorCopyImpl(src_copy, dst_place, ctx, dst); TensorCopyImpl(src_copy, dst_place, ctx, dst);
return; return;
} }
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to " VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place; << dst_place;
src.check_memory_size(); src.check_memory_size();
...@@ -325,7 +324,6 @@ void TensorCopySync(const phi::DenseTensor& src, ...@@ -325,7 +324,6 @@ void TensorCopySync(const phi::DenseTensor& src,
<< dst_place; << dst_place;
return; return;
} }
auto size = src.numel() * phi::SizeOf(src.dtype()); auto size = src.numel() * phi::SizeOf(src.dtype());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) { if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size); memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
......
...@@ -15,8 +15,12 @@ limitations under the License. */ ...@@ -15,8 +15,12 @@ limitations under the License. */
#include <memory> #include <memory>
#include <string> #include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h" #include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h" #include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle { namespace paddle {
namespace operators { namespace operators {
...@@ -25,75 +29,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel { ...@@ -25,75 +29,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
public: public:
using framework::OperatorWithKernel::OperatorWithKernel; using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM");
OP_INOUT_CHECK(
ctx->HasOutput("StateOut"), "Output", "StateOut", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
auto in_dims = ctx->GetInputDim("Input");
auto init_h_dims = ctx->GetInputDim("InitH");
auto init_c_dims = ctx->GetInputDim("InitC");
PADDLE_ENFORCE_EQ(in_dims.size(),
3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_h_dims.size(),
3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_h_dims.size()));
if (ctx->HasInput("SequenceLength")) {
auto seq_dims = ctx->GetInputDim("SequenceLength");
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
platform::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(
in_dims[1],
init_h_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_h_dims[1] is %d.",
in_dims[1],
init_h_dims[1]));
PADDLE_ENFORCE_EQ(init_c_dims,
init_h_dims,
platform::errors::InvalidArgument(
"The InitC dims and InitH "
"dims should be equal. But "
"received init_c_dims is %d and init_h_dims is %d.",
init_c_dims,
init_h_dims));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("LastH", init_c_dims);
ctx->SetOutputDim("LastC", init_h_dims);
}
protected: protected:
phi::KernelKey GetExpectedKernelType( phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override { const framework::ExecutionContext& ctx) const override {
...@@ -295,12 +230,18 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> { ...@@ -295,12 +230,18 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators } // namespace operators
} // namespace paddle } // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(cudnn_lstm,
CudnnLSTMInferShapeFunctor,
PD_INFER_META(phi::CudnnLSTMInferMeta));
namespace ops = paddle::operators; namespace ops = paddle::operators;
REGISTER_OPERATOR(cudnn_lstm, REGISTER_OPERATOR(cudnn_lstm,
ops::CudnnLSTMOp, ops::CudnnLSTMOp,
ops::CudnnLSTMOpMaker, ops::CudnnLSTMOpMaker,
ops::CudnnLSTMGradOpMaker<paddle::framework::OpDesc>, ops::CudnnLSTMGradOpMaker<paddle::framework::OpDesc>,
ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>); ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>,
CudnnLSTMInferShapeFunctor);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp); REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
// TODO(Shixiaowei02) Add ModifyInput support // TODO(Shixiaowei02) Add ModifyInput support
......
...@@ -270,7 +270,7 @@ static const T *GetInputTensorPtr(const DenseTensor *in_tensor, ...@@ -270,7 +270,7 @@ static const T *GetInputTensorPtr(const DenseTensor *in_tensor,
PADDLE_ENFORCE_NOT_NULL( PADDLE_ENFORCE_NOT_NULL(
in_tensor, in_tensor,
phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name)); phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (in_tensor->IsInitialized()) { if (in_tensor->initialized()) {
if (numel) *numel = in_tensor->numel(); if (numel) *numel = in_tensor->numel();
return in_tensor->data<T>(); return in_tensor->data<T>();
} else { } else {
...@@ -286,7 +286,7 @@ static T *GetSameInOutTensorPtr(const Context &dev_ctx, ...@@ -286,7 +286,7 @@ static T *GetSameInOutTensorPtr(const Context &dev_ctx,
const char *in_name, const char *in_name,
const char *out_name, const char *out_name,
int64_t *numel = nullptr) { int64_t *numel = nullptr) {
if (in_tensor == nullptr || !in_tensor->IsInitialized()) { if (in_tensor == nullptr || !in_tensor->initialized()) {
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
AllowNotExist, AllowNotExist,
true, true,
...@@ -1160,7 +1160,7 @@ struct VisitDTypeFunctor { ...@@ -1160,7 +1160,7 @@ struct VisitDTypeFunctor {
static std::string GetMinMaxStr(const phi::DenseTensor *x) { static std::string GetMinMaxStr(const phi::DenseTensor *x) {
if (x == nullptr) return "null"; if (x == nullptr) return "null";
if (!x->IsInitialized()) return "not_inited"; if (!x->initialized()) return "not_inited";
if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor"; if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor";
std::string str; std::string str;
VisitDTypeFunctor functor(x, &str); VisitDTypeFunctor functor(x, &str);
...@@ -1354,9 +1354,7 @@ void DistributedFusedLambKernel( ...@@ -1354,9 +1354,7 @@ void DistributedFusedLambKernel(
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL) #if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto stream = dev_ctx.stream(); auto stream = dev_ctx.stream();
auto place = dev_ctx.GetPlace(); auto place = dev_ctx.GetPlace();
found_inf->Resize({1}); found_inf->Resize({1});
// Step 1: Get fp16 param and grad tensors // Step 1: Get fp16 param and grad tensors
int64_t fp16_numel; int64_t fp16_numel;
auto *fp16_param_data = auto *fp16_param_data =
...@@ -1374,7 +1372,6 @@ void DistributedFusedLambKernel( ...@@ -1374,7 +1372,6 @@ void DistributedFusedLambKernel(
} else { } else {
fp16_param_data = nullptr; fp16_param_data = nullptr;
} }
// Step 2: Get fp32 param and grad tensors // Step 2: Get fp32 param and grad tensors
int64_t fp32_numel = 0; int64_t fp32_numel = 0;
auto *fp32_param_data = auto *fp32_param_data =
...@@ -1389,7 +1386,6 @@ void DistributedFusedLambKernel( ...@@ -1389,7 +1386,6 @@ void DistributedFusedLambKernel(
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"The element number in FP32FusedParam should be not " "The element number in FP32FusedParam should be not "
"less than FP16FusedParam.")); "less than FP16FusedParam."));
fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and
// fp16 master weight // fp16 master weight
bool has_fp32_param = (fp32_numel > 0); bool has_fp32_param = (fp32_numel > 0);
...@@ -1404,7 +1400,6 @@ void DistributedFusedLambKernel( ...@@ -1404,7 +1400,6 @@ void DistributedFusedLambKernel(
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Either FP32FusedGrad or FP16FusedGrad cannot be NULL.")); "Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
} }
auto numel = fp32_numel + fp16_numel; auto numel = fp32_numel + fp16_numel;
VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel
<< " , fp16_numel = " << fp16_numel; << " , fp16_numel = " << fp16_numel;
...@@ -1426,7 +1421,7 @@ void DistributedFusedLambKernel( ...@@ -1426,7 +1421,7 @@ void DistributedFusedLambKernel(
acc_step, acc_step,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1.")); "Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
bool is_initialized = acc_step->IsInitialized(); bool is_initialized = acc_step->initialized();
int64_t *acc_step_data; int64_t *acc_step_data;
if (is_initialized) { if (is_initialized) {
acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step); acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
...@@ -1437,14 +1432,13 @@ void DistributedFusedLambKernel( ...@@ -1437,14 +1432,13 @@ void DistributedFusedLambKernel(
*acc_step_data = 1; *acc_step_data = 1;
} }
int64_t rounded_step = (*acc_step_data) % acc_steps; int64_t rounded_step = (*acc_step_data) % acc_steps;
float *fp32_acc_grad_data = nullptr; float *fp32_acc_grad_data = nullptr;
if (has_fp32_param) { if (has_fp32_param) {
PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad, PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Output(FP32AccFusedGrad) cannot be nullptr " "Output(FP32AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1.")); "when Attr(acc_steps) > 1."));
if (!fp32_acc_grad->IsInitialized()) { if (!fp32_acc_grad->initialized()) {
fp32_acc_grad->Resize({static_cast<int64_t>(fp32_numel)}); fp32_acc_grad->Resize({static_cast<int64_t>(fp32_numel)});
fp32_acc_grad_data = dev_ctx.template Alloc<float>(fp32_acc_grad); fp32_acc_grad_data = dev_ctx.template Alloc<float>(fp32_acc_grad);
} else { } else {
...@@ -1459,7 +1453,7 @@ void DistributedFusedLambKernel( ...@@ -1459,7 +1453,7 @@ void DistributedFusedLambKernel(
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Output(FP16AccFusedGrad) cannot be nullptr " "Output(FP16AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1.")); "when Attr(acc_steps) > 1."));
if (!fp16_acc_grad->IsInitialized()) { if (!fp16_acc_grad->initialized()) {
auto acc_grad_size = auto acc_grad_size =
use_master_acc_grad ? (3 * fp16_numel) : fp16_numel; use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
fp16_acc_grad->Resize({static_cast<int64_t>(acc_grad_size)}); fp16_acc_grad->Resize({static_cast<int64_t>(acc_grad_size)});
...@@ -1544,11 +1538,9 @@ void DistributedFusedLambKernel( ...@@ -1544,11 +1538,9 @@ void DistributedFusedLambKernel(
} }
} }
} }
stop_update->Resize({1}); stop_update->Resize({1});
auto *stop_update_data = dev_ctx.template HostAlloc<bool>(stop_update); auto *stop_update_data = dev_ctx.template HostAlloc<bool>(stop_update);
auto *found_inf_cpu = dev_ctx.template HostAlloc<bool>(found_inf); auto *found_inf_cpu = dev_ctx.template HostAlloc<bool>(found_inf);
if (rounded_step != 0) { if (rounded_step != 0) {
*stop_update_data = true; *stop_update_data = true;
*found_inf_cpu = false; *found_inf_cpu = false;
...@@ -1599,7 +1591,6 @@ void DistributedFusedLambKernel( ...@@ -1599,7 +1591,6 @@ void DistributedFusedLambKernel(
int64_t partial_numel = 0; int64_t partial_numel = 0;
auto *moment1_data = GetSameInOutTensorPtr<float, Context>( auto *moment1_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel); dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel);
PADDLE_ENFORCE_EQ(numel % partial_numel, PADDLE_ENFORCE_EQ(numel % partial_numel,
0, 0,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
...@@ -1629,16 +1620,13 @@ void DistributedFusedLambKernel( ...@@ -1629,16 +1620,13 @@ void DistributedFusedLambKernel(
"exactly by the device number %d.", "exactly by the device number %d.",
fp16_numel, fp16_numel,
num_devices)); num_devices));
auto *moment2_data = GetSameInOutTensorPtr<float, Context>( auto *moment2_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out"); dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out");
auto *beta1_pow_data = GetSameInOutTensorPtr<float, Context>( auto *beta1_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut"); dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut");
auto *beta2_pow_data = GetSameInOutTensorPtr<float, Context>( auto *beta2_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut"); dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut");
auto *found_inf_data = dev_ctx.template Alloc<bool>(found_inf); auto *found_inf_data = dev_ctx.template Alloc<bool>(found_inf);
// Step 5: Get attributes weight_decay, beta1, beta2, epsilon, // Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id, // max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks // use_master_param_norm, is_grad_scaled_by_nranks
...@@ -1668,7 +1656,6 @@ void DistributedFusedLambKernel( ...@@ -1668,7 +1656,6 @@ void DistributedFusedLambKernel(
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place); paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm(); global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank(); global_rank = nccl_comm_handle->rank();
if (local_shard) { if (local_shard) {
auto *local_nccl_comm_handle = auto *local_nccl_comm_handle =
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place); paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
...@@ -1684,11 +1671,9 @@ void DistributedFusedLambKernel( ...@@ -1684,11 +1671,9 @@ void DistributedFusedLambKernel(
local_rank = global_rank; local_rank = global_rank;
} }
} }
memory_utils::Buffer grad_norm_square_buffer(place); memory_utils::Buffer grad_norm_square_buffer(place);
auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2); auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
memory_utils::Buffer cub_tmp_buffer(place); memory_utils::Buffer cub_tmp_buffer(place);
memory_utils::Buffer sum_grad_buffer(place); memory_utils::Buffer sum_grad_buffer(place);
float *fp32_sum_grad; float *fp32_sum_grad;
dtype::float16 *fp16_sum_grad; dtype::float16 *fp16_sum_grad;
...@@ -1721,7 +1706,6 @@ void DistributedFusedLambKernel( ...@@ -1721,7 +1706,6 @@ void DistributedFusedLambKernel(
fp32_sum_grad = const_cast<float *>(fp32_grad_data); fp32_sum_grad = const_cast<float *>(fp32_grad_data);
fp16_sum_grad = const_cast<dtype::float16 *>(fp16_grad_data); fp16_sum_grad = const_cast<dtype::float16 *>(fp16_grad_data);
} }
float rescale_grad = 1.0f; float rescale_grad = 1.0f;
if (!is_grad_scaled_by_nranks) { if (!is_grad_scaled_by_nranks) {
rescale_grad /= nranks; rescale_grad /= nranks;
...@@ -1845,7 +1829,6 @@ void DistributedFusedLambKernel( ...@@ -1845,7 +1829,6 @@ void DistributedFusedLambKernel(
} else { } else {
fp16_scale = cub_tmp_buffer.Alloc<dtype::float16>(1); fp16_scale = cub_tmp_buffer.Alloc<dtype::float16>(1);
} }
float clip_scale = 1.0f; float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) { if (is_grad_scaled_by_nranks) {
clip_scale *= nranks; clip_scale *= nranks;
...@@ -1988,7 +1971,6 @@ void DistributedFusedLambKernel( ...@@ -1988,7 +1971,6 @@ void DistributedFusedLambKernel(
external_comm, external_comm,
stream, stream,
dev_ctx); dev_ctx);
NCCLReduceScatterWithScale( NCCLReduceScatterWithScale(
fp16_grad_data, fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device, fp16_sum_grad + local_rank * fp16_numel_each_device,
...@@ -2064,9 +2046,7 @@ void DistributedFusedLambKernel( ...@@ -2064,9 +2046,7 @@ void DistributedFusedLambKernel(
auto *param_offsets_data = param_offsets.data<int>(); auto *param_offsets_data = param_offsets.data<int>();
const auto *fp32_partial_offsets_data = fp32_partial_offsets.data<int>(); const auto *fp32_partial_offsets_data = fp32_partial_offsets.data<int>();
const auto *fp16_partial_offsets_data = fp16_partial_offsets.data<int>(); const auto *fp16_partial_offsets_data = fp16_partial_offsets.data<int>();
auto *step_data = step->data<int64_t>(); auto *step_data = step->data<int64_t>();
VLOG(1) << "FusedParamOffsets: " VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(param_offsets_data, << FlattenToString(param_offsets_data,
param_offsets.numel(), param_offsets.numel(),
...@@ -2079,7 +2059,6 @@ void DistributedFusedLambKernel( ...@@ -2079,7 +2059,6 @@ void DistributedFusedLambKernel(
<< FlattenToString(fp16_partial_offsets_data, << FlattenToString(fp16_partial_offsets_data,
fp16_partial_offsets.numel(), fp16_partial_offsets.numel(),
fp16_partial_offsets.place()); fp16_partial_offsets.place());
memory_utils::Buffer trust_ratio_div_buffer(place); memory_utils::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel); auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = local_rank * fp32_numel_each_device; auto fp32_offset = local_rank * fp32_numel_each_device;
...@@ -2178,7 +2157,6 @@ void DistributedFusedLambKernel( ...@@ -2178,7 +2157,6 @@ void DistributedFusedLambKernel(
fp16_local_param_num, fp16_local_param_num,
param_square_norm + fp16_local_start_idx); param_square_norm + fp16_local_start_idx);
} }
MultiTensorL2Norm(place, MultiTensorL2Norm(place,
stream, stream,
trust_ratio_div, trust_ratio_div,
...@@ -2191,7 +2169,6 @@ void DistributedFusedLambKernel( ...@@ -2191,7 +2169,6 @@ void DistributedFusedLambKernel(
fp16_partial_offsets_data, fp16_partial_offsets_data,
fp16_local_param_num, fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx); trust_ratio_div_square_norm + fp16_local_start_idx);
VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: " VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
<< FlattenToString(trust_ratio_div_square_norm, param_num, place); << FlattenToString(trust_ratio_div_square_norm, param_num, place);
if (num_devices > 1) { if (num_devices > 1) {
...@@ -2296,6 +2273,12 @@ PD_REGISTER_KERNEL(distributed_fused_lamb, ...@@ -2296,6 +2273,12 @@ PD_REGISTER_KERNEL(distributed_fused_lamb,
ALL_LAYOUT, ALL_LAYOUT,
phi::fusion::DistributedFusedLambKernel, phi::fusion::DistributedFusedLambKernel,
float) { float) {
kernel->InputAt(10).SetBackend(phi::Backend::CPU);
kernel->InputAt(11).SetBackend(phi::Backend::CPU);
kernel->InputAt(12).SetBackend(phi::Backend::CPU);
kernel->InputAt(13).SetBackend(phi::Backend::CPU);
kernel->InputAt(14).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16); kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32); kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
......
...@@ -978,6 +978,85 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -978,6 +978,85 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
out->share_lod(*x.at(0)); out->share_lod(*x.at(0));
} }
void CudnnLSTMInferMeta(
const MetaTensor& x,
const MetaTensor& init_h,
const MetaTensor& init_c,
const MetaTensor& w,
const paddle::optional<std::vector<const MetaTensor*>>& weight_list,
const MetaTensor& sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
MetaTensor* out,
MetaTensor* last_h,
MetaTensor* last_c,
MetaTensor* reserve,
MetaTensor* state_out) {
auto in_dims = x.dims();
auto init_h_dims = init_h.dims();
auto init_c_dims = init_c.dims();
PADDLE_ENFORCE_EQ(in_dims.size(),
3,
phi::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_h_dims.size(),
3,
phi::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_h_dims.size()));
if (sequence_length) {
auto seq_dims = sequence_length.dims();
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
phi::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(in_dims[1],
init_h_dims[1],
phi::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_h_dims[1] is %d.",
in_dims[1],
init_h_dims[1]));
PADDLE_ENFORCE_EQ(init_c_dims,
init_h_dims,
phi::errors::InvalidArgument(
"The InitC dims and InitH "
"dims should be equal. But "
"received init_c_dims is %d and init_h_dims is %d.",
init_c_dims,
init_h_dims));
auto out_dims = in_dims;
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
last_h->set_dims(init_c_dims);
last_h->set_dtype(x.dtype());
last_c->set_dims(init_h_dims);
last_c->set_dtype(x.dtype());
reserve->set_dtype(phi::DataType::UINT8);
state_out->set_dtype(phi::DataType::UINT8);
}
inline int ConvOutputSize( inline int ConvOutputSize(
int input_size, int filter_size, int dilation, int padding, int stride) { int input_size, int filter_size, int dilation, int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1; const int dkernel = dilation * (filter_size - 1) + 1;
......
...@@ -239,6 +239,25 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x, ...@@ -239,6 +239,25 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out, MetaTensor* out,
MetaConfig config = MetaConfig()); MetaConfig config = MetaConfig());
void CudnnLSTMInferMeta(
const MetaTensor& x,
const MetaTensor& init_h,
const MetaTensor& init_c,
const MetaTensor& w,
const paddle::optional<std::vector<const MetaTensor*>>& weight_list,
const MetaTensor& sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
MetaTensor* out,
MetaTensor* last_h,
MetaTensor* last_c,
MetaTensor* reserve,
MetaTensor* state_out);
void DeformableConvInferMeta(const MetaTensor& x, void DeformableConvInferMeta(const MetaTensor& x,
const MetaTensor& offset, const MetaTensor& offset,
const MetaTensor& filter, const MetaTensor& filter,
......
...@@ -43,5 +43,4 @@ PD_REGISTER_KERNEL(eigh, ...@@ -43,5 +43,4 @@ PD_REGISTER_KERNEL(eigh,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
} }
...@@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, ...@@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
ALL_LAYOUT, ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel, phi::Pow2DecayWithLinearWarmupKernel,
float, float,
double) {} double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask, ...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -101,4 +101,6 @@ PD_REGISTER_KERNEL(fusion_group, ...@@ -101,4 +101,6 @@ PD_REGISTER_KERNEL(fusion_group,
phi::fusion::FusionGroupKernel, phi::fusion::FusionGroupKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -77,5 +77,5 @@ PD_REGISTER_KERNEL(abs, ...@@ -77,5 +77,5 @@ PD_REGISTER_KERNEL(abs,
phi::dtype::bfloat16, phi::dtype::bfloat16,
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
} }
...@@ -24,4 +24,10 @@ PD_REGISTER_KERNEL(adadelta, ...@@ -24,4 +24,10 @@ PD_REGISTER_KERNEL(adadelta,
phi::AdadeltaKernel, phi::AdadeltaKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
}
}
...@@ -132,4 +132,10 @@ PD_REGISTER_KERNEL(adamax, ...@@ -132,4 +132,10 @@ PD_REGISTER_KERNEL(adamax,
phi::AdamaxKernel, phi::AdamaxKernel,
float, float,
double, double,
phi::dtype::float16) {} phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
}
}
...@@ -176,7 +176,7 @@ void CudnnLSTMKernel( ...@@ -176,7 +176,7 @@ void CudnnLSTMKernel(
int seq_length = x.dims()[0]; int seq_length = x.dims()[0];
int batch_size = x.dims()[1]; int batch_size = x.dims()[1];
int input_size = x.dims()[2]; int input_size = x.dims()[2];
bool state_initialized = state_out->IsInitialized() ? true : false; bool state_initialized = state_out->initialized() ? true : false;
size_t workspace_size; size_t workspace_size;
size_t reserve_size; size_t reserve_size;
...@@ -188,7 +188,7 @@ void CudnnLSTMKernel( ...@@ -188,7 +188,7 @@ void CudnnLSTMKernel(
auto stream = ctx.stream(); auto stream = ctx.stream();
auto *running_w = w.get_ptr(); auto *running_w = w.get_ptr();
if (is_test && running_w != nullptr) { if (is_test && running_w != nullptr) {
w_initialized = running_w->IsInitialized() ? true : false; w_initialized = running_w->initialized() ? true : false;
weight_numel = running_w->numel(); weight_numel = running_w->numel();
} }
if (!w_initialized) { if (!w_initialized) {
...@@ -362,12 +362,14 @@ void CudnnLSTMKernel( ...@@ -362,12 +362,14 @@ void CudnnLSTMKernel(
#ifdef PADDLE_WITH_HIP #ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) { PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) {
kernel->InputAt(5).SetDataType(phi::DataType::INT32);
kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); kernel->OutputAt(3).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); kernel->OutputAt(4).SetDataType(phi::DataType::UINT8);
} }
#else #else
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) { cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) {
kernel->InputAt(5).SetDataType(phi::DataType::INT32);
kernel->OutputAt(3).SetDataType(phi::DataType::UINT8); kernel->OutputAt(3).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(4).SetDataType(phi::DataType::UINT8); kernel->OutputAt(4).SetDataType(phi::DataType::UINT8);
} }
......
...@@ -46,7 +46,6 @@ PD_REGISTER_KERNEL(eigh, // cuda_only ...@@ -46,7 +46,6 @@ PD_REGISTER_KERNEL(eigh, // cuda_only
phi::dtype::complex<float>, phi::dtype::complex<float>,
phi::dtype::complex<double>) { phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype())); kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
} }
#endif // not PADDLE_WITH_HIP #endif // not PADDLE_WITH_HIP
...@@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, ...@@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
ALL_LAYOUT, ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel, phi::Pow2DecayWithLinearWarmupKernel,
float, float,
double) {} double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask, ...@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask,
float, float,
double, double,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
...@@ -85,7 +85,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, ...@@ -85,7 +85,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same.")); "Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
step.IsInitialized(), step.initialized(),
true, true,
phi::errors::InvalidArgument("Input(Step) must be initialized.")); phi::errors::InvalidArgument("Input(Step) must be initialized."));
......
...@@ -280,7 +280,13 @@ PD_REGISTER_KERNEL(update_loss_scaling, ...@@ -280,7 +280,13 @@ PD_REGISTER_KERNEL(update_loss_scaling,
ALL_LAYOUT, ALL_LAYOUT,
phi::UpdateLossScalingKernel, phi::UpdateLossScalingKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
kernel->OutputAt(3).SetDataType(phi::DataType::INT32);
}
PD_REGISTER_KERNEL(check_finite_and_unscale, PD_REGISTER_KERNEL(check_finite_and_unscale,
XPU, XPU,
......
...@@ -369,4 +369,6 @@ PD_REGISTER_KERNEL(max_pool2d_with_index, ...@@ -369,4 +369,6 @@ PD_REGISTER_KERNEL(max_pool2d_with_index,
ALL_LAYOUT, ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel, phi::MaxPool2dWithIndexKernel,
float, float,
phi::dtype::float16) {} phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}
...@@ -41,7 +41,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx, ...@@ -41,7 +41,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
phi::errors::InvalidArgument( phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same.")); "Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ( PADDLE_ENFORCE_EQ(
step.IsInitialized(), step.initialized(),
true, true,
phi::errors::InvalidArgument("Input(Step) must be initialized.")); phi::errors::InvalidArgument("Input(Step) must be initialized."));
...@@ -68,4 +68,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup, ...@@ -68,4 +68,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
XPU, XPU,
ALL_LAYOUT, ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel, phi::Pow2DecayWithLinearWarmupKernel,
float) {} float) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
...@@ -45,4 +45,6 @@ PD_REGISTER_KERNEL(sum_raw, ...@@ -45,4 +45,6 @@ PD_REGISTER_KERNEL(sum_raw,
phi::dtype::float16, phi::dtype::float16,
int8_t, int8_t,
int, int,
int64_t) {} int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册