• Abhishek Singh

Chapter 3- Building A Logistic Regression in PyTorch, Step by Step

Classification is the most used algorithms in machine learning. More than 70% of the problems in data science are classification problems. 

There are two categories of classification are:

  • Binary-class classifier(two-class problem)

  • Multinomial classification(more than two classes are present in the target variable).

Logistic regression is the most basic algorithm in solving two-class classification problems.

Some of the common problems are churn prediction, spam detection, and many more.

In the 42 days series of PyTorch, previously we have covered how to perform linear regression in PyTorch. 

In this tutorial, we will cover the following contents:

So, let us start with the definition of logistic regression.

Definition of Logistic regression

Logistic Regression is a supervised algorithm in machine learning that is used to predict the probability of a categorical response variable. In logistic regression, the predicted variable is a binary variable that contains data encoded as 1 (True) or 0 (False). In other words, the logistic regression model predicts P(Y=1) as a function of X.

Properties of Logistic Regression:

  • The predicted variable in logistic regression follows Bernoulli Distribution.

  • Logistic regression uses the Maximum Likelihood method for parameter estimation

  • We don't have to compute R Square, Model fitness is calculated through Concordance, KS-Statistics.

Linear Regression Vs. Logistic Regression

How logistic regression works

In logistic regression the algorithm is based on logistic function(1/(1+e^(-x))) that output the probability between 0 and 1. If we plot them in a graph it resultant curve would be in an S-shaped curve like this.


As you can see from the above graph, the results of the logistic function would always be a probability between 0 and 1.

Let's take an example to make it more precise & clear.

Suppose that the x-axis denotes the number of goals scored by Lionel Messi and the y-axis denotes the probability of Barcelona winning the match. Let’s also assume that the x-axis values range from 0 to 50. So, according to the S-curve, it would mean that there is a greater probability of Barcelona winning the match if Lionel Messi scores more than 3 goals. Similarly, there’s a greater probability of Barcelona losing the match if Lionel Messi scores less than 2 goals.

Types of Logistic Regression

Types of Logistic Regression:

  • Binary Logistic Regression: The predicted variable has only two possible outcomes such as Cat or Dog, Positive or Negative.

  • Multinomial Logistic Regression: The predicted variable has three or more nominal categories such as predicting the type of dog's breed.

  • Ordinal Logistic Regression: The predicted variable has three or more ordinal categories such as education level(“high school”,” Graduation”,” Post-graduation”,” Ph.D.”).

Where to use logistic regression

  1. We have a binary or dichotomous target variable.

  2. We have predictor X-variables that we think are related to the Y-variable.

Use Cases of logistic regression

Financial sector

In the financial industry, this algorithm is used to predicts loan defaulters, credit scoring, loan distribution, and many more. Many giant companies like Morgan Stanley are using these methods.

Medical sector

In the medical industry, this algorithm is used to predict if a patient has diabetes or not. There are many other applications like breast cancer prediction, tumor prediction, and many more.

Telecommunication sector

In this sector, this algorithm is used to predict customer churn, so in this way, they can give better plans to the customer, so the customer won't churn out.

Network security

In network security, logistic regression is used to predict if a network packet has successfully delivered or not.

Implementation of logistic regression in PyTorch

The dataset comes from the UCI Machine Learning repository, and it is related to economics. The classification goal is to predict whether personal income greater than(<=50K or >50K). You can download the dataset from here.

The dataset provides the information of customers. It includes 48,842 records and 15 columns.

