2016-06-02 04:39:46 +00:00
# Wrappers for the Scikit-Learn API
2016-04-07 21:23:34 +00:00
2016-08-07 02:29:27 +00:00
You can use `Sequential` Keras models (single-input only) as part of your Scikit-Learn workflow via the wrappers found at `keras.wrappers.scikit_learn.py` .
2016-04-07 21:23:34 +00:00
There are two wrappers available:
2016-08-07 02:29:27 +00:00
`keras.wrappers.scikit_learn.KerasClassifier(build_fn=None, **sk_params)` , which implements the Scikit-Learn classifier interface,
2016-04-07 21:23:34 +00:00
2016-08-07 02:29:27 +00:00
`keras.wrappers.scikit_learn.KerasRegressor(build_fn=None, **sk_params)` , which implements the Scikit-Learn regressor interface.
2016-04-07 21:23:34 +00:00
### Arguments
- __build_fn__: callable function or class instance
- __sk_params__: model parameters & fitting parameters
`build_fn` should construct, compile and return a Keras model, which
will then be used to fit/predict. One of the following
three values could be passed to build_fn:
1. A function
2. An instance of a class that implements the __call__ method
3. None. This means you implement a class that inherits from either
`KerasClassifier` or `KerasRegressor` . The __call__ method of the
present class will then be treated as the default build_fn.
`sk_params` takes both model parameters and fitting parameters. Legal model
parameters are the arguments of `build_fn` . Note that like all other
2016-06-06 20:29:25 +00:00
estimators in scikit-learn, 'build_fn' should provide default values for
2016-04-07 21:23:34 +00:00
its arguments, so that you could create the estimator without passing any
values to `sk_params` .
`sk_params` could also accept parameters for calling `fit` , `predict` ,
2017-02-15 00:08:30 +00:00
`predict_proba` , and `score` methods (e.g., `epochs` , `batch_size` ).
2016-04-07 21:23:34 +00:00
fitting (predicting) parameters are selected in the following order:
1. Values passed to the dictionary arguments of
`fit` , `predict` , `predict_proba` , and `score` methods
2. Values passed to `sk_params`
3. The default values of the `keras.models.Sequential`
`fit` , `predict` , `predict_proba` and `score` methods
When using scikit-learn's `grid_search` API, legal tunable parameters are
those you could pass to `sk_params` , including fitting parameters.
In other words, you could use `grid_search` to search for the best
2017-02-15 00:08:30 +00:00
`batch_size` or `epochs` as well as the model parameters.