// Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. // You may obtain a copy of the License at // // http://www.apache.org/licenses/LICENSE-2.0 // // Unless required by applicable law or agreed to in writing, software // distributed under the License is distributed on an "AS IS" BASIS, // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // See the License for the specific language governing permissions and // limitations under the License. #pragma once #include #include #include #include #include #include #include #include "paddle/fluid/operators/math/blas.h" #include "paddle/fluid/platform/device_context.h" namespace paddle { namespace distributed { template inline paddle::operators::math::BlasT GetBlas() { auto cpu_ctx = paddle::platform::CPUDeviceContext(); return paddle::operators::math::GetBlas(cpu_ctx); } template inline void SQRT(int n, const T* x, T* z) { for (int i = 0; i < n; ++i) { z[i] = sqrt(x[i]); } } template inline void ADD(int n, const T* x, const T y, T* z) { for (int i = 0; i < n; ++i) { z[i] = x[i] + y; } } static bool StartWith(const std::string& str, const std::string& substr) { return str.find(substr) == 0; } static bool EndWith(const std::string& str, const std::string& substr) { return str.rfind(substr) == (str.length() - substr.length()); } inline std::vector bucket(const int v_size, const int b_size) { int remainder = v_size % b_size; int bucket = v_size / b_size; std::vector ret_vec(b_size, bucket); for (int i = 0; i < remainder; ++i) { ret_vec[i] = ret_vec[i] + 1; } int cur_bucket = 0; for (int& j : ret_vec) { int tmp = j; j = cur_bucket; cur_bucket += tmp; } ret_vec.push_back(cur_bucket); return ret_vec; } template std::string to_string(const std::vector& vec) { std::stringstream ss; for (const auto& c : vec) { ss << c << " "; } return ss.str(); } } }