Flux.jl 中的交叉熵损失 - Julia

Cross Entropy Loss in Flux.jl - Julia

我想设置我的模型以在 Flux.jl 中使用交叉熵损失。我该怎么做以及我将损失函数本身传递到哪里?

Flux.jl 通过 Flux.Losses 模块提供了一个具有许多常见损失函数的内置模块,您可以通过 using Flux.Losses 访问该模块。模块中内置了交叉熵损失函数,使用方法如下:

julia> y_label = Flux.onehotbatch([0, 1, 2, 1, 0], 0:2)
3×5 Flux.OneHotArray{3,2,Vector{UInt32}}:
 1  0  0  0  1
 0  1  0  1  0
 0  0  1  0  0

julia> y_model = softmax(reshape(-7:7, 3, 5) .* 1f0)
3×5 Matrix{Float32}:
 0.0900306  0.0900306  0.0900306  0.0900306  0.0900306
 0.244728   0.244728   0.244728   0.244728   0.244728
 0.665241   0.665241   0.665241   0.665241   0.665241

julia> sum(y_model; dims=1)
1×5 Matrix{Float32}:
 1.0  1.0  1.0  1.0  1.0

julia> Flux.crossentropy(y_model, y_label)
1.6076053f0

您可以在此处找到完整的 Flux.crossentropy 函数定义:https://fluxml.ai/Flux.jl/stable/models/losses/#Flux.Losses.crossentropy

定义损失函数后,您可以将其传递给内置训练函数:Flux.train!(loss, params(model), data, opt)或在您的自定义训练循环中使用它。