Close Menu
    Facebook X (Twitter) Instagram
    • Privacy Policy
    • Terms Of Service
    • Social Media Disclaimer
    • DMCA Compliance
    • Anti-Spam Policy
    Facebook X (Twitter) Instagram
    Deep Tech Ledger
    • Home
    • Crypto News
      • Bitcoin
      • Ethereum
      • Altcoins
      • Blockchain
      • DeFi
    • AI News
    • Stock News
    • Learn
      • AI for Beginners
      • AI Tips
      • Make Money with AI
    • Reviews
    • Tools
      • Best AI Tools
      • Crypto Market Cap List
      • Stock Market Overview
      • Market Heatmap
    • Contact
    Deep Tech Ledger
    Home»AI News»A Coding Implementation to Train Safety-Critical Reinforcement Learning Agents Offline Using Conservative Q-Learning with d3rlpy and Fixed Historical Data
    A Coding Implementation to Train Safety-Critical Reinforcement Learning Agents Offline Using Conservative Q-Learning with d3rlpy and Fixed Historical Data
    AI News

    A Coding Implementation to Train Safety-Critical Reinforcement Learning Agents Offline Using Conservative Q-Learning with d3rlpy and Fixed Historical Data

    February 4, 20269 Mins Read
    Share
    Facebook Twitter LinkedIn Pinterest Email
    kraken


    In this tutorial, we build a safety-critical reinforcement learning pipeline that learns entirely from fixed, offline data rather than live exploration. We design a custom environment, generate a behavior dataset from a constrained policy, and then train both a Behavior Cloning baseline and a Conservative Q-Learning agent using d3rlpy. By structuring the workflow around offline datasets, careful evaluation, and conservative learning objectives, we demonstrate how robust decision-making policies can be trained in settings where unsafe exploration is not an option. Check out the FULL CODES here.

    !pip -q install -U “d3rlpy” “gymnasium” “numpy” “torch” “matplotlib” “scikit-learn”

    import os
    import time
    import random
    import inspect
    import numpy as np
    import matplotlib.pyplot as plt

    import gymnasium as gym
    from gymnasium import spaces

    bybit

    import torch
    import d3rlpy

    SEED = 42
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    def pick_device():
    if torch.cuda.is_available():
    return “cuda:0”
    return “cpu”

    DEVICE = pick_device()
    print(“d3rlpy:”, getattr(d3rlpy, “__version__”, “unknown”), “| torch:”, torch.__version__, “| device:”, DEVICE)

    def make_config(cls, **kwargs):
    sig = inspect.signature(cls.__init__)
    allowed = set(sig.parameters.keys())
    allowed.discard(“self”)
    filtered = {k: v for k, v in kwargs.items() if k in allowed}
    return cls(**filtered)

    We set up the environment by installing dependencies, importing libraries, and fixing random seeds for reproducibility. We detect and configure the computation device to ensure consistent execution across systems. We also define a utility to create configuration objects safely across different d3rlpy versions. Check out the FULL CODES here.

    class SafetyCriticalGridWorld(gym.Env):
    metadata = {“render_modes”: []}

    def __init__(
    self,
    size=15,
    max_steps=80,
    hazard_coords=None,
    start=(0, 0),
    goal=None,
    slip_prob=0.05,
    seed=0,
    ):
    super().__init__()
    self.size = int(size)
    self.max_steps = int(max_steps)
    self.start = tuple(start)
    self.goal = tuple(goal) if goal is not None else (self.size – 1, self.size – 1)
    self.slip_prob = float(slip_prob)

    if hazard_coords is None:
    hz = set()
    rng = np.random.default_rng(seed)
    for _ in range(max(1, self.size // 2)):
    x = rng.integers(2, self.size – 2)
    y = rng.integers(2, self.size – 2)
    hz.add((int(x), int(y)))
    self.hazards = hz
    else:
    self.hazards = set(tuple(x) for x in hazard_coords)

    self.action_space = spaces.Discrete(4)
    self.observation_space = spaces.Box(low=0.0, high=float(self.size – 1), shape=(2,), dtype=np.float32)

    self._rng = np.random.default_rng(seed)
    self._pos = None
    self._t = 0

    def reset(self, *, seed=None, options=None):
    if seed is not None:
    self._rng = np.random.default_rng(seed)
    self._pos = [int(self.start[0]), int(self.start[1])]
    self._t = 0
    obs = np.array(self._pos, dtype=np.float32)
    return obs, {}

    def _clip(self):
    self._pos[0] = int(np.clip(self._pos[0], 0, self.size – 1))
    self._pos[1] = int(np.clip(self._pos[1], 0, self.size – 1))

    def step(self, action):
    self._t += 1

    a = int(action)
    if self._rng.random() < self.slip_prob:
    a = int(self._rng.integers(0, 4))

    if a == 0:
    self._pos[1] += 1
    elif a == 1:
    self._pos[0] += 1
    elif a == 2:
    self._pos[1] -= 1
    elif a == 3:
    self._pos[0] -= 1

    self._clip()

    x, y = int(self._pos[0]), int(self._pos[1])
    terminated = False
    truncated = self._t >= self.max_steps

    reward = -1.0

    if (x, y) in self.hazards:
    reward = -100.0
    terminated = True

    if (x, y) == self.goal:
    reward = +50.0
    terminated = True

    obs = np.array([x, y], dtype=np.float32)
    return obs, float(reward), terminated, truncated, {}

    We define a safety-critical GridWorld environment with hazards, terminal states, and stochastic transitions. We encode penalties for unsafe states and rewards for successful task completion. We ensure the environment strictly controls dynamics to reflect real-world safety constraints. Check out the FULL CODES here.

    def safe_behavior_policy(obs, env: SafetyCriticalGridWorld, epsilon=0.15):
    x, y = int(obs[0]), int(obs[1])
    gx, gy = env.goal

    preferred = []
    if gx > x:
    preferred.append(1)
    elif gx < x:
    preferred.append(3)
    if gy > y:
    preferred.append(0)
    elif gy < y:
    preferred.append(2)

    if len(preferred) == 0:
    preferred = [int(env._rng.integers(0, 4))]

    if env._rng.random() < epsilon:
    return int(env._rng.integers(0, 4))

    candidates = []
    for a in preferred:
    nx, ny = x, y
    if a == 0:
    ny += 1
    elif a == 1:
    nx += 1
    elif a == 2:
    ny -= 1
    elif a == 3:
    nx -= 1
    nx = int(np.clip(nx, 0, env.size – 1))
    ny = int(np.clip(ny, 0, env.size – 1))
    if (nx, ny) not in env.hazards:
    candidates.append(a)

    if len(candidates) == 0:
    return preferred[0]
    return int(random.choice(candidates))

    def generate_offline_episodes(env, n_episodes=400, epsilon=0.20, seed=0):
    episodes = []
    for i in range(n_episodes):
    obs, _ = env.reset(seed=int(seed + i))
    obs_list = []
    act_list = []
    rew_list = []
    done_list = []

    done = False
    while not done:
    a = safe_behavior_policy(obs, env, epsilon=epsilon)
    nxt, r, terminated, truncated, _ = env.step(a)
    done = bool(terminated or truncated)

    obs_list.append(np.array(obs, dtype=np.float32))
    act_list.append(np.array([a], dtype=np.int64))
    rew_list.append(np.array([r], dtype=np.float32))
    done_list.append(np.array([1.0 if done else 0.0], dtype=np.float32))

    obs = nxt

    episodes.append(
    {
    “observations”: np.stack(obs_list, axis=0),
    “actions”: np.stack(act_list, axis=0),
    “rewards”: np.stack(rew_list, axis=0),
    “terminals”: np.stack(done_list, axis=0),
    }
    )
    return episodes

    def build_mdpdataset(episodes):
    obs = np.concatenate([ep[“observations”] for ep in episodes], axis=0).astype(np.float32)
    acts = np.concatenate([ep[“actions”] for ep in episodes], axis=0).astype(np.int64)
    rews = np.concatenate([ep[“rewards”] for ep in episodes], axis=0).astype(np.float32)
    terms = np.concatenate([ep[“terminals”] for ep in episodes], axis=0).astype(np.float32)

    if hasattr(d3rlpy, “dataset”) and hasattr(d3rlpy.dataset, “MDPDataset”):
    return d3rlpy.dataset.MDPDataset(observations=obs, actions=acts, rewards=rews, terminals=terms)

    raise RuntimeError(“d3rlpy.dataset.MDPDataset not found. Upgrade d3rlpy.”)

    We design a constrained behavior policy that generates offline data without risky exploration. We roll out this policy to collect trajectories and structure them into episodes. We then convert these episodes into a format compatible with d3rlpy’s offline learning APIs. Check out the FULL CODES here.

    def _get_episodes_from_dataset(dataset):
    if hasattr(dataset, “episodes”) and dataset.episodes is not None:
    return dataset.episodes
    if hasattr(dataset, “get_episodes”):
    return dataset.get_episodes()
    raise AttributeError(“Could not find episodes in dataset (d3rlpy version mismatch).”)

    def _iter_all_observations(dataset):
    for ep in _get_episodes_from_dataset(dataset):
    obs = getattr(ep, “observations”, None)
    if obs is None:
    continue
    for o in obs:
    yield o

    def _iter_all_transitions(dataset):
    for ep in _get_episodes_from_dataset(dataset):
    obs = getattr(ep, “observations”, None)
    acts = getattr(ep, “actions”, None)
    rews = getattr(ep, “rewards”, None)
    if obs is None or acts is None:
    continue
    n = min(len(obs), len(acts))
    for i in range(n):
    o = obs[i]
    a = acts[i]
    r = rews[i] if rews is not None and i < len(rews) else None
    yield o, a, r

    def visualize_dataset(dataset, env, title=”Offline Dataset”):
    state_visits = np.zeros((env.size, env.size), dtype=np.float32)
    for obs in _iter_all_observations(dataset):
    x, y = int(obs[0]), int(obs[1])
    x = int(np.clip(x, 0, env.size – 1))
    y = int(np.clip(y, 0, env.size – 1))
    state_visits[y, x] += 1

    plt.figure(figsize=(6, 5))
    plt.imshow(state_visits, origin=”lower”)
    plt.colorbar(label=”Visits”)
    plt.scatter([env.start[0]], [env.start[1]], marker=”o”, label=”start”)
    plt.scatter([env.goal[0]], [env.goal[1]], marker=”*”, label=”goal”)
    if len(env.hazards) > 0:
    hz = np.array(list(env.hazards), dtype=np.int32)
    plt.scatter(hz[:, 0], hz[:, 1], marker=”x”, label=”hazards”)
    plt.title(f”{title} — State visitation”)
    plt.xlabel(“x”)
    plt.ylabel(“y”)
    plt.legend()
    plt.show()

    rewards = []
    for _, _, r in _iter_all_transitions(dataset):
    if r is not None:
    rewards.append(float(r))
    if len(rewards) > 0:
    plt.figure(figsize=(6, 4))
    plt.hist(rewards, bins=60)
    plt.title(f”{title} — Reward distribution”)
    plt.xlabel(“reward”)
    plt.ylabel(“count”)
    plt.show()

    We implement dataset utilities that correctly iterate through episodes rather than assuming flat arrays. We visualize state visitation to understand coverage and data bias in the offline dataset. We also analyze reward distributions to inspect the learning signal available to the agent. Check out the FULL CODES here.

    def rollout_eval(env, algo, n_episodes=25, seed=0):
    returns = []
    lengths = []
    hazard_hits = 0
    goal_hits = 0

    for i in range(n_episodes):
    obs, _ = env.reset(seed=seed + i)
    done = False
    total = 0.0
    steps = 0
    while not done:
    a = int(algo.predict(np.asarray(obs, dtype=np.float32)[None, …])[0])
    obs, r, terminated, truncated, _ = env.step(a)
    total += float(r)
    steps += 1
    done = bool(terminated or truncated)
    if terminated:
    x, y = int(obs[0]), int(obs[1])
    if (x, y) in env.hazards:
    hazard_hits += 1
    if (x, y) == env.goal:
    goal_hits += 1

    returns.append(total)
    lengths.append(steps)

    return {
    “return_mean”: float(np.mean(returns)),
    “return_std”: float(np.std(returns)),
    “len_mean”: float(np.mean(lengths)),
    “hazard_rate”: float(hazard_hits / max(1, n_episodes)),
    “goal_rate”: float(goal_hits / max(1, n_episodes)),
    “returns”: np.asarray(returns, dtype=np.float32),
    }

    def action_mismatch_rate_vs_data(dataset, algo, sample_obs=7000, seed=0):
    rng = np.random.default_rng(seed)
    obs_all = []
    act_all = []
    for o, a, _ in _iter_all_transitions(dataset):
    obs_all.append(np.asarray(o, dtype=np.float32))
    act_all.append(int(np.asarray(a).reshape(-1)[0]))
    if len(obs_all) >= 80_000:
    break

    obs_all = np.stack(obs_all, axis=0)
    act_all = np.asarray(act_all, dtype=np.int64)

    idx = rng.choice(len(obs_all), size=min(sample_obs, len(obs_all)), replace=False)
    obs_probe = obs_all[idx]
    act_probe_data = act_all[idx]
    act_probe_pi = algo.predict(obs_probe).astype(np.int64)

    mismatch = (act_probe_pi != act_probe_data).astype(np.float32)
    return float(mismatch.mean())

    def create_discrete_bc(device):
    if hasattr(d3rlpy.algos, “DiscreteBCConfig”):
    cls = d3rlpy.algos.DiscreteBCConfig
    cfg = make_config(
    cls,
    learning_rate=3e-4,
    batch_size=256,
    )
    return cfg.create(device=device)
    if hasattr(d3rlpy.algos, “DiscreteBC”):
    return d3rlpy.algos.DiscreteBC()
    raise RuntimeError(“DiscreteBC not available in this d3rlpy version.”)

    def create_discrete_cql(device, conservative_weight=6.0):
    if hasattr(d3rlpy.algos, “DiscreteCQLConfig”):
    cls = d3rlpy.algos.DiscreteCQLConfig
    cfg = make_config(
    cls,
    learning_rate=3e-4,
    actor_learning_rate=3e-4,
    critic_learning_rate=3e-4,
    temp_learning_rate=3e-4,
    alpha_learning_rate=3e-4,
    batch_size=256,
    conservative_weight=float(conservative_weight),
    n_action_samples=10,
    rollout_interval=0,
    )
    return cfg.create(device=device)
    if hasattr(d3rlpy.algos, “DiscreteCQL”):
    algo = d3rlpy.algos.DiscreteCQL()
    if hasattr(algo, “conservative_weight”):
    try:
    algo.conservative_weight = float(conservative_weight)
    except Exception:
    pass
    return algo
    raise RuntimeError(“DiscreteCQL not available in this d3rlpy version.”)

    We define controlled evaluation routines to measure policy performance without uncontrolled exploration. We compute returns and safety metrics, including hazard and goal rates. We also introduce a mismatch diagnostic to quantify how often learned actions deviate from the dataset behavior. Check out the FULL CODES here.

    def main():
    env = SafetyCriticalGridWorld(
    size=15,
    max_steps=80,
    slip_prob=0.05,
    seed=SEED,
    )

    raw_eps = generate_offline_episodes(env, n_episodes=500, epsilon=0.22, seed=SEED)
    dataset = build_mdpdataset(raw_eps)

    print(“dataset built:”, type(dataset).__name__)
    visualize_dataset(dataset, env, title=”Behavior Dataset (Offline)”)

    bc = create_discrete_bc(DEVICE)
    cql = create_discrete_cql(DEVICE, conservative_weight=6.0)

    print(“\nTraining Discrete BC (offline)…”)
    t0 = time.time()
    bc.fit(
    dataset,
    n_steps=25_000,
    n_steps_per_epoch=2_500,
    experiment_name=”grid_bc_offline”,
    )
    print(“BC train sec:”, round(time.time() – t0, 2))

    print(“\nTraining Discrete CQL (offline)…”)
    t0 = time.time()
    cql.fit(
    dataset,
    n_steps=80_000,
    n_steps_per_epoch=8_000,
    experiment_name=”grid_cql_offline”,
    )
    print(“CQL train sec:”, round(time.time() – t0, 2))

    print(“\nControlled online evaluation (small number of rollouts):”)
    bc_metrics = rollout_eval(env, bc, n_episodes=30, seed=SEED + 1000)
    cql_metrics = rollout_eval(env, cql, n_episodes=30, seed=SEED + 2000)

    print(“BC :”, {k: v for k, v in bc_metrics.items() if k != “returns”})
    print(“CQL:”, {k: v for k, v in cql_metrics.items() if k != “returns”})

    print(“\nOOD-ish diagnostic (policy action mismatch vs data action at same states):”)
    bc_mismatch = action_mismatch_rate_vs_data(dataset, bc, sample_obs=7000, seed=SEED + 1)
    cql_mismatch = action_mismatch_rate_vs_data(dataset, cql, sample_obs=7000, seed=SEED + 2)
    print(“BC mismatch rate :”, bc_mismatch)
    print(“CQL mismatch rate:”, cql_mismatch)

    plt.figure(figsize=(6, 4))
    labels = [“BC”, “CQL”]
    means = [bc_metrics[“return_mean”], cql_metrics[“return_mean”]]
    stds = [bc_metrics[“return_std”], cql_metrics[“return_std”]]
    plt.bar(labels, means, yerr=stds)
    plt.ylabel(“Return”)
    plt.title(“Online Rollout Return (Controlled)”)
    plt.show()

    plt.figure(figsize=(6, 4))
    plt.plot(np.sort(bc_metrics[“returns”]), label=”BC”)
    plt.plot(np.sort(cql_metrics[“returns”]), label=”CQL”)
    plt.xlabel(“Episode (sorted)”)
    plt.ylabel(“Return”)
    plt.title(“Return Distribution (Sorted)”)
    plt.legend()
    plt.show()

    out_dir = “/content/offline_rl_artifacts”
    os.makedirs(out_dir, exist_ok=True)
    bc_path = os.path.join(out_dir, “grid_bc_policy.pt”)
    cql_path = os.path.join(out_dir, “grid_cql_policy.pt”)

    if hasattr(bc, “save_policy”):
    bc.save_policy(bc_path)
    print(“Saved BC policy:”, bc_path)
    if hasattr(cql, “save_policy”):
    cql.save_policy(cql_path)
    print(“Saved CQL policy:”, cql_path)

    print(“\nDone.”)

    if __name__ == “__main__”:
    main()

    We train both Behavior Cloning and Conservative Q-Learning agents purely from offline data. We compare their performance using controlled rollouts and diagnostic metrics. We finalize the workflow by saving trained policies and summarizing safety-aware learning outcomes.

    In conclusion, we demonstrated that Conservative Q-Learning yields a more reliable policy than simple imitation when learning from historical data in safety-sensitive environments. By comparing offline training outcomes, controlled online evaluations, and action-distribution mismatches, we illustrated how conservatism helps reduce risky, out-of-distribution behavior. Overall, we presented a complete, reproducible offline RL workflow that we can extend to more complex domains such as robotics, healthcare, or finance without compromising safety.

    Check out the FULL CODES here. Also, feel free to follow us on Twitter and don’t forget to join our 100k+ ML SubReddit and Subscribe to our Newsletter. Wait! are you on telegram? now you can join us on telegram as well.



    Source link

    quillbot
    Share. Facebook Twitter Pinterest LinkedIn Tumblr Email
    CryptoExpert
    • Website

    I’m someone who’s deeply curious about crypto and artificial intelligence. I created this site to share what I’m learning, break down complex ideas, and keep people updated on what’s happening in crypto and AI—without the unnecessary hype.

    Related Posts

    AI gave China a god’s-eye view of its energy grid. No one else has this mapping.

    May 25, 2026

    Microsoft Research Releases Webwright: A Terminal-Native Web Agent Framework That Scores 60.1% on Odysseys, Up from Base GPT-5.4’s 33.5%

    May 24, 2026

    Technology usually creates jobs for young, skilled workers. Will AI do the same? | MIT News

    May 23, 2026

    D&B's database of 642 million businesses was built for humans, not AI agents. So they rebuilt it.

    May 22, 2026
    Add A Comment
    Leave A Reply Cancel Reply

    frase
    Latest Posts

    Ethereum Pushes Privacy Forward: EIP-8182 Eyes Hegota Upgrade Integration

    May 26, 2026

    Bitcoin Eyes $80K Rally on Middle East Peace Hopes: Analyst

    May 25, 2026

    Sugar Prices Slip on Stronger Sugar Exports from Thailand

    May 25, 2026

    AI gave China a god’s-eye view of its energy grid. No one else has this mapping.

    May 25, 2026

    Bitcoin Spot ETFs Bleed $1.26 Billion In Largest Net Outflows In 3 Months – Details

    May 25, 2026
    synthesia
    LEGAL INFORMATION
    • Privacy Policy
    • Terms Of Service
    • Social Media Disclaimer
    • DMCA Compliance
    • Anti-Spam Policy
    Top Insights

    How To Make Money By Selling AI Websites To Businesses In 2026

    May 26, 2026

    Updated Essential AI Skills For 2026

    May 26, 2026
    changelly
    Facebook X (Twitter) Instagram Pinterest
    © 2026 DeepTechLedger.com - All rights reserved.

    Type above and press Enter to search. Press Esc to cancel.