Input variables

  1. age is a numeric variable.

  2. workclass: type of workclass (categorical: "Private", "Self-emp-not-inc", "Local-gov", "?", "State-gov", "Self-emp-inc", "Federal-gov", "Without-pay", "Never-worked").

  3. fnlwgt is a numeric variable.

  4. education: educational level (categorical: "HS-grad", "Some-college", "Bachelors", "Masters", "Assoc-voc", "11th", "Assoc-acdm", "10th", "7th-8th", "Prof-school", "9th", "12th", "Doctorate", "5th-6th", "1st-4th", "Preschool").

  5. educational-num is a numeric variable.

  6. marital-status : marital status (categorical: "Married-civ-spouse", "Never-married", "Divorced", "Separated", "Widowed", "Married-spouse-absent", "Married-AF-spouse").

  7. occupation: types of occupation(categorical: "Prof-specialty", "Craft-repair", "Exec-managerial", "Adm-clerical", "Sales", "Other-service", "Machine-op-inspct", "?", "Transport-moving", "Handlers-cleaners", "Farming-fishing", "Tech-support", "Protective-serv", "Priv-house-serv", "Armed-Forces").

  8. relationship: types of relationship(categorical: "Husband", "Not-in-family", "Own-child", "Unmarried", "Wife", "Other-relative").

  9. race: types of race(categorical: "White", "Black", "Asian-Pac-Islander", "Amer-Indian-Eskimo", "Other").

  10. gender: types of gender(categorical: "Male", "Female").

  11. capital-gain is a numeric variable.

  12. capital-loss is a numeric variable.

  13. hours-per-week is a numeric variable.

  14. native-country: names of country(categorical: "United-States", "Mexico", "?", "Philippines", "Germany","Puerto-Rico", "Canada", "El-Salvador", "India", "Cuba", "England", "China", "South", "Jamaica", "Italy", "Dominican-Republic", "Japan", "Guatemala", "Poland", "Vietnam", "Columbia", "Haiti", "Portugal", "Taiwan", "Iran", "Greece", "Nicaragua", "Peru", "Ecuador", "France", "Ireland", "Hong", "Thailand", "Cambodia", "Trinadad&Tobago", "Outlying-US(Guam-USVI-etc), "Laos", "Yugoslavia","Scotland, "Honduras", "Hungary", "Holand-Netherlands".

  15. income: income-level(categorical" <=50K or >50K).

Data exploration

Now we make a function name plotPerColumnDistribution which plots all the columns.


Next step

  1. Check for missing values in the columns.

  2. Now we encode all the categorical columns in the dataset.

  3. After that, we define features and target variables in the dataset. Next, we split the dataset into training and test set.

Next, we define model class aka which is logistic regression.

In which we initialize init constructor and then we instantiate two nn.Linear module. In the forward function, we accept a Variable of input data and we must return a Variable of output data. We can use Modules defined in the constructor as well as arbitrary operators on Variables.

Now we save our model and Construct our loss function and an Optimizer.

Let us start the training loop

We perform 500 iterations and then compute predicted y by passing x to the model, calculate loss, Zero gradients, perform a backward pass, and update the weights.

Now we test the model and make the prediction. The accuracy that we are getting is 76%.

Advantages and disadvantages of logistic regression


  • It is fast to train the logistic regression model

  • It works well on simple datasets.

  • We can also use a logistic regression model for predicting multiple classes.

  • There is no violation of Ordinary least square assumptions.

  • It can handle polytomous data(more than two distinct categories).


  • In non-linear models, the effect is not consistent.

  • Sometimes it fails to capture the complex relationship between variables.

  • It requires large size datasets for stable results.

  • It creates a problem when group efficiency distributions have little overlap.

Wrap up the Session

Finally, we have made it to the end of the tutorial. You may know 

  • what is logistic regression, 

  • properties of logistic regression, 

  • differences between linear regression and logistic regression 

  • types of logistic regression

  • where to use logistic regression

  • Logistic Regression Assumptions

  • the use-case of logistic regression

  • how to implement logistic regression in PyTorch. 

You’ve come a long way in understanding one of the most important areas of machine learning! If you have questions or comments, then please put them in the comments section below.

You can also join our telegram channel to get free cheatsheets, projects, ebooks, study material related to machine learning, deep learning, data science, natural language processing, python programming, r programming, big data, and many more.