Source code for lib.sedna.service.server.aggregation

# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import time
import uuid
from typing import Any, Dict, List, Optional

from fastapi import FastAPI, WebSocket
from fastapi.routing import APIRoute
from pydantic import BaseModel
from starlette.endpoints import WebSocketEndpoint
from starlette.requests import Request
from starlette.responses import JSONResponse
from starlette.routing import WebSocketRoute
from starlette.types import ASGIApp, Receive, Scope, Send

from sedna.algorithms.aggregation import AggClient
from sedna.common.config import BaseConfig, Context
from sedna.common.class_factory import ClassFactory, ClassType
from sedna.common.log import LOGGER
from sedna.common.config import Context
from sedna.common.utils import get_host_ip
from .base import BaseServer

__all__ = ('AggregationServer', 'AggregationServerV2')


class WSClientInfo(BaseModel):  # pylint: disable=too-few-public-methods
    """
    client information
    """
    client_id: str
    connected_at: float
    info: Any


class WSClientInfoList(BaseModel):  # pylint: disable=too-few-public-methods
    clients: List


class WSEventMiddleware:  # pylint: disable=too-few-public-methods
    def __init__(self, app: ASGIApp, **kwargs):
        self._app = app
        self._server = Aggregator(**kwargs)

    async def __call__(self, scope: Scope, receive: Receive, send: Send):
        if scope["type"] in ("lifespan", "http", "websocket"):
            servername = scope["path"].lstrip("/")
            scope[servername] = self._server
        await self._app(scope, receive, send)
        # exit agg server if job complete
        scope["app"].shutdown = (self._server.exit_check()
                                 and self._server.empty)


class WSServerBase:
    def __init__(self):
        self._clients: Dict[str, WebSocket] = {}
        self._client_meta: Dict[str, WSClientInfo] = {}

    def __len__(self) -> int:
        return len(self._clients)

    @property
    def empty(self) -> bool:
        return len(self._clients) == 0

    @property
    def client_list(self) -> List[str]:
        # Todo: Considering the expansion of server center,
        #  saving the data to a database would be more appropriate.

        return list(self._clients)

    def add_client(self, client_id: str, websocket: WebSocket):
        if client_id in self._clients:
            raise ValueError(f"Client {client_id} is already in the server")
        LOGGER.info(f"Adding client {client_id}")
        self._clients[client_id] = websocket
        self._client_meta[client_id] = WSClientInfo(
            client_id=client_id, connected_at=time.time(), info=None
        )

    async def kick_client(self, client_id: str):
        if client_id not in self._clients:
            raise ValueError(f"Client {client_id} is not in the server")
        await self._clients[client_id].close()

    def remove_client(self, client_id: str):
        if client_id not in self._clients:
            raise ValueError(f"Client {client_id} is not in the server")
        LOGGER.info(f"Removing Client {client_id} from server")
        del self._clients[client_id]
        del self._client_meta[client_id]

    def get_client(self, client_id: str) -> Optional[WSClientInfo]:
        return self._client_meta.get(client_id)

    async def send_message(self, client_id: str, msg: Dict):
        for to_client, websocket in self._clients.items():
            if to_client == client_id:
                continue
            LOGGER.info(f"send data to Client {to_client} from server")
            await websocket.send_json(msg)

    async def client_joined(self, client_id: str):
        for websocket in self._clients.values():
            await websocket.send_json({"type": "CLIENT_JOIN",
                                       "data": client_id})


class Aggregator(WSServerBase):
    def __init__(self, **kwargs):
        super(Aggregator, self).__init__()
        self.exit_round = int(kwargs.get("exit_round", 3))
        aggregation = kwargs.get("aggregation", "FedAvg")
        self.aggregation = ClassFactory.get_cls(ClassType.FL_AGG, aggregation)
        if callable(self.aggregation):
            self.aggregation = self.aggregation()
        self.participants_count = int(kwargs.get("participants_count", "1"))
        self.current_round = 0

    async def send_message(self, client_id: str, msg: Dict):
        data = msg.get("data")
        if data and msg.get("type", "") == "update_weight":
            info = AggClient()
            info.num_samples = int(data["num_samples"])
            info.weights = data["weights"]
            self._client_meta[client_id].info = info
            current_clinets = [
                x.info for x in self._client_meta.values() if x.info
            ]
            # exit while aggregation job is NOT start
            if len(current_clinets) < self.participants_count:
                return
            self.current_round += 1
            weights = self.aggregation.aggregate(current_clinets)
            exit_flag = "ok" if self.exit_check() else "continue"

            msg["type"] = "recv_weight"
            msg["round_number"] = self.current_round
            msg["data"] = {
                "total_sample": self.aggregation.total_size,
                "round_number": self.current_round,
                "weights": weights,
                "exit_flag": exit_flag
            }
        for to_client, websocket in self._clients.items():
            try:
                await websocket.send_json(msg)
            except Exception as err:
                LOGGER.error(err)
            else:
                if msg["type"] == "recv_weight":
                    self._client_meta[to_client].info = None

    def exit_check(self):
        return self.current_round >= self.exit_round


