Causal Inference Model For Breast Cancer

DANIEL ZELALEM Zewdie
9 min readAug 28, 2021

Introduction

Machine learning algorithms tend to focus on making the most accurate predictions or classifications rather than providing insights on the cause-and-effect relationships of features. But unlocking the cause and effect relationship can be important in decision making, especially in the sectors of health. It would also help to build more accurate models that exploits not only the data but also the cause and effect structure of the data. Casual inference is to infer what causes what. Casual inference enables us to answer questions like what causes what or what would be the effect if we improve some features ? This type questions can not be answered by machine learning algorithms only. we may calculate correlation or feature importance from machine leaning models, but machine learning algorithms captures association or correlation and not causation.

Correlation And Causation: is a statistical measure (expressed as a number) that describes the size and direction of a relationship between two or more variables; a correlation indicates association between variables. A correlation between variables, however, does not automatically mean that the change in one variable is the cause of the change in the values of the other variable.

In this blog, we applied a casual graph model on a breast cancer dataset to learn the cause and effect structure behind the diagnosis of breast cancer.

Data

The dataset we are using is the Breast Cancer Wisconsin (Diagnostic) Data Set which can be found on kaggle or UCI Machine Learning Repository. It was created by Dr. William H. Wolberg, W. Nick Street and Prof. Olvi L. Mangasarian. The dataset contains 569 samples of malignant and benign tumor cells. The first two columns in the dataset store the unique ID numbers of the samples and the corresponding diagnosis (M=malignant, B=benign), respectively. The columns 3–32 contain 30 real-valued features that have been computed from digitized images of a fine needle aspirate (FNA) of a breast mass. These features describe characteristics of the cell nuclei
present in the image and can be used to build a model to predict whether a tumor is benign or malignant. These 30 features are derived from the following ten real-valued features that are computed for each cell
nucleus. For each feature there is mean, se(standard error), and worst

eg. radius_mean, radius_worst, radius_se,

  1. radius (mean of distances from center to points on the perimeter)
  2. texture (standard deviation of gray-scale values)
  3. perimeter
  4. area
  5. smoothness (local variation in radius lengths)
  6. compactness (perimeter² / area — 1.0)
  7. concavity (severity of concave portions of the contour)
  8. concave points (number of concave portions of the contour)
  9. symmetry
  10. fractal dimension (“coastline approximation” — 1)

Date Exploration

Overview

The first five rows

In our dataset, we have 569 row and 33 columns. The 30 columns are derived from the 10 features that are listed above. The remaining 3 columns are ‘id’, ‘diagnosis’ and ‘unnamed’. Our target column is diagnosis column.

The diagnosis column has two unique values:

M=malignant

B=benign

The number of Benign diagnosis is 357

The number of Malignant diagnosis is 212

All of our columns in the dataset is numeric expect our target variable.

Our dataset don’t have any missing values, except for the column “nnamed: 32”, which has no value at all. We dropped both Unnamed and ID columns.

Pair scatter plot b/n mean features

Correlation

Let’s see the correlation b/n our numeric features.

And now let’s see the correlation of our input features with diagnosis.

Correlation b/n diagnosis and the other features

As you can see from the above, the top features that are correlated with diagnosis are the following:

[perimeter_worst’, ‘area_worst’, ‘radius_worst’, ‘concave points_worst’, ‘concave points_mean’, ‘perimeter_mean’, ‘area_mean’, ‘radius_mean’, ‘area_se’]

And least correlated columns with diagnosis are the following:

[‘fractal_dimension_worst’, ‘fractal_dimension_se’, ‘texture_se’, ‘fractal_dimension_mean’, ‘symmetry_se’, ‘smoothness_se’]

The above graph shows the distribution of each features that are arranged in the increasing order of their correlation with diagnosis (i.e the first row shows the distribution graph of the top 4 correlated features with diagnosis). In the top rows, you can see that the distribution graph for Malignant and Benign are very different. On the other hand when we come to the bottom graphs, the distribution graphs b/n the Malignant and Benign becomes some what similar. This suggests that the top correlated features are important features for the diagnosis classification. And they are highly associated with the diagnosis.

