Source code for lib.sedna.service.server.inference

# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Optional

from pydantic import BaseModel
from fastapi import FastAPI
from fastapi.routing import APIRoute
from starlette.responses import JSONResponse

from .base import BaseServer


__all__ = ('InferenceServer', )


class ServeModelInfoResult(
        BaseModel):  # pylint: disable=too-few-public-methods
    """
    Expose model information
    """

    infos: List


class ServePredictResult(BaseModel):  # pylint: disable=too-few-public-methods
    """
    result
    """

    result: List


class InferenceItem(BaseModel):  # pylint: disable=too-few-public-methods
    data: List
    callback: Optional[str] = None


[docs]class InferenceServer(BaseServer): # pylint: disable=too-many-arguments """ rest api server for inference """ def __init__( self, model, servername, host: str = '127.0.0.1', http_port: int = 8080, max_buffer_size: int = 104857600, workers: int = 1): super( InferenceServer, self).__init__( servername=servername, host=host, http_port=http_port, workers=workers) self.model = model self.max_buffer_size = max_buffer_size self.app = FastAPI( routes=[ APIRoute( f"/{servername}", self.model_info, response_model=ServeModelInfoResult, response_class=JSONResponse, methods=["GET"], ), APIRoute( f"/{servername}/predict", self.predict, response_model=ServePredictResult, response_class=JSONResponse, methods=["POST"], ), ], log_level="trace", timeout=600, )
[docs] def start(self): return self.run(self.app)
[docs] def model_info(self): return ServeModelInfoResult(infos=self.get_all_urls())
[docs] def predict(self, data: InferenceItem): inference_res = self.model.inference( data.data, post_process=data.callback) return ServePredictResult(result=inference_res)