提交 502a8abc 编写于 作者: B barrierye

fix bug in gpu web_service && test=serving

上级 66e030c5
...@@ -183,7 +183,6 @@ int GeneralResponseOp::inference() { ...@@ -183,7 +183,6 @@ int GeneralResponseOp::inference() {
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) { for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]); input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]);
VLOG(2) << "p size for input blob: " << input_blob->p_size; VLOG(2) << "p size for input blob: " << input_blob->p_size;
ModelOutput *output = res->mutable_outputs(pi);
int profile_time_idx = -1; int profile_time_idx = -1;
if (pi == 0) { if (pi == 0) {
profile_time_idx = 0; profile_time_idx = 0;
......
...@@ -138,7 +138,6 @@ int GeneralTextResponseOp::inference() { ...@@ -138,7 +138,6 @@ int GeneralTextResponseOp::inference() {
for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) { for (uint32_t pi = 0; pi < pre_node_names.size(); ++pi) {
input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]); input_blob = get_depend_argument<GeneralBlob>(pre_node_names[pi]);
VLOG(2) << "p size for input blob: " << input_blob->p_size; VLOG(2) << "p size for input blob: " << input_blob->p_size;
ModelOutput *output = res->mutable_outputs(pi);
int profile_time_idx = -1; int profile_time_idx = -1;
if (pi == 0) { if (pi == 0) {
profile_time_idx = 0; profile_time_idx = 0;
......
...@@ -11,6 +11,7 @@ ...@@ -11,6 +11,7 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
# pylint: disable=doc-string-missing
from flask import Flask, request, abort from flask import Flask, request, abort
from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server from paddle_serving_server_gpu import OpMaker, OpSeqMaker, Server
...@@ -103,11 +104,22 @@ class WebService(object): ...@@ -103,11 +104,22 @@ class WebService(object):
abort(400) abort(400)
if "fetch" not in request.json: if "fetch" not in request.json:
abort(400) abort(400)
feed, fetch = self.preprocess(request.json, request.json["fetch"]) try:
fetch_map_batch = self.client.predict(feed=feed, fetch=fetch) feed, fetch = self.preprocess(request.json, request.json["fetch"])
fetch_map_batch = self.postprocess( if isinstance(feed, list):
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch) fetch_map_batch = self.client.predict(
result = {"result": fetch_map_batch} feed_batch=feed, fetch=fetch)
fetch_map_batch = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map_batch)
result = {"result": fetch_map_batch}
elif isinstance(feed, dict):
if "fetch" in feed:
del feed["fetch"]
fetch_map = self.client_service.predict(feed=feed, fetch=fetch)
result = self.postprocess(
feed=request.json, fetch=fetch, fetch_map=fetch_map)
except ValueError:
result = {"result": "Request Value Error"}
return result return result
def run_server(self): def run_server(self):
......
Markdown is supported
0% .
You are about to add 0 people to the discussion. Proceed with caution.
先完成此消息的编辑!
想要评论请 注册