未验证 提交 0a59825e 编写于 作者: W wz1qqx 提交者: GitHub

[XPU] Optimize fp16 xpu models (#53523)

上级 186f5e0f
...@@ -34,6 +34,25 @@ class Scope; ...@@ -34,6 +34,25 @@ class Scope;
} // namespace framework } // namespace framework
} // namespace paddle } // namespace paddle
namespace {
template <typename T1, typename T2>
void ConvertTensorType(phi::DenseTensor* tensor) {
phi::DenseTensor tmp_tensor;
tmp_tensor.set_type(phi::CppTypeToDataType<T2>::Type());
tmp_tensor.Resize(tensor->dims());
auto* tmp_data = tmp_tensor.mutable_data<T2>(paddle::platform::CPUPlace());
auto* data = tensor->mutable_data<T1>(paddle::platform::CPUPlace());
for (int i = 0; i < tensor->numel(); i++) {
tmp_data[i] = static_cast<T2>(data[i]);
}
tensor->clear();
paddle::framework::TensorCopySync(
tmp_tensor, paddle::platform::CPUPlace(), tensor);
}
} // namespace
namespace paddle { namespace paddle {
namespace framework { namespace framework {
namespace ir { namespace ir {
...@@ -157,15 +176,23 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern, ...@@ -157,15 +176,23 @@ Conv2dXPUPattern::Conv2dXPUPattern(PDPattern* pattern,
if (with_bn_) { if (with_bn_) {
ew_bias_add_out->assert_is_op_input("batch_norm", "X"); ew_bias_add_out->assert_is_op_input("batch_norm", "X");
bn_bias = pattern->NewNode(bn_bias_repr()) bn_bias = pattern->NewNode(bn_bias_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Bias") ->assert_is_op_input("batch_norm", "Bias")
->assert_has_n_outputs(1); ->assert_has_n_outputs(1);
bn_mean = pattern->NewNode(bn_mean_repr()) bn_mean = pattern->NewNode(bn_mean_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Mean") ->assert_is_op_input("batch_norm", "Mean")
->assert_has_n_outputs(1); ->assert_has_n_outputs(1);
bn_scale = pattern->NewNode(bn_scale_repr()) bn_scale = pattern->NewNode(bn_scale_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Scale") ->assert_is_op_input("batch_norm", "Scale")
->assert_has_n_outputs(1); ->assert_has_n_outputs(1);
bn_var = pattern->NewNode(bn_var_repr()) bn_var = pattern->NewNode(bn_var_repr())
->AsInput()
->assert_is_persistable_var()
->assert_is_op_input("batch_norm", "Variance") ->assert_is_op_input("batch_norm", "Variance")
->assert_has_n_outputs(1); ->assert_has_n_outputs(1);
bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm"); bn = pattern->NewNode(bn_repr())->assert_is_op("batch_norm");
...@@ -420,13 +447,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -420,13 +447,17 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
// recompute bias and weight for conv2d_xpu op // recompute bias and weight for conv2d_xpu op
auto* filter_t = auto* filter_t =
scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>(); scope->FindVar(conv_filter->Name())->GetMutable<phi::DenseTensor>();
// conv_filter fp16 --> fp32
auto tensor_type = filter_t->dtype();
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float16, float>(filter_t);
}
auto filter_dims = filter_t->dims(); auto filter_dims = filter_t->dims();
bool has_bias = with_bn || with_conv_bias; bool has_bias = with_bn || with_conv_bias;
bool has_branch = with_branch_x || with_branch_y;
// Create conv_fusion_bias (conv bias) variable // Create conv_fusion_bias (conv bias) variable
Node* fusion_bias_node = nullptr; Node* fusion_bias_node = nullptr;
if (has_bias) { if (has_bias) {
if (ew_bias_add != nullptr) { if (with_conv_bias) {
auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name()) auto* ew_bias_add_y_t = scope->FindVar(ew_bias_add_y->Name())
->GetMutable<phi::DenseTensor>(); ->GetMutable<phi::DenseTensor>();
auto ew_bias_add_y_dims = ew_bias_add_y_t->dims(); auto ew_bias_add_y_dims = ew_bias_add_y_t->dims();
...@@ -439,7 +470,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -439,7 +470,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
filter_dims[0])); filter_dims[0]));
PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node); PrepareBias(graph, scope, block, ew_bias_add_y, &fusion_bias_node);
} }
if (bn != nullptr) { if (with_bn) {
auto bn_bias_t = auto bn_bias_t =
scope->Var(bn_bias->Name())->GetMutable<phi::DenseTensor>(); scope->Var(bn_bias->Name())->GetMutable<phi::DenseTensor>();
PADDLE_ENFORCE_EQ(filter_dims[0], PADDLE_ENFORCE_EQ(filter_dims[0],
...@@ -469,7 +500,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -469,7 +500,7 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
auto filter_len = filter_t->numel(); auto filter_len = filter_t->numel();
auto filter_stride = filter_len / mean_len; auto filter_stride = filter_len / mean_len;
float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon")); float epsilon = PADDLE_GET_CONST(float, bn->Op()->GetAttr("epsilon"));
if (fusion_bias_node == nullptr) { // prev node is conv if (!with_conv_bias) { // prev node is conv
PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node); PrepareBias(graph, scope, block, bn_bias, &fusion_bias_node);
} }
auto fusion_bias_t = scope->Var(fusion_bias_node->Name()) auto fusion_bias_t = scope->Var(fusion_bias_node->Name())
...@@ -477,10 +508,10 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -477,10 +508,10 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
float* fusion_bias_ptr = float* fusion_bias_ptr =
fusion_bias_t->mutable_data<float>(paddle::platform::CPUPlace()); fusion_bias_t->mutable_data<float>(paddle::platform::CPUPlace());
// recompute bias and weights // recompute bias and weights
if (ew_bias_add == nullptr) { if (!with_conv_bias) { // prev node is conv
for (int i = 0; i < mean_len; ++i) { for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
fusion_bias_ptr[i] += (0.f - bn_mean_ptr[i]) * bn_scale_ptr[i]; fusion_bias_ptr[i] += (0.0f - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_stride; j++) { for (int j = 0; j < filter_stride; j++) {
filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i];
} }
...@@ -488,21 +519,25 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph, ...@@ -488,21 +519,25 @@ int Conv2dXPUFusePass::ApplyImpl(ir::Graph* graph,
} else { } else {
for (int i = 0; i < mean_len; ++i) { for (int i = 0; i < mean_len; ++i) {
bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon); bn_scale_ptr[i] = bn_scale_ptr[i] / sqrtf(bn_var_ptr[i] + epsilon);
bn_bias_ptr[i] += fusion_bias_ptr[i] =
bn_bias_ptr[i] +
(fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i]; (fusion_bias_ptr[i] - bn_mean_ptr[i]) * bn_scale_ptr[i];
for (int j = 0; j < filter_stride; j++) { for (int j = 0; j < filter_stride; j++) {
filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i]; filter_ptr[i * filter_stride + j] *= bn_scale_ptr[i];
} }
} }
memcpy(fusion_bias_ptr, bn_bias_ptr, mean_len * sizeof(float));
} }
} }
} }
if (tensor_type == phi::DataType::FLOAT16) {
ConvertTensorType<float, float16>(filter_t);
}
// filter max // filter max
Node* filter_int16 = nullptr; Node* filter_int16 = nullptr;
Node* filter_max = nullptr; Node* filter_max = nullptr;
PrepareWeight<int16_t>( PrepareWeight<int16_t>(
graph, scope, block, conv_filter, &filter_int16, &filter_max, false); graph, scope, block, conv_filter, &filter_int16, &filter_max, false);
bool has_branch = with_branch_x || with_branch_y;
// output && output max // output && output max
std::string conv2d_xpu_out_name; std::string conv2d_xpu_out_name;
if (!act_type.empty()) { if (!act_type.empty()) {
......
...@@ -5,14 +5,14 @@ ...@@ -5,14 +5,14 @@
# otherwise the operator only could be used in static mode. # otherwise the operator only could be used in static mode.
- op : conv2d_xpu - op : conv2d_xpu
args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param) args : (Tensor x, Tensor x_max, Tensor filter, Tensor filter_max, Tensor bias, Tensor branch, Tensor branch_max, int[] paddings, int[] dilations, int[] strides, str padding_algorithm, int groups, bool has_bias, bool has_branch, int act_type, float act_param)
output : Tensor(out), Tensor(out_max) output : Tensor(out), Tensor(out_max)
infer_meta : infer_meta :
func : Conv2dXPUInferMeta func : Conv2dXPUInferMeta
kernel : kernel :
func : conv2d_xpu func : conv2d_xpu
data_type : x data_type : x
optional : bias, branch, x_max optional : bias, branch, branch_max ,x_max
- op : embedding_with_eltwise_add_xpu - op : embedding_with_eltwise_add_xpu
args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx) args : (Tensor[] ids, Tensor[] tables, int64_t padding_idx)
......
...@@ -58,7 +58,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -58,7 +58,8 @@ XPUOpMap& get_kl2_ops() {
{"atan_grad", {"atan_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"batch_norm_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"batch_norm", XPUKernelSet({phi::DataType::FLOAT32})}, {"batch_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"bmm", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"bmm_grad", {"bmm_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
...@@ -401,7 +402,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -401,7 +402,8 @@ XPUOpMap& get_kl2_ops() {
{"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})}, {"grid_sampler", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"hard_sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_swish_grad", {"hard_swish_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})}, {"hard_swish", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -438,7 +440,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -438,7 +440,8 @@ XPUOpMap& get_kl2_ops() {
{"layer_norm", {"layer_norm",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"leaky_relu_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"leaky_relu", XPUKernelSet({phi::DataType::FLOAT32})}, {"leaky_relu",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"less_equal", {"less_equal",
XPUKernelSet({phi::DataType::INT64, XPUKernelSet({phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
...@@ -554,7 +557,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -554,7 +557,8 @@ XPUOpMap& get_kl2_ops() {
{"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_max_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_max", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_mean_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_mean", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_mean",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_min", XPUKernelSet({phi::DataType::FLOAT32})},
{"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})}, {"reduce_prod", XPUKernelSet({phi::DataType::FLOAT32})},
...@@ -646,7 +650,8 @@ XPUOpMap& get_kl2_ops() { ...@@ -646,7 +650,8 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT64, phi::DataType::INT64,
phi::DataType::INT32, phi::DataType::INT32,
phi::DataType::FLOAT16})}, phi::DataType::FLOAT16})},
{"sigmoid", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sigmoid_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"sign", XPUKernelSet({phi::DataType::FLOAT32})}, {"sign", XPUKernelSet({phi::DataType::FLOAT32})},
{"slice_grad", {"slice_grad",
...@@ -676,7 +681,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -676,7 +681,7 @@ XPUOpMap& get_kl2_ops() {
XPUKernelSet({phi::DataType::FLOAT32, XPUKernelSet({phi::DataType::FLOAT32,
phi::DataType::FLOAT16, phi::DataType::FLOAT16,
phi::DataType::INT32})}, phi::DataType::INT32})},
{"sqrt", XPUKernelSet({phi::DataType::FLOAT32})}, {"sqrt", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"sqrt_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"square_grad", {"square_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
...@@ -733,7 +738,7 @@ XPUOpMap& get_kl2_ops() { ...@@ -733,7 +738,7 @@ XPUOpMap& get_kl2_ops() {
phi::DataType::INT16, phi::DataType::INT16,
phi::DataType::INT32})}, phi::DataType::INT32})},
{"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, {"sum", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish", XPUKernelSet({phi::DataType::FLOAT32})}, {"swish", XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
{"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})}, {"swish_grad", XPUKernelSet({phi::DataType::FLOAT32})},
{"tanh_grad", {"tanh_grad",
XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})}, XPUKernelSet({phi::DataType::FLOAT32, phi::DataType::FLOAT16})},
......
...@@ -41,6 +41,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x, ...@@ -41,6 +41,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& filter_max, const MetaTensor& filter_max,
const MetaTensor& bias, const MetaTensor& bias,
const MetaTensor& branch, const MetaTensor& branch,
const MetaTensor& branch_max,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
......
...@@ -28,6 +28,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x, ...@@ -28,6 +28,7 @@ void Conv2dXPUInferMeta(const MetaTensor& x,
const MetaTensor& filter_max, const MetaTensor& filter_max,
const MetaTensor& bias, const MetaTensor& bias,
const MetaTensor& branch, const MetaTensor& branch,
const MetaTensor& branch_max,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
......
...@@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(swish, ...@@ -63,7 +63,8 @@ PD_REGISTER_KERNEL(swish,
#if defined PADDLE_WITH_XPU #if defined PADDLE_WITH_XPU
PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {} PD_REGISTER_KERNEL(relu6, XPU, ALL_LAYOUT, phi::Relu6Kernel, float) {}
PD_REGISTER_KERNEL(swish, XPU, ALL_LAYOUT, phi::SwishKernel, float) {} PD_REGISTER_KERNEL(
swish, XPU, ALL_LAYOUT, phi::SwishKernel, float, phi::dtype::float16) {}
#endif #endif
#ifdef PADDLE_WITH_MKLDNN #ifdef PADDLE_WITH_MKLDNN
......
...@@ -106,6 +106,10 @@ PD_REGISTER_KERNEL(batch_norm_infer, ...@@ -106,6 +106,10 @@ PD_REGISTER_KERNEL(batch_norm_infer,
phi::dtype::float16) {} phi::dtype::float16) {}
#endif #endif
#ifdef PADDLE_WITH_XPU #ifdef PADDLE_WITH_XPU
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(batch_norm_infer,
batch_norm_infer, XPU, ALL_LAYOUT, phi::BatchNormInferKernel, float) {} XPU,
ALL_LAYOUT,
phi::BatchNormInferKernel,
float,
phi::dtype::float16) {}
#endif #endif
...@@ -27,6 +27,7 @@ void Conv2dXPUKernel(const Context& ctx, ...@@ -27,6 +27,7 @@ void Conv2dXPUKernel(const Context& ctx,
const DenseTensor& filter_max, const DenseTensor& filter_max,
const paddle::optional<DenseTensor>& bias, const paddle::optional<DenseTensor>& bias,
const paddle::optional<DenseTensor>& branch, const paddle::optional<DenseTensor>& branch,
const paddle::optional<DenseTensor>& branch_max,
const std::vector<int>& paddings, const std::vector<int>& paddings,
const std::vector<int>& dilations, const std::vector<int>& dilations,
const std::vector<int>& strides, const std::vector<int>& strides,
...@@ -69,10 +70,12 @@ void Conv2dXPUKernel(const Context& ctx, ...@@ -69,10 +70,12 @@ void Conv2dXPUKernel(const Context& ctx,
branch.get_ptr() == nullptr branch.get_ptr() == nullptr
? nullptr ? nullptr
: reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>()); : reinterpret_cast<const XPUType*>(branch.get_ptr()->data<T>());
const float* branch_max_data = branch_max.get_ptr() == nullptr
? nullptr
: branch_max.get_ptr()->data<float>();
const float* bias_data = const float* bias_data =
bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>(); bias.get_ptr() == nullptr ? nullptr : bias.get_ptr()->data<float>();
auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out)); auto* out_data = reinterpret_cast<XPUType*>(ctx.template Alloc<T>(out));
xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type)); xpu::Activation_t act(static_cast<xpu::Activation_t::act_enum>(act_type));
if (act_type == xpu::Activation_t::LEAKY_RELU) { if (act_type == xpu::Activation_t::LEAKY_RELU) {
act.leaky_alpha = act_param; act.leaky_alpha = act_param;
...@@ -102,7 +105,7 @@ void Conv2dXPUKernel(const Context& ctx, ...@@ -102,7 +105,7 @@ void Conv2dXPUKernel(const Context& ctx,
/* const float* bias */ bias_data, /* const float* bias */ bias_data,
/* const TY* branch */ branch_data, /* const TY* branch */ branch_data,
/* const baidu::xpu::api::Activation_t& act */ act, /* const baidu::xpu::api::Activation_t& act */ act,
/* const float* branch_maxptr */ nullptr, /* const float* branch_maxptr */ branch_max_data,
/* const float* scale */ nullptr); /* const float* scale */ nullptr);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu"); PADDLE_ENFORCE_XDNN_SUCCESS(r, "conv2d_xpu");
} }
......
...@@ -195,6 +195,13 @@ void PowKernel(const Context& dev_ctx, ...@@ -195,6 +195,13 @@ void PowKernel(const Context& dev_ctx,
const DenseTensor& x, const DenseTensor& x,
const Scalar& factor, const Scalar& factor,
DenseTensor* out) { DenseTensor* out) {
// using XPUType = typename XPUTypeTrait<T>::Type;
// // dev_ctx.template Alloc<T>(out);
// auto pow_factor = factor.to<T>();
// const auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
// auto* y_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(out));
// // const T* x_data = x.data<T>();
// // T* y_data = out->data<T>();
dev_ctx.template Alloc<T>(out); dev_ctx.template Alloc<T>(out);
float pow_factor = factor.to<float>(); float pow_factor = factor.to<float>();
const T* x_data = x.data<T>(); const T* x_data = x.data<T>();
...@@ -534,9 +541,28 @@ PD_REGISTER_KERNEL( ...@@ -534,9 +541,28 @@ PD_REGISTER_KERNEL(
relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {} relu, XPU, ALL_LAYOUT, phi::ReluKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {} silu, XPU, ALL_LAYOUT, phi::SiluKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \ sigmoid, XPU, ALL_LAYOUT, phi::SigmoidKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {} PD_REGISTER_KERNEL(swish_raw,
XPU,
ALL_LAYOUT,
phi::SwishRawKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(hard_sigmoid,
XPU,
ALL_LAYOUT,
phi::HardSigmoidKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(leaky_relu,
XPU,
ALL_LAYOUT,
phi::LeakyReluKernel,
float,
phi::dtype::float16) {}
PD_REGISTER_KERNEL(
sqrt, XPU, ALL_LAYOUT, phi::SqrtKernel, float, phi::dtype::float16) {}
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {} tanh, XPU, ALL_LAYOUT, phi::TanhKernel, float, phi::dtype::float16) {}
...@@ -547,18 +573,21 @@ PD_REGISTER_KERNEL( ...@@ -547,18 +573,21 @@ PD_REGISTER_KERNEL(
PD_REGISTER_KERNEL( PD_REGISTER_KERNEL(
log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {} log, XPU, ALL_LAYOUT, phi::LogKernel, float, phi::dtype::float16) {}
#define PD_REGISTER_ACTIVATION_KERNEL(name, func) \
PD_REGISTER_KERNEL(name, XPU, ALL_LAYOUT, phi::func, float) {}
PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad PD_REGISTER_ACTIVATION_KERNEL(exp, ExpKernel) // no grad
PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel) PD_REGISTER_ACTIVATION_KERNEL(floor, FloorKernel)
PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel) // PD_REGISTER_ACTIVATION_KERNEL(leaky_relu, LeakyReluKernel)
PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel) // PD_REGISTER_ACTIVATION_KERNEL(hard_sigmoid, HardSigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel) PD_REGISTER_ACTIVATION_KERNEL(hardswish, HardSwishKernel)
PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel) PD_REGISTER_ACTIVATION_KERNEL(mish, MishKernel)
PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel) PD_REGISTER_ACTIVATION_KERNEL(pow, PowKernel)
PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel) PD_REGISTER_ACTIVATION_KERNEL(reciprocal, ReciprocalKernel)
PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel) PD_REGISTER_ACTIVATION_KERNEL(relu6_raw, Relu6RawKernel)
PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel) // PD_REGISTER_ACTIVATION_KERNEL(sigmoid, SigmoidKernel)
PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel) // PD_REGISTER_ACTIVATION_KERNEL(sqrt, SqrtKernel)
PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel) // PD_REGISTER_ACTIVATION_KERNEL(swish_raw, SwishRawKernel)
PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel) PD_REGISTER_ACTIVATION_KERNEL(softplus, SoftplusKernel)
PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel) PD_REGISTER_ACTIVATION_KERNEL(sin, SinKernel)
PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel) PD_REGISTER_ACTIVATION_KERNEL(cos, CosKernel)
...@@ -39,6 +39,7 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -39,6 +39,7 @@ void BatchNormKernel(const Context& dev_ctx,
DenseTensor* saved_mean, DenseTensor* saved_mean,
DenseTensor* saved_variance, DenseTensor* saved_variance,
DenseTensor* reserve_space) { DenseTensor* reserve_space) {
using XPUType = typename XPUTypeTrait<T>::Type;
bool test_mode = is_test && (!trainable_statistics); bool test_mode = is_test && (!trainable_statistics);
bool global_stats = test_mode || use_global_stats; bool global_stats = test_mode || use_global_stats;
const auto data_layout = phi::StringToDataLayout(data_layout_str); const auto data_layout = phi::StringToDataLayout(data_layout_str);
...@@ -68,12 +69,12 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -68,12 +69,12 @@ void BatchNormKernel(const Context& dev_ctx,
W = W * D; W = W * D;
const auto* x_data = x.data<T>(); const auto* x_data = reinterpret_cast<const XPUType*>(x.data<T>());
const auto* scale_data = scale.data<float>(); const auto* scale_data = scale.data<float>();
const auto* bias_data = bias.data<float>(); const auto* bias_data = bias.data<float>();
// alloc memory // alloc memory
auto* y_data = dev_ctx.template Alloc<T>(y); auto* y_data = reinterpret_cast<XPUType*>(dev_ctx.template Alloc<T>(y));
dev_ctx.template Alloc<float>(mean_out); dev_ctx.template Alloc<float>(mean_out);
dev_ctx.template Alloc<float>(variance_out); dev_ctx.template Alloc<float>(variance_out);
dev_ctx.template Alloc<float>(saved_mean); dev_ctx.template Alloc<float>(saved_mean);
...@@ -95,7 +96,7 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -95,7 +96,7 @@ void BatchNormKernel(const Context& dev_ctx,
auto* saved_mean_data = saved_mean->data<float>(); auto* saved_mean_data = saved_mean->data<float>();
auto* saved_variance_data = saved_variance->data<float>(); auto* saved_variance_data = saved_variance->data<float>();
int r = xpu::batch_norm<T>(dev_ctx.x_context(), int r = xpu::batch_norm<XPUType>(dev_ctx.x_context(),
x_data, x_data,
y_data, y_data,
N, N,
...@@ -115,7 +116,7 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -115,7 +116,7 @@ void BatchNormKernel(const Context& dev_ctx,
} else { } else {
const auto* mean_data = mean.data<float>(); const auto* mean_data = mean.data<float>();
const auto* variance_data = variance.data<float>(); const auto* variance_data = variance.data<float>();
int r = xpu::batch_norm_infer(dev_ctx.x_context(), int r = xpu::batch_norm_infer<XPUType>(dev_ctx.x_context(),
x_data, x_data,
y_data, y_data,
N, N,
...@@ -134,4 +135,9 @@ void BatchNormKernel(const Context& dev_ctx, ...@@ -134,4 +135,9 @@ void BatchNormKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(batch_norm, XPU, ALL_LAYOUT, phi::BatchNormKernel, float) {} PD_REGISTER_KERNEL(batch_norm,
XPU,
ALL_LAYOUT,
phi::BatchNormKernel,
float,
phi::dtype::float16) {}
...@@ -50,4 +50,6 @@ void MeanRawKernel(const Context& dev_ctx, ...@@ -50,4 +50,6 @@ void MeanRawKernel(const Context& dev_ctx,
} // namespace phi } // namespace phi
PD_REGISTER_KERNEL(mean_raw, XPU, ALL_LAYOUT, phi::MeanRawKernel, float) {} PD_REGISTER_KERNEL(
mean_raw, XPU, ALL_LAYOUT, phi::MeanRawKernel, float, phi::dtype::float16) {
}
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册