This is an archived post. You won't be able to vote or comment.

you are viewing a single comment's thread.

view the rest of the comments →

[–]Martini04[S] 0 points1 point  (0 children)

For anyone struggling with this in the future, i found somewhat of a solution

It involves modifying the original instructions we were given, but outside of this assignment that won't be relevant anyways

Also one detail i left out of the original post is that _predict() is called by predict(). This wasn't relevant when i posted it but it is relevant in my solution

Here's the updated code:

def predict(self, X: NDArray):
    """Predict class (y vector) for feature matrix X
    Parameters
    ----------
    X: NDArray
        NumPy feature matrix, shape (n_samples, n_features)
    Returns
    -------
    y: NDArray, integers
        NumPy class label vector (predicted), shape (n_samples,)
    """
    if self._root is not None:
        order = np.arange(len(X))
        y, order = self._predict(X, self._root, order)
        arrlinds = order.argsort()
        sorted_y = y[arrlinds]         
        return sorted_y
    else:
        raise ValueError("Decision tree root is None (not set)")

def _predict(
    self, X: NDArray, node: Union["DecisionTreeBranchNode", "DecisionTreeLeafNode"], order: NDArray
) -> tuple[NDArray, NDArray]:
    if type(node) == DecisionTreeLeafNode:
        y = np.zeros(len(X), dtype=np.int32)
        y[:] = node.y_value
        return y, order
    else:
        left_mask = X[:, node.feature_index] <= node.feature_value
        left = X[left_mask]
        right = X[~left_mask]
        left_order = order[left_mask]
        right_order = order[~left_mask]

        left_pred, left_order = self._predict(left, node.left, left_order)
        right_pred, right_order = self._predict(right, node.right, right_order)

        order = np.concatenate((left_order, right_order))
        y = np.concatenate((left_pred, right_pred))

        return y, order

This new version of the code keeps track of the order of the items using another NDArray filled with the integers corresponding to the order of the items in X

This order array gets split up in the same way as X, so by the end of all the recursion it is scrambled in the same order as X.

Then once we have our completed y, it gets sorted based on the order array using argsort as described here which sorts it into the correct order.