Causal graph modeling using CausalNex

Feature selection

Our input features are 30, that is too many features for casual graph modeling. So we selected the top 10 highly correlated features (with diagnosis). which are [‘perimeter_worst’, ‘area_worst’, ‘radius_worst’, ‘concave points_worst’, ‘concave points_mean’, ‘perimeter_mean’, ‘area_mean’, ‘radius_mean’, ‘area_se’]

Normalization

For creating effective casual graph model, we normalized and scaled all of the numeric columns.

Scaled and Normalized Data

How CausalNex works?

First, you need to define structural model or infer structural model from the data. Structural models represent graph where edges of the graph indicates what node affects other nodes. This can be defined by domain experts. For example, we collaborate with some clinic and doctors suggest that the number of pregnancies affects diabetes outcome. Reason can be that during the pregnancy women have insulin resistance and some women cannot produce enough insulin to overcome this resistance, thus, they develop diabetes.

However, most of the time, we’re not lucky enough to have experts hand by hand, so inferring structure from the data can be very helpful. CausalNex uses algorithm NOTEARS (Non-combinatorial Optimization via Trace Exponential and Augmented lagRangian for Structure learning), published at NIPS conference in 2018 for inferring the structure. Algorithm learns from the data how nodes are connected between each other as a weighted adjacency matrix.

Now let’s build a causal structural model using CausalNex by selecting features that are the top highly correlated with diagnosis.

from causalnex.structure.notears import from_pandas
from causalnex.plots import plot_structure, NODE_STYLE, EDGE_STYLE
def vis_sm(sm): viz = plot_structure( sm, graph_attributes={"scale": "2.0", 'size':2.5}, all_node_attributes=NODE_STYLE.WEAK,
all_edge_attributes=EDGE_STYLE.WEAK)
return Image(viz.draw(format='png'))sm = from_pandas(normal_df.iloc[:, :9], tabu_parent_nodes=['diagnosis'],)
vs.vis_sm(sm)

the “tabu_parant_nodes”=[‘diagnosis’] argument adds a constraint for prohibiting diagnosis from being a parent in the graph.

Our first casual graph looks a mess, we can not infer much from this graph. Let’s decrease the number of edges by applying Edge pruning. Edge pruning means that all edges with weight below defined threshold are removed from the graph. I used it here to reduce the complexity of the graph.

sm.remove_edges_below_threshold(0.8)
vs.vis_sm(sm)

We can now see that there are six features that are directly affecting diagnosis. concave_points_worst affects perimeter_worst and area_mean. radius_worst affects concave_points_mean and concave_points_mean affects diagnosis.

Stability of causal graph

Checking the stability of a causal graph means to check if the structure of the graphs is stable when the number of data increases or decreases. The more a casual graph is stable the better it captures the causal relationship of a model.

To check the stability of our causal graph, we used increasing fractions of the data and construct a causal graph for each, then we did comparison of the resulting graphs using Jaccard Similarity Index.

The Jaccard similarity index (sometimes called the Jaccard similarity coefficient) compares members for two sets to see which members are shared and which are distinct. It’s a measure of similarity for the two sets of data, with a range from 0% to 100%. The higher the percentage, the more similar the two populations.

To calculate Jaccard similarity we used the following function

def jaccard_similarity(g, h):
i = set(g).intersection(h)
return round(len(i) / (len(g) + len(h) - len(i)), 3)

Casual graph With 60% of the data

portion = int(x_selected.shape[0]*.6)
x_portion = x_selected.head(portion)
sm2 = from_pandas(x_portion, tabu_parent_nodes=['diagnosis'],)
sm2.remove_edges_below_threshold(0.8)
sm2 = sm2.get_largest_subgraph()
vs.vis_sm(sm2)
Casual graph With 60% of the data

Casual graph With 70% of the data

