提交 8d6d95cc 编写于 作者: A Adam 提交者: Tao Luo

paddle::framework::vectorize() templatization (#19611)

test=develop
上级 75d15719
...@@ -147,8 +147,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout, ...@@ -147,8 +147,8 @@ void innerTransDataLayoutFromMKLDNN(DataLayout in_layout, DataLayout out_layout,
auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place)); auto* dev_ctx = dynamic_cast<platform::MKLDNNDeviceContext*>(pool.Get(place));
auto& cpu_engine = dev_ctx->GetEngine(); auto& cpu_engine = dev_ctx->GetEngine();
std::vector<int> in_tz = paddle::framework::vectorize2int(in.dims()); auto in_tz = paddle::framework::vectorize<int>(in.dims());
std::vector<int> out_tz = in_tz; auto out_tz = in_tz;
memory::data_type in_type = ToMKLDNNDataType(in.type()); memory::data_type in_type = ToMKLDNNDataType(in.type());
PADDLE_ENFORCE(in_type != memory::data_type::data_undef, PADDLE_ENFORCE(in_type != memory::data_type::data_undef,
......
...@@ -48,13 +48,6 @@ bool DDim::operator==(const DDim& d) const { ...@@ -48,13 +48,6 @@ bool DDim::operator==(const DDim& d) const {
bool DDim::operator!=(const DDim& d) const { return !(*this == d); } bool DDim::operator!=(const DDim& d) const { return !(*this == d); }
std::vector<int64_t> vectorize(const DDim& ddim) {
std::vector<int64_t> result(DDim::kMaxRank);
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result;
}
// NOTE: framework::vectorize converts to type int64_t // NOTE: framework::vectorize converts to type int64_t
// which does not fit cudnn inputs. // which does not fit cudnn inputs.
std::vector<int> vectorize2int(const DDim& ddim) { std::vector<int> vectorize2int(const DDim& ddim) {
......
...@@ -170,7 +170,13 @@ DDim make_ddim(const std::vector<int>& dims); ...@@ -170,7 +170,13 @@ DDim make_ddim(const std::vector<int>& dims);
*/ */
DDim make_ddim(std::initializer_list<int64_t> dims); DDim make_ddim(std::initializer_list<int64_t> dims);
std::vector<int64_t> vectorize(const DDim& ddim); template <typename T = int64_t>
std::vector<T> vectorize(const DDim& ddim) {
std::vector<T> result(DDim::kMaxRank);
dynamic_dim_assign(ddim.Get(), result.data(), ddim.size());
result.resize(ddim.size());
return result;
}
std::vector<int> vectorize2int(const DDim& ddim); std::vector<int> vectorize2int(const DDim& ddim);
int64_t product(const DDim& ddim); int64_t product(const DDim& ddim);
......
...@@ -816,7 +816,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims( ...@@ -816,7 +816,7 @@ void CompileTimeInferShapeContext::SetRepeatedDims(
auto var = block_.FindVarRecursive(name); auto var = block_.FindVarRecursive(name);
PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name); PADDLE_ENFORCE(var != nullptr, "Cannot find variable %s", name);
std::vector<std::vector<int64_t>> dim_vec(dims.size()); std::vector<std::vector<int64_t>> dim_vec(dims.size());
std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize); std::transform(dims.begin(), dims.end(), dim_vec.begin(), vectorize<>);
var->SetShapes(dim_vec); var->SetShapes(dim_vec);
} }
......
...@@ -97,7 +97,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx, ...@@ -97,7 +97,7 @@ void eltwise_forward(const framework::ExecutionContext &ctx,
x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4, x->dims().size() == 2 || x->dims().size() == 3 || x->dims().size() == 4,
"Input dim must be with 2, 3 or 4"); "Input dim must be with 2, 3 or 4");
std::vector<int> src_tz = framework::vectorize2int(x->dims()); auto src_tz = framework::vectorize<int>(x->dims());
auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format(); auto src_format = src_tz.size() == 2 ? MKLDNNMemoryFormat::nc : x->format();
...@@ -149,7 +149,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx, ...@@ -149,7 +149,7 @@ void eltwise_grad(const framework::ExecutionContext &ctx,
const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0; const T alpha = ctx.op().HasAttr("alpha") ? ctx.Attr<T>("alpha") : 0;
const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0; const T beta = ctx.op().HasAttr("beta") ? ctx.Attr<T>("beta") : 0;
std::vector<int> diff_dst_tz = framework::vectorize2int(diff_y->dims()); auto diff_dst_tz = framework::vectorize<int>(diff_y->dims());
// diff_dst and src dims should be the same // diff_dst and src dims should be the same
auto src_format = auto src_format =
......
...@@ -214,8 +214,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -214,8 +214,8 @@ class BatchNormMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
? mkldnn::prop_kind::forward_scoring ? mkldnn::prop_kind::forward_scoring
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
auto src_tz = paddle::framework::vectorize2int(x->dims()); auto src_tz = paddle::framework::vectorize<int>(x->dims());
auto scale_tz = paddle::framework::vectorize2int(scale->dims()); auto scale_tz = paddle::framework::vectorize<int>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0]; const unsigned int ic = scale_tz[0];
...@@ -349,11 +349,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -349,11 +349,11 @@ class BatchNormMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace()); T *diff_scale_data = diff_scale->mutable_data<T>(ctx.GetPlace());
T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace()); T *diff_shift_data = diff_shift->mutable_data<T>(ctx.GetPlace());
auto src_tz = paddle::framework::vectorize2int(x->dims()); auto src_tz = paddle::framework::vectorize<int>(x->dims());
auto diff_src_tz = src_tz; auto diff_src_tz = src_tz;
auto dst_tz = src_tz; auto dst_tz = src_tz;
auto diff_dst_tz = dst_tz; auto diff_dst_tz = dst_tz;
auto scale_tz = paddle::framework::vectorize2int(scale->dims()); auto scale_tz = paddle::framework::vectorize<int>(scale->dims());
PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1"); PADDLE_ENFORCE(scale_tz.size() == 1, "Dims of scale tensor is NOT 1");
const unsigned int ic = scale_tz[0]; const unsigned int ic = scale_tz[0];
......
...@@ -40,7 +40,7 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) { ...@@ -40,7 +40,7 @@ static void EnforceLayouts(const std::vector<const Tensor*> inputs) {
static memory::primitive_desc CreateMemPrimDesc(const Tensor& input, static memory::primitive_desc CreateMemPrimDesc(const Tensor& input,
const mkldnn::engine& engine, const mkldnn::engine& engine,
const memory::data_type& dt) { const memory::data_type& dt) {
const auto dims = paddle::framework::vectorize2int(input.dims()); const auto dims = paddle::framework::vectorize<int>(input.dims());
const auto format = input.format(); const auto format = input.format();
auto description = memory::desc(dims, dt, format); auto description = memory::desc(dims, dt, format);
auto mem_prim_desc = memory::primitive_desc(description, engine); auto mem_prim_desc = memory::primitive_desc(description, engine);
...@@ -73,7 +73,7 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx, ...@@ -73,7 +73,7 @@ std::string CreateKey(const paddle::framework::ExecutionContext& ctx,
key.reserve(platform::MKLDNNHandler::MaxKeyLength); key.reserve(platform::MKLDNNHandler::MaxKeyLength);
for (size_t i = 0; i < multi_input.size(); i++) { for (size_t i = 0; i < multi_input.size(); i++) {
platform::MKLDNNHandler::AppendKeyDims( platform::MKLDNNHandler::AppendKeyDims(
&key, paddle::framework::vectorize2int(multi_input[i]->dims())); &key, paddle::framework::vectorize<int>(multi_input[i]->dims()));
} }
platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis)); platform::MKLDNNHandler::AppendKey(&key, std::to_string(concat_axis));
platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out")); platform::MKLDNNHandler::AppendKey(&key, ctx.op().Output("Out"));
...@@ -124,7 +124,7 @@ class ConcatPrimitiveFactory { ...@@ -124,7 +124,7 @@ class ConcatPrimitiveFactory {
private: private:
memory::desc CreateDstMemDescriptor(Tensor* output, memory::desc CreateDstMemDescriptor(Tensor* output,
const memory::data_type& dt) { const memory::data_type& dt) {
auto dst_dims = paddle::framework::vectorize2int(output->dims()); auto dst_dims = paddle::framework::vectorize<int>(output->dims());
return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any); return memory::desc(dst_dims, dt, MKLDNNMemoryFormat::any);
} }
......
...@@ -183,12 +183,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -183,12 +183,11 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> weights_tz = auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvMKLDNNHandler::GetHash( const std::string key = platform::ConvMKLDNNHandler::GetHash(
...@@ -238,7 +237,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -238,7 +237,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize<int>(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_pd = handler.AcquireConvolutionPrimitiveDescriptor(
...@@ -281,7 +280,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -281,7 +280,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto output_data = auto output_data =
output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize()); output->mutable_data<T>(ctx.GetPlace(), handler.GetDstMemorySize());
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims()); paddle::framework::vectorize<int>(residual_param->dims());
auto residual_data_type = auto residual_data_type =
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
...@@ -405,13 +404,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -405,13 +404,12 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> weights_tz = auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
...@@ -514,7 +512,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -514,7 +512,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize<int>(bias->dims());
auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32, auto bias_md = platform::MKLDNNMemDesc(bias_tz, memory::data_type::s32,
MKLDNNMemoryFormat::x); MKLDNNMemoryFormat::x);
conv_pd = handler->AcquireConvolutionPrimitiveDescriptor( conv_pd = handler->AcquireConvolutionPrimitiveDescriptor(
...@@ -554,7 +552,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -554,7 +552,7 @@ class ConvMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
paddle::framework::ToMKLDNNDataType(residual_param->type()); paddle::framework::ToMKLDNNDataType(residual_param->type());
if (residual_param->format() != handler->GetDstFormat()) { if (residual_param->format() != handler->GetDstFormat()) {
auto residual_data_tz = auto residual_data_tz =
paddle::framework::vectorize2int(residual_param->dims()); paddle::framework::vectorize<int>(residual_param->dims());
auto user_residual_md = platform::MKLDNNMemDesc( auto user_residual_md = platform::MKLDNNMemDesc(
residual_data_tz, residual_dt, residual_param->format()); residual_data_tz, residual_dt, residual_param->format());
dst_memory_p = platform::SetDstMemory<T_out>( dst_memory_p = platform::SetDstMemory<T_out>(
...@@ -705,13 +703,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -705,13 +703,11 @@ class ConvMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T* input_grad_data = nullptr; T* input_grad_data = nullptr;
T* filter_grad_data = nullptr; T* filter_grad_data = nullptr;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> weights_tz = auto weights_tz = paddle::framework::vectorize<int>(filter->dims());
paddle::framework::vectorize2int(filter->dims());
int g = std::max(groups, 1); int g = std::max(groups, 1);
GetWeightsTz(weights_tz, g, is_conv3d); GetWeightsTz(weights_tz, g, is_conv3d);
std::vector<int> dst_tz = auto dst_tz = paddle::framework::vectorize<int>(output_grad->dims());
paddle::framework::vectorize2int(output_grad->dims());
auto src_format = input->format(); auto src_format = input->format();
MKLDNNMemoryFormat weights_format = MKLDNNMemoryFormat weights_format =
GetWeightsFormat(filter->format(), g, is_conv3d); GetWeightsFormat(filter->format(), g, is_conv3d);
......
...@@ -82,10 +82,10 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -82,10 +82,10 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
const T* filter_data = filter->data<T>(); const T* filter_data = filter->data<T>();
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> iohw_weights_tz = auto iohw_weights_tz = paddle::framework::vectorize<int>(filter->dims());
paddle::framework::vectorize2int(filter->dims()); auto weights_tz = iohw_weights_tz;
std::vector<int> weights_tz = iohw_weights_tz;
// IOHW -> OIHW // IOHW -> OIHW
weights_tz[0] = iohw_weights_tz[1]; weights_tz[0] = iohw_weights_tz[1];
weights_tz[1] = iohw_weights_tz[0]; weights_tz[1] = iohw_weights_tz[0];
...@@ -124,7 +124,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -124,7 +124,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
weights_tz[3] = h; weights_tz[3] = h;
weights_tz[4] = w; weights_tz[4] = w;
} }
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
// Get unique name for storing MKLDNN primitives // Get unique name for storing MKLDNN primitives
const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash( const std::string key = platform::ConvTransposeMKLDNNHandler::GetHash(
...@@ -166,7 +166,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -166,7 +166,7 @@ class ConvTransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference auto fwd_prop_kind = is_test ? mkldnn::prop_kind::forward_inference
: mkldnn::prop_kind::forward_training; : mkldnn::prop_kind::forward_training;
if (bias) { if (bias) {
bias_tz = paddle::framework::vectorize2int(bias->dims()); bias_tz = paddle::framework::vectorize<int>(bias->dims());
auto bias_md = platform::MKLDNNMemDesc( auto bias_md = platform::MKLDNNMemDesc(
bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x); bias_tz, platform::MKLDNNGetDataType<T>(), MKLDNNMemoryFormat::x);
conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor( conv_transpose_pd = handler.AcquireConvolutionPrimitiveDescriptor(
......
...@@ -59,8 +59,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> { ...@@ -59,8 +59,8 @@ class DeQuantOpKernel : public framework::OpKernel<T> {
std::vector<float> reorder_scale = {1.0f / scale_data}; std::vector<float> reorder_scale = {1.0f / scale_data};
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
MKLDNNMemoryFormat src_fmt = input->format(); MKLDNNMemoryFormat src_fmt = input->format();
......
...@@ -109,7 +109,7 @@ class FCPrimitiveFactory { ...@@ -109,7 +109,7 @@ class FCPrimitiveFactory {
static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor, static mkldnn::memory::desc CreateMemDescriptor(const Tensor* tensor,
MKLDNNMemoryFormat format) { MKLDNNMemoryFormat format) {
auto dims = framework::vectorize2int(tensor->dims()); auto dims = framework::vectorize<int>(tensor->dims());
return CreateMemDescriptor(dims, format); return CreateMemDescriptor(dims, format);
} }
...@@ -124,7 +124,7 @@ class FCPrimitiveFactory { ...@@ -124,7 +124,7 @@ class FCPrimitiveFactory {
} }
mkldnn::memory TransposeWeights(const Tensor* weights) { mkldnn::memory TransposeWeights(const Tensor* weights) {
auto dims = framework::vectorize2int(weights->dims()); auto dims = framework::vectorize<int>(weights->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io); auto src_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi); auto dst_desc = CreateMemDescriptor(dims, MKLDNNMemoryFormat::oi);
...@@ -182,8 +182,8 @@ class FCPrimitiveFactory { ...@@ -182,8 +182,8 @@ class FCPrimitiveFactory {
mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input, mkldnn::memory CreateFourDimWeightsMemory(const Tensor* input,
const Tensor* weights) { const Tensor* weights) {
auto input_dims = framework::vectorize2int(input->dims()); auto input_dims = framework::vectorize<int>(input->dims());
auto weight_dims = framework::vectorize2int(weights->dims()); auto weight_dims = framework::vectorize<int>(weights->dims());
auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]}; auto dims = {weight_dims[1], input_dims[1], input_dims[2], input_dims[3]};
auto dst_format = MatchWeightFormat(input->format()); auto dst_format = MatchWeightFormat(input->format());
......
...@@ -56,7 +56,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -56,7 +56,7 @@ class LRNMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
auto e_mid = framework::EigenTensor<T, 4>::From(*mid); auto e_mid = framework::EigenTensor<T, 4>::From(*mid);
e_mid = e_mid.constant(k); e_mid = e_mid.constant(k);
auto dims = paddle::framework::vectorize2int(x->dims()); auto dims = paddle::framework::vectorize<int>(x->dims());
// Format and dims are assumed to be the same for dst and src // Format and dims are assumed to be the same for dst and src
auto md = paddle::platform::MKLDNNMemDesc( auto md = paddle::platform::MKLDNNMemDesc(
...@@ -119,7 +119,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -119,7 +119,7 @@ class LRNMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
auto x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace()); auto x_grad_data = x_grad->mutable_data<T>(ctx.GetPlace());
auto out_grad_data = out_grad->data<T>(); auto out_grad_data = out_grad->data<T>();
auto dims = paddle::framework::vectorize2int(x->dims()); auto dims = paddle::framework::vectorize<int>(x->dims());
const std::string key = platform::LRNMKLDNNHandler::GetHash( const std::string key = platform::LRNMKLDNNHandler::GetHash(
dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out")); dims, n, alpha, beta, k, x->format(), ctx.op().Input("Out"));
......
...@@ -116,7 +116,7 @@ class MulPrimitiveFactory { ...@@ -116,7 +116,7 @@ class MulPrimitiveFactory {
memory::desc CreateMemDescriptor( memory::desc CreateMemDescriptor(
const Tensor *tensor, MKLDNNMemoryFormat format, const Tensor *tensor, MKLDNNMemoryFormat format,
memory::data_type type = platform::MKLDNNGetDataType<T>()) { memory::data_type type = platform::MKLDNNGetDataType<T>()) {
auto dims = framework::vectorize2int(tensor->dims()); auto dims = framework::vectorize<int>(tensor->dims());
return platform::MKLDNNMemDesc(dims, type, format); return platform::MKLDNNMemDesc(dims, type, format);
} }
...@@ -156,7 +156,7 @@ class MulPrimitiveFactory { ...@@ -156,7 +156,7 @@ class MulPrimitiveFactory {
} }
memory TransposeInputY(const Tensor *input_y) { memory TransposeInputY(const Tensor *input_y) {
auto dims = framework::vectorize2int(input_y->dims()); auto dims = framework::vectorize<int>(input_y->dims());
std::swap(dims[0], dims[1]); // Correct output dimensions std::swap(dims[0], dims[1]); // Correct output dimensions
auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io); auto src_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::io);
auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi); auto dst_desc = CreateMemDescriptor<YT>(dims, MKLDNNMemoryFormat::oi);
......
...@@ -69,8 +69,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -69,8 +69,8 @@ class PoolMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
auto input_format = input->format(); auto input_format = input->format();
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef}; MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
...@@ -166,10 +166,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -166,10 +166,8 @@ class PoolMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace()); T* in_x_grad_data = in_x_grad->mutable_data<T>(ctx.GetPlace());
MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef}; MKLDNNMemoryFormat in_x_grad_format{MKLDNNMemoryFormat::format_undef};
std::vector<int> diff_src_tz = auto diff_src_tz = paddle::framework::vectorize<int>(in_x_grad->dims());
paddle::framework::vectorize2int(in_x_grad->dims()); auto diff_dst_tz = paddle::framework::vectorize<int>(out_grad->dims());
std::vector<int> diff_dst_tz =
paddle::framework::vectorize2int(out_grad->dims());
// Get an unique name from "argument" name of "Out" variable // Get an unique name from "argument" name of "Out" variable
// This name will be used as key when referring info from device context // This name will be used as key when referring info from device context
......
...@@ -54,8 +54,8 @@ class QuantOpKernel : public framework::OpKernel<T> { ...@@ -54,8 +54,8 @@ class QuantOpKernel : public framework::OpKernel<T> {
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
const T* input_data = input->data<T>(); const T* input_data = input->data<T>();
......
...@@ -43,8 +43,8 @@ class ReQuantOpKernel : public framework::OpKernel<T> { ...@@ -43,8 +43,8 @@ class ReQuantOpKernel : public framework::OpKernel<T> {
const auto& engine = dev_ctx.GetEngine(); const auto& engine = dev_ctx.GetEngine();
std::vector<primitive> pipeline; std::vector<primitive> pipeline;
std::vector<int> src_tz = paddle::framework::vectorize2int(input->dims()); auto src_tz = paddle::framework::vectorize<int>(input->dims());
std::vector<int> dst_tz = paddle::framework::vectorize2int(output->dims()); auto dst_tz = paddle::framework::vectorize<int>(output->dims());
mkldnn::memory::data_type src_dt = mkldnn::memory::data_type src_dt =
paddle::framework::ToMKLDNNDataType(input->type()); paddle::framework::ToMKLDNNDataType(input->type());
mkldnn::memory::data_type dst_dt = src_dt; mkldnn::memory::data_type dst_dt = src_dt;
......
...@@ -199,8 +199,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> { ...@@ -199,8 +199,8 @@ class SoftmaxMKLDNNKernel : public paddle::framework::OpKernel<T> {
const T* input_data = flattened_input.data<T>(); const T* input_data = flattened_input.data<T>();
T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace()); T* output_data = flattened_output.mutable_data<T>(ctx.GetPlace());
std::vector<int> src_tz = paddle::framework::vectorize2int(flattened_dims); auto src_tz = paddle::framework::vectorize<int>(flattened_dims);
std::vector<int> dst_tz = src_tz; auto dst_tz = src_tz;
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
// Generate keys for storing/retriving primitives for this operator // Generate keys for storing/retriving primitives for this operator
...@@ -268,8 +268,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> { ...@@ -268,8 +268,8 @@ class SoftmaxMKLDNNGradKernel : public paddle::framework::OpKernel<T> {
const T* diff_dst_ptr = flattened_dout.template data<T>(); const T* diff_dst_ptr = flattened_dout.template data<T>();
T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace()); T* diff_src_ptr = flattened_dx.template mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = paddle::framework::vectorize2int(flattened_dims); auto dst_tz = paddle::framework::vectorize<int>(flattened_dims);
std::vector<int> src_tz(dst_tz); auto src_tz(dst_tz);
// Same memory descriptor to be used for input and output // Same memory descriptor to be used for input and output
memory::dims softmax_tz = {src_tz[0], src_tz[1]}; memory::dims softmax_tz = {src_tz[0], src_tz[1]};
......
...@@ -63,7 +63,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -63,7 +63,7 @@ class SumMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
LoDTensor* output = ctx.Output<LoDTensor>("Out"); LoDTensor* output = ctx.Output<LoDTensor>("Out");
T* output_data = output->mutable_data<T>(ctx.GetPlace()); T* output_data = output->mutable_data<T>(ctx.GetPlace());
std::vector<int> dst_tz = framework::vectorize2int(output->dims()); auto dst_tz = framework::vectorize<int>(output->dims());
auto src_tz = dst_tz; auto src_tz = dst_tz;
MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef}; MKLDNNMemoryFormat output_format{MKLDNNMemoryFormat::format_undef};
std::vector<float> scales; std::vector<float> scales;
......
...@@ -43,7 +43,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> { ...@@ -43,7 +43,7 @@ class TransposeMKLDNNOpKernel : public paddle::framework::OpKernel<T> {
return; return;
} }
std::vector<int> nchw_tz = paddle::framework::vectorize2int(input->dims()); auto nchw_tz = paddle::framework::vectorize<int>(input->dims());
const std::string key = platform::TransposeMKLDNNHandler::GetHash( const std::string key = platform::TransposeMKLDNNHandler::GetHash(
nchw_tz, axis, nchw_tz, axis,
...@@ -97,8 +97,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> { ...@@ -97,8 +97,7 @@ class TransposeMKLDNNGradOpKernel : public paddle::framework::OpKernel<T> {
const T* out_grad_data = out_grad->data<T>(); const T* out_grad_data = out_grad->data<T>();
x_grad->mutable_data<T>(ctx.GetPlace()); x_grad->mutable_data<T>(ctx.GetPlace());
std::vector<int> nchw_tz = auto nchw_tz = paddle::framework::vectorize<int>(out_grad->dims());
paddle::framework::vectorize2int(out_grad->dims());
const std::string key = platform::TransposeMKLDNNHandler::GetHash( const std::string key = platform::TransposeMKLDNNHandler::GetHash(
nchw_tz, axis, ctx.op().Output(framework::GradVarName("X"))); nchw_tz, axis, ctx.op().Output(framework::GradVarName("X")));
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册