SKLearnWrapper

class baikal.sklearn.SKLearnWrapper(build_fn, **params)

Bases: object

Wrapper utility class that allows models to used in scikit-learn’s GridSearchCV API. It follows the style of Keras’ own wrapper.

A future release of baikal plans to remove this class and instead include a custom GridSearchCV API, based on the original scikit-learn implementation, that can handle baikal models natively.

Parameters
  • build_fn

    A function that takes no arguments and builds and returns a baikal Model.

    Note that, in order to specify which parameters of which steps to tune using a dictionary keyed by <step>__<parameter>, you must pass a name to the appropriate steps when building the model in this function.

  • params – Dictionary mapping parameter names to their values. Valid parameter names are ‘build_fn’ and any parameter the wrapped model can take (in the form <step>__<parameter>).

Methods

fit(X[, y])

Fit wrapped model.

get_params([deep])

Get parameters for this estimator.

predict(X)

Predict with the wrapped model.

set_params(**params)

Set the parameters of this estimator.

Attributes

model

Get the wrapped model.