[16/09/2022] El aprendizaje profundo está cambiando nuestras vidas en pequeñas y grandes formas cada día. Ya sea Siri o Alexa siguiendo nuestros comandos de voz, las aplicaciones de traducción en tiempo real en nuestros teléfonos, o la tecnología de visión por computadora que permite tractores inteligentes, robots de almacén y coches de auto-conducción, cada mes parece traer nuevos avances. Y casi todas estas aplicaciones de aprendizaje profundo están escritas en uno de estos tres frameworks: TensorFlow, PyTorch y JAX.
[Reciba lo último de CIO Perú suscribiéndose a nuestro newsletter semanal]
¿Cuál de estos frameworks de aprendizaje profundo debería utilizar? En este artículo, echaremos un vistazo comparativo de alto nivel a TensorFlow, PyTorch y JAX. Intentaremos darle una idea de los tipos de aplicaciones que se ajustan a sus puntos fuertes, así como considerar factores como el apoyo de la comunidad y la facilidad de uso.
¿Debería usar TensorFlow?
"Nunca despidieron a nadie por comprar IBM" fue el grito de guerra de la informática en los años 70 y 80, y lo mismo podría decirse del uso de TensorFlow en la década del 2010 para el aprendizaje profundo. ¿Sigue siendo TensorFlow competitivo en esta nueva década, siete años después de su lanzamiento inicial en el 2015?
Bueno, ciertamente. No es que TensorFlow se haya quedado parado durante todo ese tiempo. TensorFlow 1.x era todo sobre la construcción de gráficos estáticos de una manera muy poco Python, pero con la línea TensorFlow 2.x, también puede construir modelos utilizando el modo "eager" para la evaluación inmediata de las operaciones, haciendo que las cosas se sientan mucho más como PyTorch. A alto nivel, TensorFlow le da Keras para facilitar el desarrollo, y a bajo nivel, le da el compilador de optimización XLA (Accelerated Linear Algebra) para la velocidad. XLA hace maravillas para aumentar el rendimiento en las GPU, y es el principal método para aprovechar la potencia de las TPU (Unidades de Procesamiento de Tensores) de Google, que ofrecen un rendimiento sin precedentes para el entrenamiento de modelos a escala masiva.
Luego están todas las cosas que TensorFlow ha estado haciendo bien durante años. ¿Necesita servir modelos de una manera bien definida y repetible en una plataforma madura? TensorFlow Serving está ahí para usted. ¿Necesita reorientar sus despliegues de modelos para la web, o para la computación de baja potencia, como los teléfonos inteligentes, o para los dispositivos de recursos limitados como las cosas IoT? TensorFlow.js y TensorFlow Lite están muy maduros en este momento. Y, obviamente, teniendo en cuenta que Google todavía ejecuta el 100% de sus despliegues de producción utilizando TensorFlow, puede estar seguro de que TensorFlow puede manejar su escala.
Pero... bueno, ha habido una cierta falta de energía alrededor del proyecto que es un poco difícil de ignorar en estos días. La actualización de TensorFlow 1.x a TensorFlow 2.x fue, en una palabra, brutal. Algunas empresas se fijaron en el esfuerzo que suponía actualizar su código para que funcionara correctamente en la nueva versión principal, y decidieron en cambio portar su código a PyTorch. TensorFlow también perdió fuelle en la comunidad investigadora, que empezó a preferir la flexibilidad que ofrecía PyTorch hace unos años, lo que provocó un descenso en el uso de TensorFlow en los trabajos de investigación.
El asunto de Keras tampoco ha ayudado. Keras se convirtió en una parte integrada de las versiones de TensorFlow hace dos años, pero recientemente se ha vuelto a convertir en una biblioteca separada con su propio calendario de lanzamientos. Sin duda, la separación de Keras no es algo que afecte al día a día de los desarrolladores, pero un cambio tan notorio en una revisión menor del marco de trabajo no inspira confianza.
Dicho esto, TensorFlow es un marco fiable y alberga un amplio ecosistema de aprendizaje profundo. Puede crear aplicaciones y modelos en TensorFlow que funcionen a todas las escalas, y estará en muy buena compañía si lo hace. Pero puede que TensorFlow no sea su primera opción hoy en día.
¿Deberías usar PyTorch?
Ya no es el advenedizo que le pisa los talones a TensorFlow, PyTorch es una fuerza importante en el mundo del aprendizaje profundo hoy en día, quizás principalmente para la investigación, pero también en aplicaciones de producción cada vez más. Y con el modo eager habiéndose convertido en el método de desarrollo por defecto en TensorFlow, así como en PyTorch, el enfoque más pitónico que ofrece la diferenciación automática de PyTorch (autograd) parece haber ganado la guerra contra los gráficos estáticos.
A diferencia de TensorFlow, PyTorch no ha experimentado ninguna ruptura importante en el código del núcleo desde la desaprobación de la API Variable en la versión 0.4. (Anteriormente, Variable era necesaria para usar autograd con tensores; ahora todo es un tensor). Pero eso no quiere decir que no haya habido algunos errores aquí y allá. Por ejemplo, si ha estado usando PyTorch para entrenar a través de múltiples GPUs, es probable que se haya encontrado con las diferencias entre DataParallel y el nuevo DistributedDataParallel. Debería usar siempre DistributedDataParallel, pero DataParallel no está obsoleto.
Aunque PyTorch ha ido por detrás de TensorFlow y JAX en el soporte de XLA/TPU, la situación ha mejorado mucho a partir del 2022. PyTorch ahora tiene soporte para acceder a las VMs de TPU, así como el antiguo estilo de soporte de TPU Node, junto con una fácil implementación de línea de comandos para ejecutar su código en CPUs, GPUs o TPUs sin cambios en el código. Y si no quiere lidiar con parte del código repetitivo que PyTorch a menudo le hace escribir, puede recurrir a adiciones de mayor nivel como PyTorch Lightning, que le permite concentrarse en su trabajo real en lugar de reescribir bucles de entrenamiento. En el lado negativo, aunque se sigue trabajando en PyTorch Mobile, todavía está mucho menos maduro que TensorFlow Lite.
En términos de producción, PyTorch ahora tiene integraciones con plataformas agnósticas como Kubeflow, mientras que el proyecto TorchServe puede manejar los detalles de despliegue como el escalado, las métricas y la inferencia por lotes, dándole toda la bondad de MLOps en un pequeño paquete que es mantenido por los propios desarrolladores de PyTorch. ¿Escala PyTorch? Meta ha estado utilizando PyTorch en producción durante años, así que cualquiera que le diga que PyTorch no puede manejar cargas de trabajo a escala le está mintiendo. Sin embargo, se puede argumentar que PyTorch podría no ser tan amigable como JAX para las carreras de entrenamiento muy, muy grandes que requieren bancos de GPUs o TPUs.
Por último, está el elefante en la habitación. La popularidad de PyTorch en los últimos años está casi seguramente ligada al éxito de la biblioteca Transformers de Hugging Face. Sí, Transformers ahora también es compatible con TensorFlow y JAX, pero comenzó como un proyecto de PyTorch, y sigue estrechamente ligado a este marco. Con el auge de la arquitectura de Transformers, la flexibilidad de PyTorch para la investigación y la capacidad de obtener tantos modelos nuevos a los pocos días u horas de su publicación a través del centro de modelos de Hugging Face, es fácil ver por qué PyTorch se está poniendo de moda en todas partes estos días.
¿Debería usar JAX?
Si no le gusta TensorFlow, entonces Google podría tener algo más para usted. Más o menos, al menos. JAX es un marco de aprendizaje profundo construido, mantenido y utilizado por Google, pero no es oficialmente un producto de Google. Sin embargo, si mira los documentos y publicaciones de Google/DeepMind durante el último año, no puede evitar notar que mucha de la investigación de Google se ha trasladado a JAX. Así que JAX no es un producto "oficial" de Google, pero es lo que los investigadores de Google están utilizando para ampliar los límites.
¿Qué es exactamente JAX? Una forma fácil de pensar en JAX es la siguiente: Imagine una versión de NumPy acelerada por la GPU/TPU que puede, con un movimiento de varita, vectorizar mágicamente una función de Python, y manejar todos los cálculos derivados de dichas funciones. Por último, tiene un componente JIT (Just-In-Time) que toma su código y lo optimiza para el compilador XLA, lo que resulta en mejoras significativas de rendimiento sobre TensorFlow y PyTorch. He visto la ejecución de algún código aumentar su velocidad en cuatro o cinco veces simplemente reimplementándolo en JAX sin que se produzca ningún trabajo de optimización real.
Dado que JAX trabaja a nivel de NumPy, el código JAX está escrito a un nivel mucho más bajo que TensorFlow/Keras, y, sí, incluso PyTorch. Afortunadamente, hay un pequeño pero creciente ecosistema de proyectos circundantes que añaden bits adicionales. ¿Quiere bibliotecas de redes neuronales? Está Flax de Google, y Haiku de DeepMind (también de Google). Está Optax para todas sus necesidades de optimización, y PIX para el procesamiento de imágenes, y mucho más. Una vez que se trabaja con algo como Flax, la construcción de redes neuronales es relativamente fácil de manejar. Solo hay que tener en cuenta que todavía hay algunas asperezas. Los veteranos hablan mucho de cómo JAX maneja los números aleatorios de forma diferente a muchos otros marcos, por ejemplo.
¿Debería convertir todo en JAX y montarse en esa vanguardia? Bueno, tal vez, si está metido de lleno en la investigación que implica modelos a gran escala que requieren enormes recursos para entrenar. Los avances que JAX hace en áreas como el entrenamiento determinista, y otras situaciones que requieren miles de pods TPU, probablemente valen el cambio por sí mismos.
TensorFlow vs. PyTorch vs. JAX
¿Cuál es la conclusión, entonces? ¿Qué marco de aprendizaje profundo debería usar? Lamentablemente, no creo que haya una respuesta definitiva. Todo depende del tipo de problema en el que se trabaje, de la escala a la que se pretenda desplegar los modelos, e incluso de las plataformas de computación a las que se apunte.
Sin embargo, no creo que sea controvertido decir que, si está trabajando en los dominios del texto y la imagen, y está haciendo una investigación a pequeña o mediana escala con vistas a desplegar estos modelos en producción, entonces PyTorch es probablemente su mejor apuesta ahora mismo. Es la mejor opción en este ámbito.
Sin embargo, si necesita exprimir todo el rendimiento de los dispositivos de baja computación, entonces se dirigiría a TensorFlow con su sólido paquete TensorFlow Lite. Y en el otro extremo de la escala, si está trabajando en el entrenamiento de modelos que tienen decenas o cientos de miles de millones de parámetros o más, y los está entrenando principalmente con fines de investigación, entonces tal vez sea el momento de darle una vuelta a JAX.
Basado en el artículo de Ian Pointer (InfoWorld) y editado por CIO Perú