A decision tree (ID3) was introduced in a previous article. A brief review: ID3 selects the best feature to segment the data every time, and the judgment principle of the best feature is realized through information gain. After the data is sharded according to a certain feature, the feature will not be used in the future data set sharding, so there is the problem of too fast sharding. ID3 algorithm is not yet able to deal with continuity features. Here are some other algorithms:

CART classification regression tree

CART, short for Classification And Regerssion Trees, can handle both Classification And regression tasks.

CART tree construction algorithm is similar to the construction method of ID3 decision tree, which directly presents the construction process of CART tree. First, similar to ID3, the data structure of dictionary tree is adopted, which contains the following 4 elements:

  • Characteristics to be sliced
  • Eigenvalues to be sliced
  • The right subtree. It can also be a single value when no longer needed to be sharded
  • Left subtree, like right subtree.

The process is as follows:

  1. Search for the most suitable segmentation feature
  2. If the data set cannot be split, the data set is used as a leaf node.
  3. The data set is divided into two parts
  4. Repeat steps 1, 2,3 for the split dataset 1 to create the right subtree.
  5. Repeat steps 1, 2, 3 for the split dataset 2 to create the left subtree.

Obvious recursive algorithm.

Split the data set by data filtering and return two subsets.

def splitDatas(rows, value, column):
    # splitDatas by value, column
    # return 2 part (list1, list2)

    list1 = []
    list2 = []

    if isinstance(value, int) or isinstance(value, float) :for row in rows:
            if row[column] >= value:
                list1.append(row)
            else:
                list2.append(row)
    else:
        for row in rows:
            if row[column] == value:
                list1.append(row)
            else:
                list2.append(row)
    return list1, list2
Copy the code

Dividing data points

Creating a binary decision tree is essentially a recursive partitioning of the input space.

The code is as follows:

# gini()
def gini(rows):
    Calculate gini (Calculate gini)Length = len(rows) Results = calculateDiffCount(rows) IMP = 0.0for i in results:
        imp += results[i] / length * results[i] / length
    return 1 - imp
Copy the code

Build a tree

def buildDecisionTree(rows, evaluationFunction=gini):
    # Recursively build the decision tree, stop regression when gain=0
    # build decision tree bu recursive function
    # stop recursive function when gain = 0
    # return treeCurrentGain = evaluationFunction(rows) column_lenght = len(rows[0]) rows_length = len(rows) best_gain = 0.0 best_value =  None best_set = None# choose the best gain
    for col in range(column_lenght - 1):
        col_value_set = set([x[col] for x in rows])
        for value in col_value_set:
            list1, list2 = splitDatas(rows, value, col)
            p = len(list1) / rows_length
            gain = currentGain - p * evaluationFunction(list1) - (1 - p) * evaluationFunction(list2)
            if gain > best_gain:
                best_gain = gain
                best_value = (col, value)
                best_set = (list1, list2)
    dcY = {'impurity': '%.3f' % currentGain, 'sample': '%d' % rows_length}
    #
    # stop or not stop

    if best_gain > 0:
        trueBranch = buildDecisionTree(best_set[0], evaluationFunction)
        falseBranch = buildDecisionTree(best_set[1], evaluationFunction)
        return Tree(col=best_value[0], value = best_value[1], trueBranch = trueBranch, falseBranch=falseBranch, summary=dcY)
    else:
        return Tree(results=calculateDiffCount(rows), summary=dcY, data=rows)
Copy the code

The function of the above code is to first find the best location for data set segmentation and split the data set. Then recursively build the entire tree of the image above.

pruning

In the learning of decision tree, sometimes there are too many branches in the decision tree, so it is necessary to remove some branches to reduce over-fitting. The process of avoiding overfitting through the complexity of the decision tree is called pruning. Post pruning requires the generation of a complete decision tree from the training set and then the bottom-up examination of non-leaf nodes. The test set is used to determine whether to replace the subtree corresponding to the node with a leaf node. The code is as follows:

def prune(tree, miniGain, evaluationFunction=gini):
    When a gain < mini gain, merge the trueBranch and falseBranch
    if tree.trueBranch.results == None:
        prune(tree.trueBranch, miniGain, evaluationFunction)
    if tree.falseBranch.results == None:
        prune(tree.falseBranch, miniGain, evaluationFunction)

    iftree.trueBranch.results ! = None and tree.falseBranch.results ! = None: len1 = len(tree.trueBranch.data) len2 = len(tree.falseBranch.data) len3 = len(tree.trueBranch.data + tree.falseBranch.data) p =float(len1) / (len1 + len2)

        gain = evaluationFunction(tree.trueBranch.data + tree.falseBranch.data) - p * evaluationFunction(tree.trueBranch.data) - (1 - p) * evaluationFunction(tree.falseBranch.data)

        if gain < miniGain:
            tree.data = tree.trueBranch.data + tree.falseBranch.data
            tree.results = calculateDiffCount(tree.data)
            tree.trueBranch = None
            tree.falseBranch = None
Copy the code

When the gain of the node is less than the given mini gain, the two nodes are combined.

Finally, the code to build the tree:

if __name__ == '__main__': dataSet = loadCSV() decisionTree = buildDecisionTree(dataSet, evaluationFunction=gini) prune(decisionTree, R = false (test_data, decisionTree)print(r)
Copy the code

We can print decisionTree to build a decisionTree like the one shown above. And then I’m going to test a bunch of data and see if I can get the classification right.

See github: CART for the complete code and data set

Conclusion:

  • CART decision tree
  • Split data set
  • Recursively create a tree

A python source implementation of the CART Decision Tree