portion = int(x_selected.shape[0]*.7)
x_portion = x_selected.head(portion)
sm3 = from_pandas(x_portion, tabu_parent_nodes=['diagnosis'],)
sm3.remove_edges_below_threshold(0.8)
sm3 = sm3.get_largest_subgraph()
vs.vis_sm(sm3)
Casual graph With 70% of the data
jaccard_similarity(sm2.edges, sm3.edges)

The Jaccard similarity b/n the casual graph of 60% and 70% of the data is 85.7%.

Casual graph With 80% of the data

portion = int(x_selected.shape[0] * .8)
x_portion = x_selected.head(portion)
sm4 = from_pandas(x_portion, tabu_parent_nodes=['diagnosis'],)
sm4.remove_edges_below_threshold(0.8)
sm4 = sm4.get_largest_subgraph()
vs.vis_sm(sm4)
Casual graph With 90% of the data
jaccard_similarity(sm3.edges, sm4.edges)

The Jaccard similarity b/n the casual graph of 70% and 80% of the data is 86.7%.

Casual graph With 100% of the data

In the same way we calculated the Jaccard similarity b/n 90% and 100% of the data to be 83.3%. Since all of our Jaccard similarity index are greater than 80%, we can say that our graph is stable.

Reducing a graph using Markov Blanket

In statistics and machine learning, when one wants to infer a random variable with a set of variables, usually a subset is enough, and other variables are useless. Such a subset that contains all the useful information is called a Markov blanket.

‘diagnosis’ is our variable of interest. We actually do not need all the nodes in the network but only the Markov Blanket of the target. To achieve that, we simply need to use the get_markov_blanket function from causalnex.

from causalnex.network import BayesianNetwork
from causalnex.utils.network_utils import get_markov_blanket
bn = BayesianNetwork(sm)
blanket = get_markov_blanket(bn, 'diagnosis')
edge_list = list(blanket.structure.edges)

The number of the input features after applying markov blanket is reduced to seven. The resulting casual graph is the following.

Fitting Bayesian Network

Bayesian Networks is a probabilistic graphical model that represents dependencies between variables and their joint distribution. Bayesian Network is directed acyclic graph, DAG, where nodes are random variables and edges are causal connections between variables and represent conditional probability distribution. Once we have a structural model, we create our Bayesian Network and fit conditional probabilities

CasualNex provides BayesianNetwork implementation that we are going to use. But the BayesianNetwork implementation expects all its features to be discrete values. To discretizing our features, we used DecisionTreeSupervisedDiscretiserMethod provided by CasualNex.

from causalnex.discretiser.discretiser_strategy import ( DecisionTreeSupervisedDiscretiserMethod )tree_discretiser = DecisionTreeSupervisedDiscretiserMethod(
mode='single',
tree_params={'max_depth': 3, 'random_state': 27},
)
tree_discretiser.fit(
feat_names=features,
dataframe=x,
target_continuous=True,
target='diagnosis',
)
discretised_data = x_selected.copy()
for col in features:
discretised_data[col] = tree_discretiser.transform(x_selected[[col]])
The resulting of discretizing data

After applying discretization on our data, all values of the features that were continuous are mapped to 8 unique values.

The Final step is to train and fit our data to a Bayesian Network

train, test = train_test_split( discretised_data, train_size=0.8, test_size=0.2, random_state=27)bn = BayesianNetwork(blanket.structure)
bn = bn.fit_node_states(discretised_data)
bn = bn.fit_cpds(train, method="BayesianEstimator", bayes_prior="K2")

Results

print('Recall: {:.2f}'.format(recall_score(y_true=true, y_pred=pred)))
print('F1: {:.2f} '.format(f1_score(y_true=true, y_pred=pred)))
print('Accuracy: {:.2f} '.format(accuracy_score(y_true=true, y_pred=pred)))
print('Precision: {:.2f} '.format(precision_score(y_true=true, y_pred=pred)))

The accuracy of our casual graph model is 86%.

Conclusion

For comparison, we have trained a Logistic Regression model using the whole dataset and features, and the result is the following

Although Logistic performs better than Bayesian Network the argument for using Bayesian Network is not performance, but it is the ability to ask more questions and understand and quantify better causality in your data.

References

--

--