class BroadcastWs(WebSocketEndpoint):
    encoding: str = "json"
    session_name: str = ""
    count: int = 0

    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.server: Optional[Aggregator] = None
        self.client_id: Optional[str] = None

    async def on_connect(self, websocket: WebSocket):
        servername = websocket.scope['path'].lstrip("/")
        LOGGER.info("Connecting new client...")
        server: Optional[Aggregator] = self.scope.get(servername)
        if server is None:
            raise RuntimeError("HOST `client` instance unavailable!")
        self.server = server
        await websocket.accept()

    async def on_disconnect(self, _websocket: WebSocket, _close_code: int):
        if self.client_id is None:
            raise RuntimeError(
                "on_disconnect() called without a valid client_id"
            )
        self.server.remove_client(self.client_id)

    async def on_receive(self, _websocket: WebSocket, msg: Dict):
        command = msg.get("type", "")
        if command == "subscribe":
            self.client_id = msg.get("client_id", "") or uuid.uuid4().hex
            await self.server.client_joined(self.client_id)
            self.server.add_client(self.client_id, _websocket)
        if self.client_id is None:
            raise RuntimeError(
                "on_receive() called without a valid client_id")
        await self.server.send_message(self.client_id, msg)


[docs]class AggregationServer(BaseServer): def __init__( self, aggregation: str, host: str = None, http_port: int = None, exit_round: int = 1, participants_count: int = 1, ws_size: int = 10 * 1024 * 1024): if not host: host = Context.get_parameters("AGG_BIND_IP", get_host_ip()) if not http_port: http_port = int(Context.get_parameters("AGG_BIND_PORT", 7363)) super( AggregationServer, self).__init__( servername=aggregation, host=host, http_port=http_port, ws_size=ws_size) self.aggregation = aggregation self.participants_count = participants_count self.exit_round = max(int(exit_round), 1) self.app = FastAPI( routes=[ APIRoute( f"/{aggregation}", self.client_info, response_class=JSONResponse, ), WebSocketRoute( f"/{aggregation}", BroadcastWs ) ], ) self.app.shutdown = False
[docs] def start(self): """ Start the server """ self.app.add_middleware( WSEventMiddleware, exit_round=self.exit_round, aggregation=self.aggregation, participants_count=self.participants_count ) # define the aggregation method and exit condition self.run(self.app, ws_max_size=self.ws_size)
[docs] async def client_info(self, request: Request): server: Optional[Aggregator] = request.get(self.server_name) try: data = await request.json() except BaseException: data = {} client_id = data.get("client_id", "") if data else "" if client_id: return server.get_client(client_id) return WSClientInfoList(clients=server.client_list)
[docs]class AggregationServerV2(): def __init__(self, data=None, estimator=None, aggregation=None, transmitter=None, chooser=None) -> None: from plato.config import Config # set parameters server = Config().server._asdict() clients = Config().clients._asdict() datastore = Config().data._asdict() train = Config().trainer._asdict() if data is not None: datastore.update(data.parameters) Config().data = Config.namedtuple_from_dict(datastore) self.model = None if estimator is not None: self.model = estimator.model if estimator.pretrained is not None: Config().params['pretrained_model_dir'] = estimator.pretrained if estimator.saved is not None: Config().params['model_dir'] = estimator.saved train.update(estimator.hyperparameters) Config().trainer = Config.namedtuple_from_dict(train) server["address"] = Context.get_parameters("AGG_BIND_IP", "0.0.0.0") server["port"] = int(Context.get_parameters("AGG_BIND_PORT", 7363)) if transmitter is not None: server.update(transmitter.parameters) if aggregation is not None: Config().algorithm = Config.namedtuple_from_dict( aggregation.parameters) if aggregation.parameters["type"] == "mistnet": clients["type"] = "mistnet" server["type"] = "mistnet" else: clients["do_test"] = True if chooser is not None: clients["per_round"] = chooser.parameters["per_round"] LOGGER.info("address %s, port %s", server["address"], server["port"]) Config().server = Config.namedtuple_from_dict(server) Config().clients = Config.namedtuple_from_dict(clients) from plato.servers import registry as server_registry self.server = server_registry.get(model=self.model)
[docs] def start(self): self.server.run()