1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20 package org.apache.mina.filter.codec.demux;
21
22 import java.util.Map;
23 import java.util.Set;
24 import java.util.concurrent.ConcurrentHashMap;
25
26 import org.apache.mina.core.session.AttributeKey;
27 import org.apache.mina.core.session.IoSession;
28 import org.apache.mina.core.session.UnknownMessageTypeException;
29 import org.apache.mina.filter.codec.ProtocolEncoder;
30 import org.apache.mina.filter.codec.ProtocolEncoderOutput;
31 import org.apache.mina.util.CopyOnWriteMap;
32 import org.apache.mina.util.IdentityHashSet;
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48 public class DemuxingProtocolEncoder implements ProtocolEncoder {
49
50 private final AttributeKey STATE = new AttributeKey(getClass(), "state");
51
52 @SuppressWarnings("rawtypes")
53 private final Map<Class<?>, MessageEncoderFactory> type2encoderFactory = new CopyOnWriteMap<Class<?>, MessageEncoderFactory>();
54
55 private static final Class<?>[] EMPTY_PARAMS = new Class[0];
56
57 public DemuxingProtocolEncoder() {
58
59 }
60
61 @SuppressWarnings({ "rawtypes", "unchecked" })
62 public void addMessageEncoder(Class<?> messageType, Class<? extends MessageEncoder> encoderClass) {
63 if (encoderClass == null) {
64 throw new IllegalArgumentException("encoderClass");
65 }
66
67 try {
68 encoderClass.getConstructor(EMPTY_PARAMS);
69 } catch (NoSuchMethodException e) {
70 throw new IllegalArgumentException("The specified class doesn't have a public default constructor.");
71 }
72
73 boolean registered = false;
74 if (MessageEncoder.class.isAssignableFrom(encoderClass)) {
75 addMessageEncoder(messageType, new DefaultConstructorMessageEncoderFactory(encoderClass));
76 registered = true;
77 }
78
79 if (!registered) {
80 throw new IllegalArgumentException("Unregisterable type: " + encoderClass);
81 }
82 }
83
84 @SuppressWarnings({ "unchecked", "rawtypes" })
85 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoder<? super T> encoder) {
86 addMessageEncoder(messageType, new SingletonMessageEncoderFactory(encoder));
87 }
88
89 public <T> void addMessageEncoder(Class<T> messageType, MessageEncoderFactory<? super T> factory) {
90 if (messageType == null) {
91 throw new IllegalArgumentException("messageType");
92 }
93
94 if (factory == null) {
95 throw new IllegalArgumentException("factory");
96 }
97
98 synchronized (type2encoderFactory) {
99 if (type2encoderFactory.containsKey(messageType)) {
100 throw new IllegalStateException("The specified message type (" + messageType.getName()
101 + ") is registered already.");
102 }
103
104 type2encoderFactory.put(messageType, factory);
105 }
106 }
107
108 @SuppressWarnings("rawtypes")
109 public void addMessageEncoder(Iterable<Class<?>> messageTypes, Class<? extends MessageEncoder> encoderClass) {
110 for (Class<?> messageType : messageTypes) {
111 addMessageEncoder(messageType, encoderClass);
112 }
113 }
114
115 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes, MessageEncoder<? super T> encoder) {
116 for (Class<? extends T> messageType : messageTypes) {
117 addMessageEncoder(messageType, encoder);
118 }
119 }
120
121 public <T> void addMessageEncoder(Iterable<Class<? extends T>> messageTypes,
122 MessageEncoderFactory<? super T> factory) {
123 for (Class<? extends T> messageType : messageTypes) {
124 addMessageEncoder(messageType, factory);
125 }
126 }
127
128
129
130
131 public void encode(IoSession session, Object message, ProtocolEncoderOutput out) throws Exception {
132 State state = getState(session);
133 MessageEncoder<Object> encoder = findEncoder(state, message.getClass());
134 if (encoder != null) {
135 encoder.encode(session, message, out);
136 } else {
137 throw new UnknownMessageTypeException("No message encoder found for message: " + message);
138 }
139 }
140
141 protected MessageEncoder<Object> findEncoder(State state, Class<?> type) {
142 return findEncoder(state, type, null);
143 }
144
145 @SuppressWarnings("unchecked")
146 private MessageEncoder<Object> findEncoder(State state, Class<?> type, Set<Class<?>> triedClasses) {
147 @SuppressWarnings("rawtypes")
148 MessageEncoder encoder = null;
149
150 if (triedClasses != null && triedClasses.contains(type)) {
151 return null;
152 }
153
154
155
156
157 encoder = state.findEncoderCache.get(type);
158
159 if (encoder != null) {
160 return encoder;
161 }
162
163
164
165
166 encoder = state.type2encoder.get(type);
167
168 if (encoder == null) {
169
170
171
172
173 if (triedClasses == null) {
174 triedClasses = new IdentityHashSet<Class<?>>();
175 }
176
177 triedClasses.add(type);
178
179 Class<?>[] interfaces = type.getInterfaces();
180
181 for (Class<?> element : interfaces) {
182 encoder = findEncoder(state, element, triedClasses);
183
184 if (encoder != null) {
185 break;
186 }
187 }
188 }
189
190 if (encoder == null) {
191
192
193
194
195
196 Class<?> superclass = type.getSuperclass();
197
198 if (superclass != null) {
199 encoder = findEncoder(state, superclass);
200 }
201 }
202
203
204
205
206
207
208 if (encoder != null) {
209 state.findEncoderCache.put(type, encoder);
210 MessageEncoder<Object> tmpEncoder = state.findEncoderCache.putIfAbsent(type, encoder);
211
212 if (tmpEncoder != null) {
213 encoder = tmpEncoder;
214 }
215 }
216
217 return encoder;
218 }
219
220
221
222
223 public void dispose(IoSession session) throws Exception {
224 session.removeAttribute(STATE);
225 }
226
227 private State getState(IoSession session) throws Exception {
228 State state = (State) session.getAttribute(STATE);
229 if (state == null) {
230 state = new State();
231 State oldState = (State) session.setAttributeIfAbsent(STATE, state);
232 if (oldState != null) {
233 state = oldState;
234 }
235 }
236 return state;
237 }
238
239 private class State {
240 @SuppressWarnings("rawtypes")
241 private final ConcurrentHashMap<Class<?>, MessageEncoder> findEncoderCache = new ConcurrentHashMap<Class<?>, MessageEncoder>();
242
243 @SuppressWarnings("rawtypes")
244 private final Map<Class<?>, MessageEncoder> type2encoder = new ConcurrentHashMap<Class<?>, MessageEncoder>();
245
246 @SuppressWarnings("rawtypes")
247 private State() throws Exception {
248 for (Map.Entry<Class<?>, MessageEncoderFactory> e : type2encoderFactory.entrySet()) {
249 type2encoder.put(e.getKey(), e.getValue().getEncoder());
250 }
251 }
252 }
253
254 private static class SingletonMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
255 private final MessageEncoder<T> encoder;
256
257 private SingletonMessageEncoderFactory(MessageEncoder<T> encoder) {
258 if (encoder == null) {
259 throw new IllegalArgumentException("encoder");
260 }
261 this.encoder = encoder;
262 }
263
264 public MessageEncoder<T> getEncoder() {
265 return encoder;
266 }
267 }
268
269 private static class DefaultConstructorMessageEncoderFactory<T> implements MessageEncoderFactory<T> {
270 private final Class<MessageEncoder<T>> encoderClass;
271
272 private DefaultConstructorMessageEncoderFactory(Class<MessageEncoder<T>> encoderClass) {
273 if (encoderClass == null) {
274 throw new IllegalArgumentException("encoderClass");
275 }
276
277 if (!MessageEncoder.class.isAssignableFrom(encoderClass)) {
278 throw new IllegalArgumentException("encoderClass is not assignable to MessageEncoder");
279 }
280 this.encoderClass = encoderClass;
281 }
282
283 public MessageEncoder<T> getEncoder() throws Exception {
284 return encoderClass.newInstance();
285 }
286 }
287 }