Sparkify churn Prediction

Millions of users stream their favorite songs to Sparkify service every day, The Sparkify provides the free account and the premium account for the user.

Problem introduction

The users could cancel the account at any time. It let the Sparkify company lose income. The company needs me to predict which users are at risk to churn canceling their service. If I can accurately identify these users before they leave. The business can offer them discounts and incentives.

Customer churn is the percentage of customers that stopped using your company’s product or service during a certain time frame. You can calculate churn rate by dividing the number of customers you lost during that time period — say a quarter — by the number of customers you had at the beginning of that time period.

Predicting churn is a very essential task for companies that business relies on service subscription. This is a supervised machine learning model.

In the dataset, The userId with null value would be removed. The userId with null value means this user is a guest who didn’t create an account.

Here this project would predict churn that helps identify the users who are at risk of canceling their service. I would focus on customer’s behaviors for prediction. I choose the three algorithms. They are LogisticRegression, DecisionTreeClassifier, RandomForestClassifier. These three are basic supervised machine learning algorithms for the classification problem. In the modeling, Spark provides the CrossValidator to implement cross validation.

The definition of CrossValidator: K-fold cross validation performs model selection by splitting the dataset into a set of non-overlapping randomly partitioned folds which are used as separate training and test datasets e.g., with k=3 folds, K-fold cross validation will generate 3 (training, test) dataset pairs, each of which uses 2/3 of the data for training and 1/3 for testing. Each fold is used as the test set exactly once.

The steps I would to do:

Load and clean data

Exploratory data

Data engineering




The size of the dataset is 128MB that provides by Udacity.

Evaluation Metrics

In statistical analysis of binary classification, the F-score or F-measure is a measure of a test’s accuracy. It is calculated from the precision and recall of the test, where the precision is the number of true positive results divided by the number of all positive results, including those not identified correctly, and the recall is the number of true positive results divided by the number of all samples that should have been identified as positive. Precision is also known as the positive predictive value, and recall is also known as sensitivity in diagnostic binary classification.

The F1 score is the harmonic mean of the precision and recall that would be between 0.0 and 1.0. The highest possible value of an F1-score is 1.0.

Figure 1. F1 function

I would use F1-score to evaluate the model and find out which model could get the highest F1-score.

Load data

In the raw dataset, there are 286500 records and 18 columns.

Exploratory data and data engineering

Data exploration is the initial step in data analysis. This process helps us uncover initial patterns, characteristics, the points of interest.

Figure 2. Columns

There are 28650 records and 18 fields in the dataset. There are 28650 records and 18 fields in the dataset. There are three types of values in the dataset. There are 4 columns(artist, auth, song, length) about music characteristics that would be ignored in the future. There are 7 columns (first name, gender, last name, useId, location, userAgent)about the user’s characteristics. In these 7 columns, I remain gender and useId. The distribution of gender shows below.

Figure 3

For useId, I find out there 225 distinct users in the dataset.

Except for the above 10 columns, There 8 columns (itemInsession, level, method, page, registration, sessionId, status, ts). The registration and the ts column are timestamp values. The itemInsession and sessionId are also useless for the prediction. I would ignore four of them.

there are 3 columns (level, method, status) The figures show the distributions below. The 404 status means the error happened during the meantime. here according to the amount of 404 status is unbalance. So status column would be ignored in the future. The method gives information about the type of http communication (get or put). The method wouldn’t choose.

Figure 3

In order to create a new churn label, I would see the details of ‘page’.

Figure 4

The amount of cancel is the same as the amount of cancellation confirmation. The users get cancel confirmation after did the canceling action. This result shows that the ‘cancel’ and ‘Cancellation Confirmation’ are duplicate records. Then I select ‘Cancellation Confirmation’ to create a new label.

The types of columns are categorical. In order for training, these columns need to be encoding. Example code for column ‘gender’.

# define gender_new
define_gender_new = udf(lambda x: 1 if x == "M" else 0, IntegerType())
# add feature GenderNew
data_new = data_new.withColumn("GenderNew", define_gender_new("gender"))

As I said before, I would focus on customer’s behaviors for prediction. here I create new labels using customer’s actions in page column.

These records show the features after encoding.

Figure 5

For the dataset, the label has an unbalanced data problem.

the amount of majority : 173.   the amount of minority : 52 
the ratio is 3

