未验证 提交 f4abe34b 编写于 作者: Y Yiqun Liu 提交者: GitHub

Try to increase the repeat of autotune and fix the setting of allow_tf32_cublas. (#53622)

* Try to increase the repeat of autotune and fix the setting of allow_tf32_cublas.

* Change the repeat of cublaslt to 10.

* Use FLAGS_cublaslt_exhaustive_search_times as repeats.

* Fix compiling error on CI.

* Polish the key and simplify codes.
上级 f488e3fd
...@@ -52,8 +52,8 @@ phi::funcs::MatmulFusedType GetFwdFusedEpilogueType( ...@@ -52,8 +52,8 @@ phi::funcs::MatmulFusedType GetFwdFusedEpilogueType(
} }
} else { } else {
PADDLE_THROW(platform::errors::InvalidArgument( PADDLE_THROW(platform::errors::InvalidArgument(
"Fued linear epilogue type should be one of {none, relu, gelu}." "fused_gemm_epilogue's activate should be one of {none, relu, gelu},"
"But received activation is %s, please check", " but received %s, please check",
activation)); activation));
} }
} }
......
...@@ -25,10 +25,10 @@ limitations under the License. */ ...@@ -25,10 +25,10 @@ limitations under the License. */
#include "glog/logging.h" #include "glog/logging.h"
#include "paddle/phi/api/ext/exception.h" #include "paddle/phi/api/ext/exception.h"
#include "paddle/phi/backends/context_pool.h"
#include "paddle/phi/backends/gpu/gpu_decls.h" #include "paddle/phi/backends/gpu/gpu_decls.h"
#include "paddle/phi/backends/gpu/gpu_info.h" #include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/backends/gpu/gpu_resources.h" #include "paddle/phi/backends/gpu/gpu_resources.h"
#include "paddle/phi/common/float16.h"
#include "paddle/phi/common/place.h" #include "paddle/phi/common/place.h"
#include "paddle/phi/core/allocator.h" #include "paddle/phi/core/allocator.h"
#include "paddle/phi/core/cuda_stream.h" #include "paddle/phi/core/cuda_stream.h"
...@@ -601,7 +601,7 @@ struct GPUContext::Impl { ...@@ -601,7 +601,7 @@ struct GPUContext::Impl {
#endif #endif
#endif #endif
}); });
if (blas_tf32_tensor_core_handle_ != nullptr) { if (blas_tf32_tensor_core_handle_ && phi::AllowTF32Cublas()) {
std::lock_guard<std::mutex> guard(blas_tf32_mtx_); std::lock_guard<std::mutex> guard(blas_tf32_mtx_);
callback(blas_tf32_tensor_core_handle_); callback(blas_tf32_tensor_core_handle_);
} else { } else {
......
...@@ -29,7 +29,7 @@ class KernelCallback { ...@@ -29,7 +29,7 @@ class KernelCallback {
using FuncType = ReturnType (*)(Args...); using FuncType = ReturnType (*)(Args...);
KernelCallback() {} KernelCallback() {}
explicit KernelCallback(FuncType func_) : func(func_) {} explicit KernelCallback(FuncType f) : func(f) {}
virtual ~KernelCallback() {} virtual ~KernelCallback() {}
ReturnType Run(Args... args) { return func(args...); } ReturnType Run(Args... args) { return func(args...); }
...@@ -50,8 +50,8 @@ class AutoTuneBase { ...@@ -50,8 +50,8 @@ class AutoTuneBase {
AutoTuneBase() {} AutoTuneBase() {}
virtual ~AutoTuneBase() {} virtual ~AutoTuneBase() {}
explicit AutoTuneBase(KernelType kernel) { explicit AutoTuneBase(KernelType default_kernel) {
kernels_.push_back(/*default=*/kernel); kernels_.push_back(default_kernel);
} }
template <typename ReturnType, typename... Args> template <typename ReturnType, typename... Args>
...@@ -121,7 +121,7 @@ class AutoTuneBase { ...@@ -121,7 +121,7 @@ class AutoTuneBase {
float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) { float RunAndMeasureKernel(const Context& ctx, const int idx, Args&&... args) {
// Regard 1st run as warmup, judge the compare result by the time cost // Regard 1st run as warmup, judge the compare result by the time cost
// of rest cycles. // of rest cycles.
constexpr int repeats = 6; constexpr int repeats = 11;
phi::GpuTimer timer; phi::GpuTimer timer;
float time_cost = 0; float time_cost = 0;
const auto& stream = ctx.stream(); const auto& stream = ctx.stream();
......
...@@ -25,8 +25,11 @@ limitations under the License. */ ...@@ -25,8 +25,11 @@ limitations under the License. */
#include "paddle/phi/common/amp_type_traits.h" #include "paddle/phi/common/amp_type_traits.h"
#include "paddle/phi/common/memory_utils.h" #include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/flags.h"
#include "paddle/phi/kernels/autotune/gpu_timer.h" #include "paddle/phi/kernels/autotune/gpu_timer.h"
#include "paddle/phi/kernels/autotune/switch_autotune.h" #include "paddle/phi/kernels/autotune/switch_autotune.h"
PHI_DECLARE_int64(cublaslt_exhaustive_search_times);
#endif #endif
namespace phi { namespace phi {
...@@ -41,21 +44,42 @@ namespace funcs { ...@@ -41,21 +44,42 @@ namespace funcs {
// no matter forward or backward, they could share the same descriptor // no matter forward or backward, they could share the same descriptor
// cache, in that the descriptor is for description of matmul operation. // cache, in that the descriptor is for description of matmul operation.
enum MatmulFusedType { enum MatmulFusedType {
kMatmul = CUBLASLT_EPILOGUE_DEFAULT, kMatmul = 0,
kMatmulGrad = CUBLASLT_EPILOGUE_DEFAULT, kMatmulGrad = 1,
kMatmulGradWithoutBias = CUBLASLT_EPILOGUE_DEFAULT, kMatmulGradWithoutBias = 2,
kMatmulBias = CUBLASLT_EPILOGUE_BIAS, kMatmulBias = 3,
kMatmulRelu = CUBLASLT_EPILOGUE_RELU, kMatmulRelu = 4,
kMatmulBiasRelu = CUBLASLT_EPILOGUE_RELU_BIAS, kMatmulBiasRelu = 5,
kMatmulBiasGelu = CUBLASLT_EPILOGUE_GELU_BIAS, kMatmulBiasGelu = 6,
kMatmulBiasReluWithReservedData = CUBLASLT_EPILOGUE_RELU_AUX_BIAS, kMatmulBiasReluWithReservedData = 7,
kMatmulBiasGeluWithReservedData = CUBLASLT_EPILOGUE_GELU_AUX_BIAS, kMatmulBiasGeluWithReservedData = 8,
kMatmulReluGrad = CUBLASLT_EPILOGUE_DRELU, kMatmulReluGrad = 9,
kMatmulGeluGrad = CUBLASLT_EPILOGUE_DGELU, kMatmulGeluGrad = 10,
kMatmulBiasGradToA = CUBLASLT_EPILOGUE_BGRADA, kMatmulBiasGradToA = 11,
kMatmulBiasGradToB = CUBLASLT_EPILOGUE_BGRADB kMatmulBiasGradToB = 12
}; };
static cublasLtEpilogue_t ConvertFusedType(MatmulFusedType fused_type) {
static std::map<MatmulFusedType, cublasLtEpilogue_t> fused_type_map = {
{MatmulFusedType::kMatmul, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulGrad, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulGradWithoutBias, CUBLASLT_EPILOGUE_DEFAULT},
{MatmulFusedType::kMatmulBias, CUBLASLT_EPILOGUE_BIAS},
{MatmulFusedType::kMatmulRelu, CUBLASLT_EPILOGUE_RELU},
{MatmulFusedType::kMatmulBiasRelu, CUBLASLT_EPILOGUE_RELU_BIAS},
{MatmulFusedType::kMatmulBiasGelu, CUBLASLT_EPILOGUE_GELU_BIAS},
{MatmulFusedType::kMatmulBiasReluWithReservedData,
CUBLASLT_EPILOGUE_RELU_AUX_BIAS},
{MatmulFusedType::kMatmulBiasGeluWithReservedData,
CUBLASLT_EPILOGUE_GELU_AUX_BIAS},
{MatmulFusedType::kMatmulReluGrad, CUBLASLT_EPILOGUE_DRELU},
{MatmulFusedType::kMatmulGeluGrad, CUBLASLT_EPILOGUE_DGELU},
{MatmulFusedType::kMatmulBiasGradToA, CUBLASLT_EPILOGUE_BGRADA},
{MatmulFusedType::kMatmulBiasGradToB, CUBLASLT_EPILOGUE_BGRADB}};
return fused_type_map[fused_type];
}
enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 }; enum FusedGEMMGradInType { kDX = 0, kDY = 1, kDZ = 2 };
template <bool TransX, bool TransY> template <bool TransX, bool TransY>
...@@ -125,31 +149,31 @@ struct MatmulPlanner { ...@@ -125,31 +149,31 @@ struct MatmulPlanner {
const bool trans_x, const bool trans_x,
const bool trans_y, const bool trans_y,
phi::DataType dtype, phi::DataType dtype,
MatmulFusedType impl_type, MatmulFusedType fused_type,
const void* bias_data = nullptr, const void* bias_data = nullptr,
void* reserve_data = nullptr, // Commonly for ReLu bit-mask. void* reserve_data = nullptr, // Commonly for ReLu bit-mask.
bool use_addto = false, bool use_addto = false,
bool no_exchange = true) bool no_exchange = true)
: bias(bias_data), aux_data(reserve_data), impl_type_(impl_type) { : bias(bias_data), aux_data(reserve_data), fused_type_(fused_type) {
use_addto_ = use_addto; use_addto_ = use_addto;
key_ = phi::autotune::GenKey(x_dims, key_ = phi::autotune::GenKey(x_dims,
y_dims, y_dims,
static_cast<int>(trans_x), static_cast<int>(trans_x),
static_cast<int>(trans_y), static_cast<int>(trans_y),
static_cast<int>(dtype), static_cast<int>(dtype),
static_cast<int>(fused_type_),
static_cast<int>(use_addto_),
static_cast<int>(no_exchange)); static_cast<int>(no_exchange));
} }
bool UseAddTo() const { return use_addto_; } bool UseAddTo() const { return use_addto_; }
size_t GetKey() const { return key_; } size_t GetKey() const { return key_; }
MatmulFusedType ImplType() const { return impl_type_; } MatmulFusedType GetFusedType() const { return fused_type_; }
size_t GenSubKey(int idx) const { size_t GenSubKey() const { return key_; }
return phi::autotune::GenKey(key_, static_cast<int>(use_addto_), idx);
}
private: private:
MatmulFusedType impl_type_; MatmulFusedType fused_type_;
bool use_addto_; bool use_addto_;
size_t key_; size_t key_;
}; };
...@@ -265,23 +289,28 @@ struct MatmulDescriptor { ...@@ -265,23 +289,28 @@ struct MatmulDescriptor {
bool has_algo = true) const { bool has_algo = true) const {
std::ostringstream out; std::ostringstream out;
out << prefix << " \n"; out << prefix << " \n";
#define GET_DESC_DATA_INFO(src) \ #define GET_DESC_DATA_STRING(src) \
do { \ do { \
out << #src << "= ["; \ out << " " << #src << " = ["; \
int num = sizeof((*src)) / sizeof(src->data[0]); \ int num = sizeof((*src)) / sizeof(src->data[0]); \
for (int i = 0; i < num; ++i) { \ for (int i = 0; i < num; ++i) { \
out << src->data[i] << ", "; \ if (i == 0) { \
out << src->data[i]; \
} else { \
out << ", " << src->data[i]; \
} \
} \ } \
out << "]\n"; \ out << "]\n"; \
} while (0); } while (0);
if (has_algo) { if (has_algo) {
GET_DESC_DATA_INFO(&algo); GET_DESC_DATA_STRING(algo);
} }
GET_DESC_DATA_INFO(x_desc); GET_DESC_DATA_STRING(x_desc);
GET_DESC_DATA_INFO(y_desc); GET_DESC_DATA_STRING(y_desc);
GET_DESC_DATA_INFO(out_desc); GET_DESC_DATA_STRING(out_desc);
GET_DESC_DATA_INFO(op_desc); GET_DESC_DATA_STRING(op_desc);
#undef GET_DESC_DATA_STRING
return out.str(); return out.str();
} }
...@@ -304,12 +333,13 @@ struct MatmulDescriptor { ...@@ -304,12 +333,13 @@ struct MatmulDescriptor {
CUBLASLT_MATMUL_DESC_TRANSA, CUBLASLT_MATMUL_DESC_TRANSA,
&cublas_trans_y, &cublas_trans_y,
sizeof(cublas_trans_y))); sizeof(cublas_trans_y)));
if (planner->ImplType() != kMatmul) { MatmulFusedType fused_type = planner->GetFusedType();
auto fused_type = static_cast<cublasLtEpilogue_t>(planner->ImplType()); if (fused_type != MatmulFusedType::kMatmul) {
cublasLtEpilogue_t cublaslt_fused_type = ConvertFusedType(fused_type);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulDescSetAttribute(op_desc, dynload::cublasLtMatmulDescSetAttribute(op_desc,
CUBLASLT_MATMUL_DESC_EPILOGUE, CUBLASLT_MATMUL_DESC_EPILOGUE,
&fused_type, &cublaslt_fused_type,
sizeof(fused_type))); sizeof(fused_type)));
} }
if (planner->aux_data) { if (planner->aux_data) {
...@@ -452,7 +482,7 @@ struct CublasLtBase { ...@@ -452,7 +482,7 @@ struct CublasLtBase {
} }
} }
VLOG(6) << desc->GetDescResultString("[Impl CublasltDescriptor] "); VLOG(7) << desc->GetDescResultString("[Impl CublasltDescriptor] ");
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmul(cublaslt_handle, dynload::cublasLtMatmul(cublaslt_handle,
desc->op_desc, desc->op_desc,
...@@ -482,10 +512,6 @@ struct CublasLtBase { ...@@ -482,10 +512,6 @@ struct CublasLtBase {
void* out_data, void* out_data,
void* workspace_ptr, void* workspace_ptr,
size_t workspace_size) { size_t workspace_size) {
cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo();
const auto& stream = ctx.stream();
int returned_results = 0;
constexpr int requested_algo_count = 10;
cublasLtMatmulPreference_t preference; cublasLtMatmulPreference_t preference;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulPreferenceCreate(&preference)); dynload::cublasLtMatmulPreferenceCreate(&preference));
...@@ -494,6 +520,9 @@ struct CublasLtBase { ...@@ -494,6 +520,9 @@ struct CublasLtBase {
CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES, CUBLASLT_MATMUL_PREF_MAX_WORKSPACE_BYTES,
&workspace_size, &workspace_size,
sizeof(workspace_size))); sizeof(workspace_size)));
int returned_results = 0;
constexpr int requested_algo_count = 10;
std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results( std::vector<cublasLtMatmulHeuristicResult_t> heuristic_results(
requested_algo_count); requested_algo_count);
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
...@@ -510,52 +539,90 @@ struct CublasLtBase { ...@@ -510,52 +539,90 @@ struct CublasLtBase {
PADDLE_ENFORCE_GT(returned_results, PADDLE_ENFORCE_GT(returned_results,
0, 0,
phi::errors::Unavailable("No GEMM algorithm avaliable.")); phi::errors::Unavailable("No GEMM algorithm avaliable."));
phi::GpuTimer timer;
int best_algo_idx = -1; int best_algo_idx = -1;
constexpr int repeats = 6; if (returned_results == 1 || FLAGS_cublaslt_exhaustive_search_times <= 0) {
float min_time_cost = std::numeric_limits<float>::max(); best_algo_idx = 0;
for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) { } else {
ctx.Wait(); float min_time_cost = std::numeric_limits<float>::max();
float cur_time = 0.f; for (int algo_idx = 0; algo_idx < returned_results; ++algo_idx) {
for (int i = 0; i < repeats; ++i) { float cur_time_cost =
timer.Start(stream); RunAndMeasureAlgo(ctx,
PADDLE_ENFORCE_GPU_SUCCESS( lt_handle,
dynload::cublasLtMatmul(lt_handle, desc,
desc->op_desc, alpha,
alpha, beta,
y_data, y_data,
desc->y_desc, x_data,
x_data, out_data,
desc->x_desc, workspace_ptr,
beta, workspace_size,
out_data, &(heuristic_results[algo_idx].algo));
desc->out_desc, VLOG(6) << "[MatmulWithCublaslt] algo[" << algo_idx
out_data, << "] time: " << cur_time_cost << " s";
desc->out_desc,
&(heuristic_results[algo_idx].algo), if ((best_algo_idx == 0 && (1.05 * cur_time_cost < min_time_cost)) ||
workspace_ptr, (cur_time_cost < min_time_cost)) {
workspace_size, best_algo_idx = algo_idx;
stream)); min_time_cost = cur_time_cost;
timer.Stop(stream);
auto time = timer.ElapsedTime();
if (i > 0) {
cur_time += time;
} }
} }
float time_cnt = (cur_time / (repeats - 1));
VLOG(6) << "Time cost in MatmulWithCublaslt algo[" << algo_idx << "]"
<< "is : " << time_cnt << " s";
if (cur_time < min_time_cost) {
best_algo_idx = algo_idx;
min_time_cost = cur_time;
}
} }
VLOG(6) << "Best_algo_idx in MatmulWithCublaslt is : " << best_algo_idx; VLOG(6) << "[MatmulWithCublaslt] best_algo_idx: " << best_algo_idx;
cublasLtMatmulAlgo_t* best_algo = desc->SetAlgo();
*best_algo = heuristic_results[best_algo_idx].algo; *best_algo = heuristic_results[best_algo_idx].algo;
PADDLE_ENFORCE_GPU_SUCCESS( PADDLE_ENFORCE_GPU_SUCCESS(
dynload::cublasLtMatmulPreferenceDestroy(preference)); dynload::cublasLtMatmulPreferenceDestroy(preference));
} }
static float RunAndMeasureAlgo(const phi::GPUContext& ctx,
const cublasLtHandle_t& lt_handle,
MatmulDescT* desc,
const void* alpha,
const void* beta,
const void* y_data,
const void* x_data,
void* out_data,
void* workspace_ptr,
size_t workspace_size,
cublasLtMatmulAlgo_t* algo) {
int repeats = FLAGS_cublaslt_exhaustive_search_times;
if (repeats <= 0) {
return std::numeric_limits<float>::max();
}
phi::GpuTimer timer;
float time_cost = 0.f;
const auto& stream = ctx.stream();
for (int i = 0; i < repeats; ++i) {
timer.Start(stream);
PADDLE_ENFORCE_GPU_SUCCESS(dynload::cublasLtMatmul(lt_handle,
desc->op_desc,
alpha,
y_data,
desc->y_desc,
x_data,
desc->x_desc,
beta,
out_data,
desc->out_desc,
out_data,
desc->out_desc,
algo,
workspace_ptr,
workspace_size,
stream));
timer.Stop(stream);
ctx.Wait();
auto time = timer.ElapsedTime();
if (i > 0) {
// Exclude the warmup runtime.
time_cost += time;
}
}
return (time_cost / (repeats - 1));
}
}; };
// To judge if desc is cached or not. // To judge if desc is cached or not.
...@@ -583,14 +650,14 @@ struct DescriptorSetter { ...@@ -583,14 +650,14 @@ struct DescriptorSetter {
const bool no_exchange = true, const bool no_exchange = true,
bool grad_for_dx = true) { bool grad_for_dx = true) {
if (planner != nullptr) { if (planner != nullptr) {
sub_key = planner->GenSubKey(static_cast<size_t>(planner->ImplType())); sub_key = planner->GenSubKey();
} }
auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul(); auto& mamtul_cache = phi::autotune::AutoTuneCache::Instance().GetMatmul();
if (mamtul_cache.FindSubKey(sub_key)) { if (mamtul_cache.FindSubKey(sub_key)) {
desc = *(reinterpret_cast<DescT*>(mamtul_cache.GetSubKey(sub_key))); desc = *(reinterpret_cast<DescT*>(mamtul_cache.GetSubKey(sub_key)));
desc.template SetFusedEpiloguePtr<DYT>(planner); desc.template SetFusedEpiloguePtr<DYT>(planner);
VLOG(6) << desc.GetDescResultString("[Heap CublasltDescriptor] "); VLOG(7) << desc.GetDescResultString("[Heap CublasltDescriptor] ");
} else { } else {
desc.template Create<T, DXT, DYT, TransX, TransY>(M, desc.template Create<T, DXT, DYT, TransX, TransY>(M,
N, N,
...@@ -607,7 +674,7 @@ struct DescriptorSetter { ...@@ -607,7 +674,7 @@ struct DescriptorSetter {
if (planner != nullptr) { if (planner != nullptr) {
desc.template SetFusedEpiloguePtr<DYT>(planner); desc.template SetFusedEpiloguePtr<DYT>(planner);
} }
VLOG(6) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false); VLOG(7) << desc.GetDescResultString("[Stack CublasltDescriptor] ", false);
} }
} }
}; };
......
...@@ -945,75 +945,37 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx, ...@@ -945,75 +945,37 @@ void ComputeFusedGemmEpilogueBackward(const phi::GPUContext& dev_ctx,
<< ", trans_y=" << trans_y << ", trans_y=" << trans_y
<< ", activation_grad=" << activation_grad; << ", activation_grad=" << activation_grad;
#define CALL_FUSED_GRAD_IMPL(TransX, TransY) \
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, TransX, TransY>( \
dev_ctx, \
dout, \
x, \
y, \
reserve_space, \
M, \
N, \
K, \
activation_grad, \
dx, \
dy, \
dbias, \
use_addto_dx, \
use_addto_dy)
if (trans_x) { if (trans_x) {
if (trans_y) { if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, true, true>( CALL_FUSED_GRAD_IMPL(true, true);
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} else { } else {
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, true, false>( CALL_FUSED_GRAD_IMPL(true, false);
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} }
} else { } else {
if (trans_y) { if (trans_y) {
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, false, true>( CALL_FUSED_GRAD_IMPL(false, true);
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} else { } else {
ComputeFusedGemmEpilogueBackwardImpl<T, DXT, DYT, false, false>( CALL_FUSED_GRAD_IMPL(false, false);
dev_ctx,
dout,
x,
y,
reserve_space,
M,
N,
K,
activation_grad,
dx,
dy,
dbias,
use_addto_dx,
use_addto_dy);
} }
} }
#undef CALL_FUSED_GRAD_IMPL
} }
} // namespace funcs } // namespace funcs
......
...@@ -925,7 +925,11 @@ struct MatMulDispatcher<phi::GPUContext, T> { ...@@ -925,7 +925,11 @@ struct MatMulDispatcher<phi::GPUContext, T> {
trans_x, trans_x,
trans_y, trans_y,
phi::CppTypeToDataType<T>::Type(), phi::CppTypeToDataType<T>::Type(),
funcs::MatmulFusedType::kMatmul); funcs::MatmulFusedType::kMatmul,
/* bias_data */ nullptr,
/* reserve_data */ nullptr,
/* use_addto */ flag,
/* no_exchange */ true);
tuner->Run(ctx, tuner->Run(ctx,
matmul_planner.GetKey(), matmul_planner.GetKey(),
ctx, ctx,
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册