如何使神经网络训练图表具有对数纵轴?

How to make neural network training charts have a logarithmic vertical axis?

使用MATLAB的NN训练工具(trainNetwork)时,我们得到的图表有一个线性纵轴,如下图:

这个图表应该提供一些关于训练进度的图形反馈,它 可能 用于分类问题(其中 y-axis 代表“准确性(%)"),但在回归问题中,随着训练的进行,RMSE 值可能具有截然不同的数量级 - 使得初始下降后的所有内容都无法区分并且非常无用。

我想做的是将垂直轴转换为对数,得到以下结果:

(我不介意一些图形元素在过程中四处移动或丢失,因为曲线对我来说很重要。)

我现在的做法是暂停训练过程,然后手动 运行

set(findall(findall(0,'type','figure'),'type','Axes',...
  'Tag','NNET_CNN_TRAININGPLOT_AXESVIEW_AXES_REGRESSION_RMSE'),'YScale','log');

(或其某些变体,取决于开放数字等)。

我正在寻找一种无需用户干预即可更改比例的方法,并尽可能在训练开始时这样做。此外,如果我可以选择要重新缩放的图表(RMSE and/or 损失),那就太好了。

我正在使用 R2018a。


生成这样一个图形所需的最少代码(基于标题为“Train Network for Image Regression”的 MATLAB 文档):

[XTrain,~,YTrain] = digitTrain4DArrayData;

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'Verbose',false, ...
    'MaxEpochs',5, ...
    'Plots','training-progress');

net = trainNetwork(XTrain,YTrain,layers,options);

我们可以利用训练选项中可用的自定义 'OutputFcn' 机制,并在那里指定一个进行重新缩放的函数。用户可以通过变量 whichAx.

控制重新缩放哪些轴

完整代码:

function net = q51762507()
[XTrain,~,YTrain] = digitTrain4DArrayData;

layers = [ ...
    imageInputLayer([28 28 1])
    convolution2dLayer(12,25)
    reluLayer
    fullyConnectedLayer(1)
    regressionLayer];

whichAx = [false, true]; % [bottom, top]

options = trainingOptions('sgdm', ...
    'InitialLearnRate',0.001, ...
    'Verbose',false, ...
    'MaxEpochs',5, ...
    'Plots','training-progress',...
    'OutputFcn', @(x)makeLogVertAx(x,whichAx) );

net = trainNetwork(XTrain,YTrain,layers,options);

function stop = makeLogVertAx(state, whichAx)
stop = false; % The function has to return a value.
% Only do this once, following the 1st iteration
if state.Iteration == 1
  % Get handles to "Training Progress" figures:
  hF  = findall(0,'type','figure','Tag','NNET_CNN_TRAININGPLOT_FIGURE');
  % Assume the latest figure (first result) is the one we want, and get its axes:
  hAx = findall(hF(1),'type','Axes');
  % Remove all irrelevant entries (identified by having an empty "Tag", R2018a)
  hAx = hAx(~cellfun(@isempty,{hAx.Tag}));
  set(hAx(whichAx),'YScale','log');
end