ArgumentError: invalid index: false of type Bool in Julia

ArgumentError: invalid index: false of type Bool in Julia

我有一段代码是用julia写的,目的是实现ID3算法。但它有一些错误。我不知道如何解决它。希望得到您的帮助。

function entropy(counts, n_samples)
    filter!(c -> c != 0, counts) # remove '0' elements in counts
    prob = counts / float(n_samples)
    #entropy: H(S) = -sum(prob_i*log2(prob_i))
    sum = 0
    for p in prob
        sum += -(p * log2(p))
    end
    return sum
end

function entropy_of_one_division(division)
    n_samples = len(division)
    n_classes = set(division)

    #count samples in each class then store it to list counts
    counts = []
    for _class in n_classes
        count = 0 
        for sample in division
            if sample == _class 
                count += 1
                counts.append(count)
            end
        end
    end
    return entropy(counts, n_samples), n_samples
end

function get_entropy(y_predict, y)
    n = length(y)
    entropy_true, n_true = entropy_of_one_division(y[y_predict]) # left hand side entropy
    entropy_false, n_false = entropy_of_one_division(y[~y_predict]) # right hand side entropy
    s = (n_true / n) * entropy_true + (n_false / n) * entropy_false
    return s
end
struct DecisionTreeClassifier
    tree
    depth
end

function fit(dt::DecisionTreeClassifier, X, y, node = Dict(), depth = 0)
    if all(y .== y[1])
        return Dict("val" => y[1])
    else
        col_idx, cutoff, entropy = find_best_split_of_all(X, y)
        y_left = y[X[:, col_idx]<cutoff]
        y_right = y[X[:, col_idx]>=cutoff]
        node = Dict("index_col" => col_idx, "cutoff" => cutoff, "val" => Statistics.mean(y))
        node["left"] = fit(X[X[:, col_idx]<cutoff], y_left, Dict(), depth + 1)
        node["right"] = fit(X[X[:, col_idx]>=cutoff], y_right, Dict(), depth + 1)
        dt.depth += 1
        dt.tree = node
        return node
    end
end

function find_best_split_of_all(X, y)
    col_idx = Nothing
    min_entropy = 1
    cutoff = Nothing
    for (i, col_data) in enumerate(X')
        entropy, cur_cutoff = find_best_split(col_data, y)
        if entropy == 0
            return i, cur_cutoff, entropy
        elseif entropy <= min_entropy
            min_entropy = entropy
            col_idx = i
            cutoff = cur_cutoff
        end
    end
    return col_idx, cutoff, min_entropy
end

function find_best_split(col_data, y)
    min_entropy = 10

    for value in Set(col_data)
        y_predict = col_data < value
        my_entropy = get_entropy(y_predict, y)

        if my_entropy < min_entropy
            min_entropy = my_entropy
            cutoff = value
        end
    end
    print(cutoff)

    return min_entropy, cutoff
end

model = DecisionTreeClassifier(Nothing, 0)

tree = fit(model, X, y)
pred = model.predict(X_train)
print("Accuracy of your decision tree model on training data:", accuracy_score(y_train, pred))
pred = model.predict(X_test)
print("Accuracy of your decision tree model:", accuracy_score(y_test, pred))

错误:

ERROR: LoadError: ArgumentError: invalid index: false of type Bool
Stacktrace:
  [1] to_index(i::Bool)
    @ Base .\indices.jl:293
  [2] to_index(A::Vector{String15}, i::Bool)
    @ Base .\indices.jl:277
  [3] to_indices
    @ .\indices.jl:333 [inlined]
  [4] to_indices
    @ .\indices.jl:330 [inlined]
  [5] getindex
    @ .\abstractarray.jl:1221 [inlined]
  [6] get_entropy(y_predict::Bool, y::Vector{String15})
    @ Main d:\test.jl:53
  [7] find_best_split(col_data::Float64, y::Vector{String15})
    @ Main d:\test.jl:102
  [8] find_best_split_of_all
    @ d:\test.jl:85 [inlined]
  [9] fit(dt::DecisionTreeClassifier, X::Matrix{Float64}, y::Vector{String15}, node::Dict{Any, Any}, depth::Int64)
    @ Main d:\test.jl:68
 [10] fit(dt::DecisionTreeClassifier, X::Matrix{Float64}, y::Vector{String15})
    @ Main d:\test.jl:65
 [11] top-level scope
    @ d:\test.jl:138

我有一段代码是用julia写的,目的是实现ID3算法。但它有一些错误。我不知道如何解决它。希望得到您的帮助。

你似乎包含了很多不必要的代码,但不是重要的部分。

错误发生在 get_entropy 中,您没有将其包括在内,可能是因为您正在尝试使用 Bool.

对数组进行索引
y_predict = col_data < value
my_entropy = get_entropy(y_predict, y)

y_predict 在这里是布尔值 (true/false) 但为了索引,您需要一个常规整数。