Atualizando cross_validation para model_selection

  • AH Uyekita
  • Wednesday, Mar 20, 2019
  • Estimated reading time: 3 min

blog-image

As versões recentes de Scikit Learn (0.23.0) não possuem mais o módulo cross_validation, desta maneira algumas aplicações de Python podem ter o seu funcionamento afetado. O ajuste proposto nesse post é para que as aplicações que usem o train_test_split e StratifiedShuffleSplit voltem a funcionar corretamente.

Módulos Alterados

Alguns módulos foram realocados, por exemplo, o train_test_split que na versão 0.19.1 pertencia ao módulo cross_validation, ao passo que na versão 0.23.0 o train_test_split pertence ao model_selection, pois o cross_validation foi extinto.

train_test_split

O problema clássico encontrado quando se tenta usar o train_test_split numa versão atual do Scikit Learn é descrito abaixo:

# Erro quando não se encontra o módulo desejado.
ImportError: cannot import name cross_validation

A solução para esses problemas é a substituição de todas as ocorrências da importação do cross_validation pelo model_selection. Abaixo há os comparativos entre as versões e como elas devem ser substituídas.

# Será usado para importar o train_test_split.
from sklearn import cross_validation

# Antigamente o StratifiedShuffleSplit estava nesse módulo.
from sklearn.cross_validation import StratifiedShuffleSplit 

# O train_test_split foi remanejado para o model_selection.
from sklearn import model_selection

# Igualmente o train_test_split o StratifiedShuffleSplit também foi para o model_selection.
from sklearn.model_selection import StratifiedShuffleSplit

Para mariores informações acesse o site do Scikit Learn para conferir a documentação de cada módulo.



StratifiedShuffleSplit.

Além da realocação do módulo StratifiedShuffleSplit o seu comportamento foi ligeiramente alterado, nas versões anteriores ele era iterável, isto é, era possível usá-lo diretamente num for. Coloco abaixo um excerto de uma aplicação antiga anterior a versão 0.19.1.

# Até a versão 0.19.1 do Scikit Learn.
from sklearn.cross_validation import StratifiedShuffleSplit

## Dataset
# Features
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])

# Labels
y = np.array([0, 0, 0, 1, 1, 1])

# Criação do objeto sss.
sss = StratifiedShuffleSplit(y, n_splits=5, test_size=0.5, random_state=0)
     
# Iterando o objeto sss.
for train_index, test_index in sss:
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]

A construção acima funcionava até a versão 0.19.1, a partir dessa versão há esse problema quando se tenta usar o StratifiedShuffleSplit. O problema observado é descrito abaixo:

Traceback (most recent call last):
 File "tester.py", line 106, in <module>
   main()
 File "tester.py", line 103, in main
   test_classifier(clf, dataset, feature_list)
 File "tester.py", line 34, in test_classifier
   for train_idx, test_idx in cv:
TypeError: 'StratifiedShuffleSplit' object is not iterable

A maneira para corrigir esse problema foi com a utilização de dois métodos:

  • get_n_splits(), e;
  • split().

Baseado no exemplo exposto acima, atualiza-se esse código para ficar mais fácil a comparação.

# Para versões após 0.19.1 do Scikit Learn.
from sklearn.cross_validation import StratifiedShuffleSplit

## Dataset
# Features
X = np.array([[1, 2], [3, 4], [1, 2], [3, 4], [1, 2], [3, 4]])

# Labels
y = np.array([0, 0, 0, 1, 1, 1])

# Criação do objeto sss.
sss = StratifiedShuffleSplit(y, n_splits=5, test_size=0.5, random_state=0)

# Uso do método get_n_splits.
sss.get_n_splits(X, y)

# Iterando o objeto sss com o auxilio do split.
for train_index, test_index in sss.split(X, y):
   print("TRAIN:", train_index, "TEST:", test_index)
   X_train, X_test = X[train_index], X[test_index]
   y_train, y_test = y[train_index], y[test_index]