import os
import shutil
from chainer.serializers import npz
from chainer.training import extension
from chainer.training import trigger as trigger_module
from chainer.training.triggers import IntervalTrigger
import six
from chainerui.utils.command_item import CommandItem
from chainerui.utils.commands_state import CommandsState
from chainerui.utils.tempdir import tempdir
def take_snapshot(trainer, body):
filename = 'snapshot_iter_{.updater.iteration}'.format(trainer)
out_path = trainer.out
# same with SimpleWriter, supported from Chainer v6
with tempdir(prefix='snapshot', dir=out_path) as tempd:
path = os.path.join(tempd, filename)
npz.save_npz(path, trainer)
shutil.move(path, os.path.join(out_path, filename))
def adjust_hyperparams(trainer, body):
optimizer = trainer.updater.get_optimizer('main')
optimizer_name = optimizer.__class__.__name__
if optimizer_name != body.get('optimizer', None):
# invalid optimizer was specified
return None
hyperparam = getattr(optimizer, 'hyperparam', None)
if hyperparam is None:
return None
request_hyperparam = body.get('hyperparam', {})
hyperparam_dict = hyperparam.get_dict()
for key, value in six.iteritems(request_hyperparam):
if (key not in hyperparam_dict) or (value is None):
continue # pragma: no cover
setattr(hyperparam, key, value)
return {
'optimizer': optimizer_name,
'hyperparam': hyperparam.get_dict()
}
# NOTE: Chainer has a plan to add that trigger can detect training
# length (PR#4079). After merge it, the below two trigger class
# can be merge to one trigger class.
class _CommandIntervalTrigger(IntervalTrigger):
def __init__(self, trigger):
super(_CommandIntervalTrigger, self).__setattr__(
'_trigger', trigger)
super(_CommandIntervalTrigger, self).__setattr__(
'_loop_stop', False)
def __call__(self, trainer):
if self._trigger(trainer):
return True
return self._loop_stop
def stop(self):
super(_CommandIntervalTrigger, self).__setattr__(
'_loop_stop', True)
def __getattr__(self, attr_name):
return getattr(self._trigger, attr_name)
def __setattr__(self, attr_name, value):
setattr(self._trigger, attr_name, value)
class _CommandTrigger(object):
def __init__(self, trigger):
super(_CommandTrigger, self).__setattr__(
'_trigger', trigger)
super(_CommandTrigger, self).__setattr__(
'_loop_stop', False)
def __call__(self, trainer):
if self._trigger(trainer):
return True
return self._loop_stop
def stop(self):
super(_CommandTrigger, self).__setattr__(
'_loop_stop', True)
def __getattr__(self, attr_name):
return getattr(self._trigger, attr_name)
def __setattr__(self, attr_name, value):
setattr(self._trigger, attr_name, value)
def _stop_training(trainer, body):
assert isinstance(trainer.stop_trigger, _CommandTrigger) or \
isinstance(trainer.stop_trigger, _CommandIntervalTrigger)
trainer.stop_trigger.stop()
return None
[docs]class CommandsExtension(extension.Extension):
"""Trainer extension to enable command operation by output file
This extension monitors a file for command created on `trainer.out` path,
and execute each command when append the file.
"""
priority = extension.PRIORITY_READER
default_receivers = {
'take_snapshot': take_snapshot,
'adjust_hyperparams': adjust_hyperparams,
'stop': _stop_training,
}
def __init__(self, trigger=(1, 'iteration'), receivers={},
file_name='commands'):
self._trigger = trigger_module.get_trigger(trigger)
self._file_name = file_name
self._receivers = self.default_receivers.copy()
self._receivers.update(receivers)
self._out = ''
def initialize(self, trainer):
self._out = trainer.out
CommandItem.remove_commands_file(trainer.out)
CommandsState.run(trainer.out)
if isinstance(trainer.stop_trigger, IntervalTrigger):
trainer.stop_trigger = _CommandIntervalTrigger(
trainer.stop_trigger)
else:
trainer.stop_trigger = _CommandTrigger(trainer.stop_trigger)
def __call__(self, trainer):
if not self._trigger(trainer):
return
commands = CommandItem.load_commands(trainer.out)
is_updated = False
for command in commands:
if not command.should_execute(trainer):
continue
body, status = self._execute_command(
trainer, command.name, command.request)
command.set_response(trainer, status, body)
is_updated = True
if is_updated:
CommandItem.dump_commands(commands, trainer.out)
def finalize(self):
if self._out != '':
CommandsState.stop(self._out)
def add_receiver(self, command_name, function):
if command_name is None:
raise ValueError('command name is not given')
if not callable(function):
raise ValueError('receiver is not callable')
self._receivers[command_name] = function
def _execute_command(self, trainer, command_name, request):
receiver = self._receivers.get(command_name, None)
if receiver is None:
message = '%s command is not available or supported' % command_name
response_body = {'message': message}
response_status = CommandItem.RESPONSE_FAILURE
else:
try:
response_body = receiver(trainer, request.get('body', None))
response_status = CommandItem.RESPONSE_SUCCESS
except Exception as e:
print('caught an exception from receiver:', e.args)
response_body = None
response_status = CommandItem.RESPONSE_FAILURE
return response_body, response_status