function out = mlq( t, x, flag, arg) % mlq: multiple linear-quadratic controllers % To be called by ode: % mlq( t, x, 'init'); % mlq( t, x, ''); % mlq( t, x, 'done'); % % mlq( 'new', nm); % mlq( 'setup', x); % Linear dynamics model: % xdot = A*x + B*u + c = A*(x-xd) + B*u % xd = A\c % Quadratic reward model: % r = q0 + q1'*x - x'*Q*x/2 - u'*R*u/2 % = r0 - (x-xr)'*Q*(x-xr)/2 - u'*R*u/2 % xr = Q\q1 % r0 = q0 + xr'*Q*xr/2 % Value function: % V = v0 - (x-xv)'*P*(x-xv)/2 % Riccati eq.: % 0 = P/teval - P*A - A'*P + P*B'*R^-1*B*P - Q % xv = (Q + P*A)\(q1 - P*c) = (Q + P*A)\(Q*xr + P*A*xd) % v0 = teval*(r0 - (xv-xr)'*Q*(xv-xr)/2) % Optimal control: % u = -R^-1*B'*P*(x-xv) % Equilibrium: % x0 = (A - B*R^-1*B'*P)\(A*xd - B*R^-1*B'*P*xv) global E T A mm if ischar( t) flag = t; % use first arg as flag else if nargin<3, flag = ''; end % ode standard end switch( flag) case 'new' % arg: number of models % standard parameters A.rate = [ 0, 0]; % learning rates: dynamic, reward A.noise = 0.01; % noise amplitude A.tau = 10; % discount time scale A.ntr = 0; % training epochs A.ttr = 0; % training time % specific parameters A.R = 0.01; % action cost A.robust = 4; % robust control: multiply R in Riccatti if nargin >= 4 % number of modules A.nm = arg; % responsibility prior: exp( -((x-xp)*sp)^2) A.xp = zeros( E.ns, A.nm); A.sp = zeros( E.ns, A.nm); else [ A.xp, A.sp] = feval( E.fun, 'pivot'); A.nm = size(A.xp,2); end A.col = hsv( A.nm+2); % color codes A.sigma = 1.0; % error variance A.tresp = 0.05; % time const for responsibility % linear quadratic models for i=1:A.nm % linear models: xdot=A*x+B*u+c mm(i).A = zeros( E.ns, E.ns); % mm(i).A = 0.01*rand( E.ns, E.ns); mm(i).B = zeros( E.ns, E.ni); mm(i).c = zeros( E.ns,1); % instead of xd % quadratic reward models: r=q0+q1*x-x'*Q*x/2-u'*R*u/2 mm(i).Q = zeros( E.ns, E.ns); mm(i).R = A.R*eye( E.ni); mm(i).q1 = zeros( E.ns, 1); % instead of xr mm(i).q0 = 0; % quadratic value functions: V=v0-(x-xv)'*P*(x-xv)/2 mm(i).tau = A.tau; mm(i).P = zeros( E.ns, E.ns); mm(i).K = zeros( E.ni, E.ns); mm(i).xv = zeros( E.ns, 1); mm(i).v0 = 0; % least mean square: mm(i).Cxu = eye( E.ns+E.ni+1); mm(i).Cxx = eye( E.ns^2+E.ns+1); end mlq( [], [], 'init'); T.mdot=[]; T.mact=[]; T.resp=[]; T.pdot=[]; case 'init' % new trial: called by ode % reset variables A.mxdot = zeros( E.ns, A.nm); % modular prediction A.merr = zeros( E.ns, A.nm); % modular prediction error A.mcerr = zeros( 1, A.nm); % modular cum. prediction error A.cerr = 0; % error in a trial A.resp = repmat( 1/A.nm, 1, A.nm); % responsibility A.xdot = A.mxdot*A.resp'; A.mact = zeros( E.ni, A.nm); % modular output A.act = A.mact*A.resp'; E.input = max( E.imin, min( E.imax, A.act)); % limited mlq_vis( t, x, 'init'); case 'setup' % setup local models around x(ni,nm) [ A.xp, A.sp] = feval( E.fun, 'pivot'); for i = 1:A.nm [mm(i).A,mm(i).B,mm(i).c] = feval( E.fun, 'linear', A.xp(:,i)); [mm(i).Q,mm(i).q1,mm(i).q0] = feval( E.fun, 'quadra', A.xp(:,i)); mm(i).R = A.R*eye( E.ni); mm(i).tau = A.tau; end mlq_vis( 'pivot'); mlq( 'mlqc'); case 'mlqc' % update LQC for i=1:A.nm %disp( [i, mm(i).tau, eig(mm(i).Q)']); disp( [sprintf( 'Module %1d: eig(Q)=[ ', i, mm(i).tau), num2str(eig(mm(i).Q)'), ']']); % Robust RL mm(i).R = A.R*eye( E.ni); RR = A.robust*mm(i).R; [ mm(i).K, mm(i).P, mm(i).xv, mm(i).v0,mm(i).tau] =... lqc( mm(i).A, mm(i).B, mm(i).c, mm(i).Q, RR,... mm(i).q1, mm(i).q0, mm(i).tau); disp( [sprintf( ' tau=%4.2f, K=[ ', mm(i).tau), num2str(mm(i).K), ']']); end mlq_vis( 'mat'); case '' % % Multiple points % if length(t)>1 disp( 'ODE step too large!'); for s=1:length(t), out = mlq( t(s), x(:,s)); end return end % % Priors % dxp = repmat( x, 1, A.nm) - A.xp; dxp = torus( dxp, E.smin, E.smax, E.torus); dist = sum( (dxp.*A.sp).^2); A.prior = exp( -(dist-min(dist))); % % Predictions % for i = 1:A.nm mm(i).x = A.xp(:,i) + dxp(:,i); % around xp A.mxdot(:,i) = mm(i).A*mm(i).x+mm(i).B*E.input+mm(i).c; %A.mrewardi(:,i) = mm(i).q0 + mm(i).q1'*x - x'*mm(i).Q*x/2; end % % Observation % E.xdot = feval( E.fun, t, x); % assume direct measurability E.reward = feval( E.fun, t, x, 'reward'); % % Likelihood % A.merr = A.mxdot - repmat(E.xdot,1,A.nm); A.mcerr = A.mcerr + (-A.mcerr + sum(A.merr.^2))*E.dt/A.tresp; A.like = exp( -(A.mcerr-min(A.mcerr))./A.sigma.^2); % % Posterior % resp = A.prior.*A.like; A.resp = resp./sum(resp); [l,A.symb] = max( A.resp); % % Weighted prediction % A.xdot = A.mxdot*A.resp'; A.err = sum((A.xdot - E.xdot).^2); % weighted error A.cerr = A.cerr + A.err*E.dt; % cumulative error % % Update linear dyanmic model % if A.rate(1)>0 for i = 1:A.nm [dABc,mm(i).Cxu] = lms( [ mm(i).A, mm(i).B, mm(i).c], mm(i).Cxu,... E.xdot, [ mm(i).x; E.input; 1], A.resp(i)); mm(i).A = mm(i).A + A.rate(1)*dABc(:,1:E.ns); mm(i).B = mm(i).B + A.rate(1)*dABc(:,E.ns+(1:E.ni)); mm(i).c = mm(i).c + A.rate(1)*dABc(:,end); end A.ttr = A.ttr + E.dt; end % % Update quadratic reward model % if A.rate(2)>0 for i = 1:A.nm xx = mm(i).x*mm(i).x'; [dQq,mm(i).Cxx] = lms( [ mm(i).Q(:)', mm(i).q1', mm(i).q0], mm(i).Cxx,... E.reward, [ xx(:); mm(i).x; 1], A.resp(i)); dQ = reshape( dQq(1:E.ns^2), E.ns, E.ns); mm(i).Q = mm(i).Q + A.rate(2)*dQ; mm(i).q1 = mm(i).q1 + A.rate(2)*dQq(end-E.ns:end-1)'; mm(i).q0 = mm(i).q0 + A.rate(2)*dQq(end); end end % % Action % for i = 1:A.nm dxv = mm(i).x - mm(i).xv; A.mact(:,i) = -mm(i).K*dxv; end A.act = A.mact*A.resp' + A.noise/sqrt(E.dt)*randn(E.ni,1); E.input = max( E.imin, min( E.imax, A.act)); % limited % % Visualization % color = A.col(A.symb,:); if mod( t, E.vt) < E.dt feval( E.vis, t, x, '', color); % Animation end out = E.stop; % % Record varialbes % T.i = T.i + 1; T.t(T.i) = t; T.stat(T.i,:) = x'; T.act(T.i,:) = E.input'; T.dot(T.i,:) = E.xdot'; T.rew(T.i,:) = E.reward; T.col(T.i,:) = color; % mlq specific T.mdot(T.i,:) = reshape( A.mxdot', 1, A.nm*E.ns); % stretch T.mact(T.i,:) = reshape( A.mact', 1, A.nm*E.ni);; T.resp(T.i,:) = A.resp'; T.pdot(T.i,:) = A.xdot'; case 'done' mlq_vis( t, x, 'done'); mlq_vis( 'wave'); if sum( A.rate>0) % update controllers mlq( 'mlqc'); A.ntr = A.ntr + 1; end case 'talloc' % allocate trace buffer s = x; % increment T.mdot = [ T.mdot; zeros(s,A.nm*E.ns)]; T.mact = [ T.mact; zeros(s,A.nm*E.ni)]; T.resp = [ T.resp; zeros(s,A.nm)]; T.pdot = [ T.pdot; zeros(s,E.ns)]; otherwise error( [ ' invalid flag ', flag]); end %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [K,P,xv,v0,tau] = lqc( A, B, c, Q, R, q1, q0, tau) % Linear-quadratic controller % Linear dynamics model: % xdot = A*(x-xd) + B*u = A*x + B*u + c % c = -A*xd % Quadratic reward model: % r = r0 - (x-xr)'*Q*(x-xr)/2 - u'*R*u/2 % = q0 + q1'*x - x'*Q*x/2 - u'*R*u/2 % q0 = r0 - xr'*Q*xr/2 % q1 = Q*xr % Value function: % V = v0 - (x-xv)'*P*(x-xv)/2 % Riccati eq.: % P/tau = P*A + A'*P - P*B*R^-1*B'*P + Q % xv = (P*A + Q)\(-P*c + q1) = (P*A + Q)\(P*A*xd + Q*xr) % v0 = tau*(r0-(xv-xr)'*Q*(xv-xr)/2) = tau*(q0+q1'*xv-xv'*Q*xv/2) % Optimal control: % u = -R^-1*B'*P*(x-xv) [ns,ni] = size( B); if min(real(eig(Q))) >= 0 % normal routine %if 0 % [K,P] = lqr( A-eye(ns)/(2*tau), B, Q, R); [K,P] = lqr( A-eye(ns)/(2*tau), B, Q, R); K; else % destabilizing controller P = 0; %a = 0.1; for t = [ (1:100)/100*tau, repmat( tau, 1, 100)] P = t*( P*A + A'*P - P*B*R^-1*B'*P + Q); %P = (eye(ns)/t - A')\( P*A - P*B*R^-1*B'*P + Q); %P = P + a*(-P + t*( P*A + A'*P - P*B*R^-1*B'*P + Q)); % disp([t, max(abs(P(:)))]); if max(abs(P(:))) > 10000 tau = 0.8*tau; [K,P,xv,v0,tau] = lqc( A, B, c, Q, R, q1, q0, tau); return end end K = R\B'*P; end xv = (P*A + Q)\(-P*c + q1); v0 = tau*(q0+q1'*xv-xv'*Q*xv/2); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function [ dA, C] = lms( A, C, y, x, w) % least mean square regression % y = A*x % C: inverse covariance of x % w: weight of the data Cx = C*x; C = C - w*Cx*Cx'/(1 + w*x'*Cx); dA = w*( y - A*x)*x'*C; %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%% function x = torus( x, min, max, flag) % limit the range of x, possibly in a torus if nargin < 4, flag = zeros(size(min)); end m = size(x,2); mmin = repmat( min, 1, m); mran = repmat( (max-min).*(flag>0), 1, m); x = mmin + mod( x - mmin, mran); %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%