diff --git a/fitlog/__init__.py b/fitlog/__init__.py index 1eb466581cfd312b3a43aa310a5b5a0730921237..f267955ba0df95c890d66fb918f2913e247f5cbf 100755 --- a/fitlog/__init__.py +++ b/fitlog/__init__.py @@ -24,7 +24,8 @@ __all__ = [ "get_log_id", "get_commit_id", "get_fit_id", - "create_log_folder" + "create_log_folder", + "FitlogConfig" ] import os @@ -32,6 +33,7 @@ import os os.environ['GIT_PYTHON_REFRESH'] = "quiet" from fitlog.fastlog import logger as _logger +from fitlog.fastlog.logger import FitlogConfig from fitlog.fastgit import Committer, committer as _committer from typing import Union import argparse diff --git a/fitlog/fastlog/logger.py b/fitlog/fastlog/logger.py index 20ed79de6dea2664b63cf4838fdf074390df13f7..f858110e09855ad7c2ec8f73f348a7ef18881be2 100755 --- a/fitlog/fastlog/logger.py +++ b/fitlog/fastlog/logger.py @@ -18,6 +18,22 @@ import warnings import numpy as np import numbers +class FitlogConfig: + """ + 用于add_hyper函数的基类。 + 继承后无需实例化直接传入add_hyper。 + """ + pass + +def _get_config_args(conf:FitlogConfig): + """ + 读取FitlogConfig内的超参。 + """ + config_dict = { + k:v for k,v in vars(conf).items() if not k.startswith("_") + } + return config_dict + def _check_debug(func): """ @@ -460,6 +476,8 @@ class Logger: _check_dict_value(value) elif isinstance(value, ConfigParser): value = _convert_configparser_to_dict(value) # no need to check + elif issubclass(value, FitlogConfig): + value = _get_config_args(value) else: try: import dataclasses