未验证 提交 75080988 编写于 作者: W wangzhen38 提交者: GitHub

【code format】fix cpplint style 5 (#43733)

* fix cpplint style 5

* fix cpplint style 5

* fix cpplint style 5

* fix cpplint style 5

* fix cpplint style 5

* fix cpplint style 5

* fix cpplint style 5
上级 f9198372
...@@ -228,32 +228,36 @@ void testFeatureNodeSerializeFloat64() { ...@@ -228,32 +228,36 @@ void testFeatureNodeSerializeFloat64() {
// void testCache(); // void testCache();
void testGraphToBuffer(); void testGraphToBuffer();
std::string edges[] = { const char* edges[] = {"37\t45\t0.34",
std::string("37\t45\t0.34"), std::string("37\t145\t0.31"), "37\t145\t0.31",
std::string("37\t112\t0.21"), std::string("96\t48\t1.4"), "37\t112\t0.21",
std::string("96\t247\t0.31"), std::string("96\t111\t1.21"), "96\t48\t1.4",
std::string("59\t45\t0.34"), std::string("59\t145\t0.31"), "96\t247\t0.31",
std::string("59\t122\t0.21"), std::string("97\t48\t0.34"), "96\t111\t1.21",
std::string("97\t247\t0.31"), std::string("97\t111\t0.21")}; "59\t45\t0.34",
"59\t145\t0.31",
"59\t122\t0.21",
"97\t48\t0.34",
"97\t247\t0.31",
"97\t111\t0.21"};
char edge_file_name[] = "edges.txt"; char edge_file_name[] = "edges.txt";
std::string nodes[] = { const char* nodes[] = {"user\t37\ta 0.34\tb 13 14\tc hello\td abc",
std::string("user\t37\ta 0.34\tb 13 14\tc hello\td abc"), "user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd",
std::string("user\t96\ta 0.31\tb 15 10\tc 96hello\td abcd"), "user\t59\ta 0.11\tb 11 14",
std::string("user\t59\ta 0.11\tb 11 14"), "user\t97\ta 0.11\tb 12 11",
std::string("user\t97\ta 0.11\tb 12 11"), "item\t45\ta 0.21",
std::string("item\t45\ta 0.21"), "item\t145\ta 0.21",
std::string("item\t145\ta 0.21"), "item\t112\ta 0.21",
std::string("item\t112\ta 0.21"), "item\t48\ta 0.21",
std::string("item\t48\ta 0.21"), "item\t247\ta 0.21",
std::string("item\t247\ta 0.21"), "item\t111\ta 0.21",
std::string("item\t111\ta 0.21"), "item\t46\ta 0.21",
std::string("item\t46\ta 0.21"), "item\t146\ta 0.21",
std::string("item\t146\ta 0.21"), "item\t122\ta 0.21",
std::string("item\t122\ta 0.21"), "item\t49\ta 0.21",
std::string("item\t49\ta 0.21"), "item\t248\ta 0.21",
std::string("item\t248\ta 0.21"), "item\t113\ta 0.21"};
std::string("item\t113\ta 0.21")};
char node_file_name[] = "nodes.txt"; char node_file_name[] = "nodes.txt";
void prepare_file(char file_name[], bool load_edge) { void prepare_file(char file_name[], bool load_edge) {
...@@ -335,7 +339,8 @@ void GetDownpourSparseTableProto( ...@@ -335,7 +339,8 @@ void GetDownpourSparseTableProto(
/*-------------------------------------------------------------------------*/ /*-------------------------------------------------------------------------*/
std::string ip_ = "127.0.0.1", ip2 = "127.0.0.1"; const char* ip_ = "127.0.0.1";
const char* ip2 = "127.0.0.1";
uint32_t port_ = 5209, port2 = 5210; uint32_t port_ = 5209, port2 = 5210;
std::vector<std::string> host_sign_list_; std::vector<std::string> host_sign_list_;
...@@ -382,8 +387,10 @@ void RunServer2() { ...@@ -382,8 +387,10 @@ void RunServer2() {
} }
void RunClient( void RunClient(
std::map<uint64_t, std::vector<paddle::distributed::Region>>& dense_regions, const std::map<uint64_t, std::vector<paddle::distributed::Region>>&
int index, paddle::distributed::PsBaseService* service) { dense_regions,
int index,
paddle::distributed::PsBaseService* service) {
::paddle::distributed::PSParameter worker_proto = GetWorkerProto(); ::paddle::distributed::PSParameter worker_proto = GetWorkerProto();
paddle::distributed::PaddlePSEnvironment _ps_env; paddle::distributed::PaddlePSEnvironment _ps_env;
auto servers_ = host_sign_list_.size(); auto servers_ = host_sign_list_.size();
...@@ -533,8 +540,8 @@ void RunBrpcPushSparse() { ...@@ -533,8 +540,8 @@ void RunBrpcPushSparse() {
VLOG(0) << "second bound"; VLOG(0) << "second bound";
client1.load_node_file(std::string("user"), std::string(node_file_name)); client1.load_node_file(std::string("user"), std::string(node_file_name));
client1.load_node_file(std::string("item"), std::string(node_file_name)); client1.load_node_file(std::string("item"), std::string(node_file_name));
client1.load_edge_file(std::string("user2item"), std::string(edge_file_name), client1.load_edge_file(
0); std::string("user2item"), std::string(edge_file_name), 0);
nodes.clear(); nodes.clear();
VLOG(0) << "start to pull graph list"; VLOG(0) << "start to pull graph list";
nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1); nodes = client1.pull_graph_list(std::string("user"), 0, 1, 4, 1);
...@@ -551,8 +558,8 @@ void RunBrpcPushSparse() { ...@@ -551,8 +558,8 @@ void RunBrpcPushSparse() {
std::cout << "check pull graph list by step " << test_step << std::endl; std::cout << "check pull graph list by step " << test_step << std::endl;
for (int server_id = 0; server_id < 2; server_id++) { for (int server_id = 0; server_id < 2; server_id++) {
for (int start_step = 0; start_step < test_step; start_step++) { for (int start_step = 0; start_step < test_step; start_step++) {
nodes = client1.pull_graph_list(std::string("item"), server_id, nodes = client1.pull_graph_list(
start_step, 12, test_step); std::string("item"), server_id, start_step, 12, test_step);
for (auto g : nodes) { for (auto g : nodes) {
count_item_nodes.insert(g.get_id()); count_item_nodes.insert(g.get_id());
} }
...@@ -570,14 +577,15 @@ void RunBrpcPushSparse() { ...@@ -570,14 +577,15 @@ void RunBrpcPushSparse() {
std::vector<int64_t> node_ids; std::vector<int64_t> node_ids;
node_ids.push_back(96); node_ids.push_back(96);
node_ids.push_back(37); node_ids.push_back(37);
res = client1.batch_sample_neighbors(std::string("user2item"), node_ids, 4, res = client1.batch_sample_neighbors(
true, false); std::string("user2item"), node_ids, 4, true, false);
ASSERT_EQ(res.first[1].size(), 1); ASSERT_EQ(res.first[1].size(), 1);
std::vector<int64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6); std::vector<int64_t> nodes_ids = client2.random_sample_nodes("user", 0, 6);
ASSERT_EQ(nodes_ids.size(), 2); ASSERT_EQ(nodes_ids.size(), 2);
ASSERT_EQ(true, (nodes_ids[0] == 59 && nodes_ids[1] == 37) || ASSERT_EQ(true,
(nodes_ids[0] == 37 && nodes_ids[1] == 59)); (nodes_ids[0] == 59 && nodes_ids[1] == 37) ||
(nodes_ids[0] == 37 && nodes_ids[1] == 59));
VLOG(0) << "start to test get node feat"; VLOG(0) << "start to test get node feat";
// Test get node feat // Test get node feat
...@@ -598,14 +606,14 @@ void RunBrpcPushSparse() { ...@@ -598,14 +606,14 @@ void RunBrpcPushSparse() {
node_feat[1][0] = "helloworld"; node_feat[1][0] = "helloworld";
client1.set_node_feat(std::string("user"), node_ids, feature_names, client1.set_node_feat(
node_feat); std::string("user"), node_ids, feature_names, node_feat);
// sleep(5); // sleep(5);
node_feat = node_feat =
client1.get_node_feat(std::string("user"), node_ids, feature_names); client1.get_node_feat(std::string("user"), node_ids, feature_names);
VLOG(0) << "get_node_feat: " << node_feat[1][0]; VLOG(0) << "get_node_feat: " << node_feat[1][0];
ASSERT_TRUE(node_feat[1][0] == "helloworld"); ASSERT_EQ(node_feat[1][0], "helloworld");
// Test string // Test string
node_ids.clear(); node_ids.clear();
...@@ -698,7 +706,7 @@ void testGraphToBuffer() { ...@@ -698,7 +706,7 @@ void testGraphToBuffer() {
s.set_feature(0, std::string("hhhh")); s.set_feature(0, std::string("hhhh"));
s.set_id(65); s.set_id(65);
int size = s.get_size(true); int size = s.get_size(true);
char str[size]; char str[size]; // NOLINT
s.to_buffer(str, true); s.to_buffer(str, true);
s1.recover_from_buffer(str); s1.recover_from_buffer(str);
ASSERT_EQ(s.get_id(), s1.get_id()); ASSERT_EQ(s.get_id(), s1.get_id());
......
...@@ -11,7 +11,6 @@ ...@@ -11,7 +11,6 @@
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and // See the License for the specific language governing permissions and
// limitations under the License. // limitations under the License.
#pragma once #pragma once
#if defined _WIN32 || defined __APPLE__ #if defined _WIN32 || defined __APPLE__
...@@ -92,17 +91,21 @@ class ArchiveBase { ...@@ -92,17 +91,21 @@ class ArchiveBase {
char* Buffer() { return buffer_; } char* Buffer() { return buffer_; }
void SetReadBuffer(char* buffer, size_t length, void SetReadBuffer(char* buffer,
size_t length,
std::function<void(char*)>&& deleter) { std::function<void(char*)>&& deleter) {
SetBuffer(buffer, length, length, std::move(deleter)); SetBuffer(buffer, length, length, std::move(deleter));
} }
void SetWriteBuffer(char* buffer, size_t capacity, void SetWriteBuffer(char* buffer,
size_t capacity,
std::function<void(char*)>&& deleter) { std::function<void(char*)>&& deleter) {
SetBuffer(buffer, 0, capacity, std::move(deleter)); SetBuffer(buffer, 0, capacity, std::move(deleter));
} }
void SetBuffer(char* buffer, size_t length, size_t capacity, void SetBuffer(char* buffer,
size_t length,
size_t capacity,
std::function<void(char*)>&& deleter) { std::function<void(char*)>&& deleter) {
CHECK(length <= capacity); CHECK(length <= capacity);
FreeBuffer(); FreeBuffer();
...@@ -324,9 +327,10 @@ class Archive<BinaryArchiveType> : public ArchiveBase { ...@@ -324,9 +327,10 @@ class Archive<BinaryArchiveType> : public ArchiveBase {
size_t temp = Limit() - Finish(); size_t temp = Limit() - Finish();
int len = snprintf(Finish(), temp, fmt, args...); int len = snprintf(Finish(), temp, fmt, args...);
CHECK(len >= 0); // NOLINT CHECK(len >= 0); // NOLINT
if ((size_t)len >= temp) { if (static_cast<size_t>(len) >= temp) {
PrepareWrite(len + 1); PrepareWrite(len + 1);
CHECK(snprintf(Finish(), (size_t)len + 1, fmt, args...) == len); CHECK(snprintf(Finish(), static_cast<size_t>(len) + 1, fmt, args...) ==
len);
} }
AdvanceFinish(len); AdvanceFinish(len);
} }
...@@ -351,7 +355,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, T (&p)[N]) { ...@@ -351,7 +355,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, T (&p)[N]) {
template <class AR, class T> template <class AR, class T>
Archive<AR>& operator<<(Archive<AR>& ar, const std::vector<T>& p) { Archive<AR>& operator<<(Archive<AR>& ar, const std::vector<T>& p) {
#ifdef _LINUX #ifdef _LINUX
ar << (size_t)p.size(); ar << static_cast<size_t>(p.size());
#else #else
ar << (uint64_t)p.size(); ar << (uint64_t)p.size();
#endif #endif
...@@ -377,7 +381,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::vector<T>& p) { ...@@ -377,7 +381,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::vector<T>& p) {
template <class AR, class T> template <class AR, class T>
Archive<AR>& operator<<(Archive<AR>& ar, const std::valarray<T>& p) { Archive<AR>& operator<<(Archive<AR>& ar, const std::valarray<T>& p) {
#ifdef _LINUX #ifdef _LINUX
ar << (size_t)p.size(); ar << static_cast<size_t>(p.size());
#else #else
ar << (uint64_t)p.size(); ar << (uint64_t)p.size();
#endif #endif
...@@ -402,7 +406,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::valarray<T>& p) { ...@@ -402,7 +406,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::valarray<T>& p) {
inline BinaryArchive& operator<<(BinaryArchive& ar, const std::string& s) { inline BinaryArchive& operator<<(BinaryArchive& ar, const std::string& s) {
#ifdef _LINUX #ifdef _LINUX
ar << (size_t)s.length(); ar << static_cast<size_t>(s.length());
#else #else
ar << (uint64_t)s.length(); ar << (uint64_t)s.length();
#endif #endif
...@@ -482,13 +486,15 @@ Archive<AR>& operator<<(Archive<AR>& ar, const std::tuple<T...>& x) { ...@@ -482,13 +486,15 @@ Archive<AR>& operator<<(Archive<AR>& ar, const std::tuple<T...>& x) {
#ifdef _LINUX #ifdef _LINUX
template <class AR, class... T> template <class AR, class... T>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT Archive<AR>& DeserializeTuple(const Archive<AR>& ar,
std::tuple<T...>& x, // NOLINT
std::integral_constant<size_t, 0> n) { std::integral_constant<size_t, 0> n) {
return ar; return ar;
} }
#else #else
template <class AR, class... T> template <class AR, class... T>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT Archive<AR>& DeserializeTuple(const Archive<AR>& ar,
std::tuple<T...>& x, // NOLINT
std::integral_constant<uint64_t, 0> n) { std::integral_constant<uint64_t, 0> n) {
return ar; return ar;
} }
...@@ -496,14 +502,16 @@ Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT ...@@ -496,14 +502,16 @@ Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT
#ifdef _LINUX #ifdef _LINUX
template <class AR, class... T, size_t N> template <class AR, class... T, size_t N>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT Archive<AR>& DeserializeTuple(const Archive<AR>& ar,
std::tuple<T...>& x, // NOLINT
std::integral_constant<size_t, N> n) { std::integral_constant<size_t, N> n) {
return DeserializeTuple(ar, x, std::integral_constant<size_t, N - 1>()) >> return DeserializeTuple(ar, x, std::integral_constant<size_t, N - 1>()) >>
std::get<N - 1>(x); std::get<N - 1>(x);
} }
#else #else
template <class AR, class... T, uint64_t N> template <class AR, class... T, uint64_t N>
Archive<AR>& DeserializeTuple(Archive<AR>& ar, std::tuple<T...>& x, // NOLINT Archive<AR>& DeserializeTuple(const Archive<AR>& ar,
std::tuple<T...>& x, // NOLINT
std::integral_constant<uint64_t, N> n) { std::integral_constant<uint64_t, N> n) {
return DeserializeTuple(ar, x, std::integral_constant<uint64_t, N - 1>()) >> return DeserializeTuple(ar, x, std::integral_constant<uint64_t, N - 1>()) >>
std::get<N - 1>(x); std::get<N - 1>(x);
...@@ -529,7 +537,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) { ...@@ -529,7 +537,7 @@ Archive<AR>& operator>>(Archive<AR>& ar, std::tuple<T...>& x) {
template <class AR, class KEY, class VALUE, class... ARGS> \ template <class AR, class KEY, class VALUE, class... ARGS> \
Archive<AR>& operator<<(Archive<AR>& ar, \ Archive<AR>& operator<<(Archive<AR>& ar, \
const MAP_TYPE<KEY, VALUE, ARGS...>& p) { \ const MAP_TYPE<KEY, VALUE, ARGS...>& p) { \
ar << (size_t)p.size(); \ ar << static_cast<size_t>(p.size()); \
for (auto it = p.begin(); it != p.end(); ++it) { \ for (auto it = p.begin(); it != p.end(); ++it) { \
ar << *it; \ ar << *it; \
} \ } \
...@@ -579,7 +587,7 @@ ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size)) ...@@ -579,7 +587,7 @@ ARCHIVE_REPEAT(std::unordered_multimap, p.reserve(size))
#define ARCHIVE_REPEAT(SET_TYPE, RESERVE_STATEMENT) \ #define ARCHIVE_REPEAT(SET_TYPE, RESERVE_STATEMENT) \
template <class AR, class KEY, class... ARGS> \ template <class AR, class KEY, class... ARGS> \
Archive<AR>& operator<<(Archive<AR>& ar, const SET_TYPE<KEY, ARGS...>& p) { \ Archive<AR>& operator<<(Archive<AR>& ar, const SET_TYPE<KEY, ARGS...>& p) { \
ar << (size_t)p.size(); \ ar << static_cast<size_t>(p.size()); \
for (auto it = p.begin(); it != p.end(); ++it) { \ for (auto it = p.begin(); it != p.end(); ++it) { \
ar << *it; \ ar << *it; \
} \ } \
......
...@@ -11,16 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS, ...@@ -11,16 +11,17 @@ distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. */ limitations under the License. */
#pragma once #pragma once
#include <fstream> #include <fstream>
#include <map>
#include <memory> #include <memory>
#include <mutex> // NOLINT #include <mutex> // NOLINT
#include <string> #include <string>
#include <thread> // NOLINT #include <thread> // NOLINT
#include <unordered_map> // NOLINT #include <unordered_map> // NOLINT
#include <unordered_set> // NOLINT #include <unordered_set> // NOLINT
#include <utility>
#include <vector> #include <vector>
#if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS) #if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS)
#include "bthread/bthread.h" #include "bthread/bthread.h"
...@@ -76,9 +77,13 @@ class HeterTask { ...@@ -76,9 +77,13 @@ class HeterTask {
<< std::endl; << std::endl;
} }
} }
void PackTask(Scope* scope, int taskid, DataFeed* reader, int cur_batch, void PackTask(Scope* scope,
int taskid,
DataFeed* reader,
int cur_batch,
const ProgramDesc& program); const ProgramDesc& program);
void PackGpuTask(Scope* thread_scope, DataFeed* reader, void PackGpuTask(Scope* thread_scope,
DataFeed* reader,
const ProgramDesc& program); const ProgramDesc& program);
Scope* scope_{nullptr}; Scope* scope_{nullptr};
...@@ -145,7 +150,7 @@ class HeterObjectPool { ...@@ -145,7 +150,7 @@ class HeterObjectPool {
#if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS) #if defined(PADDLE_WITH_PSLIB) && !defined(PADDLE_WITH_HETERPS)
struct BthreadMutextGuard { struct BthreadMutextGuard {
BthreadMutextGuard(bthread_mutex_t* rho) { explicit BthreadMutextGuard(bthread_mutex_t* rho) {
mutex_ = rho; mutex_ = rho;
bthread_mutex_lock(mutex_); bthread_mutex_lock(mutex_);
} }
...@@ -220,7 +225,7 @@ class HeterList { ...@@ -220,7 +225,7 @@ class HeterList {
void SetCap(int num) { cap_ = num; } void SetCap(int num) { cap_ = num; }
bool TryPut(K& key, T& value) { bool TryPut(const K& key, const T& value) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this] { return size < cap_; }); cond_.wait(lock, [this] { return size < cap_; });
if (task_map_.find(key) != task_map_.end()) { if (task_map_.find(key) != task_map_.end()) {
...@@ -236,7 +241,7 @@ class HeterList { ...@@ -236,7 +241,7 @@ class HeterList {
} }
} }
bool Put(K& key, T& value) { bool Put(const K& key, const T& value) {
std::unique_lock<std::mutex> lock(mutex_); std::unique_lock<std::mutex> lock(mutex_);
cond_.wait(lock, [this] { return size < cap_; }); cond_.wait(lock, [this] { return size < cap_; });
HeterNode<K, T>* node = new HeterNode<K, T>; HeterNode<K, T>* node = new HeterNode<K, T>;
......
...@@ -11,7 +11,9 @@ limitations under the License. */ ...@@ -11,7 +11,9 @@ limitations under the License. */
#include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h" #include "paddle/fluid/framework/ir/mkldnn/cpu_bfloat16_pass.h"
#include <map>
#include <string> #include <string>
#include <unordered_map>
#include <vector> #include <vector>
#include "paddle/fluid/framework/ir/graph_pattern_detector.h" #include "paddle/fluid/framework/ir/graph_pattern_detector.h"
...@@ -43,7 +45,7 @@ class Quanter { ...@@ -43,7 +45,7 @@ class Quanter {
VarDesc quant_x_desc( VarDesc quant_x_desc(
patterns::PDNodeName(get_op_type(), get_op_edge())); patterns::PDNodeName(get_op_type(), get_op_edge()));
auto quant_x_node = graph.CreateVarNode(&quant_x_desc); auto quant_x_node = graph->CreateVarNode(&quant_x_desc);
const auto xput_name = quant_x_node->Name(); const auto xput_name = quant_x_node->Name();
quant_xput_names.emplace_back(xput_name); quant_xput_names.emplace_back(xput_name);
...@@ -64,7 +66,7 @@ class Quanter { ...@@ -64,7 +66,7 @@ class Quanter {
virtual ~Quanter() = default; virtual ~Quanter() = default;
protected: protected:
Graph& graph; Graph* graph;
ir::Node* const op; ir::Node* const op;
std::map<std::string, ir::Node*> xputs_map; std::map<std::string, ir::Node*> xputs_map;
...@@ -72,8 +74,10 @@ class Quanter { ...@@ -72,8 +74,10 @@ class Quanter {
int counter = 0; int counter = 0;
Quanter(Graph& graph, ir::Node* const op, const VariableNameMap& op_xputs) Quanter(Graph* const graph,
: graph(graph), op(op), op_xputs(op_xputs){}; ir::Node* const op,
const VariableNameMap& op_xputs)
: graph(graph), op(op), op_xputs(op_xputs) {}
virtual bool IsNotPermittedOpType() const = 0; virtual bool IsNotPermittedOpType() const = 0;
virtual bool IsNotPermittedName(const std::string& input_name) const = 0; virtual bool IsNotPermittedName(const std::string& input_name) const = 0;
...@@ -101,10 +105,11 @@ class Quanter { ...@@ -101,10 +105,11 @@ class Quanter {
op_desc.SetAttr("Scale", 1.f); op_desc.SetAttr("Scale", 1.f);
op_desc.SetAttr("Shift", 0.0f); op_desc.SetAttr("Shift", 0.0f);
op_desc.SetAttr("bfloat16", true); op_desc.SetAttr("bfloat16", true);
op_desc.SetAttr("output_format", op->Op()->HasAttr("data_layout") op_desc.SetAttr("output_format",
? op->Op()->GetAttr("data_layout") op->Op()->HasAttr("data_layout")
: std::string("NCHW")); ? op->Op()->GetAttr("data_layout")
return graph.CreateOpNode(&op_desc); // OpDesc will be copied. : std::string("NCHW"));
return graph->CreateOpNode(&op_desc); // OpDesc will be copied.
} }
void UnlinkNodes(ir::Node* a, ir::Node* b) const { void UnlinkNodes(ir::Node* a, ir::Node* b) const {
...@@ -118,16 +123,18 @@ class Quanter { ...@@ -118,16 +123,18 @@ class Quanter {
class Quantizer final : public Quanter { class Quantizer final : public Quanter {
public: public:
Quantizer(Graph* const graph, ir::Node* const op) Quantizer(Graph* const graph, ir::Node* const op)
: Quanter(*graph, op, op->Op()->Inputs()) { : Quanter(graph, op, op->Op()->Inputs()) {
auto inputs = op->inputs; auto inputs = op->inputs;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
inputs.size(), 1, inputs.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"OP(%s)'s inputs(%d) must be equal or greater than 1.", op->Name(), "OP(%s)'s inputs(%d) must be equal or greater than 1.",
op->Name(),
inputs.size())); inputs.size()));
for (auto input : inputs) xputs_map[input->Name()] = input; for (auto input : inputs) xputs_map[input->Name()] = input;
}; }
protected: protected:
bool IsNotPermittedOpType() const override { return false; } bool IsNotPermittedOpType() const override { return false; }
...@@ -138,18 +145,20 @@ class Quantizer final : public Quanter { ...@@ -138,18 +145,20 @@ class Quantizer final : public Quanter {
// Only the inputs listed in \"permitted_names\" // Only the inputs listed in \"permitted_names\"
// requires quanitization before the bfloat16 operator. // requires quanitization before the bfloat16 operator.
// Other inputs, such as Filter and Bias are reordered in the kernel. // Other inputs, such as Filter and Bias are reordered in the kernel.
const std::vector<std::string> permitted_names = {"X", "Y", "Input", const std::vector<std::string> permitted_names = {
"ResidualData"}; "X", "Y", "Input", "ResidualData"};
return std::none_of( return std::none_of(
permitted_names.begin(), permitted_names.end(), permitted_names.begin(),
permitted_names.end(),
[&input_name](const std::string& name) { return name == input_name; }); [&input_name](const std::string& name) { return name == input_name; });
} }
std::string get_op_type() const override { return "quantize"; }; std::string get_op_type() const override { return "quantize"; };
std::string get_op_edge() const override { return "out"; }; std::string get_op_edge() const override { return "out"; };
void link_nodes(ir::Node* const physical_xput_node, ir::Node* const quant_op, void link_nodes(ir::Node* const physical_xput_node,
ir::Node* const quant_op,
ir::Node* const quant_x_node) override { ir::Node* const quant_x_node) override {
UnlinkNodes(physical_xput_node, op); UnlinkNodes(physical_xput_node, op);
IR_NODE_LINK_TO(physical_xput_node, quant_op); IR_NODE_LINK_TO(physical_xput_node, quant_op);
...@@ -166,16 +175,18 @@ class Quantizer final : public Quanter { ...@@ -166,16 +175,18 @@ class Quantizer final : public Quanter {
class DeQuantizer final : public Quanter { class DeQuantizer final : public Quanter {
public: public:
DeQuantizer(Graph* const graph, ir::Node* const op) DeQuantizer(Graph* const graph, ir::Node* const op)
: Quanter(*graph, op, op->Op()->Outputs()) { : Quanter(graph, op, op->Op()->Outputs()) {
auto outputs = op->outputs; auto outputs = op->outputs;
PADDLE_ENFORCE_GE( PADDLE_ENFORCE_GE(
outputs.size(), 1, outputs.size(),
1,
platform::errors::InvalidArgument( platform::errors::InvalidArgument(
"OP(%s)'s outputs(%d) must be equal or greater than 1.", op->Name(), "OP(%s)'s outputs(%d) must be equal or greater than 1.",
op->Name(),
outputs.size())); outputs.size()));
for (auto output : outputs) xputs_map[output->Name()] = output; for (auto output : outputs) xputs_map[output->Name()] = output;
}; }
protected: protected:
bool IsNotPermittedOpType() const override { bool IsNotPermittedOpType() const override {
...@@ -195,11 +206,12 @@ class DeQuantizer final : public Quanter { ...@@ -195,11 +206,12 @@ class DeQuantizer final : public Quanter {
auto op_name = op->Name(); auto op_name = op->Name();
if (block_list.count(op_name)) { if (block_list.count(op_name)) {
const auto& op_blocklist = block_list[op_name]; const auto& op_blocklist = block_list[op_name];
blocked_outputs.insert(blocked_outputs.begin(), op_blocklist.begin(), blocked_outputs.insert(
op_blocklist.end()); blocked_outputs.begin(), op_blocklist.begin(), op_blocklist.end());
} }
return std::any_of(blocked_outputs.begin(), blocked_outputs.end(), return std::any_of(blocked_outputs.begin(),
blocked_outputs.end(),
[&output_name](const std::string& name) { [&output_name](const std::string& name) {
return name == output_name; return name == output_name;
}); });
...@@ -208,7 +220,8 @@ class DeQuantizer final : public Quanter { ...@@ -208,7 +220,8 @@ class DeQuantizer final : public Quanter {
std::string get_op_type() const override { return "dequantize"; }; std::string get_op_type() const override { return "dequantize"; };
std::string get_op_edge() const override { return "in"; }; std::string get_op_edge() const override { return "in"; };
void link_nodes(ir::Node* const physical_xput_node, ir::Node* const quant_op, void link_nodes(ir::Node* const physical_xput_node,
ir::Node* const quant_op,
ir::Node* const quant_x_node) override { ir::Node* const quant_x_node) override {
UnlinkNodes(op, physical_xput_node); UnlinkNodes(op, physical_xput_node);
IR_NODE_LINK_TO(quant_op, physical_xput_node); IR_NODE_LINK_TO(quant_op, physical_xput_node);
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册