class (label) 1 has 52 records, while class 0 has only 173 records. We can undersample class 0, or oversample class 1. These changes are called sampling the dataset. Here I chose the oversampling that duplicate the samples from under-represented class, to inflate the numbers till it reaches the same level as the dominant class.

The amount of records is 329 after oversampling

The numerical variables are scaled with a VectorAssembler. then transfer the features to vector:

Figure 6


The data is randomly split into the train set, testest, and validation sets. 70% of the data is for training and 15 % of the data is for the validation set.15 % of the data is for the test set.

The churn prediction is a classification problem. then I choose three machine learning algorithms for modeling and then compare them. Different transformers and estimators are combined in the pipeline. The code below




Figure 7


The evaluation of three classifiers:

Figure 8

The decision tree Classification Model and the RandomForest Classification Model got good F1 scores on validation data (0.597689, 0.634263).

In the tuning session, I trained The decision tree Classification Model and the RandomForest Classification Model and use ParamGridBuilder to find the best model with CrossValidator.

Figure 9

The performance of the RandomForestClassifier(num_tree=17) is better than the Decision tree classifiers. A decision tree combines some decisions, whereas a random forest combines several decision trees. Thus, It avoids the overfitting that usually happens in the Decision tree classifier. The disadvantage of random forest trees takes a long process. In this project, the time cost is not important. Finally, I choose RandomForestClassifier.

For the random forest classifier, the best parameter of the num tree is 17 comparing to the num tree (10). In this model, the F1-score is 0.635848 on test data

I find the performance of the model is unstable on validation data and test data. In the future, the model needs more data records for training.


When creating the churn label, I used the cancellation confirmation. the ‘Downgrade’ value of ‘page’ can be considered. When a user decided to downgrade the premium to the free account. I conjecture that the user is not satisfied with the premium service. There is a potential risk of losing this customer. I need to identify him before he decides to cancel the account.

figure 9

The above code used to create The new churn label.

Figure 10

The above table shows the comparison between various models. After creating the new churn label, The random forest classifier (num tree = 17) got better performance than The random forest classifier (num tree = 17) with the old churn label. You can find out the F1- scores are 0.843201 on the validation data and 0.915445 on test data.

The improvement is useful. The F1-score increase by 21% on validation data and decrease by 44% on test data.

The confusion matrix also gives us a visible evaluation. I make it beautiful with a heatmap from the Seaborn library.

Figure 11

True Positives (TP): when the actual value is Positive and predicted is also Positive.

True negatives (TN): when the actual value is Negative and prediction is also Negative.

False positives (FP): When the actual is negative but the prediction is Positive. Also known as the Type 1 error

False negatives (FN): When the actual is Positive but the prediction is Negative. Also known as the Type 2 error

For this best model, the below figure shows the feature importance. You can see the most important feature is sum_NextSong. More NextSong means that people stay on sparkify and listen to more music.

Figure 12

The final model base on the random forest classifier with num_tree = 17 and new data that the churn label is created by ‘Cancellation Confirmation’ and ‘Downgrade’ actions of ‘page’.


In this project, I worked with the Sparkify dataset and predict whether the user churns or not.

I load and clean data at first, Then exploring the data and data engineering. In the modeling session, Training the three models based on three different machine learning algorithms. The final RandomForestClassifier successfully predicts the user churn and gives a good F1 score.

In Pyspark, The interesting thing is that the features need to be transferred to vectors. It doesn’t need in skLearn.

If you’re interested in more info, please click the link, thank you.




Data Scientist

Love podcasts or audiobooks? Learn on the go with our new app.

Recommended from Medium

Now Available: World Cup 2018 Graph


3D Point Cloud Annotation is Indispensable in Boosting Self-Driving Industry

A Curation of Tools for Promoting Effective Data Re-Use for Addressing Public Challenges

Bring your data to normal distribution with imperio YeoJohnsonTransformer.

Journey from Data to Meaning — Natural Language Processing (NLP)

Data: The Fuel of the Digital Economy

5 Books to Start Being a Good Data Scientist

Get the Medium app

A button that says 'Download on the App Store', and if clicked it will lead you to the iOS App store
A button that says 'Get it on, Google Play', and if clicked it will lead you to the Google Play store


Data Scientist

More from Medium

Airline Customer Sentiment Analysis about COVID-19

Performing Analysis of Meteorological Data

Predict Ecommerce Customer Churn using Supervised Learning

RFM Analysis Using Rules and K-Means Clustering