未验证 提交 733eca85 编写于 作者: S Sonder 提交者: GitHub

remove ops from OpsWithFluidKernelNeedMoveToPhi set (#54007)

* remove ops from OpsWithFluidKernelNeedMoveToPhi set

* open static build flag

* OpsWithFluidKernelNeedMoveToPhi

* open new_executor_static_build

* add infermate for cudnn_lstm

* fix

* update

* fix

* update

* update

* update

* fix pow2 decay

* fix pow2 decay

* recover analysis_predictor.cc

* fix pow2 decay

* fix cudnn lstm

* add output register info for svd

* fix pow2_decay_with_linear_warmup_kernel

* recover test lstm cudnn

* recover svg register codes

* fix register info

* fix reduce sum register info

* add output info for adadelta

* add output info for adadelta

* add output info for adamax

* fix complex abs register info

* add register info for cudnn_lstm_grad

* recover

* fix lstm cudnn

* fix

* fix xpu output registe info

* remove std::cout

* add backend

* remove output info in pow2_decay_with_linear_warmup_kernel

* add judgment in TensorShouldBeFakeInitialized

* recover power_

* close new_executor_static_build

* fix set_value_xpu
上级 59dd97af
......@@ -40,20 +40,6 @@ std::set<std::string> OpsCanSkipedFakeAllocInStaticBuild = {
"fetch_v2",
"nop"};
// Cannot static analysis these Ops' output dtype or backend because their
// kernels have not moved to PHI yet.
std::set<std::string> OpsWithFluidKernelNeedMoveToPhi = {
"cudnn_lstm",
"dequantize",
"distributed_fused_lamb",
"fused_batch_norm_act",
"fused_batch_norm_act_grad",
"fusion_group",
"pow2_decay_with_linear_warmup",
"sequence_mask",
"sequence_pool",
"stft"};
std::set<std::string> StaticBuildBlackList = {
"batch_norm" /*: to handle reserve_space output*/,
"cinn_instruction_run" /*: to handle subgraph infermeta*/,
......@@ -95,8 +81,7 @@ bool BlockCanBeStaticBuilt(const framework::BlockDesc& block) {
bool has_fluid_kernel = OperatorWithKernel::AllOpKernels().count(op_type);
bool has_structured_kernel =
phi::KernelFactory::Instance().HasStructuredKernel(op_type);
bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel) &&
OpsWithFluidKernelNeedMoveToPhi.count(op_type);
bool need_move_to_phi = (has_fluid_kernel || has_structured_kernel);
KernelCode kernel_code =
(in_black_list << 7) + (is_operator_base << 6) + (is_custom_op << 5) +
......@@ -162,6 +147,11 @@ bool TensorShouldBeFakeInitialized(const OperatorBase& op,
return false;
}
if (op_type == "distributed_fused_lamb" && parameter_name == "ParamOut") {
VLOG(2) << "Skip fake initialization for: " << parameter_name;
return false;
}
if (op_type == "fake_quantize_range_abs_max") {
if (op.Attr<bool>("is_test") &&
(parameter_name == "OutScale" || parameter_name == "OutScales")) {
......@@ -467,11 +457,14 @@ void FakeInitializeOutputsForFunctionKernel(
} else if (op_type == "layer_norm") {
dtype = InferMPDType(runtime_ctx, "X");
} else if (op_type == "reduce_sum") {
phi::DataType in_dtype = GetInputDType(runtime_ctx, "X");
int dtype_attr = op.Attr<int>("out_dtype");
if (dtype_attr != -1) {
dtype = phi::TransToPhiDataType(dtype_attr);
if (dtype == DataType::UNDEFINED) {
dtype = in_dtype;
}
} else {
phi::DataType in_dtype = GetInputDType(runtime_ctx, "X");
dtype =
(in_dtype == DataType::BOOL || in_dtype == DataType::INT32)
? DataType::INT64
......@@ -489,7 +482,6 @@ void FakeInitializeOutputsForFunctionKernel(
// analyze layout
phi::DataLayout layout = tensor_arg_def.layout;
FakeInitializeTensorBase(dev_ctx, place, dtype, layout, out_tensor);
}
}
......
......@@ -44,7 +44,6 @@ void TensorCopyImpl(const TENSOR& src,
TensorCopyImpl(src_copy, dst_place, ctx, dst);
return;
}
VLOG(3) << "TensorCopy " << src.dims() << " from " << src.place() << " to "
<< dst_place;
src.check_memory_size();
......@@ -325,7 +324,6 @@ void TensorCopySync(const phi::DenseTensor& src,
<< dst_place;
return;
}
auto size = src.numel() * phi::SizeOf(src.dtype());
if (platform::is_cpu_place(src_place) && platform::is_cpu_place(dst_place)) {
memory::Copy(dst_place, dst_ptr, src_place, src_ptr, size);
......
......@@ -15,8 +15,12 @@ limitations under the License. */
#include <memory>
#include <string>
#include "paddle/fluid/framework/infershape_utils.h"
#include "paddle/fluid/framework/op_registry.h"
#include "paddle/fluid/framework/op_version_registry.h"
#include "paddle/phi/core/infermeta_utils.h"
#include "paddle/phi/infermeta/multiary.h"
namespace paddle {
namespace operators {
......@@ -25,75 +29,6 @@ class CudnnLSTMOp : public framework::OperatorWithKernel {
public:
using framework::OperatorWithKernel::OperatorWithKernel;
void InferShape(framework::InferShapeContext* ctx) const override {
OP_INOUT_CHECK(ctx->HasInput("Input"), "Input", "Input", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitH"), "Input", "InitH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasInput("InitC"), "Input", "InitC", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Reserve"), "Output", "Reserve", "CudnnLSTM");
OP_INOUT_CHECK(
ctx->HasOutput("StateOut"), "Output", "StateOut", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("Out"), "Output", "Out", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastH"), "Output", "LastH", "CudnnLSTM");
OP_INOUT_CHECK(ctx->HasOutput("LastC"), "Output", "LastC", "CudnnLSTM");
auto in_dims = ctx->GetInputDim("Input");
auto init_h_dims = ctx->GetInputDim("InitH");
auto init_c_dims = ctx->GetInputDim("InitC");
PADDLE_ENFORCE_EQ(in_dims.size(),
3,
platform::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_h_dims.size(),
3,
platform::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_h_dims.size()));
if (ctx->HasInput("SequenceLength")) {
auto seq_dims = ctx->GetInputDim("SequenceLength");
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
platform::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(
in_dims[1],
init_h_dims[1],
platform::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_h_dims[1] is %d.",
in_dims[1],
init_h_dims[1]));
PADDLE_ENFORCE_EQ(init_c_dims,
init_h_dims,
platform::errors::InvalidArgument(
"The InitC dims and InitH "
"dims should be equal. But "
"received init_c_dims is %d and init_h_dims is %d.",
init_c_dims,
init_h_dims));
auto out_dims = in_dims;
auto hidden_size = ctx->Attrs().Get<int>("hidden_size");
bool is_bidirec = ctx->Attrs().Get<bool>("is_bidirec");
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
ctx->SetOutputDim("Out", out_dims);
ctx->SetOutputDim("LastH", init_c_dims);
ctx->SetOutputDim("LastC", init_h_dims);
}
protected:
phi::KernelKey GetExpectedKernelType(
const framework::ExecutionContext& ctx) const override {
......@@ -295,12 +230,18 @@ class CudnnLSTMGradOpMaker : public framework::SingleGradOpMaker<T> {
} // namespace operators
} // namespace paddle
DECLARE_INFER_SHAPE_FUNCTOR(cudnn_lstm,
CudnnLSTMInferShapeFunctor,
PD_INFER_META(phi::CudnnLSTMInferMeta));
namespace ops = paddle::operators;
REGISTER_OPERATOR(cudnn_lstm,
ops::CudnnLSTMOp,
ops::CudnnLSTMOpMaker,
ops::CudnnLSTMGradOpMaker<paddle::framework::OpDesc>,
ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>);
ops::CudnnLSTMGradOpMaker<paddle::imperative::OpBase>,
CudnnLSTMInferShapeFunctor);
REGISTER_OPERATOR(cudnn_lstm_grad, ops::CudnnLSTMGradOp);
// TODO(Shixiaowei02) Add ModifyInput support
......
......@@ -270,7 +270,7 @@ static const T *GetInputTensorPtr(const DenseTensor *in_tensor,
PADDLE_ENFORCE_NOT_NULL(
in_tensor,
phi::errors::InvalidArgument("Input(%s) cannot be NULL.", in_name));
if (in_tensor->IsInitialized()) {
if (in_tensor->initialized()) {
if (numel) *numel = in_tensor->numel();
return in_tensor->data<T>();
} else {
......@@ -286,7 +286,7 @@ static T *GetSameInOutTensorPtr(const Context &dev_ctx,
const char *in_name,
const char *out_name,
int64_t *numel = nullptr) {
if (in_tensor == nullptr || !in_tensor->IsInitialized()) {
if (in_tensor == nullptr || !in_tensor->initialized()) {
PADDLE_ENFORCE_EQ(
AllowNotExist,
true,
......@@ -1160,7 +1160,7 @@ struct VisitDTypeFunctor {
static std::string GetMinMaxStr(const phi::DenseTensor *x) {
if (x == nullptr) return "null";
if (!x->IsInitialized()) return "not_inited";
if (!x->initialized()) return "not_inited";
if (x->place().GetType() != phi::AllocationType::GPU) return "CPUTensor";
std::string str;
VisitDTypeFunctor functor(x, &str);
......@@ -1354,9 +1354,7 @@ void DistributedFusedLambKernel(
#if defined(PADDLE_WITH_NCCL) || defined(PADDLE_WITH_RCCL)
auto stream = dev_ctx.stream();
auto place = dev_ctx.GetPlace();
found_inf->Resize({1});
// Step 1: Get fp16 param and grad tensors
int64_t fp16_numel;
auto *fp16_param_data =
......@@ -1374,7 +1372,6 @@ void DistributedFusedLambKernel(
} else {
fp16_param_data = nullptr;
}
// Step 2: Get fp32 param and grad tensors
int64_t fp32_numel = 0;
auto *fp32_param_data =
......@@ -1389,7 +1386,6 @@ void DistributedFusedLambKernel(
phi::errors::InvalidArgument(
"The element number in FP32FusedParam should be not "
"less than FP16FusedParam."));
fp32_numel -= fp16_numel; // the FP32FusedParam contains fp32 param and
// fp16 master weight
bool has_fp32_param = (fp32_numel > 0);
......@@ -1404,7 +1400,6 @@ void DistributedFusedLambKernel(
phi::errors::InvalidArgument(
"Either FP32FusedGrad or FP16FusedGrad cannot be NULL."));
}
auto numel = fp32_numel + fp16_numel;
VLOG(1) << "numel = " << numel << " , fp32_numel = " << fp32_numel
<< " , fp16_numel = " << fp16_numel;
......@@ -1426,7 +1421,7 @@ void DistributedFusedLambKernel(
acc_step,
phi::errors::InvalidArgument(
"Output(AccStep) cannot be nullptr when Attr(acc_steps) > 1."));
bool is_initialized = acc_step->IsInitialized();
bool is_initialized = acc_step->initialized();
int64_t *acc_step_data;
if (is_initialized) {
acc_step_data = dev_ctx.template HostAlloc<int64_t>(acc_step);
......@@ -1437,14 +1432,13 @@ void DistributedFusedLambKernel(
*acc_step_data = 1;
}
int64_t rounded_step = (*acc_step_data) % acc_steps;
float *fp32_acc_grad_data = nullptr;
if (has_fp32_param) {
PADDLE_ENFORCE_NOT_NULL(fp32_acc_grad,
phi::errors::InvalidArgument(
"Output(FP32AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1."));
if (!fp32_acc_grad->IsInitialized()) {
if (!fp32_acc_grad->initialized()) {
fp32_acc_grad->Resize({static_cast<int64_t>(fp32_numel)});
fp32_acc_grad_data = dev_ctx.template Alloc<float>(fp32_acc_grad);
} else {
......@@ -1459,7 +1453,7 @@ void DistributedFusedLambKernel(
phi::errors::InvalidArgument(
"Output(FP16AccFusedGrad) cannot be nullptr "
"when Attr(acc_steps) > 1."));
if (!fp16_acc_grad->IsInitialized()) {
if (!fp16_acc_grad->initialized()) {
auto acc_grad_size =
use_master_acc_grad ? (3 * fp16_numel) : fp16_numel;
fp16_acc_grad->Resize({static_cast<int64_t>(acc_grad_size)});
......@@ -1544,11 +1538,9 @@ void DistributedFusedLambKernel(
}
}
}
stop_update->Resize({1});
auto *stop_update_data = dev_ctx.template HostAlloc<bool>(stop_update);
auto *found_inf_cpu = dev_ctx.template HostAlloc<bool>(found_inf);
if (rounded_step != 0) {
*stop_update_data = true;
*found_inf_cpu = false;
......@@ -1599,7 +1591,6 @@ void DistributedFusedLambKernel(
int64_t partial_numel = 0;
auto *moment1_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment1, moment1_out, "Moment1", "Moment1Out", &partial_numel);
PADDLE_ENFORCE_EQ(numel % partial_numel,
0,
phi::errors::InvalidArgument(
......@@ -1629,16 +1620,13 @@ void DistributedFusedLambKernel(
"exactly by the device number %d.",
fp16_numel,
num_devices));
auto *moment2_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &moment2, moment2_out, "Moment2", "Moment2Out");
auto *beta1_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta1_pow, beta1_pow_out, "Beta1Pow", "Beta1PowOut");
auto *beta2_pow_data = GetSameInOutTensorPtr<float, Context>(
dev_ctx, &beta2_pow, beta2_pow_out, "Beta2Pow", "Beta2PowOut");
auto *found_inf_data = dev_ctx.template Alloc<bool>(found_inf);
// Step 5: Get attributes weight_decay, beta1, beta2, epsilon,
// max_grad_norm, ring_id,
// use_master_param_norm, is_grad_scaled_by_nranks
......@@ -1668,7 +1656,6 @@ void DistributedFusedLambKernel(
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[0], place);
global_comm = nccl_comm_handle->comm();
global_rank = nccl_comm_handle->rank();
if (local_shard) {
auto *local_nccl_comm_handle =
paddle::platform::NCCLCommContext::Instance().Get(ring_ids[1], place);
......@@ -1684,11 +1671,9 @@ void DistributedFusedLambKernel(
local_rank = global_rank;
}
}
memory_utils::Buffer grad_norm_square_buffer(place);
auto *fp32_square_grad_norm = grad_norm_square_buffer.Alloc<float>(2);
memory_utils::Buffer cub_tmp_buffer(place);
memory_utils::Buffer sum_grad_buffer(place);
float *fp32_sum_grad;
dtype::float16 *fp16_sum_grad;
......@@ -1721,7 +1706,6 @@ void DistributedFusedLambKernel(
fp32_sum_grad = const_cast<float *>(fp32_grad_data);
fp16_sum_grad = const_cast<dtype::float16 *>(fp16_grad_data);
}
float rescale_grad = 1.0f;
if (!is_grad_scaled_by_nranks) {
rescale_grad /= nranks;
......@@ -1845,7 +1829,6 @@ void DistributedFusedLambKernel(
} else {
fp16_scale = cub_tmp_buffer.Alloc<dtype::float16>(1);
}
float clip_scale = 1.0f;
if (is_grad_scaled_by_nranks) {
clip_scale *= nranks;
......@@ -1988,7 +1971,6 @@ void DistributedFusedLambKernel(
external_comm,
stream,
dev_ctx);
NCCLReduceScatterWithScale(
fp16_grad_data,
fp16_sum_grad + local_rank * fp16_numel_each_device,
......@@ -2064,9 +2046,7 @@ void DistributedFusedLambKernel(
auto *param_offsets_data = param_offsets.data<int>();
const auto *fp32_partial_offsets_data = fp32_partial_offsets.data<int>();
const auto *fp16_partial_offsets_data = fp16_partial_offsets.data<int>();
auto *step_data = step->data<int64_t>();
VLOG(1) << "FusedParamOffsets: "
<< FlattenToString(param_offsets_data,
param_offsets.numel(),
......@@ -2079,7 +2059,6 @@ void DistributedFusedLambKernel(
<< FlattenToString(fp16_partial_offsets_data,
fp16_partial_offsets.numel(),
fp16_partial_offsets.place());
memory_utils::Buffer trust_ratio_div_buffer(place);
auto *trust_ratio_div = trust_ratio_div_buffer.Alloc<float>(partial_numel);
auto fp32_offset = local_rank * fp32_numel_each_device;
......@@ -2178,7 +2157,6 @@ void DistributedFusedLambKernel(
fp16_local_param_num,
param_square_norm + fp16_local_start_idx);
}
MultiTensorL2Norm(place,
stream,
trust_ratio_div,
......@@ -2191,7 +2169,6 @@ void DistributedFusedLambKernel(
fp16_partial_offsets_data,
fp16_local_param_num,
trust_ratio_div_square_norm + fp16_local_start_idx);
VLOG(1) << "TrustRatioDiv L2-Norm before allreduce: "
<< FlattenToString(trust_ratio_div_square_norm, param_num, place);
if (num_devices > 1) {
......@@ -2296,6 +2273,12 @@ PD_REGISTER_KERNEL(distributed_fused_lamb,
ALL_LAYOUT,
phi::fusion::DistributedFusedLambKernel,
float) {
kernel->InputAt(10).SetBackend(phi::Backend::CPU);
kernel->InputAt(11).SetBackend(phi::Backend::CPU);
kernel->InputAt(12).SetBackend(phi::Backend::CPU);
kernel->InputAt(13).SetBackend(phi::Backend::CPU);
kernel->InputAt(14).SetBackend(phi::Backend::CPU);
kernel->OutputAt(0).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT16);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
......
......@@ -978,6 +978,85 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
out->share_lod(*x.at(0));
}
void CudnnLSTMInferMeta(
const MetaTensor& x,
const MetaTensor& init_h,
const MetaTensor& init_c,
const MetaTensor& w,
const paddle::optional<std::vector<const MetaTensor*>>& weight_list,
const MetaTensor& sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
MetaTensor* out,
MetaTensor* last_h,
MetaTensor* last_c,
MetaTensor* reserve,
MetaTensor* state_out) {
auto in_dims = x.dims();
auto init_h_dims = init_h.dims();
auto init_c_dims = init_c.dims();
PADDLE_ENFORCE_EQ(in_dims.size(),
3,
phi::errors::InvalidArgument(
"The rank of Input in CudnnLSTM must be 3. But "
"received Input's rank is %d.",
in_dims.size()));
PADDLE_ENFORCE_EQ(init_h_dims.size(),
3,
phi::errors::InvalidArgument(
"The rank of InitH in CudnnLSTM must be 3. But "
"received InitH's rank is %d.",
init_h_dims.size()));
if (sequence_length) {
auto seq_dims = sequence_length.dims();
PADDLE_ENFORCE_EQ(
in_dims[1],
seq_dims[0],
phi::errors::InvalidArgument(
"The size of SequenceLength has to equal the batch_size. But "
"received batch_size is %d and the size of SequenceLength is %d.",
in_dims[1],
seq_dims[0]));
}
PADDLE_ENFORCE_EQ(in_dims[1],
init_h_dims[1],
phi::errors::InvalidArgument(
"The in_dims[1] (Input dims) and init_h_dims[1] (InitH "
"dims) should be equal. But "
"received in_dims[1] is %d and init_h_dims[1] is %d.",
in_dims[1],
init_h_dims[1]));
PADDLE_ENFORCE_EQ(init_c_dims,
init_h_dims,
phi::errors::InvalidArgument(
"The InitC dims and InitH "
"dims should be equal. But "
"received init_c_dims is %d and init_h_dims is %d.",
init_c_dims,
init_h_dims));
auto out_dims = in_dims;
out_dims[2] = is_bidirec ? hidden_size * 2 : hidden_size;
out->set_dims(out_dims);
out->set_dtype(x.dtype());
last_h->set_dims(init_c_dims);
last_h->set_dtype(x.dtype());
last_c->set_dims(init_h_dims);
last_c->set_dtype(x.dtype());
reserve->set_dtype(phi::DataType::UINT8);
state_out->set_dtype(phi::DataType::UINT8);
}
inline int ConvOutputSize(
int input_size, int filter_size, int dilation, int padding, int stride) {
const int dkernel = dilation * (filter_size - 1) + 1;
......
......@@ -239,6 +239,25 @@ void ConcatInferMeta(const std::vector<const MetaTensor*>& x,
MetaTensor* out,
MetaConfig config = MetaConfig());
void CudnnLSTMInferMeta(
const MetaTensor& x,
const MetaTensor& init_h,
const MetaTensor& init_c,
const MetaTensor& w,
const paddle::optional<std::vector<const MetaTensor*>>& weight_list,
const MetaTensor& sequence_length,
float dropout_prob,
bool is_bidirec,
int hidden_size,
int num_layers,
bool is_test,
int seed,
MetaTensor* out,
MetaTensor* last_h,
MetaTensor* last_c,
MetaTensor* reserve,
MetaTensor* state_out);
void DeformableConvInferMeta(const MetaTensor& x,
const MetaTensor& offset,
const MetaTensor& filter,
......
......@@ -43,5 +43,4 @@ PD_REGISTER_KERNEL(eigh,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float,
double) {}
double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -101,4 +101,6 @@ PD_REGISTER_KERNEL(fusion_group,
phi::fusion::FusionGroupKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -77,5 +77,5 @@ PD_REGISTER_KERNEL(abs,
phi::dtype::bfloat16,
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->InputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
......@@ -24,4 +24,10 @@ PD_REGISTER_KERNEL(adadelta,
phi::AdadeltaKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
}
}
......@@ -132,4 +132,10 @@ PD_REGISTER_KERNEL(adamax,
phi::AdamaxKernel,
float,
double,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(2).SetDataType(phi::DataType::FLOAT32);
kernel->OutputAt(3).SetDataType(phi::DataType::FLOAT32);
}
}
......@@ -176,7 +176,7 @@ void CudnnLSTMKernel(
int seq_length = x.dims()[0];
int batch_size = x.dims()[1];
int input_size = x.dims()[2];
bool state_initialized = state_out->IsInitialized() ? true : false;
bool state_initialized = state_out->initialized() ? true : false;
size_t workspace_size;
size_t reserve_size;
......@@ -188,7 +188,7 @@ void CudnnLSTMKernel(
auto stream = ctx.stream();
auto *running_w = w.get_ptr();
if (is_test && running_w != nullptr) {
w_initialized = running_w->IsInitialized() ? true : false;
w_initialized = running_w->initialized() ? true : false;
weight_numel = running_w->numel();
}
if (!w_initialized) {
......@@ -362,12 +362,14 @@ void CudnnLSTMKernel(
#ifdef PADDLE_WITH_HIP
PD_REGISTER_KERNEL(cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float) {
kernel->InputAt(5).SetDataType(phi::DataType::INT32);
kernel->OutputAt(3).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(4).SetDataType(phi::DataType::UINT8);
}
#else
PD_REGISTER_KERNEL(
cudnn_lstm, GPU, ALL_LAYOUT, phi::CudnnLSTMKernel, float, double) {
kernel->InputAt(5).SetDataType(phi::DataType::INT32);
kernel->OutputAt(3).SetDataType(phi::DataType::UINT8);
kernel->OutputAt(4).SetDataType(phi::DataType::UINT8);
}
......
......@@ -46,7 +46,6 @@ PD_REGISTER_KERNEL(eigh, // cuda_only
phi::dtype::complex<float>,
phi::dtype::complex<double>) {
kernel->OutputAt(0).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
kernel->OutputAt(1).SetDataType(phi::dtype::ToReal(kernel_key.dtype()));
}
#endif // not PADDLE_WITH_HIP
......@@ -22,4 +22,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float,
double) {}
double) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -24,4 +24,6 @@ PD_REGISTER_KERNEL(sequence_mask,
float,
double,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
......@@ -85,7 +85,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ(
step.IsInitialized(),
step.initialized(),
true,
phi::errors::InvalidArgument("Input(Step) must be initialized."));
......
......@@ -280,7 +280,13 @@ PD_REGISTER_KERNEL(update_loss_scaling,
ALL_LAYOUT,
phi::UpdateLossScalingKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
if (kernel_key.dtype() == phi::DataType::FLOAT16) {
kernel->OutputAt(1).SetDataType(phi::DataType::FLOAT32);
}
kernel->OutputAt(2).SetDataType(phi::DataType::INT32);
kernel->OutputAt(3).SetDataType(phi::DataType::INT32);
}
PD_REGISTER_KERNEL(check_finite_and_unscale,
XPU,
......
......@@ -369,4 +369,6 @@ PD_REGISTER_KERNEL(max_pool2d_with_index,
ALL_LAYOUT,
phi::MaxPool2dWithIndexKernel,
float,
phi::dtype::float16) {}
phi::dtype::float16) {
kernel->OutputAt(1).SetDataType(phi::DataType::INT32);
}
......@@ -41,7 +41,7 @@ void Pow2DecayWithLinearWarmupKernel(const Context& dev_ctx,
phi::errors::InvalidArgument(
"Input(Step) and Output(StepOut) must be the same."));
PADDLE_ENFORCE_EQ(
step.IsInitialized(),
step.initialized(),
true,
phi::errors::InvalidArgument("Input(Step) must be initialized."));
......@@ -68,4 +68,7 @@ PD_REGISTER_KERNEL(pow2_decay_with_linear_warmup,
XPU,
ALL_LAYOUT,
phi::Pow2DecayWithLinearWarmupKernel,
float) {}
float) {
kernel->InputAt(1).SetDataType(phi::DataType::INT64);
kernel->OutputAt(1).SetDataType(phi::DataType::INT64);
}
......@@ -45,4 +45,6 @@ PD_REGISTER_KERNEL(sum_raw,
phi::dtype::float16,
int8_t,
int,
int64_t) {}
int64_t) {
kernel->OutputAt(0).SetDataType(phi::DataType::UNDEFINED);
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册