Source code for chainerui.contrib.ignite.handler

from ignite.contrib.handlers.base_logger import BaseLogger
from ignite.contrib.handlers.base_logger import BaseOutputHandler
from ignite.engine import Events

import chainerui
from chainerui.utils.log_report import _get_time


[docs]class OutputHandler(BaseOutputHandler): """Handler for ChainerUI logger A helper for handler to log engine's output, specialized for ChainerUI. This handler sets 'epoch', 'iteration' and 'elapsed_time' automatically, these are default x axis to show. .. code-block:: python from chainerui.contrib.ignite.handler import OutputHandler train_handler = OutputHandler( 'train', output_transform=lambda o: {'param': o}) val_handler = OutputHandler('val', metric_names='all') Args: tag (str): use for a prefix of parameter name, will show as {tag}/{param} metric_names (str or list): keys names of ``list`` to monitor. set ``'all'`` to get all metrics monitored by the engine. output_transform (func): if set, use this function to convert output from ``engine.state.output`` another_engine (``ignite.engine.Engine``): if set, use for getting global step. This option is deprecated from 0.3. global_step_transform (func): if set, use this to get global step. interval_step (int): interval step for posting metrics to ChainerUI server. """ def __init__( self, tag, metric_names=None, output_transform=None, another_engine=None, global_step_transform=None, interval_step=-1): super(OutputHandler, self).__init__( tag, metric_names, output_transform, another_engine, global_step_transform) self.interval = interval_step def __call__(self, engine, logger, event_name): if not isinstance(logger, ChainerUILogger): raise RuntimeError( '`chainerui.contrib.ignite.handler.OutputHandler` works only ' 'with ChainerUILogger, but set {}'.format(type(logger))) metrics = self._setup_output_metrics(engine) if not metrics: return iteration = self.global_step_transform( engine, Events.ITERATION_COMPLETED) epoch = self.global_step_transform(engine, Events.EPOCH_COMPLETED) # convert metrics name rendered_metrics = {} for k, v in metrics.items(): rendered_metrics['{}/{}'.format(self.tag, k)] = v rendered_metrics['iteration'] = iteration rendered_metrics['epoch'] = epoch if 'elapsed_time' not in rendered_metrics: rendered_metrics['elapsed_time'] = _get_time() - logger.start_at if self.interval <= 1: logger.post_log([rendered_metrics]) return # enable interval, cache metrics logger.cache.setdefault(self.tag, []).append(rendered_metrics) # select appropriate even set by handler init global_count = self.global_step_transform(engine, event_name) if global_count % self.interval == 0: logger.post_log(logger.cache[self.tag]) logger.cache[self.tag].clear()
[docs]class ChainerUILogger(BaseLogger): """Logger handler for ChainerUI A helper logger to post metrics to ChainerUI server. Attached handlers are expected using ``chainerui.contrib.ignite.handler.OutputHandler``. A tag name of handler must be unique when attach several handlers. .. code-block:: python from chainerui.contrib.ignite.handler import OutputHandler train_handler = OutputHandler(...) val_handler = OutputHandler(...) from ignite.engine.engine import Engine train_engine = Engine(...) eval_engine = Engine(...) from chainerui.contrib.ignite.handler import ChainerUILogger logger = ChainerUILogger() logger.attach( train_engine, log_handler=train_handler, event_name=Events.EPOCH_COMPLETED) logger.attach( eval_engine, log_handler=val_handler, event_name=Event.EPOCH_COMPLETED) """ def __init__(self): web_client = chainerui.client.client._client if not web_client: raise RuntimeError( 'fail to setup ChainerUI logger, please check ' '`chainerui.init()` is success to execute.') self.attached_tags = set() self.client = web_client self.cache = {} # key is tag name, value is list of metrics self.start_at = _get_time() def attach(self, engine, log_handler, event_name): if log_handler.tag in self.attached_tags: raise RuntimeError('attached handlers must have unique tag name') self.attached_tags.add(log_handler.tag) super(ChainerUILogger, self).attach(engine, log_handler, event_name) def post_log(self, metrics): self.client.post_log(metrics) def __enter__(self): # overwrite start timestamp when engine is used with 'with' block self.start_at = _get_time() super(ChainerUILogger, self).__enter__() def close(self): for k, v in self.cache.items(): if v: self.post_log(v)