diff --git a/demo/demo3/data/cls4mrqa/dev.tsv b/demo/demo3/data/cls4mrqa/dev.tsv new file mode 100644 index 0000000000000000000000000000000000000000..8b565c40aa09ca2478763e6afa61156b93e5d4a1 --- /dev/null +++ b/demo/demo3/data/cls4mrqa/dev.tsv @@ -0,0 +1,62 @@ +label text_a +1 when was the last time the san antonio spurs missed the playoffshave only missed the playoffs four times since entering the NBA ; they have not missed the playoffs in the 20 seasons since Tim Duncan was drafted by the Spurs in 1997 . With their 50th win in the 2016 -- 17 season , the Spurs extended their record for most consecutive 50 - win seasons to 18 ( the Spurs did not +0 the creation of the federal reserve system was an attempt toReserve System ( also known as the Federal Reserve or simply the Fed ) is the central banking system of the United States of America . Over the years , events such as the Great Depression in the 1930s and the Great Recession during the 2000s have led to the expansion of the +2 group f / 64 was a major backlash against the earlier photographic movement off / 64 was formed , Edward Weston went to a meeting of the John Reed Club , which was founded to support Marxist artists and writers . These circumstances not only helped set up the situation in which a group +0 Bessarabia eventually became under the control of which country?city of Vilnius – its historical capital, which was under Polish control during the inter-war +0 Iran's inflation led to what in 1975-1976?the economy of Iran was flooded with foreign currency, which caused inflation. By 1974, the economy of Iran was experiencing double digit inflation, and despite many large projects to modernize the country, corruption was rampant and caused large +1 How many steam warships did Japan have in 1867?Yokosuka and Nagasaki. By the end of the Tokugawa shogunate in 1867, the Japanese navy of the shogun already possessed eight western-style steam warships around the flagship Kaiyō Maru, which were used against pro-imperial forces during the Boshin war, under the command +0 How many people were inside?f former NFL head coach Dan Reeves, suffered a broken back. DeCamillis was seen on a stretcher wearing a neck brace. A line of heavy thunderstorms was moving through the Dallas area at the time, he said, but no other damage to buildings was reported, said Mike Adams, a dispatcher for the Irving, Texas, fire department. Watch the roof collapse on players, coaches » Arnold Payne, a photographer for WFAA, was shooting the Cowboys' practice session when rain began falling "tremendously hard." "I noticed the walls started to waver ... and then I noticed that the lights that were hanging from the ceiling started to sway, and it wouldn't stop," Payne told CNN. Shortly after that, he said, "It was as if someone took a stick pin and hit a balloon." Watch Payne describe being inside when structure collpased » Payne said +0 Ishita Dutta is the sister of an actress who is typically cast in what genre of movies?he suspense thriller film "Drishyam" (2015) and the Hindi soap opera "Ek Ghar Banaunga", that aired on Star Plus. She is the younger sister of actress Tanushree Dutta. Dutta is the recipient of Femina Miss India Universe title in 2004. During the same year +3 when did the the civil war start and end/Th>

110,000 + killed in action / died of wounds 230,000 + accident / disease deaths 25,000 -- 30,000 died in Confederate prisons

365,000 + total dead +1 What has Pakistan told phone companies?Islamabad, Pakistan (CNN) -- Under heavy criticism for a telling cell phone carriers to ban certain words in text messages, the Pakistan Telecommunication Authority went into damage control mode Wednesday. PTA spokesman Mohammed Younis Wednesday denied the existence of the plan, which has met with derision from mobile phone users in the country. "If at all we finally decide to +0 What did Bush say the proposal was to a proposal he vetoed before?(CNN) -- President Bush vetoed an expansion of the federally funded, state-run health insurance program for poor children for a second time Wednesday, telling Congress the bill "moves our country's health care system in the wrong direction." In his veto message, President Bush calls on Congress to extend funding for the current program. "Because the Congress has chosen to send me an essentially identical bill that has the same problems as the flawed bill I previously vetoed, I must veto this legislation, too," he said in a statement released by the White House. The bill would +0 Where did the football team that Bob Simmons coached from 1995 to 2000 play their home games?Cowboys football team [SEP] The 1998 Oklahoma State Cowboys football team represented the Oklahoma State University during the 1998 NCAA Division I-A football season. They participated as members of the Big 12 Conference in the South Division. They were coached by head coach Bob Simmons. [PAR] [TLE] Bob Simmons (American football coach) [SEP] Bob +2 What anniversary was recently celebrated in Iran?us to move our policy in a new direction," Obama said. "So there are going to be a set of objectives that we have in these conversations, but I think that there's the possibility at least of a relationship of mutual respect and progress." The United States and Iran have not had diplomatic relations since 1979. During that year, the Shah of Iran was forced to flee the country and the Ayatollah Khomeini took power. Later that year, Iranian students took over and seized hostages at the U.S. Embassy. Relations have been cut since then. U.S. President George W. Bush labeled Iran as a member of the "axis of evil" after the Sept. 11, 2001 attacks. Iran celebrated the 30th anniversary of the revolution Tuesday with crowds chanting "Death to America." Watch the parade in Iran » Tensions have rippled over issues such as Iran's nuclear program, Israel, and Iraq, and have been aggravated since the outspoken Ahmadinejad came to power in 2005. Western +1 Which Italian composer did George Balanchine add in 1976?[PAR] [TLE] Arcangelo Corelli [SEP] Arcangelo Corelli ( ; 17 February 1653 – 8 January 1713) was an Italian violinist and composer of the Baroque era. His music +0 Will the playstation 4 be announced?a new system sometime in the next five years, of course. Sony continued to sell the PlayStation 2 system and games years after the PlayStation 3 debuted in stores. For Sony's next console, the company will not deploy a streaming delivery system like OnLive, or fully cut out disc retailers like Best Buy and GameStop, Hirai said. While Sony has increased the number of games and other media available for download or streaming through its networks, most people cannot be expected to frequently download several gigabytes worth of data, which can be a time-consuming process, he said. Sony Computer Entertainment president Andrew House said earlier that Sony is not planning to discuss a new console, the website ComputerAndVideogames.com reported on Monday. +1 How many children were the Americans trying to kidnap out of Haiti?Port-au-Prince, Haiti (CNN) -- A Haitian attorney representing 10 Americans charged with kidnapping for trying to take 33 children out of Haiti told CNN Sunday he has resigned. Edwin Coq said he had quit as a lawyer for the Americans. It wasn't immediately clear who would replace him. "I know that they have been looking at other lawyers," said Phyllis Allison, mother of one of those detained, Jim Allen. "They don't know what to do." The 10 missionaries, including group leader Laura Silsby, were charged Thursday with kidnapping children and criminal association. Coq had said that court hearings would be held Monday +0 who kills tree gelbman in happy death dayTree convinces Carter of her predicament by showing that she holds foreknowledge of the day 's events . Tree admits to Carter she does n't like who +0 What will no person be denied the enjoyment of in Georgia based on their religious principles?amended as follows: "Article IV. Section 10. No person within this state shall, upon any pretense, be deprived of the inestimable privilege of worshipping God in any +0 who came up with the idea of footballpass . The popularity of college football grew as it became the dominant version of the sport in the United States for the first half of the 20th century . Bowl games , a college football tradition , attracted a national audience for college +0 what is the name of the female smurfbefore the smurflings created Sassette , Smurfette was the only female smurf in the Smurf Village . +3 Who contributed to the American studies programs at Yale and University of Wyoming?struggle. Norman Holmes Pearson, who worked for the Office of Strategic Studies in London during World War II, returned to Yale and headed the new American studies program, in which scholarship quickly became an instrument of promoting +0 What is the group's former name that now has an office with the Chief Actuary besides the Social Security Administration?Office of the Chief Actuary [SEP] The Office of the Chief Actuary is a government agency that has responsibility for actuarial estimates regarding social welfare programs. In Canada, the Office of the Chief Actuary works with the Canada Pension Plan and the Old Age Security Program. In the United States, both the Social Security Administration and the Centers for Medicare and Medicaid Services have an Office of the Chief Actuary that deals with Social Security and Medicare, respectively. A similar agency in the United Kingdom is called the Government Actuary's Department +0 The actor that playes Han Solo in the "Star Wars" film series stars with Blake Lively and Michiel Huisman in a film directed by who?about a woman who stops aging after an accident at the age of 29. Mills Goodloe and Salvador Paskowitz. The film stars Blake Lively, Michiel Huisman, Kathy Baker, Amanda Crew, Harrison Ford, and Ellen Burstyn. The film was theatrically released on April 24, 2015 by Lionsgate. [PAR] [TLE] Harrison Ford [SEP] Harrison Ford +0 What historically black university's men's basketball coach was formerly head coach at Virginia Tech?well as an 1890 Historically Black Land-Grant University. The University is a member-school of the Thurgood Marshall College Fund. He was also the head coach at Virginia Tech, Tennessee +0 what year did syracuse win the ncaa tournament. Their combined record is 67 -- 39 . +1 where do i get chips at a casino

Money is exchanged for tokens in a casino at the casino cage , at the gaming tables , or at a cashier station . The tokens are +0 when was the winter fuel payment first introducedheating over the winter months . +0 Trophy hunting can include areas which would likely be unsuitable for what other types of ecotourism?study states that less than 3% of a trophy hunters' expenditures reach the local level, meaning that the economic incentive and benefit is "minimal, particularly when we consider the vast areas of +1 In simple language, what are the interconnections in an embedding matrix?Since it was quite easy to stack interconnections (wires) inside the embedding matrix, the approach allowed designers to forget completely about the routing of wires (usually a time-consuming operation of PCB design): Anywhere the designer needs a connection, the machine will draw a wire in straight line from one location/pin +2 rho has been to the most all star games in baseballn4

  • Stan Musial 24
  • +0 In 1169, Ireland was invaded by which people?High King to ensure the terms of the Treaty of Windsor led Henry II, as King of England, to rule as effective monarch under the title of Lord of Ireland. This title was granted to his younger son but when Henry's heir unexpectedly died the title of King of England and Lord of Ireland became entwined in one +1 What year did a biracial Populist fusion gain the Governors office?to the legislature and governor's office, but the Populists attracted voters displeased with them. In 1896 a biracial, Populist-Republican Fusionist coalition gained the governor's office. The Democrats regained control of the legislature +1 nearest metro station to majnu ka tilla delhiRing Road of Delhi . It is at a walkable distance from ISBT Kashmere Gate . It is approachable through the Kashmeri Gate station of the Delhi Metro , lies on both the Red ( Dilshad Garden - Rithala ) and Yellow Lines ( Samaypur +3 where is california located in the region of the united states

    California is a U.S. state in the Pacific Region of the United States . With 39.5 million residents , California is the most populous state in the United States and the third largest by area . The +1 when did the baptist church start in americacoworker for religious freedom , are variously credited as founding the earliest Baptist church in North America . In 1639 , Williams established a Baptist church in Providence , Rhode Island , and Clarke began a Baptist church in +0 where was the first capital of the united states locatedpassed to pave the way for a permanent capital . The decision to locate the capital was contentious , but Alexander Hamilton helped broker a compromise in which the federal government would take on war debt incurred during the American Revolutionary War , in exchange for support from northern states for locating the capital along the Potomac +0 What will new regulations will reduce?products off of the consumer market," said Michael Fry, director of conservation advocacy for the American Bird Conservancy. "By putting these restrictions in place, they are allowing a compromise to be made between themselves and organizations who have been working on this problem for a long time." The EPA's new measures, which were handed down Thursday, require that rat poisons be kept in bait stations above ground and in containers that meet agency standards. Loose bait, such as pellets, and the four most hazardous types of pesticides, known as "second-generation anticoagulants," will no longer be sold for personal use. Under the new restrictions, only farmers, livestock owners and certified rodent control employees will be allowed to purchase rat poison in bulk. Bags larger than 8 pounds will no longer be sold at hardware and home-improvement stores. Children who come into contact +0 who played lois lane in the man of steelmixture of toughness and vulnerability , but Peter Bradshaw thought that the character was `` sketchily conceived '' and criticized her lack of chemistry with Cavill . Even so , the film earned over $660 million to become one of her biggest box +0 What year did the writer of the 1968 novel "The Iron Man" become Poet Laurete?Giant is a 1999 American animated science-fiction comedy-drama action film using both traditional animation and computer animation, produced by and directed by Brad Bird in his directorial debut. It is based on the 1968 novel "The Iron Man" by Ted Hughes (which was published in the United States as "The Iron Giant") and was scripted by Tim McCanlies from a story treatment by Bird. The film stars the voices of Eli Marienthal, +2 The conquest of Nice was an effort by Suleiman and what French king?allies. A month prior to the siege of Nice, France supported the Ottomans with an artillery unit during the 1543 Ottoman conquest of Esztergom in northern Hungary. After further advances by the Turks, the Habsburg ruler Ferdinand officially recognized Ottoman ascendancy in Hungary in +0 when was the vaccine receivedfor swine flu, also known as 2009 H1N1, using reverse genetics, he said. "Suitable viruses will hopefully be sent to manufacturers by end of next week," Skinner wrote. Once that happens, vaccine makers will tweak the virus and have "pilot lots" of vaccine ready to be tested by mid- to late June. Several thousand cases have been reported +1 What is the nationality of the actor who costarred with Matt LeBlanc in "All the Queen's Men"?n approximate -99.92% return. [PAR] [TLE] Eddie Izzard [SEP] Edward John "Eddie" Izzard ( ; born 7 February 1962) is an English stand-up comedian, actor, writer and political activist. His comedic style takes the form of rambling, whimsical monologue, and self-referential pantomime. He +0 What sickened thousands of children?executives detained, a local official said, according to Xinhua, Initial tests showed more than 1,300 children in the Hunan province town of Wenping have excessive lead in their blood from the Wugang Manganese Smelting Plant. A second round of testing has been ordered to confirm the results. The plant opened in May 2008 without gaining the approval of the local environment protection bureau, said Huang Wenbin, a deputy environment chief in Wugang City, Xinhua reported. The plant was within 500 meters (about a quarter mile) of three schools. The +0 What percentage of the population are the Kpelle?are descendants of African American and West Indian, mostly Barbadian settlers, make up 2.5%. Congo people, descendants of repatriated Congo and Afro-Caribbean +1 Amount of people left homeless?86 dead, the state news agency said. About 30 people are missing, the official news agency Agencia Brasil said, citing civil defense officials. Earlier reports had indicated as many as 100 people were dead. In addition, more than 54,000 residents have been left homeless, and another 1.5 million have been affected by the heavy rains, the state news agency reported. Brazilian President Luiz Inacio Lula da Silva announced he will release nearly 700 million reais ($350 million) +2 What other countries were in disagreement with the United Nations decision on Burma ?that strongly called upon the government of Myanmar to end its systematic violations of human rights. In January 2007, Russia and China vetoed a +0 Besides Barcelona and Real Madrid, what other team has remained in the Primera Division?first football club to win six out of six competitions in a single year, completing the sextuple in also winning the Spanish Super Cup, UEFA Super Cup and FIFA Club World Cup. In 2011, the club became +0 William Frederick Truax, is a former professional American football tight end in the National Football League (NFL) from 1964 to 1973 for the Los Angeles Rams and the Dallas Cowboys, following the 1970 NFL season, Truax was traded by the Rams to the Cowboys for wide receiver Lance Rentzel, a former American football flanker, in which organization?in New Orleans and college football at Louisiana State University and was drafted in the second round of the 1964 NFL draft. Following the 1970 NFL season, Truax was traded by the Rams to the Cowboys for wide receiver Lance Rentzel. He was part of the Cowboys' Super Bowl VI championship team in 1971. He played +3 What year did Chopin learn that the uprising in Warsaw was crushed?enlist. Chopin, now alone in Vienna, was nostalgic for his homeland, and wrote to a friend, "I curse the moment of my departure." When in September 1831 he learned, while travelling from Vienna to Paris, that the uprising had been crushed, he expressed his anguish in the pages of his private journal: "Oh +1 where do they make money in washington dc; all coinage is produced by the United States Mint . With production facilities in Washington , DC , and Fort Worth , Texas , the Bureau of Engraving and Printing is the largest producer of government security documents in the United States .

    +0 What did a researcher compare this process to?which makes it one of the highest rates of maternal mortality in the Americas. In wealthy developed nations, only nine women die for every 100,000 births. The five main causes of pregnancy-related deaths in Peru are hemorrhage, pre-eclampsia, infection, complications following abortion and obstructed birth, according to Peru's Ministry of Health figures. Amnesty's Peru researcher Nuria Garcia said, in a written statement: "The rates of maternal mortality in Peru are scandalous. The fact that so many women are dying from preventable causes is a human rights violation. "The Peruvian state is simply ignoring +0 How many containers can Longtan Containers Port Area handle?Port of Nanjing is the largest inland port in China, with annual cargo tonnage reached 191,970,000 t in 2012. The port area is 98 kilometres (61 mi) in length and has 64 berths +0 The 2011 New York City Marathon was sponsored by which Dutch multinational banking corporation?are retail banking, direct banking, commercial banking, investment banking, asset management, and insurance services. ING is an abbreviation for "Internationale Nederlanden Groep " (English: International Netherlands Group). [PAR] [TLE] 2011 New York City Marathon [SEP] The 42nd New York City Marathon took +0 What is human flourishing?it does not involve believing that human nature is purely good or that all people can live up to the Humanist ideals without help. If anything, there is recognition that living up to one's potential is hard +0 What was the result of Dida appealto play in next month's Champions League match at Shakhtar Donetsk after partially winning his appeal to UEFA against a two-match ban. Dida has had one game of his two-match ban suspended for a year following an appeal to UEFA. Brazilian Dida was also fined 60,000 Swiss francs by European football's ruling body following an incident involving a supporter during the Champions clash against Celtic in Scotland on October 3. The 34-year-old Brazilian was initially banned for two games for his theatrics following a Celtic fan's encroachment onto the pitch during the 2-1 defeat at Celtic +1 What is more plentiful in capital projects?generates economic distortion in the public sector by diverting public investment into capital projects where bribes and kickbacks are more plentiful. Officials may increase the technical complexity of public sector projects to conceal or +0 where were band greeted with cheers?the United States for a show in Stamford, Connecticut, on Tuesday, after they have "a few days off to recuperate," Robinson said. The trio was the opening act for Nelson until they were loudly booed in Toronto, a day after the actor-musician's bizarre interview with a CBC radio host. Ironically, the comments that offended Canadians included Thornton's assessment that they were "very reserved" and "it doesn't matter what you say to them." "It's mashed potatoes with no gravy," Thornton told CBC host Jian Ghomeshi. "We tend to play places where people throw things at each other and here they just sort of sit there," he said. Watch Thornton's interview » The audience at Thursday night's show in Toronto loudly booed the Boxmasters, with some shouts of "Here comes the gravy!" The Toronto Star newspaper reported. Thornton's remarks about +0 What do Mexicans call Mexico City?the Federal District in Spanish: D.F., which is read "De-Efe"). They are formally called capitalinos (in reference to the city being the capital of the country), but "[p]erhaps because capitalino is the +0 where does lock stock and barrel come fromindividual components one at a time . One craftsman made the `` lock '' which would have been a `` match lock '' , `` wheel lock '' , `` flint lock '' etc . +1 who has the power to establish a prison system

    The Federal Bureau of Prisons ( BOP ) is a United States federal law enforcement agency . A subdivision of +0 what are south americas only 2 landlocked countriessuch countries , including five partially recognised states . diff --git a/demo/demo3/run.py b/demo/demo3/run.py index b7d8ea13002952fe42a4ec2bf51a8eebecf9ad13..09e3fd83b795fe07820855ee69aee74548344a5f 100644 --- a/demo/demo3/run.py +++ b/demo/demo3/run.py @@ -6,11 +6,12 @@ if __name__ == '__main__': max_seqlen = 512 batch_size = 4 - num_epochs = 20 + num_epochs = 2 lr = 1e-3 vocab_path = './pretrain/ernie/vocab.txt' train_file = './data/cls4mrqa/train.tsv' + predict_file = './data/cls4mrqa/dev.tsv' config = json.load(open('./pretrain/ernie/ernie_config.json')) # ernie = palm.backbone.ERNIE(...) @@ -23,15 +24,25 @@ if __name__ == '__main__': # 创建该分类任务的reader,由诸多参数控制数据集读入格式、文件数量、预处理规则等 cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen) + predict_cls_reader = palm.reader.ClassifyReader(vocab_path, max_seqlen, phase='predict') print(cls_reader.outputs_attr) + print(predict_cls_reader.outputs_attr) # 不同的backbone会对任务reader有不同的特征要求,例如对于分类任务,基本的输入feature为token_ids和label_ids,但是对于BERT,还要求从输入中额外提取position、segment、input_mask等特征,因此经过register后,reader会自动补充backbone所要求的字段 cls_reader.register_with(ernie) print(cls_reader.outputs_attr) + print(predict_cls_reader.outputs_attr) + + print("preparing data...") + print(cls_reader.num_examples) + cls_reader.load_data(train_file, batch_size, num_epochs=num_epochs) + print(cls_reader.num_examples) + print('done!') + # 创建任务头(task head),如分类、匹配、机器阅读理解等。每个任务头有跟该任务相关的必选/可选参数。注意,任务头与reader是解耦合的,只要任务头依赖的数据集侧的字段能被reader提供,那么就是合法的 cls_head = palm.head.Classify(4, 1024, 0.1) # 根据reader和任务头来创建一个训练器trainer,trainer代表了一个训练任务,内部维护着训练进程、和任务的关键信息,并完成合法性校验,该任务的模型保存、载入等相关规则控制 - trainer = palm.Trainer('senti_cls', cls_reader, cls_head) + trainer = palm.Trainer('senti_cls') # match4mrqa.reuse_head_with(mrc4mrqa) @@ -39,19 +50,16 @@ if __name__ == '__main__': # output_vars = ernie.build(data_vars) # cls_head.build({'backbone': output_vars, 'reader': data_vars}) - loss_var = trainer.build_forward(ernie) + loss_var = trainer.build_forward(ernie, cls_head) # controller.build_forward() # Error! a head/backbone can be only build once! Try NOT to call build_forward method for any Trainer! - print(trainer.num_examples) - iterator_fn = trainer.load_data(train_file, 'csv', num_epochs=num_epochs, batch_size=batch_size) - print(trainer.num_examples) - - n_steps = trainer.num_examples * num_epochs // batch_size - warmup_steps = int(0.1 * n_steps) - print(warmup_steps) - sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) + # n_steps = cls_reader.num_examples * num_epochs // batch_size + # warmup_steps = int(0.1 * n_steps) + # print(warmup_steps) + # sched = palm.lr_sched.TriangularSchedualer(warmup_steps, n_steps) + sched = None adam = palm.optimizer.Adam(loss_var, lr, sched) @@ -60,17 +68,22 @@ if __name__ == '__main__': trainer.random_init_params() trainer.load_pretrain('pretrain/ernie/params') - # print(trainer.train_one_step(next(iterator_fn()))) - # trainer.train_one_epoch() - # for save predict model. - pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred') - cls_pred_head = palm.head.Classify(4, 1024, phase='pred') - trainer.build_predict_head(cls_pred_head, pred_ernie) - # trainer.train(iterator_fn, print_steps=1, save_steps=5, save_path='outputs', save_type='ckpt,predict') - trainer.train(iterator_fn, print_steps=1) + trainer.fit_reader(cls_reader) + trainer.train(print_steps=1) # trainer.save() + print('prepare to predict...') + pred_ernie = palm.backbone.ERNIE.from_config(config, phase='pred') + cls_pred_head = palm.head.Classify(4, 1024, phase='pred') + trainer.build_predict_forward(pred_ernie, cls_pred_head) + + predict_cls_reader.load_data(predict_file, 8) + print(predict_cls_reader.num_examples) + predict_cls_reader.register_with(pred_ernie) + trainer.fit_reader(predict_cls_reader, phase='predict') + print('predicting..') + trainer.predict(print_steps=20) diff --git a/paddlepalm/.trainer.py.swp b/paddlepalm/.trainer.py.swp deleted file mode 100644 index 11a7749e8a5e0e2e3bc785053720f83bc0dfef15..0000000000000000000000000000000000000000 Binary files a/paddlepalm/.trainer.py.swp and /dev/null differ diff --git a/paddlepalm/backbone/base_backbone.py b/paddlepalm/backbone/base_backbone.py index aab1ddea30fb496e5161a05547bb5fed6b034078..d821470cd384fe772ff3c6af67b4c8aea469a5a7 100644 --- a/paddlepalm/backbone/base_backbone.py +++ b/paddlepalm/backbone/base_backbone.py @@ -58,52 +58,3 @@ class BaseBackbone(object): """ raise NotImplementedError() - - - -class task_paradigm(object): - - def __init__(self, config, phase, backbone_config): - """ - config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 - phase: str类型。运行阶段,目前支持train和predict - """ - - @property - def inputs_attrs(self): - """描述task_layer需要从reader, backbone等输入对象集合所读取到的输入对象的属性,第一级key为对象集和的名字,如backbone,reader等(后续会支持更灵活的输入),第二级key为对象集和中各对象的属性,包括对象的名字,shape和dtype。当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - Return: - dict类型。对各个对象集及其输入对象的属性描述。""" - raise NotImplementedError() - - @property - def outputs_attr(self): - """描述task输出对象的属性,包括对象的名字,shape和dtype。输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 - 当某个对象为标量数据类型(如str, int, float等)时,shape设置为空列表[],当某个对象的某个维度长度可变时,shape中的相应维度设置为-1。 - Return: - dict类型。对各个输入对象的属性描述。注意,训练阶段必须包含名为loss的输出对象。 - """ - - raise NotImplementedError() - - @property - def epoch_inputs_attrs(self): - return {} - - def build(self, inputs, scope_name=""): - """建立task_layer的计算图。将符合inputs_attrs描述的来自各个对象集的静态图Variables映射成符合outputs_attr描述的静态图Variable输出。 - Args: - inputs: dict类型。字典中包含inputs_attrs中的对象名到计算图Variable的映射,inputs中至少会包含inputs_attr中定义的对象 - Return: - 需要输出的计算图变量,输出对象会被加入到fetch_list中,从而在每个训练/推理step时得到runtime的计算结果,该计算结果会被传入postprocess方法中供用户处理。 - - """ - raise NotImplementedError() - - def postprocess(self, rt_outputs): - """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" - pass - - def epoch_postprocess(self, post_inputs): - pass - diff --git a/paddlepalm/backbone/ernie.py b/paddlepalm/backbone/ernie.py index c52baed1311488d2933c79f0e90f7f8e2fea69e6..3300b11b4856025e9bbc416c033ede75b6fe1bf3 100644 --- a/paddlepalm/backbone/ernie.py +++ b/paddlepalm/backbone/ernie.py @@ -114,8 +114,6 @@ class ERNIE(BaseBackbone): input_mask = inputs['input_mask'] task_ids = inputs['task_ids'] - fluid.layers.Print(src_ids) - # padding id in vocabulary must be set to 0 emb_out = fluid.embedding( input=src_ids, diff --git a/paddlepalm/head/base_head.py b/paddlepalm/head/base_head.py index 09cce602e05746b4e75f43d1863d222b762fc496..7d24798a0afae7c813eb7366d64a298b8608bf0f 100644 --- a/paddlepalm/head/base_head.py +++ b/paddlepalm/head/base_head.py @@ -13,16 +13,20 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os +import json class BaseHead(object): - def __init__(self, config, phase, backbone_config): + def __init__(self, phase='train'): """ config: dict类型。描述了 任务实例(task instance)+多任务配置文件 中定义超参数 phase: str类型。运行阶段,目前支持train和predict """ self._stop_gradient = {} + self._phase = phase self._prog = None + self._results_buffer = [] @property def inputs_attrs(self): @@ -67,10 +71,31 @@ class BaseHead(object): raise NotImplementedError() - def postprocess(self, rt_outputs): + def batch_postprocess(self, rt_outputs): """每个训练或推理step后针对当前batch的task_layer的runtime计算结果进行相关后处理。注意,rt_outputs除了包含build方法,还自动包含了loss的计算结果。""" - pass + if isinstance(rt_outputs, dict): + keys = rt_outputs.keys() + vals = [rt_outputs[k] for k in keys] + lens = [len(v) for v in vals] + if len(set(lens)) == 1: + results = [dict(zip(*[keys, i])) for i in zip(*vals)] + self._results_buffer.extend(results) + return results + else: + print('WARNING: irregular output results. visualize failed.') + self._results_buffer.append(rt_outputs) + return None + + def epoch_postprocess(self, post_inputs, output_dir=None): + if output_dir is not None: + for i in self._results_buffer: + print(i) + else: + if not os.path.exists(output_dir): + os.makedirs(output_dir) + with open(os.path.join(output_dir, self._phase), 'w') as writer: + for i in self._results_buffer: + writer.write(json.dumps(i)+'\n') + - def epoch_postprocess(self, post_inputs): - pass diff --git a/paddlepalm/head/cls.py b/paddlepalm/head/cls.py index 3f2afe9c66756e3ca57870a79539e70b41fd6fd6..133e01338e9eed34debe1882a91780a9961dbfac 100644 --- a/paddlepalm/head/cls.py +++ b/paddlepalm/head/cls.py @@ -87,14 +87,16 @@ class Classify(BaseHead): self._preds.extend(preds.tolist()) return preds - def epoch_postprocess(self, post_inputs): + def epoch_postprocess(self, post_inputs, output_dir=None): # there is no post_inputs needed and not declared in epoch_inputs_attrs, hence no elements exist in post_inputs if not self._is_training: - if self._pred_output_path is None: - raise ValueError('argument pred_output_path not found in config. Please add it into config dict/file.') - with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer: + if output_dir is None: for p in self._preds: - writer.write(str(p)+'\n') - print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json')) + print(p) + else: + with open(os.path.join(self._pred_output_path, 'predictions.json'), 'w') as writer: + for p in self._preds: + writer.write(str(p)+'\n') + print('Predictions saved at '+os.path.join(self._pred_output_path, 'predictions.json')) diff --git a/paddlepalm/mtl_controller.py b/paddlepalm/mtl_controller.py index 30fac4dd7802d091e1ccb37f92f6477de28d2144..6cca513af1a0fff715f7ae4f160f02ed691a8935 100755 --- a/paddlepalm/mtl_controller.py +++ b/paddlepalm/mtl_controller.py @@ -31,7 +31,7 @@ from paddlepalm.utils.saver import init_pretraining_params, init_checkpoint from paddlepalm.utils.config_helper import PDConfig from paddlepalm.utils.print_helper import print_dict from paddlepalm.utils.reader_helper import create_net_inputs, create_iterator_fn, create_joint_iterator_fn, merge_input_attrs -from paddlepalm.distribute import data_feeder +from paddlepalm.distribute import data_feeder, decode_fake from default_settings import * from task_instance import TaskInstance, check_instances @@ -186,13 +186,20 @@ def _fit_attr(conf, fit_attr, strict=False): def create_feed_batch_process_fn(net_inputs): - def feed_batch_process_fn(data): + + def feed_batch_process_fn(data, id=-1): + # temps = {} + # for i in range(len(net_inputs)): temp = {} - for q, var in net_inputs.items(): + inputs = net_inputs[id] if id != -1 else net_inputs + + for q, var in inputs.items(): if isinstance(var, str) or isinstance(var, unicode): temp[var] = data[q] else: temp[var.name] = data[q] + # temps[i] = temp + return temp return feed_batch_process_fn @@ -221,6 +228,7 @@ class Controller(object): exe, dev_count = _init_env(use_gpu=mtl_conf.get('use_gpu', True)) self.exe = exe self.dev_count = dev_count + self.batch_size = mtl_conf.get('batch_size') print_dict(mtl_conf, title='global configuration') @@ -343,6 +351,7 @@ class Controller(object): dev_count = self.dev_count num_instances = len(instances) mrs = self.mrs + branch = fluid.data(name="branch",shape=[1],dtype='int64') # set first_target/main task instance main_inst = None @@ -362,35 +371,51 @@ class Controller(object): # create reader, task # then check i/o across reader, backbone and task_layer - task_attrs = [] + + # check_fns = {} + task_attrs = {} pred_task_attrs = [] - for inst in instances: - train_reader = inst.Reader(inst.config, phase='train') - inst.reader['train'] = train_reader - train_parad = inst.Paradigm(inst.config, phase='train', backbone_config=bb_conf) - inst.task_layer['train'] = train_parad - task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], inst.name) - task_attrs.append(task_attr_from_reader) + joint_input_names = {} + joint_shape_and_dtypes = {} + name_to_position = {} + for i in range(num_instances): + # def check_tasks(): + # i = s + # def checkeach(): + + train_reader = instances[i].Reader(instances[i].config, phase='train') + instances[i].reader['train'] = train_reader + train_parad = instances[i].Paradigm(instances[i].config, phase='train', backbone_config=bb_conf) + instances[i].task_layer['train'] = train_parad + task_attr_from_reader = _encode_inputs(train_parad.inputs_attrs['reader'], instances[i].name) + task_attrs[i] = task_attr_from_reader _check_io(train_backbone.inputs_attr, train_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.train') _check_io(train_parad.inputs_attrs['reader'], train_reader.outputs_attr, in_name='task_paradigm.train.reader', out_name='reader.train') _check_io(train_parad.inputs_attrs['backbone'], train_backbone.outputs_attr, in_name='task_paradigm.train.backbone', out_name=bb_name+'_backbone') - - if inst.is_target: - if 'pred_file' not in inst.config: - inst.config['pred_file'] = '' - pred_reader = inst.Reader(inst.config, phase='pred') - pred_parad = inst.Paradigm(inst.config, phase='pred', backbone_config=bb_conf) - inst.task_layer['pred'] = pred_parad - task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], inst.name) + # merge reader input attrs from backbone and task_instances + # pred_joint_input_names = [] + # pred_joint_shape_and_dtypes = [] + if instances[i].is_target: + if 'pred_file' not in instances[i].config: + instances[i].config['pred_file'] = '' + pred_reader = instances[i].Reader(instances[i].config, phase='pred') + pred_parad = instances[i].Paradigm(instances[i].config, phase='pred', backbone_config=bb_conf) + instances[i].task_layer['pred'] = pred_parad + task_attr_from_reader = _encode_inputs(pred_parad.inputs_attrs['reader'], instances[i].name) pred_task_attrs.append(task_attr_from_reader) _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - - # merge reader input attrs from backbone and task_instances - joint_input_names, joint_shape_and_dtypes, name_to_position = merge_input_attrs(train_backbone.inputs_attr, task_attrs) + # pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + # return joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i], pred_joint_input_names, pred_joint_shape_and_dtypes + # return checkeach + # check_fns[i] = check_tasks() + joint_input_names[i], joint_shape_and_dtypes[i], name_to_position[i] = merge_input_attrs(train_backbone.inputs_attr, task_attrs[i]) + pred_joint_input_names, pred_joint_shape_and_dtypes, _ = merge_input_attrs(pred_backbone.inputs_attr, pred_task_attrs, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + + # shapes: [task_id, shapes_of_backbone, shapes_of_inst1, ..., shapes_of_instN] if DEBUG: @@ -400,16 +425,18 @@ class Controller(object): print('joint input shape and dtypes:') print(joint_shape_and_dtypes) - # load data - for inst in instances: - print(inst.name+": preparing data...", end='') - inst.reader['train'].load_data() + # load data + data_fns={} + for i in range(num_instances): + print(instances[i].name+": preparing data...", end='') + instances[i].reader['train'].load_data() print('ok!') # merge dataset iterators and create net input vars iterators = [] prefixes = [] mrs = [] + for inst in instances: iterators.append(inst.reader['train'].iterator()) prefixes.append(inst.name) @@ -418,65 +445,65 @@ class Controller(object): joint_iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') self._joint_iterator_fn = joint_iterator_fn - input_attrs = [[i, j, k] for i, (j,k) in zip(joint_input_names, joint_shape_and_dtypes)] - pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] - # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) - net_inputs = create_net_inputs(input_attrs, async=False) - self._net_inputs = net_inputs + input_attrs = {} + net_inputs = {} + bb_output_vars = {} + bb_output_fns = {} - # build backbone and task layers - train_prog = fluid.default_main_program() - train_init_prog = fluid.default_startup_program() - bb_output_vars = train_backbone.build(net_inputs, scope_name='__paddlepalm_') - assert sorted(bb_output_vars.keys()) == sorted(train_backbone.outputs_attr.keys()) - + # prepare predict vars for saving inference model + pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_joint_input_names, pred_joint_shape_and_dtypes)] pred_prog = fluid.Program() pred_init_prog = fluid.Program() + self._pred_prog = pred_prog with fluid.program_guard(main_program = pred_prog, startup_program = pred_init_prog): - pred_net_inputs = create_net_inputs(pred_input_attrs) - pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') - - fluid.framework.switch_main_program(train_prog) - fluid.framework.switch_startup_program(train_init_prog) + pred_net_inputs = create_net_inputs(pred_input_attrs) + pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') + task_inputs = {} task_output_vars = {} - for inst in instances: - task_inputs = {'backbone': bb_output_vars} - task_inputs_from_reader = _decode_inputs(net_inputs, inst.name) - task_inputs['reader'] = task_inputs_from_reader - - scope = inst.task_reuse_scope + '/' + task_fns = {} + + def get_loss(i): + input_attrs[i] = [[m, j, k] for m, (j,k) in zip(joint_input_names[i], joint_shape_and_dtypes[i])] + net_inputs[i] = create_net_inputs(input_attrs[i], async=False) + # net_inputs = create_net_inputs(input_attrs, async=True, iterator_fn=joint_iterator_fn, dev_count=dev_count, n_prefetch=3) + bb_output_vars[i] = train_backbone.build(net_inputs[i], scope_name='__paddlepalm_') + assert sorted(bb_output_vars[i].keys()) == sorted(train_backbone.outputs_attr.keys()) + + # build backbone and task layers + task_inputs[i] = {'backbone': bb_output_vars[i]} + task_inputs_from_reader = _decode_inputs(net_inputs[i], instances[i].name) + task_inputs[i]['reader'] = task_inputs_from_reader + + scope = instances[i].task_reuse_scope + '/' with fluid.unique_name.guard(scope): - output_vars = inst.build_task_layer(task_inputs, phase='train', scope=scope) - output_vars = {inst.name+'/'+key: val for key, val in output_vars.items()} - old = len(task_output_vars) # for debug - task_output_vars.update(output_vars) - assert len(task_output_vars) - old == len(output_vars) # for debug - - # prepare predict vars for saving inference model - if inst.is_target: + output_vars = instances[i].build_task_layer(task_inputs[i], phase='train', scope=scope) + output_vars = {instances[i].name+'/'+key: val for key, val in output_vars.items()} + loss_var = output_vars[instances[i].name+'/loss'] + task_output_vars[i] = output_vars + if instances[i].is_target: with fluid.program_guard(pred_prog, pred_init_prog): - cur_inputs = _decode_inputs(pred_net_inputs, inst.name) - inst.pred_input = cur_inputs + cur_inputs = _decode_inputs(pred_net_inputs, instances[i].name) + instances[i].pred_input = cur_inputs pred_task_inputs = {'backbone': pred_bb_output_vars, 'reader': cur_inputs} - scope = inst.task_reuse_scope + '/' + scope = instances[i].task_reuse_scope + '/' with fluid.unique_name.guard(scope): - inst.build_task_layer(pred_task_inputs, phase='pred', scope=scope) - - - bb_fetches = {k: v.name for k,v in bb_output_vars.items()} - task_fetches = {k: v.name for k,v in task_output_vars.items()} - fetches = task_fetches - fetches['__task_id'] = net_inputs['__task_id'].name - - # compute loss - task_id_var = net_inputs['__task_id'] - task_id_vec = fluid.one_hot(task_id_var, num_instances) - losses = fluid.layers.concat([task_output_vars[inst.name+'/loss'] for inst in instances], axis=0) - loss = layers.reduce_sum(task_id_vec * losses) - + instances[i].build_task_layer(pred_task_inputs, phase='pred', scope=scope) + return loss_var + + for i in range(num_instances): + def task_loss(): + task_id = i + return lambda: get_loss(task_id) + task_fns[i] = task_loss() + + loss = layers.switch_case( + branch_index=branch, + branch_fns=task_fns + ) + self._switched_loss = loss.name main_reader = main_inst.reader['train'] num_examples = main_reader.num_examples @@ -514,9 +541,9 @@ class Controller(object): self.saver_program = fluid.default_main_program() self.main_inst = main_inst - self.fetches = fetches self.has_init_train = True self.has_init_pred = True + self._net_inputs = net_inputs self.exe.run(fluid.default_startup_program()) print("\nRandomly initialize parameters...\n") @@ -569,8 +596,6 @@ class Controller(object): backbone = self.train_backbone train_program = self.train_program saver_program = self.saver_program - fetches = self.fetches - finish = [] for inst in instances: if inst.is_target: @@ -588,46 +613,45 @@ class Controller(object): return False return True - # do training - fetch_names, fetch_list = zip(*fetches.items()) + # do training + fetch_names = {} + fetch_list = [] main_step = 0 # only count for main task global_step = 0 # count for all tasks epoch = 0 time_begin = time.time() backbone_buffer = [] - + feed_batch_process_fn = create_feed_batch_process_fn(self._net_inputs) distribute_feeder = data_feeder(self._joint_iterator_fn, feed_batch_process_fn) - - # palm.distribute.reader(self._joint_iterator_fn, self._net_inputs, prefetch_steps=2) while not train_finish(): - feed, mask = next(distribute_feeder) + feed, mask, id = next(distribute_feeder) + for i in range(self.dev_count): + feed[i].update({'branch':np.array([id],dtype='int64')}) + fetch_list.append(self._switched_loss) rt_outputs = self.exe.run(train_program, feed=feed, fetch_list=fetch_list) - while mask.pop() == False: - rt_outputs.pop() + rt_loss = rt_outputs.pop() rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} - rt_task_id = np.squeeze(rt_outputs['__task_id']).tolist() - rt_task_id = rt_task_id[0] if isinstance(rt_task_id, list) else rt_task_id - cur_task = instances[rt_task_id] + cur_task = instances[id] - backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} - backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) + # backbone_rt_outputs = {k:v for k,v in rt_outputs.items() if '/' not in k} + # backbone_buffer.append(backbone.postprocess(backbone_rt_outputs)) - task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} - instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) + # task_rt_outputs = {k[len(cur_task.name+'/'):]: v for k,v in rt_outputs.items() if k.startswith(cur_task.name+'/')} + # instances[rt_task_id].task_layer['train'].postprocess(task_rt_outputs) global_step += 1 cur_task.cur_train_step += 1 cur_task_global_step = cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch if cur_task.is_target and cur_task.save_infermodel_every_n_steps > 0 and cur_task_global_step % cur_task.save_infermodel_every_n_steps == 0: - cur_task.save(suffix='.step'+str(cur_task_global_step)) + cur_task.save(suffix='.step'+str(cur_task_global_step), prog=self._pred_prog) if global_step % main_conf.get('print_every_n_steps', 5) == 0: - loss = rt_outputs[cur_task.name+'/loss'] + loss = rt_loss loss = np.mean(np.squeeze(loss)).tolist() time_end = time.time() @@ -640,7 +664,7 @@ class Controller(object): if cur_task.train_finish and cur_task.cur_train_step + cur_task.cur_train_epoch * cur_task.steps_pur_epoch == cur_task.expected_train_steps: print(cur_task.name+': train finished!') - cur_task.save() + cur_task.save(prog=self._pred_prog) if 'save_ckpt_every_n_steps' in main_conf and global_step % main_conf['save_ckpt_every_n_steps'] == 0: save_path = os.path.join(main_conf['save_path'], 'ckpt', @@ -686,37 +710,26 @@ class Controller(object): print('predicting...') feed_batch_process_fn = create_feed_batch_process_fn(inst.pred_input) - distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1) + distribute_feeder = data_feeder(inst.reader['pred'].iterator, feed_batch_process_fn, prefetch_steps=1, phase='pred') buf = [] - for feed, mask in distribute_feeder: - print('before run') + for feed, mask, id in distribute_feeder: + rt_outputs = self.exe.run(pred_prog, feed, fetch_vars) - print('after run') - splited_rt_outputs = [] - for item in rt_outputs: - splited_rt_outputs.append(np.split(item, len(mask))) - - # assert len(rt_outputs) == len(mask), [len(rt_outputs), len(mask)] - print(mask) - - while mask.pop() == False: - print(mask) - for item in splited_rt_outputs: + + num_fakes = decode_fake(len(rt_outputs[0]), mask, self.batch_size) + for _ in range(num_fakes): + for item in rt_outputs: item.pop() - rt_outputs = [] - print('cancat') - for item in splited_rt_outputs: - rt_outputs.append(np.concatenate(item)) - + rt_outputs = {k:v for k,v in zip(fetch_names, rt_outputs)} inst.postprocess(rt_outputs, phase='pred') - print('leave feeder') + if inst.task_layer['pred'].epoch_inputs_attrs: reader_outputs = inst.reader['pred'].get_epoch_outputs() else: reader_outputs = None - print('epoch postprocess') + inst.epoch_postprocess({'reader':reader_outputs}, phase='pred') @@ -731,6 +744,3 @@ if __name__ == '__main__': __all__ = ["Controller"] - - - diff --git a/paddlepalm/multihead_trainer.py b/paddlepalm/multihead_trainer.py new file mode 100644 index 0000000000000000000000000000000000000000..d78b34bd7135e30b0153e735c92aabd0fe79ceb4 --- /dev/null +++ b/paddlepalm/multihead_trainer.py @@ -0,0 +1,91 @@ + +from paddlepalm.distribute import gpu_dev_count, cpu_dev_count +from paddlepalm import Trainer + +dev_count = 1 if gpu_dev_count <= 1 else gpu_dev_count +VERBOSE=False + + +class MultiHeadTrainer(Trainer): + + def __init__(self, trainers, reuse_flags=None): + assert len(trainers) == len(mix_ratios) + if reuse_flags is not None: + assert len(reuse_flags) == len(trainers) + + self._trainers = trainers + + def build_forward(self, backbone, head_dict): + + num_heads = len(self._trainers) + assert len(head_dict) == num_heads + + for t in trainers: + assert t.name in head_dict + + train_prog = fluid.Program() + train_init_prog = fluid.Program() + + def get_loss(i): + head = head_dict[self._trainers[i].name] + loss_var = self._trainers[i].build_forward(backbone, head, train_prog, train_init_prog) + return loss_var + + task_fns = {} + for i in range(num_heads): + def task_loss(): + task_id = i + return lambda: get_loss(task_id) + task_fns[i] = task_loss() + + head_id_var = fluid.data(name="branch",shape=[1],dtype='int64') + loss_var = layers.switch_case( + branch_index=head_id_var, + branch_fns=task_fns + ) + self._head_id_var = head_id_var + return loss_var + + def fit_readers(self, reader_dict, mix_ratio, ): + + num_heads = len(self._trainers) + assert len(head_dict) == num_heads + + name_to_position = [] + joint_shape_and_dtypes = [] + iterators = [] + prefixes = [] + mrs = [] + net_inputs = [] + for t in trainers: + assert t.name in reader_dict + t.fit_reader(reader_dict[t.name]) + net_inputs.append(t._net_inputs) + prefixes.append(t.name) + mrs.append(t.mix_ratio) + iterators.append(t._raw_iterator_fn()) + name_to_position.append(t._name_to_position) + joint_shape_and_dtypes.append(t._shape_and_dtypes) + + iterator_fn = create_joint_iterator_fn(iterators, prefixes, joint_shape_and_dtypes, mrs, name_to_position, dev_count=dev_count, verbose=VERBOSE, return_type='dict') + feed_batch_process_fn = reader_helper.create_multihead_feed_batch_process_fn(net_inputs) + + if gpu_dev_count > 1: + distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn) + else: + distribute_feeder_fn = iterator_fn + + if phase == 'train': + self._train_reader = distribute_feeder_fn() + self._feed_batch_process_fn = feed_batch_process_fn + elif phase == 'predict': + self._predict_reader = distribute_feeder_fn() + self._pred_feed_batch_process_fn = feed_batch_process_fn + + + def train(self): + pass + + def train_one_step(self): + pass + diff --git a/paddlepalm/optimizer/adam.py b/paddlepalm/optimizer/adam.py index 51343f36889632ce3b625b60b37cb1ba4953e254..bdae6a749a9e5bc5d6d666ad319388c81017a988 100644 --- a/paddlepalm/optimizer/adam.py +++ b/paddlepalm/optimizer/adam.py @@ -37,8 +37,6 @@ class Adam(BaseOptimizer): if self._lr_schedualer is not None: self._lr = self._lr_schedualer.build(self._lr) - fluid.layers.Print(self._lr) - optimizer = fluid.optimizer.Adam(learning_rate=self._lr) if grad_clip is not None: diff --git a/paddlepalm/reader/base_reader.py b/paddlepalm/reader/base_reader.py index 03e541a8161651466c1c74a79917c4b931fc1dba..5041bbb1b441d1880c5142a0b2ad04bcc056d27d 100644 --- a/paddlepalm/reader/base_reader.py +++ b/paddlepalm/reader/base_reader.py @@ -21,6 +21,8 @@ class BaseReader(object): # assert isinstance(config, dict) # self._config = config self._phase = phase + self._batch_size = None + self._num_epochs = 1 self._register = set() self._registered_backbone = None @@ -117,4 +119,8 @@ class BaseReader(object): """数据集中的样本数量,即每个epoch中iterator所生成的样本数。注意,使用滑动窗口等可能导致数据集样本数发生变化的策略时,该接口应返回runtime阶段的实际样本数。""" raise NotImplementedError() + @property + def num_epochs(self): + """""" + raise NotImplementedError() diff --git a/paddlepalm/reader/cls.py b/paddlepalm/reader/cls.py index 044fe6ddcf60c0a0ab3749bf26f2a59fa97f420d..b1ae96be081f29cdb0febc195f94521cbf207957 100644 --- a/paddlepalm/reader/cls.py +++ b/paddlepalm/reader/cls.py @@ -32,7 +32,7 @@ class ClassifyReader(BaseReader): BaseReader.__init__(self, phase) assert lang.lower() in ['en', 'cn', 'english', 'chinese'], "supported language: en (English), cn (Chinese)." - assert phase in ['train', 'pred'], "supported phase: train, pred." + assert phase in ['train', 'predict'], "supported phase: train, predict." for_cn = lang.lower() == 'cn' or lang.lower() == 'chinese' @@ -66,10 +66,13 @@ class ClassifyReader(BaseReader): return self._get_registed_attrs(attrs) - def _load_data(self, input_file, batch_size, num_epochs=None, \ + def load_data(self, input_file, batch_size, num_epochs=None, \ file_format='csv', shuffle_train=True): - self._data_generator = self._reader.data_generator(input_file, batch_size, \ - num_epochs, shuffle=shuffle_train if self._phase == 'train' else False, \ + self._batch_size = batch_size + self._num_epochs = num_epochs + self._data_generator = self._reader.data_generator( \ + input_file, batch_size, num_epochs if phase == 'train' else 1, \ + shuffle=shuffle_train if self._phase == 'train' else False, \ phase=self._phase) def _iterator(self): @@ -92,4 +95,8 @@ class ClassifyReader(BaseReader): def num_examples(self): return self._reader.get_num_examples(phase=self._phase) + @property + def num_epochs(self): + return self._num_epochs + diff --git a/paddlepalm/reader/utils/reader4ernie.py b/paddlepalm/reader/utils/reader4ernie.py index f41200a654d415c0d6b14b9c56640201defa0b0e..3b47cb9b723a13d14167acfe93f17c3cbcfd204b 100644 --- a/paddlepalm/reader/utils/reader4ernie.py +++ b/paddlepalm/reader/utils/reader4ernie.py @@ -219,7 +219,7 @@ class BaseReader(object): qid=qid) return record - def _prepare_batch_data(self, examples, batch_size, phase=None): + def _prepare_batch_data(self, examples, batch_size, phase='train'): """generate batch records""" batch_records, max_len = [], 0 if len(examples) < batch_size: @@ -243,13 +243,11 @@ class BaseReader(object): if phase == 'pred' and batch_records: yield self._pad_batch_records(batch_records) - def get_num_examples(self, input_file=None, phase=None): - if self.examples is not None: - if phase is None: - phase = 'all' - return len(self.examples[phase]) + def get_num_examples(self, input_file=None, phase='train'): + if input_file is None: + return len(self.examples.get(phase, [])) else: - assert input_file is not None, "Argument input_file should be given or the data_generator should be created when this func is called." + # assert input_file is not None, "Argument input_file should be given or the data_generator should be created when this func is called." examples = self._read_tsv(input_file) return len(examples) diff --git a/paddlepalm/trainer.py b/paddlepalm/trainer.py index 2ae106f952f4a6917e1110d1ff98c86eadea88f8..4384bbaac2156ea9d7fc5e65b94a50e43c2e82c7 100644 --- a/paddlepalm/trainer.py +++ b/paddlepalm/trainer.py @@ -29,15 +29,13 @@ DEBUG=False class Trainer(object): - def __init__(self, name, reader, task_head, \ - mix_ratio=1.0, reuse_head_with=None, \ + def __init__(self, name, mix_ratio=1.0, reuse_head_with=None, \ silent=False): self._name = name self._verbose = not silent - self._reader = reader self._pred_reader = None - self._task_head = task_head + self._task_head = None self._pred_head = None self._train_init = False @@ -66,15 +64,12 @@ class Trainer(object): self._expected_train_steps = None self._expected_train_epochs = None self._steps_pur_epoch = None + self._pred_steps_pur_epoch = None self._cur_train_epoch = 0 self._cur_train_step = 0 self._train_finish = False - # 存放不同运行阶段(train,eval,pred)的数据集reader,key为phase,value为Reader实例 - # self._reader = {'train': reader, 'eval': None, 'pred': self._pred_reader} - # self._input_layer = None self._inputname_to_varname = {} - # self._task_layer = {'train': task_head, 'eval': None, 'pred': pred_head} self._pred_input_name_list = [] self._pred_input_varname_list = [] self._pred_fetch_name_list = [] @@ -92,7 +87,7 @@ class Trainer(object): self._lock = False self._build_forward = False - def build_predict_head(self, pred_head, pred_backbone, pred_prog=None, pred_init_prog=None): + def build_predict_forward(self, pred_backbone, pred_head, pred_prog=None, pred_init_prog=None): self._pred_head = pred_head # self._pred_reader = self._reader.clone(phase='pred') pred_task_attr_from_reader = helper.encode_inputs(self._pred_head.inputs_attrs['reader'], self.name) @@ -101,8 +96,10 @@ class Trainer(object): # _check_io(pred_backbone.inputs_attr, pred_reader.outputs_attr, in_name=bb_name+'_backbone', out_name='reader.pred') # _check_io(pred_parad.inputs_attrs['reader'], pred_reader.outputs_attr, in_name='task_paradigm.pred.reader', out_name='reader.pred') # _check_io(pred_parad.inputs_attrs['backbone'], pred_backbone.outputs_attr, in_name='task_paradigm.pred.backbone', out_name=bb_name+'_backbone') - pred_input_names, pred_shape_and_dtypes, _ = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) + pred_input_names, pred_shape_and_dtypes, pred_name_to_position = reader_helper.merge_input_attrs(pred_backbone.inputs_attr, pred_task_attr_from_reader, insert_taskid=False, insert_batchsize=False, insert_seqlen=False, insert_batchsize_x_seqlen=False) pred_input_attrs = [[i, j, k] for i, (j,k) in zip(pred_input_names, pred_shape_and_dtypes)] + self._pred_shape_and_dtypes = pred_shape_and_dtypes + self._pred_name_to_position = pred_name_to_position if pred_prog is None: pred_prog = fluid.Program() @@ -114,6 +111,7 @@ class Trainer(object): pred_net_inputs = reader_helper.create_net_inputs(pred_input_attrs) # pred_bb_output_vars = pred_backbone.build(pred_net_inputs, scope_name='__paddlepalm_') pred_bb_output_vars = pred_backbone.build(pred_net_inputs) + self._pred_net_inputs = pred_net_inputs # prepare predict vars for saving inference model with fluid.program_guard(pred_prog, pred_init_prog): @@ -128,16 +126,16 @@ class Trainer(object): output_vars = self._build_head(pred_task_inputs, phase='pred', scope=scope) if output_vars is not None: - self._pred_fetch_name_list, self._pred_fetch_var_list = zip(*output_vars.items()) + self._pred_fetch_name_list, self._pred_fetch_list = zip(*output_vars.items()) else: self._pred_fetch_name_list = [] self._pred_fetch_var_list = [] - self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel() return output_vars - def build_forward(self, backbone, pred_backbone=None, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None): + def build_forward(self, backbone, task_head, train_prog=None, train_init_prog=None, pred_prog=None, pred_init_prog=None): + self._task_head = task_head # assert self._backbone is not None, "backbone is required for Trainer to build net forward to run with single task mode" self._build_forward = True @@ -220,9 +218,9 @@ class Trainer(object): with fluid.program_guard(train_prog, train_init_prog): loss_var = fluid.layers.reduce_sum(task_output_vars[self.name+'.loss']) - for _id, block in enumerate(self._train_prog.blocks): - for var in block.vars: - print("[debug] : %d, %s" % (_id, var)) + # for _id, block in enumerate(self._train_prog.blocks): + # for var in block.vars: + # print("[debug] : %d, %s" % (_id, var)) self._loss_var = loss_var return loss_var @@ -272,43 +270,69 @@ class Trainer(object): # print(self._train_prog) - def load_data(self, input_file, file_format, batch_size, num_epochs=None, shuffle_train=True): + def fit_reader(self, reader, phase='train'): # load data - print("preparing data...", end='') - self._reader._load_data(input_file=input_file, batch_size=batch_size, \ - num_epochs=num_epochs, file_format=file_format, \ - shuffle_train=shuffle_train) - self._num_examples = self._reader.num_examples + assert self._train_init_prog is not None or self._pred_init_prog is not None, "You need to build_forward or build_predict_head first to prepare input features." # 这里不确定是否要向上取整,需确认 # tail = self._num_examples % batch_size > 0 # self._steps_pur_epoch = self._num_examples // batch_size + 1 if tail else 0 - self._steps_pur_epoch = self._num_examples // batch_size + batch_size = reader._batch_size + self._num_epochs = reader.num_epochs + if phase == 'train': + self._steps_pur_epoch = reader.num_examples // batch_size + shape_and_dtypes = self._shape_and_dtypes + name_to_position = self._name_to_position + net_inputs = self._net_inputs + self._train_batch_size = batch_size + self._num_examples = reader.num_examples + elif phase == 'predict': + tail = self._num_examples % batch_size > 0 + self._pred_steps_pur_epoch = reader.num_examples // batch_size + 1 if tail else 0 + shape_and_dtypes = self._pred_shape_and_dtypes + name_to_position = self._pred_name_to_position + net_inputs = self._pred_net_inputs + self._predict_batch_size = batch_size + self._pred_num_examples = reader.num_examples + else: + raise NotImplementedError() + print('ok!') # merge dataset iterators and create net input vars - iterator = self._reader._iterator() + iterator = reader._iterator() prefix = self.name # 对yield出的数据进行runtime检查和适配 - iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, self._shape_and_dtypes, self._name_to_position, return_type='dict') - feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(self._net_inputs) - self._feed_batch_process_fn = feed_batch_process_fn + iterator_fn = reader_helper.create_iterator_fn(iterator, prefix, shape_and_dtypes, name_to_position, return_type='dict') + self._raw_iterator_fn = iterator_fn + feed_batch_process_fn = reader_helper.create_feed_batch_process_fn(net_inputs) if gpu_dev_count > 1: distribute_feeder_fn = data_feeder(iterator_fn, feed_batch_process_fn) else: distribute_feeder_fn = iterator_fn - return distribute_feeder_fn() + + if phase == 'train': + self._train_reader = distribute_feeder_fn() + self._feed_batch_process_fn = feed_batch_process_fn + elif phase == 'predict': + self._predict_reader = distribute_feeder_fn() + self._pred_feed_batch_process_fn = feed_batch_process_fn + # return distribute_feeder_fn() def _init_exe_prog(self, for_train=True): - assert self._train_init_prog is not None, "train graph not foung! You should build_forward first before you random init parameters." - self._train_init = True - self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) - on_gpu = gpu_dev_count > 0 - self._exe = helper.build_executor(on_gpu) - if not for_train: - raise NotImplementedError() + if not self._train_init and not self._predict_init: + on_gpu = gpu_dev_count > 0 + self._exe = helper.build_executor(on_gpu) + + if for_train: + assert self._train_prog is not None, "train graph not foung! You should build_forward first before you random init parameters." + self._train_init = True + else: + assert self._pred_prog is not None, "predict graph not foung! You should build_predict_head first before you random init parameters." + self._predict_init = True def random_init_params(self): + if not self._train_init: self._init_exe_prog() @@ -319,9 +343,9 @@ class Trainer(object): # load pretrain model (or ckpt) # assert self._exe is not None, "You need to random_init_params before load checkpoints." if phase == 'train' and not self._train_init: - self._init_exe_prog() + self._init_exe_prog(for_train=True) if phase == 'predict' and not self._predict_init: - pass + self._init_exe_prog(for_train=False) if phase == 'train': assert self._train_init_prog is not None, "train graph not found! You should build_forward first before load checkpoint." @@ -344,23 +368,23 @@ class Trainer(object): def load_predict_model(self, model_path): raise NotImplementedError() - def load_pretrain(self, model_path): + def load_pretrain(self, model_path, convert=False): # load pretrain model (or ckpt) assert self._exe is not None, "You need to random_init_params before load pretrain models." saver.init_pretraining_params( self._exe, model_path, + convert=convert, main_program=self._train_init_prog) - def set_predict_head(self): - pass - - def train(self, iterator, save_path=None, save_steps=None, save_type='ckpt', print_steps=5): + def train(self, save_path=None, save_steps=None, save_type='ckpt', print_steps=5): """ Argument: save_type: ckpt, predict, pretrain """ + iterator = self._train_reader + self._distribute_train_prog = fluid.CompiledProgram(self._train_prog).with_data_parallel(loss_name=self._loss_var.name) save_type = save_type.split(',') if 'predict' in save_type: @@ -412,15 +436,13 @@ class Trainer(object): # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} - self._task_head.postprocess(task_rt_outputs) + self._task_head.batch_postprocess(task_rt_outputs) # rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} task_rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} - self._task_head.postprocess(task_rt_outputs) + self._task_head.batch_postprocess(task_rt_outputs) - self._cur_train_step += 1 - self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch # if self._save_predict_model and self._cur_train_step % save_steps == 0: # self.save(save_path, suffix='.step'+str(self._cur_train_steps)) @@ -448,6 +470,8 @@ class Trainer(object): fluid.io.save_persistables(self._exe, os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step)), self._train_prog) print('checkpoint has been saved at '+os.path.join(save_path, 'ckpt.step'+str(self._cur_train_step))) + if self._num_epochs is None and self._cur_train_step == self._steps_pur_epoch: + break # save_path = os.path.join(main_conf['save_path'], 'ckpt', # "step_" + str(global_step)) # fluid.io.save_persistables(self.exe, save_path, saver_program) @@ -455,32 +479,83 @@ class Trainer(object): # print("ALL tasks train finished, exiting...") + def get_one_batch(self, phase='train'): + if phase == 'train': + return next(self._train_reader) + elif phase == 'predict': + return next(self._predict_reader) + else: + raise NotImplementedError() + + def predict(self, output_dir=None, print_steps=1000): + """ + Argument: + save_type: ckpt, predict, pretrain + """ + iterator = self._predict_reader + self._distribute_pred_prog = fluid.CompiledProgram(self._pred_prog).with_data_parallel() + + if output_dir is not None and not os.path.exists(output_dir): + os.makedirs(output_dir) + + time_begin = time.time() + cur_predict_step = 0 + for feed in iterator: + rt_outputs = self.predict_one_batch(feed) + # rt_outputs = {k[len(self.name+'.'):]: v for k,v in rt_outputs.items() if k.startswith(self.name+'.')} + # print(rt_outputs) + self._pred_head.batch_postprocess(rt_outputs) + + cur_predict_step += 1 + + if print_steps > 0 and cur_predict_step % print_steps == 0: + time_end = time.time() + time_cost = time_end - time_begin + + print("batch {}/{}, speed: {:.2f} steps/s".format( + cur_predict_step, self._pred_steps_pur_epoch, + print_steps / time_cost)) + time_begin = time.time() + + if self._pred_head.epoch_inputs_attrs: + reader_outputs = self._pred_reader.get_epoch_outputs() + else: + reader_outputs = None + + results = self._pred_head.epoch_postprocess({'reader':reader_outputs}, output_dir=output_dir) + return results + def train_one_step(self, batch): if gpu_dev_count > 1: feed, mask = batch rt_outputs = self.exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list) - while mask.pop() == False: - rt_outputs.pop() + num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size) + for _ in range(num_fakes): + for item in rt_outputs: + item.pop() else: feed = self._feed_batch_process_fn(batch) rt_outputs = self._exe.run(self._distribute_train_prog, feed=feed, fetch_list=self._fetch_list) rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} + self._cur_train_step += 1 + self._cur_train_epoch = (self._cur_train_step-1) // self._steps_pur_epoch return rt_outputs def predict_one_batch(self, batch): if gpu_dev_count > 1: feed, mask = batch - rt_outputs = self.exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._fetch_list) - while mask.pop() == False: - rt_outputs.pop() + rt_outputs = self.exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list) + num_fakes = decode_fake(len(rt_outputs[0]), mask, self._batch_size) + for _ in range(num_fakes): + for item in rt_outputs: + item.pop() else: - feed = self._feed_batch_process_fn(batch) - rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._fetch_list) + feed = self._pred_feed_batch_process_fn(batch) + rt_outputs = self._exe.run(self._distribute_pred_prog, feed=feed, fetch_list=self._pred_fetch_list) - rt_outputs = {k:v for k,v in zip(self._fetch_names, rt_outputs)} - - + rt_outputs = {k:v for k,v in zip(self._pred_fetch_name_list, rt_outputs)} + return rt_outputs def _build_head(self, net_inputs, phase, scope=""): if phase == 'train': @@ -488,12 +563,6 @@ class Trainer(object): if phase == 'pred': output_vars = self._pred_head.build(net_inputs, scope_name=scope) return output_vars - - def _postprocess(self, rt_outputs, phase): - return self._task_layer[phase].postprocess(rt_outputs) - - def _epoch_postprocess(self, epoch_inputs, phase): - return self._task_layer[phase].epoch_postprocess(epoch_inputs) def save(self, save_path, suffix=None): # dirpath = save_path.rstrip('/').rstrip('\\') + suffix @@ -536,20 +605,6 @@ class Trainer(object): def num_examples(self): return self._num_examples - # @property - # def _pred_input(self): - # return zip(*[self._pred_input_name_list, self._pred_input_varname_list]) - - # @_pred_input.setter - # def _pred_input(self, val): - # assert isinstance(val, dict) - # self._pred_input_name_list, self._pred_input_varname_list = \ - # zip(*[[k, v.name] for k,v in val.items()]) - - # @property - # def _pred_fetch_list(self): - # return [self._pred_fetch_name_list, self._pred_fetch_var_list] - @property def mix_ratio(self): if self._mix_ratio is not None: @@ -563,57 +618,6 @@ class Trainer(object): if self._verbose: print('{}: mix_ratio is set to {}'.format(self._name, self._mix_ratio)) - @property - def save_infermodel_every_n_steps(self): - return self._save_infermodel_every_n_steps - - @save_infermodel_every_n_steps.setter - def save_infermodel_every_n_steps(self, val): - self._save_infermodel_every_n_steps = val - - @property - def expected_train_steps(self): - return self._expected_train_steps - - @expected_train_steps.setter - def expected_train_steps(self, value): - self._expected_train_steps = value - self._expected_train_epochs = value / float(self._steps_pur_epoch) - - @property - def expected_train_epochs(self): - return self._expected_train_epochs - - @property - def cur_train_epoch(self): - return self._cur_train_epoch - - @property - def cur_train_step(self): - return self._cur_train_step - - # @cur_train_step.setter - # def _cur_train_step(self, value): - # self._cur_train_step = value - # if self._cur_train_step > self._steps_pur_epoch: - # self._cur_train_epoch += 1 - # self._cur_train_step = 1 - # if self._is_target and self._cur_train_step + self._cur_train_epoch * self._steps_pur_epoch >= self._expected_train_steps: - # self._train_finish = True - @steps_pur_epoch.setter - def steps_pur_epoch(self, value): - self._steps_pur_epoch = value - - @property - def train_finish(self): - return self._train_finish - - def tasklayer_reuse_with(self, task): - assert isinstance(task, Task) - if self._lock: - raise Exception('you can only set tasklayer reuses BEFORE Controller created.') - self._task_reuse_scope = task.name - def _set_lock(self): self._lock = True diff --git a/paddlepalm/utils/.saver.py.swp b/paddlepalm/utils/.saver.py.swp deleted file mode 100644 index 229018dbf773ff028d6b8f0e7be08c74ebdcdf2a..0000000000000000000000000000000000000000 Binary files a/paddlepalm/utils/.saver.py.swp and /dev/null differ diff --git a/paddlepalm/utils/reader_helper.py b/paddlepalm/utils/reader_helper.py index a603292cce748e48b68247cf334c50be98953f57..d0727aaa2c8a8ecd6692c15fa8c219a2a23bf1ff 100644 --- a/paddlepalm/utils/reader_helper.py +++ b/paddlepalm/utils/reader_helper.py @@ -35,6 +35,27 @@ def create_feed_batch_process_fn(net_inputs): return feed_batch_process_fn + +def create_multihead_feed_batch_process_fn(net_inputs): + + def feed_batch_process_fn(data, id=-1): + # temps = {} + # for i in range(len(net_inputs)): + temp = {} + inputs = net_inputs[id] if id != -1 else net_inputs + + for q, var in inputs.items(): + if isinstance(var, str) or isinstance(var, unicode): + temp[var] = data[q] + else: + temp[var.name] = data[q] + # temps[i] = temp + + return temp + + return feed_batch_process_fn + + def _check_and_adapt_shape_dtype(rt_val, attr, message=""): if not isinstance(rt_val, np.ndarray): rt_val = np.array(rt_val) diff --git a/paddlepalm/utils/saver.py b/paddlepalm/utils/saver.py index b4f0241e2c04d6d462a1afd1bfb6ca2e4cb2179a..c1da2883273c1b08a2f895f49f3462d036799c17 100644 --- a/paddlepalm/utils/saver.py +++ b/paddlepalm/utils/saver.py @@ -47,7 +47,9 @@ def init_checkpoint(exe, init_checkpoint_path, main_program, skip_list = []): def init_pretraining_params(exe, pretraining_params_path, convert, - main_program): + main_program, + strict=False): + assert os.path.exists(pretraining_params_path ), "[%s] cann't be found." % pretraining_params_path @@ -69,7 +71,10 @@ def init_pretraining_params(exe, if not isinstance(var, fluid.framework.Parameter): return False if not os.path.exists(os.path.join(pretraining_params_path, var.name)): - print('Warning: {} not found in {}.'.format(var.name, log_path)) + if strict: + raise Exception('Error: {} not found in {}.'.format(var.name, log_path)) + else: + print('Warning: {} not found in {}.'.format(var.name, log_path)) return os.path.exists(os.path.join(pretraining_params_path, var.name)) fluid.io.load_vars(