Commit Graph

7 Commits

Author SHA1 Message Date
Francois Chollet
ca006ece74 Increase GRU + LSTM test coverage. 2023-05-16 13:25:48 -07:00
Francois Chollet
74b7675be6 Fix RNN masking 2023-05-16 13:12:40 -07:00
Francois Chollet
0dc71f879d Add JAX RNN support. 2023-05-16 10:52:14 -07:00
Neel Kovelamudi
2d40cb20b9 Adds CategoryEncoding layer, bincount op, and tests (#161)
* Adds unit normalization and tests

* Adds layer normalization and initial tests

* Fixes formatting in docstrings

* Fix type issues for JAX

* Fix nits

* Initial stash for group_normalization and spectral_normalization

* Adds spectral normalization and tests

* Adds group normalization and tests

* Formatting fixes

* Fix small nit in docstring

* Fix docstring and tests

* Adds RandomContrast and associated tests

* Remove arithmetic comment

* Adds RandomBrightness and tests

* Fix docstring and format

* Fix nits and add backend generator

* Inlines random_contrast helper

* Add bincount op

* Add CategoryEncoding layer and tests

* Fix formatting

* Fix JAX issues

* Fix JAX bincount

* Formatting and small fix

* Fix nits and docstrings

* Add args to bincount op test
2023-05-14 00:07:43 +00:00
Jonathan Bischof
ff0db1c03f Initialize expected modules and import torch early 2023-05-13 18:06:22 +00:00
Francois Chollet
cc1482d10f Add LSTM layer 2023-05-12 18:59:21 -07:00
Francois Chollet
42236e5d4e Add GRU layer. 2023-05-12 15:39:48 -07:00