# Copyright 2021 The KubeEdge Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Copy from https://github.com/huawei-noah/vega/blob/master/zeus/common/config.py # noqa
# We made a re-modify due to vega is exceed out needs
import os
import sys
import yaml
import json
from copy import deepcopy
from importlib import import_module
from inspect import ismethod, isfunction
from .utils import singleton
__all__ = ('Context', 'BaseConfig', )
def _url2dict(arg):
if arg.endswith('.yaml') or arg.endswith('.yml'):
with open(arg) as f:
raw_dict = yaml.load(f, Loader=yaml.FullLoader)
elif arg.endswith('.py'):
module_name = os.path.basename(arg)[:-3]
config_dir = os.path.dirname(arg)
sys.path.insert(0, config_dir)
mod = import_module(module_name)
sys.path.pop(0)
raw_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
}
sys.modules.pop(module_name)
elif arg.endswith(".json"):
with open(arg) as f:
raw_dict = json.load(f)
else:
try:
raw_dict = json.loads(arg, encoding="utf-8")
except json.JSONDecodeError:
raise Exception('config file must be yaml or py')
return raw_dict
def _dict2config(config, dic):
"""Convert dictionary to config.
:param Config config: config
:param dict dic: dictionary
"""
if isinstance(dic, dict):
for key, value in dic.items():
if isinstance(value, dict):
config[key] = Config()
_dict2config(config[key], value)
else:
config[key] = value
class Config(dict):
"""A Config class is inherit from dict.
Config class can parse arguments from a config file
of yaml, json or pyscript.
:param args: tuple of Config initial arguments
:type args: tuple of str or dict
:param kwargs: dict of Config initial argumnets
:type kwargs: dict
"""
def __init__(self, *args, **kwargs):
"""Init config class with multiple config files or dictionary."""
super(Config, self).__init__()
for arg in args:
if isinstance(arg, str):
_dict2config(self, _url2dict(arg))
elif isinstance(arg, dict):
_dict2config(self, arg)
else:
raise TypeError('args is not dict or str')
if kwargs:
_dict2config(self, kwargs)
def __call__(self, *args, **kwargs):
"""Call config class to return a new Config object.
:return: a new Config object.
:rtype: Config
"""
return Config(self, *args, **kwargs)
def __setstate__(self, state):
"""Set state is to restore state from the unpickled state values.
:param dict state: the `state` type should be the output of
`__getstate__`.
"""
_dict2config(self, state)
def __getstate__(self):
"""Return state values to be pickled.
:return: change the Config to a dict.
:rtype: dict
"""
d = dict()
for key, value in self.items():
if isinstance(value, Config):
value = value.__getstate__()
d[key] = value
return d
def __getattr__(self, key):
"""Get a object attr by its `key`.
:param str key: the name of object attr.
:return: attr of object that name is `key`.
:rtype: attr of object.
"""
if key in self:
return self[key]
else:
raise AttributeError(key)
def __setattr__(self, key, value):
"""Get a object attr `key` with `value`.
:param str key: the name of object attr.
:param value: the `value` need to set to target object attr.
:type value: attr of object.
"""
self[key] = value
def __delattr__(self, key):
"""Delete a object attr by its `key`.
:param str key: the name of object attr.
"""
del self[key]
def __deepcopy__(self, memo):
"""After `deepcopy`, return a Config object.
:param dict memo: same to deepcopy `memo` dict.
:return: a deep copyed self Config object.
:rtype: Config object
"""
return Config(deepcopy(dict(self)))
class ConfigSerializable(object):
"""Seriablizable config base class."""
__original__value__ = None
@property
def __allattr__(self):
attrs = filter(
lambda attr: not (
attr.startswith("__") or ismethod(
getattr(
self,
attr)) or isfunction(
getattr(
self,
attr))),
dir(self))
return list(attrs)
def update(self, **kwargs):
for attr in self.__allattr__:
if attr not in kwargs:
continue
setattr(self, attr, kwargs[attr])
def to_json(self):
"""Serialize config to a dictionary."""
attr_dict = {}
for attr in self.__allattr__:
value = getattr(self, attr)
if isinstance(value, type) and isinstance(
value(), ConfigSerializable):
value = value().to_json()
elif isinstance(value, ConfigSerializable):
value = value.to_json()
attr_dict[attr] = value
return Config(deepcopy(attr_dict))
def dict(self):
attr_dict = {}
for attr in self.__allattr__:
value = getattr(self, attr)
if isinstance(value, type) and isinstance(
value(), ConfigSerializable):
value = value().dict()
elif isinstance(value, ConfigSerializable):
value = value.dict()
attr_dict[attr] = value
return attr_dict
def __getitem__(self, item):
return getattr(self, item, None)
def get(self, item, default=""):
return self.__getitem__(item) or default
@classmethod
def from_json(cls, data):
"""Restore config from a dictionary or a file."""
if not data:
return cls
if cls.__name__ == "ConfigSerializable":
return cls
config = Config(deepcopy(data))
for attr in config:
if not hasattr(cls, attr):
setattr(cls, attr, config[attr])
continue
class_value = getattr(cls, attr)
config_value = config[attr]
if isinstance(class_value, ConfigSerializable) and hasattr(
config_value, 'from_json'):
setattr(cls, attr, class_value.from_json(config_value))
else:
setattr(cls, attr, config_value)
return cls
@singleton
[docs]class BaseConfig(ConfigSerializable):
"""The base config"""
[docs] device_category = os.getenv('DEVICE_CATEGORY', 'CPU') # device category
# ML framework backend
[docs] backend_type = os.getenv('BACKEND_TYPE', 'TENSORFLOW')
# local control server
[docs] lc_server = os.getenv("LC_SERVER", "http://127.0.0.1:9100")
# dataset
[docs] original_dataset_url = os.getenv("ORIGINAL_DATASET_URL")
[docs] train_dataset_url = os.getenv("TRAIN_DATASET_URL")
[docs] test_dataset_url = os.getenv("TEST_DATASET_URL")
[docs] data_path_prefix = os.getenv("DATA_PATH_PREFIX", "/home/data")
# k8s crd info
[docs] namespace = os.getenv("NAMESPACE", "")
[docs] worker_name = os.getenv("WORKER_NAME", "")
# the name of JointInferenceService and others Service
[docs] service_name = os.getenv("SERVICE_NAME", "")
# the name of FederatedLearningJob and others Job
[docs] job_name = os.getenv("JOB_NAME", "sedna")
[docs] pretrained_model_url = os.getenv("PRETRAINED_MODEL_URL", "./")
[docs] model_url = os.getenv("MODEL_URL")
[docs] model_name = os.getenv("MODEL_NAME")
[docs] log_level = os.getenv("LOG_LEVEL", "INFO")
[docs] transmitter = os.getenv("TRANSMITTER", "ws")
[docs] agg_data_path = os.getenv("AGG_DATA_PATH", "./")
[docs] s3_endpoint_url = os.getenv("S3_ENDPOINT_URL", "")
[docs] access_key_id = os.getenv("ACCESS_KEY_ID", "")
[docs] secret_access_key = os.getenv("SECRET_ACCESS_KEY", "")
# user parameter
[docs] parameters = os.getenv("PARAMETERS")
def __init__(self):
if self.parameters:
self.parameter = _url2dict(self.parameters)
[docs]class Context:
"""The Context provides the capability of obtaining the context"""
[docs] parameters = os.environ
@classmethod
[docs] def get_parameters(cls, param, default=None):
"""get the value of the key `param` in `PARAMETERS`,
if not exist, the default value is returned"""
value = cls.parameters.get(
param) or cls.parameters.get(str(param).upper())
return value if value else default
@classmethod
[docs] def get_algorithm_from_api(cls, algorithm, **param) -> dict:
"""get the algorithm and parameter from api"""
hard_example_name = cls.get_parameters(f'{algorithm}_NAME')
hem_parameters = cls.get_parameters(f'{algorithm}_PARAMETERS')
if not hard_example_name:
return {}
try:
hem_parameters = json.loads(hem_parameters)
hem_parameters = {
p["key"]: p.get("value", "")
for p in hem_parameters if "key" in p
}
except Exception:
hem_parameters = {}
hem_parameters.update(**param)
hard_example_mining = {
"method": hard_example_name,
"param": hem_parameters
}
return hard_example_mining