jax-skills

v0.1.0

High-performance numerical computing and machine learning workflows using JAX. Supports array operations, automatic differentiation, JIT compilation, RNN-sty...

0· 71·0 current·0 all-time

Install

OpenClaw Prompt Flow

Install with OpenClaw

Best for remote or guided setup. Copy the exact prompt, then paste it into OpenClaw for wu-uk/jax-computing-basics-jax-skills.

Previewing Install & Setup.
Prompt PreviewInstall & Setup
Install the skill "jax-skills" (wu-uk/jax-computing-basics-jax-skills) from ClawHub.
Skill page: https://clawhub.ai/wu-uk/jax-computing-basics-jax-skills
Keep the work scoped to this skill only.
After install, inspect the skill metadata and help me finish setup.
Use only the metadata you can verify from ClawHub; do not invent missing requirements.
Ask before making any broader environment changes.

Command Line

CLI Commands

Use the direct CLI path if you want to install manually and keep every step visible.

OpenClaw CLI

Bare skill slug

openclaw skills install jax-computing-basics-jax-skills

ClawHub CLI

Package manager switcher

npx clawhub@latest install jax-computing-basics-jax-skills
Security Scan
VirusTotalVirusTotal
Benign
View report →
OpenClawOpenClaw
Benign
high confidence
Purpose & Capability
The name, description, SKILL.md, and the included jax_skills.py functions (load/save, map_op, reduce_op, logistic_grad, rnn_scan, jit_run) match each other and are appropriate for JAX numerical and ML workflows. There are no unrelated credentials, binaries, or config paths requested. Minor metadata mismatch: SKILL.md claims a Proprietary LICENSE.txt but no LICENSE file is present in the manifest; source and homepage are unknown.
Instruction Scope
Runtime instructions stay within scope: they describe local array operations, reading/writing .npy/.npz files, using JAX primitives, and validating shapes. The SKILL.md does not instruct reading other system files, contacting external endpoints, or collecting unrelated data. Note: the doc mentions pickle/json as saving options but the provided save implementation only uses np.save (npy); this is an implementation/documentation inconsistency but not a security concern.
Install Mechanism
No install spec is provided (instruction-only with a small code file), so nothing is downloaded or written by an installer. This is low-risk from an install perspective. The package does depend on JAX and NumPy being present at runtime; the skill does not declare an install to fetch those.
Credentials
The skill requires no environment variables, credentials, or config paths. That is proportionate for the declared functionality (local numerical computing).
Persistence & Privilege
The skill is not forced-always (always:false) and does not request persistent system privileges. Model invocation is allowed (platform default) but the skill does not request elevated or cross-skill configuration changes.
Assessment
This skill appears internally consistent and limited to local JAX/NumPy operations. Before installing: ensure you trust the owner (source/homepage are missing) and have JAX/NumPy available; test it in a safe environment since it reads/writes local files (.npy/.npz) and could overwrite files if given paths you care about; verify the license (SKILL.md references LICENSE.txt which is not included); and if you want to avoid any autonomous actions, keep model-invocation disabled at the agent/platform level. The code contains no network calls, credential access, or other unexpected behaviors.

Like a lobster shell, security has layers — review code before you run it.

latestvk9765hy9kzfhyw5bwfa00gpers84xfde
71downloads
0stars
1versions
Updated 1w ago
v0.1.0
MIT-0

Requirements for Outputs

General Guidelines

Arrays

  • All arrays MUST be compatible with JAX (jnp.array) or convertible from Python lists.
  • Use .npy, .npz, JSON, or pickle for saving arrays.

Operations

  • Validate input types and shapes for all functions.
  • Maintain numerical stability for all operations.
  • Provide meaningful error messages for unsupported operations or invalid inputs.

JAX Skills

1. Loading and Saving Arrays

load(path)

Description: Load a JAX-compatible array from a file. Supports .npy and .npz.
Parameters:

  • path (str): Path to the input file.

Returns: JAX array or dict of arrays if .npz.

import jax_skills as jx

arr = jx.load("data.npy")
arr_dict = jx.load("data.npz")

save(data, path)

Description: Save a JAX array or Python array to .npy. Parameters:

  • data (array): Array to save.
  • path (str): File path to save.
jx.save(arr, "output.npy")

2. Map and Reduce Operations

map_op(array, op)

Description: Apply elementwise operations on an array using JAX vmap. Parameters:

  • array (array): Input array.
  • op (str): Operation name ("square" supported).
squared = jx.map_op(arr, "square")

reduce_op(array, op, axis)

Description: Reduce array along a given axis. Parameters:

  • array (array): Input array.
  • op (str): Operation name ("mean" supported).
  • axis (int): Axis along which to reduce.
mean_vals = jx.reduce_op(arr, "mean", axis=0)

3. Gradients and Optimization

logistic_grad(x, y, w)

Description: Compute the gradient of logistic loss with respect to weights. Parameters:

  • x (array): Input features.
  • y (array): Labels.
  • w (array): Weight vector.
grad_w = jx.logistic_grad(X_train, y_train, w_init)

Notes:

  • Uses jax.grad for automatic differentiation.
  • Logistic loss: mean(log(1 + exp(-y * (x @ w)))).

4. Recurrent Scan

rnn_scan(seq, Wx, Wh, b)

Description: Apply an RNN-style scan over a sequence using JAX lax.scan. Parameters:

  • seq (array): Input sequence.
  • Wx (array): Input-to-hidden weight matrix.
  • Wh (array): Hidden-to-hidden weight matrix.
  • b (array): Bias vector.
hseq = jx.rnn_scan(sequence, Wx, Wh, b)

Notes:

  • Returns sequence of hidden states.
  • Uses tanh activation.

5. JIT Compilation

jit_run(fn, args)

Description: JIT compile and run a function using JAX. Parameters:

  • fn (callable): Function to compile.
  • args (tuple): Arguments for the function.
result = jx.jit_run(my_function, (arg1, arg2))

Notes:

  • Speeds up repeated function calls.
  • Input shapes must be consistent across calls.

Best Practices

  • Prefer JAX arrays (jnp.array) for all operations; convert to NumPy only when saving.
  • Avoid side effects inside functions passed to vmap or scan.
  • Validate input shapes for map_op, reduce_op, and rnn_scan.
  • Use JIT compilation (jit_run) for compute-heavy functions.
  • Save arrays using .npy or pickle/json to avoid system-specific issues.

Example Workflow

import jax.numpy as jnp
import jax_skills as jx

# Load array
arr = jx.load("data.npy")

# Square elements
arr2 = jx.map_op(arr, "square")

# Reduce along axis
mean_arr = jx.reduce_op(arr2, "mean", axis=0)

# Compute logistic gradient
grad_w = jx.logistic_grad(X_train, y_train, w_init)

# RNN scan
hseq = jx.rnn_scan(sequence, Wx, Wh, b)

# Save result
jx.save(hseq, "hseq.npy")

Notes

  • This skill set is designed for scientific computing, ML model prototyping, and dynamic array transformations.

  • Emphasizes JAX-native operations, automatic differentiation, and JIT compilation.

  • Avoid unnecessary conversions to NumPy; only convert when interacting with external file formats.

Comments

Loading comments...