Decision trees are widely used machine learning algorithms and can be applied to both classification and regression tasks. These models work by splitting data into subsets based on features this process is known as decision making. Each leaf node provides a prediction and the splits create a tree-like structure. Decision trees are popular because they are easy to interpret and visualize making it easier to understand the decision-making process.
In machine learning, there are various types of decision tree algorithms. In this article, we'll explore these types so that you can choose the most appropriate one for your task.
Types of Decision Tree Algorithms
There are six different decision tree algorithms as shown in diagram are listed below. Each one of has its advantage and limitations. Let's understand them one-by-one:
Decision Tree Algorithmsa 1. ID3 (Iterative Dichotomiser 3)
ID3 is a classic decision tree algorithm commonly used for classification tasks. It works by greedily choosing the feature that maximizes the information gain at each node. It calculates entropy and information gain for each feature and selects the feature with the highest information gain for splitting.
Entropy: It measures impurity in the dataset. Denoted by H(D) for dataset D is calculated using the formula:
H(D) = \Sigma^n _{i=1}\;p_{i}\; log_{2}(p_{i})
Information gain: It quantifies the reduction in entropy after splitting the dataset on a feature:
Information\; Gain = H(D) - \Sigma^V_{v=1} \frac{|D_{v}|}{|D|}H (D_{v})
ID3 recursively splits the dataset using the feature with the highest information gain until all examples in a node belong to the same class or no features remain to split. After the tree is constructed it prune branches that don't significantly improve accuracy to reduce overfitting. But it tends to overfit the training data and cannot directly handle continuous attributes. These issues are addressed by other algorithms like C4.5 and CART.
For its implementation you can refer to the article: Iterative Dichotomiser 3 (ID3) Algorithm From Scratch
2. C4.5
C4.5 uses a modified version of information gain called the gain ratio to reduce the bias towards features with many values. The gain ratio is computed by dividing the information gain by the intrinsic information which measures the amount of data required to describe an attribute’s values:
Gain Ratio = \frac{Split\; gain}{Gain\;\;information}
- It addresses several limitations of ID3 including its inability to handle continuous attributes and its tendency to overfit the training set. It handles continuous attributes by first sorting the attribute values and then selecting the midpoint between adjacent values as a potential split point. The split that maximizes information gain or gain ratio is chosen.
- It can also generate rules from the decision tree by converting each path from the root to a leaf into a rule, which can be used to make predictions on new data.
- This algorithm improves accuracy and reduces overfitting by using gain ratio and post-pruning. While effective for both discrete and continuous attributes, C4.5 may still struggle with noisy data and large feature sets.
C4.5 has limitations:
- It can be prone to overfitting especially in noisy datasets even if uses pruning techniques.
- Performance may degrade when dealing with datasets that have many features.
3. CART (Classification and Regression Trees)
CART is a widely used decision tree algorithm that is used for classification and regression tasks.
- For classification CART splits data based on the Gini impurity which measures the likelihood of incorrectly classified randomly selected data. The feature that minimizes the Gini impurity is selected for splitting at each node. The formula is:
Gini(D) = 1 - \Sigma^n _{i=1}\; p^2_{i}
where p_i​ is the probability of class i in dataset D.
- For regression CART builds regression trees by minimizing the variance of the target variable within each subset. The split that reduces the variance the most is chosen.
To reduce overfitting CART uses cost-complexity pruning after tree construction. This method involves minimizing a cost function that combines the impurity and tree complexity by adding a complexity parameter to the impurity measure. It builds binary trees where each internal node has exactly two child nodes simplifying the splitting process and making the resulting tree easier to interpret.
For its implementation you can refer to the article: Implementing CART (Classification And Regression Tree) in Python
4. CHAID (Chi-Square Automatic Interaction Detection)
CHAID uses chi-square tests to determine the best splits especially for categorical variables. It recursively divides the data into smaller subsets until each subset contains only data points of the same class or within a specified range of values. It chooses feature for splitting with highest chi-squared statistic indicating the strong relationship with the target variable. This approach is particularly useful for analyzing large datasets with many categorical features. The Chi-Square Statistic formula:
X^2 = \Sigma \frac{(O_{i} - E_{i})^2}{E_{i}}
Where:
- O_i represents the observed frequency
- E_i represents the expected frequency in each category.
It compares the observed distribution to the expected distribution to determine if there is a significant difference. CHAID can be applied to both classification and regression tasks. In classification algorithm assigns a class label to new data points by following the tree from the root to a leaf node with leaf node’s class label being assigned to data. In regression it predicts the target variable by averaging the values at the leaf node.
5. MARS (Multivariate Adaptive Regression Splines)
MARS is an extension of the CART algorithm. It uses splines to model non-linear relationships between variables. It constructs a piecewise linear model where the relationship between the input and output variables is linear but with variable slopes at different points, known as knots. It automatically selects and positions these knots based on the data distribution and the need to capture non-linearities.
Basis Functions: Each basis function in MARS is a simple linear function defined over a range of the predictor variable. The function is described as:
h(x) = \Bigg \{ x - t \;\; if \; x>t \\ t-x \;\; if x \leq t \Bigg\}
Where
- x is a predictor variable
- t is the knot function.
Knot Function: The knots are the points where the piecewise linear functions connect. MARS places these knots to best represent the data's non-linear structure.
MARS begins by constructing a model with a single piece and then applies forward stepwise selection to iteratively add pieces that reduce the error. The process continues until the model reaches a desired complexity. It is particularly effective for modeling complex relationships in data and is widely used in regression tasks.
6. Conditional Inference Trees
Conditional Inference Trees uses statistical tests to choose splits based on the relationship between features and the target variable. It use permutation tests to select the feature that best splits the data while minimizing bias.
The algorithm follows a recursive approach. At each node it evaluates the statistical significance of potential splits using tests like the Chi-squared test for categorical features and the F-test for continuous features. The feature with the strongest relationship to the target is selected for the split. The process continues until the data cannot be further split or meets predefined stopping criteria.
Summarizing all Algorithms
Here’s a short summary of all decision tree algorithms we have learned so far:
- ID3: Uses information gain to split data and works well for classification but it is prone to overfitting and struggles with continuous data.
- C4.5: Advance version of ID3 with gain ratio for both discrete and continuous data but struggle with noisy data.
- CART: Used for both classification and regression task. It minimizes Gini impurity for classification and MSE for regression with pruning technique to prevent overfitting.
- CHAID: Uses chi-square tests for splitting and is effective for large categorical datasets but not for continuous data.
- MARS: Extended version of CART using piecewise linear functions to model non-linear relationships but it is computationally expensive.
- Conditional Inference Trees: Uses statistical hypothesis testing for unbiased splits and handles various data types but it is slower than others.
Decision tree algorithms provide approach for both classification and regression tasks. While each algorithm brings its own strengths understanding its mechanism is important for selecting the best algorithm for a given problem for better accuracy of model.
Similar Reads
Decision Tree
Decision tree is a simple diagram that shows different choices and their possible results helping you make decisions easily. This article is all about what decision trees are, how they work, their advantages and disadvantages and their applications.Understanding Decision TreeA decision tree is a gra
5 min read
Tree Based Machine Learning Algorithms
Tree-based algorithms are a fundamental component of machine learning, offering intuitive decision-making processes akin to human reasoning. These algorithms construct decision trees, where each branch represents a decision based on features, ultimately leading to a prediction or classification. By
14 min read
Learn-One-Rule Algorithm
Prerequisite: Rule-Based Classifier Learn-One-Rule: This method is used in the sequential learning algorithm for learning the rules. It returns a single rule that covers at least some examples (as shown in Fig 1). However, what makes it really powerful is its ability to create relations among the at
3 min read
Root Finding Algorithm
Root-finding algorithms are tools used in mathematics and computer science to locate the solutions, or "roots," of equations. These algorithms help us find solutions to equations where the function equals zero. For example, if we have an equation like f(x) = 0, a root-finding algorithm will help us
8 min read
Algorithm definition and meaning
Algorithm can be defined as - A set of finite rules or instructions to be followed in calculations or other problem-solving operations. An algorithm can be expressed using pseudocode or flowcharts. Properties of Algorithm: An algorithm has several important properties that include: Input: An algorit
3 min read
Search Algorithms in AI
Artificial Intelligence is the study of building agents that act rationally. Most of the time, these agents perform some kind of search algorithm in the background in order to achieve their tasks. A search problem consists of: A State Space. Set of all possible states where you can be.A Start State.
10 min read
Machine Learning Algorithms
Machine learning algorithms are essentially sets of instructions that allow computers to learn from data, make predictions, and improve their performance over time without being explicitly programmed. Machine learning algorithms are broadly categorized into three types: Supervised Learning: Algorith
8 min read
Nlp Algorithms
Natural Language Processing (NLP) is a branch of artificial intelligence (AI) that focuses on developing algorithms to understand and process human language. These algorithms enable computers to comprehend, analyze, and generate human language, allowing for more natural interactions between humans a
5 min read
Apriori Algorithm
Apriori Algorithm is a basic method used in data analysis to find groups of items that often appear together in large sets of data. It helps to discover useful patterns or rules about how items are related which is particularly valuable in market basket analysis. Like in a grocery store if many cust
6 min read
Decision Theory in AI
Decision theory is a foundational concept in Artificial Intelligence (AI), enabling machines to make rational and informed decisions based on available data. It combines principles from mathematics, statistics, economics, and psychology to model and improve decision-making processes. In AI, decision
8 min read